In [1]:
import os

import fsspec
import hydra
import lightning as L
import omegaconf
import rich.syntax
import rich.tree
import torch
import datasets
import dataloader
import diffusion
import utils

import json
import mauve

omegaconf.OmegaConf.register_new_resolver(
  'cwd', os.getcwd)
omegaconf.OmegaConf.register_new_resolver(
  'device_count', torch.cuda.device_count)
omegaconf.OmegaConf.register_new_resolver(
  'eval', eval)
omegaconf.OmegaConf.register_new_resolver(
  'div_up', lambda x, y: (x + y - 1) // y)

In [2]:
HOME = os.environ.get('HOME')
checkpoint_path=f"{HOME}/Git/remdm/outputs/checkpoints/mdlm.ckpt"
T=0
sampling_steps=128 #1024
p=0.9
num_sample_batches=1 # 5000
global_batch_size=512
devices=1
generated_seqs_path=f"{HOME}/Git/remdm/outputs/mdlm_T-{sampling_steps}_topp-{p}.json"


In [5]:
# config = omegaconf.OmegaConf.load("configs/config.yaml")

with hydra.initialize(config_path="configs/", version_base=None):
    config = hydra.compose(config_name="config", overrides=[
        f"data=openwebtext-split",
        f"eval.checkpoint_path={checkpoint_path}",
        f"time_conditioning=false",
        f"T={T}",
        f"loader.global_batch_size={global_batch_size}",
        f"sampling.steps={sampling_steps}",
        f"seed=1",
        f"loader.batch_size=1",
        f"loader.eval_batch_size=1",
        f"eval.perplexity_batch_size=1",
        f"sampling.num_sample_batches={num_sample_batches}",
        f"sampling.generated_seqs_path={generated_seqs_path}",
        f"sampling.nucleus_p={p}",
        f"sampling.sampler=mdlm",
        f"trainer.devices={devices}",
        f"data.cache_dir={HOME}/Git/remdm/outputs/data",
    ])
tokenizer = dataloader.get_tokenizer(config)
_, valid_loader = dataloader.get_dataloaders(config, tokenizer, valid_seed=config.seed, skip_train=True)


Using 1 GPUs for training.
Using 1 batch size and 1 eval batch size. Global batch size is 512.
num_nodes: 1,        accumulate_grad_batches: 512


In [7]:
x=next(iter(valid_loader))  # warmup dataloader

In [9]:
x["input_ids"].shape

torch.Size([1, 1024])