In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot as plt

from d3pm_sc.ct_sched_cond import ScheduleCondition
from d3pm_sc.ct_sched_cond_sparse_k import ScheduleConditionSparseK
from d3pm_sc.masking_diffusion import MaskingDiffusion
from d3pm_sc.d3pm_classic import D3PMClassic
from d3pm_sc.unet import KingmaUNet, UNet, SimpleUNet
# from d3pm_sc.dit import DiT_Llama
from d3pm_sc import utils
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

import torch
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10

import wandb
wandb.login(key="6a47f093d2a55e4f4e85b33767423f2db66355b8")


In [None]:
from nets import get_model_setup
from data import get_dataloaders
from omegaconf import OmegaConf
import data

cfg = OmegaConf.load('configs/basic_language.yaml')

##### Load data
train_dataloader, test_dataloader = data.get_dataloaders(cfg)
tokenizer = train_dataloader.tokenizer if hasattr(train_dataloader, "tokenizer") else None

##### Setup x0_model
x0_model_class, nn_params = get_model_setup(cfg, tokenizer) 

##### Pick model
model = ScheduleConditionSparseK(        
    x0_model_class,
    nn_params,
    num_classes=len(tokenizer) if tokenizer else cfg.data.N,
    hybrid_loss_coeff=cfg.model.hybrid_loss_coeff,
    gamma=cfg.model.gamma,
    forward_kwargs=OmegaConf.to_container(cfg.model.forward_kwargs, resolve=True),
    schedule_type=cfg.model.schedule_type,
    logistic_pars=cfg.model.logistic_pars,
    fix_x_t_bias=cfg.model.fix_x_t_bias,
    n_T=cfg.model.n_T,
    t_max=cfg.model.t_max,
    seed=cfg.model.seed,
    input_logits=cfg.model.input_logits,
    **OmegaConf.to_container(cfg.train, resolve=True),)


In [None]:
model.pre_configure_model(train_dataloader)

In [None]:
import gc

torch.cuda.empty_cache()
gc.collect()

In [None]:
steps = torch.arange(1000 + 1, dtype=torch.float32) / 1000

fig, ax = plt.subplots(1, 3, figsize=[9, 3])
ax[0].semilogy(model.beta(steps), label="Hazard", color='black')
ax[0].legend()

alpha_bar = torch.exp(model.log_alpha(steps))
ax[1].plot(alpha_bar, label="p(unmut)", color='black')
ax[1].legend()

alpha_bar = model.log_alpha(steps)
ax[2].plot(-alpha_bar, label="E[S]", color='black')
ax[2].legend()

# L = utils.get_inf_gens(forward_kwargs, N)
# ax[2].imshow(L, vmin=-0.1, vmax=0.1, cmap='bwr')

In [None]:
# K_sq = (model.K_T @ model.K_T).cuda()
model.K_T = model.K_T.cuda()

g = torch.randn([16000, model.num_classes]).float().cuda()
S = 2 * torch.ones([16000]).long().cuda()

# %timeit K_sq @ g
%timeit model.K_T_power_mult(S, g)

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

wandb.init()
wandb_logger = WandbLogger(project="debugging")
lightning_model = model
torch.set_float32_matmul_precision('high')

# from pytorch_lightning.profilers import PyTorchProfiler
# profiler = PyTorchProfiler(
#     on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'),
#     schedule=torch.profiler.schedule(wait=1, warmup=1, active=10, repeat=2)
# )
trainer = Trainer(max_epochs=1, accelerator='auto',
                  devices=torch.cuda.device_count(), logger=wandb_logger)#, profiler=profiler)
trainer.fit(lightning_model, train_dataloader, test_dataloader)

In [None]:

import wandb
wandb.finish()

In [None]:
a = [np.array(p.shape) for p in model.x0_model.down_blocks[0][0].parameters()]

In [None]:
print(sum([np.prod(p) for p in a]))
print(a)

In [None]:
31470288 - (17884416 + 11563008) * (1-48/64)