In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot as plt

from d3pm_sc.ct_sched_cond import ScheduleCondition
from d3pm_sc.ct_sched_cond_sparse_k import ScheduleConditionSparseK
from d3pm_sc.masking_diffusion import MaskingDiffusion
from d3pm_sc.d3pm_classic import D3PMClassic
from d3pm_sc.unet import KingmaUNet, UNet, SimpleUNet
# from d3pm_sc.dit import DiT_Llama
from d3pm_sc import utils
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

import torch
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10

import wandb
# wandb.login(key="6a47f093d2a55e4f4e85b33767423f2db66355b8")


In [None]:
from nets import get_model_setup
from data import get_dataloaders
from omegaconf import OmegaConf
import data

cfg = OmegaConf.load('configs/basic_language.yaml')

##### Load data
train_dataloader, test_dataloader = data.get_dataloaders(cfg)
tokenizer = train_dataloader.tokenizer if hasattr(train_dataloader, "tokenizer") else None

##### Setup x0_model
x0_model_class, nn_params = get_model_setup(cfg, tokenizer) 

##### Pick model
model = ScheduleConditionSparseK(        
    x0_model_class,
    nn_params,
    num_classes=len(tokenizer) if tokenizer else cfg.data.N,
    hybrid_loss_coeff=cfg.model.hybrid_loss_coeff,
    gamma=cfg.model.gamma,
    forward_kwargs=OmegaConf.to_container(cfg.model.forward_kwargs, resolve=True),
    schedule_type=cfg.model.schedule_type,
    logistic_pars=cfg.model.logistic_pars,
    fix_x_t_bias=cfg.model.fix_x_t_bias,
    n_T=cfg.model.n_T,
    t_max=cfg.model.t_max,
    seed=cfg.model.seed,
    input_logits=cfg.model.input_logits,
    **OmegaConf.to_container(cfg.train, resolve=True),)


In [None]:
model.pre_configure_model(train_dataloader)

In [None]:
import gc

torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
gc.collect()

In [None]:
steps = torch.arange(1000 + 1, dtype=torch.float32) / 1000

fig, ax = plt.subplots(1, 3, figsize=[9, 3])
ax[0].semilogy(model.beta(steps), label="Hazard", color='black')
ax[0].legend()

alpha_bar = torch.exp(model.log_alpha(steps))
ax[1].plot(alpha_bar, label="p(unmut)", color='black')
ax[1].legend()

alpha_bar = model.log_alpha(steps)
ax[2].plot(-alpha_bar, label="E[S]", color='black')
ax[2].legend()

# L = utils.get_inf_gens(forward_kwargs, N)
# ax[2].imshow(L, vmin=-0.1, vmax=0.1, cmap='bwr')

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

wandb.init()
wandb_logger = WandbLogger(project="debugging")
lightning_model = model
torch.set_float32_matmul_precision('high')
model.to(torch.bfloat32)

# from pytorch_lightning.profilers import PyTorchProfiler
# profiler = PyTorchProfiler(
#     on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'),
#     schedule=torch.profiler.schedule(wait=1, warmup=1, active=10, repeat=2)
# )
trainer = Trainer(max_epochs=1, accelerator='auto',
                  devices=torch.cuda.device_count(), logger=wandb_logger)#, profiler=profiler)
trainer.fit(lightning_model, train_dataloader, test_dataloader)

In [None]:
import wandb
wandb.finish()

In [None]:
datum = next(iter(train_dataloader))

In [None]:
datum_t = model.sample_point(datum['input_ids'])

In [None]:
datum['attention_mask'].sum(-1)

In [None]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")



In [None]:
import sys
del sys.modules['d3pm_sc.ct_sched_cond_sparse_k']
import d3pm_sc.ct_sched_cond_sparse_k

In [None]:
from importlib import reload
import d3pm_sc.ct_sched_cond_sparse_k
reload(d3pm_sc.ct_sched_cond_sparse_k)

In [None]:
batch_size = 1
x_sample = datum['input_ids'].cuda()
model = model.cuda()

cond = torch.arange(0, batch_size).to(sample_x.device) % 10
p = model.get_stationary()
samples = torch.multinomial(p, num_samples=batch_size*sample_x.shape[1:].numel(), replacement=True)
init_noise = samples.reshape((batch_size,)+sample_x.shape[1:]).to(sample_x.device)

images = model.sample_with_image_sequence(
        init_noise, cond, stride=3, n_T=100,
    )

In [None]:
images

In [None]:
images = model.sample_with_image_sequence2(
        init_noise, cond, stride=3, n_T=100,
    )

In [None]:
x = init_noise
stride=3
n_T=1000

In [None]:
from d3pm_sc.schedule_sample import sample_n_transitions_cont
from tqdm import tqdm

t = model.t_max * torch.ones(x.shape[0], device=x.device)
S = sample_n_transitions_cont(model.log_alpha, x[0].flatten().shape[0], t)
S = S.swapaxes(0, 1).reshape(*x.shape).long()
steps = 0
images = []
n_steps = S.sum(-1).sum(-1).sum(-1).max().item()
if n_steps > 1e6:
    print("n_steps:", n_steps)
pbar = tqdm(total=n_steps, unit="iteration",
            position=0, leave=True)
trans_step = n_steps // n_T
while S.sum() > 0:
    # predict what comes next
    x_next = model.p_sample(
        x, t, cond, torch.rand((*x.shape, model.num_classes), device=x.device), S
    )
    for b in range(len(x)):
        trans_indices = torch.argwhere(S[b] > 0)
        trans_indices = trans_indices[torch.randperm(len(trans_indices))]
        if len(trans_indices) > 0:
            # randomly transiiton
            for k in trans_indices[:trans_step]:
                x[b, k] = x_next[b, k]
                S[b, k] -= 1
    pbar.update(trans_step)
    steps += 1
    if steps % stride == 0:
        images.append(torch.clone(x))
pbar.close()
# if last step is not divisible by stride, we add the last image.
if steps % stride != 0:
    images.append(x)


In [None]:
tokenizer.decode(images[-1][0])[:1000]