In [2]:
import os

import fsspec
import hydra
import lightning as L
import omegaconf
import rich.syntax
import rich.tree
import torch
import datasets
import dataloader
import diffusion
import utils

import json
import mauve

omegaconf.OmegaConf.register_new_resolver(
  'cwd', os.getcwd)
omegaconf.OmegaConf.register_new_resolver(
  'device_count', torch.cuda.device_count)
omegaconf.OmegaConf.register_new_resolver(
  'eval', eval)
omegaconf.OmegaConf.register_new_resolver(
  'div_up', lambda x, y: (x + y - 1) // y)

In [3]:
HOME = os.environ.get('HOME')
checkpoint_path=f"{HOME}/Git/remdm/outputs/checkpoints/mdlm.ckpt"
T=0
sampling_steps=128 #1024
p=0.9
num_sample_batches=1 # 5000
global_batch_size=512
devices=1
generated_seqs_path=f"{HOME}/Git/remdm/outputs/mdlm_T-{sampling_steps}_topp-{p}.json"


In [4]:
# config = omegaconf.OmegaConf.load("configs/config.yaml")

with hydra.initialize(config_path="configs/", version_base=None):
    config = hydra.compose(config_name="config", overrides=[
        f"data=openwebtext-split",
        f"eval.checkpoint_path={checkpoint_path}",
        f"time_conditioning=false",
        f"T={T}",
        f"loader.global_batch_size={global_batch_size}",
        f"sampling.steps={sampling_steps}",
        f"seed=1",
        f"loader.batch_size=1",
        f"loader.eval_batch_size=1",
        f"eval.perplexity_batch_size=1",
        f"sampling.num_sample_batches={num_sample_batches}",
        f"sampling.generated_seqs_path={generated_seqs_path}",
        f"sampling.nucleus_p={p}",
        f"sampling.sampler=mdlm",
        f"trainer.devices={devices}",
        f"data.cache_dir={HOME}/Git/remdm/outputs/data",
    ])
tokenizer = dataloader.get_tokenizer(config)
_, valid_loader = dataloader.get_dataloaders(config, tokenizer, valid_seed=config.seed, skip_train=True)




Using 1 GPUs for training.
Using 1 batch size and 1 eval batch size. Global batch size is 512.
num_nodes: 1,        accumulate_grad_batches: 512


Downloading data:   0%|          | 0/21 [00:00<?, ?files/s]

subsets/urlsf_subset00.tar:   0%|          | 0.00/633M [00:00<?, ?B/s]

subsets/urlsf_subset01.tar:   0%|          | 0.00/629M [00:00<?, ?B/s]

subsets/urlsf_subset02.tar:   0%|          | 0.00/629M [00:00<?, ?B/s]

subsets/urlsf_subset03.tar:   0%|          | 0.00/628M [00:00<?, ?B/s]

subsets/urlsf_subset04.tar:   0%|          | 0.00/627M [00:00<?, ?B/s]

subsets/urlsf_subset05.tar:   0%|          | 0.00/630M [00:00<?, ?B/s]

subsets/urlsf_subset06.tar:   0%|          | 0.00/626M [00:00<?, ?B/s]

subsets/urlsf_subset07.tar:   0%|          | 0.00/625M [00:00<?, ?B/s]

subsets/urlsf_subset08.tar:   0%|          | 0.00/625M [00:00<?, ?B/s]

subsets/urlsf_subset09.tar:   0%|          | 0.00/626M [00:00<?, ?B/s]

subsets/urlsf_subset10.tar:   0%|          | 0.00/625M [00:00<?, ?B/s]

subsets/urlsf_subset11.tar:   0%|          | 0.00/625M [00:00<?, ?B/s]

subsets/urlsf_subset12.tar:   0%|          | 0.00/624M [00:00<?, ?B/s]

subsets/urlsf_subset13.tar:   0%|          | 0.00/629M [00:00<?, ?B/s]

subsets/urlsf_subset14.tar:   0%|          | 0.00/627M [00:00<?, ?B/s]

subsets/urlsf_subset15.tar:   0%|          | 0.00/621M [00:00<?, ?B/s]

subsets/urlsf_subset16.tar:   0%|          | 0.00/619M [00:00<?, ?B/s]

subsets/urlsf_subset17.tar:   0%|          | 0.00/619M [00:00<?, ?B/s]

subsets/urlsf_subset18.tar:   0%|          | 0.00/618M [00:00<?, ?B/s]

subsets/urlsf_subset19.tar:   0%|          | 0.00/619M [00:00<?, ?B/s]

subsets/urlsf_subset20.tar:   0%|          | 0.00/377M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8013769 [00:00<?, ? examples/s]

Tokenizing (num_proc=8):   0%|          | 0/100000 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1239 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (3475 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1180 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1384 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1255 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Grouping (num_proc=8):   0%|          | 0/100000 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/110466 [00:00<?, ? examples/s]

In [7]:
x=next(iter(valid_loader))  # warmup dataloader

In [9]:
x["input_ids"].shape

torch.Size([1, 1024])