In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot as plt

from d3pm_sc.ct_sched_cond import ScheduleCondition
from d3pm_sc.sedd import SEDD
from d3pm_sc.masking_diffusion import MaskingDiffusion
from d3pm_sc.d3pm_classic import D3PMClassic
from d3pm_sc.unet import KingmaUNet, UNet, SimpleUNet, GigaUNet
from d3pm_sc import utils
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

import torch
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10

import wandb
wandb.login(key="6a47f093d2a55e4f4e85b33767423f2db66355b8")


In [None]:
from nets import get_model_setup
from data import get_dataloaders
from omegaconf import OmegaConf
import data

cfg = OmegaConf.load('configs/basic.yaml')
cfg.train.batch_size = cfg.train.batch_size//2
cfg.model.model = 'SEDD'

##### Load data
torch.manual_seed(cfg.model.seed)
train_dataloader, test_dataloader = data.get_dataloaders(cfg)
tokenizer = train_dataloader.tokenizer if hasattr(train_dataloader, "tokenizer") else None

##### Setup x0_model
x0_model_class, nn_params = get_model_setup(cfg, tokenizer) 

#### Pick model
model = SEDD(
    x0_model_class,
    nn_params,
    num_classes=len(tokenizer) if tokenizer else cfg.data.N,
    hybrid_loss_coeff=cfg.model.hybrid_loss_coeff,
    gamma=cfg.model.gamma,
    forward_kwargs=OmegaConf.to_container(cfg.model.forward_kwargs, resolve=True),
    schedule_type=cfg.model.schedule_type,
    logistic_pars=cfg.model.logistic_pars,
    fix_x_t_bias=cfg.model.fix_x_t_bias,
    n_T=cfg.model.n_T,
    t_max=cfg.model.t_max,
    seed=cfg.model.seed,
    sedd_param=cfg.model.sedd_param,
    eff_num_classes=cfg.model.eff_num_classes,
    input_logits=cfg.model.input_logits,
    tokenizer=tokenizer if cfg.data.data != 'uniref50' else Tokenizer(),
    **OmegaConf.to_container(cfg.train, resolve=True),
)
# model = ScheduleCondition.load_from_checkpoint('checkpoints/prime-sweep-1/epoch=33-step=11968.ckpt')

In [None]:
model.pre_configure_model(train_dataloader)

In [None]:
steps = (torch.arange(1000, dtype=torch.float32) / 1000)
betas = model.beta(steps)
mla = model.log_alpha(steps)
alpha_bar = torch.exp(mla)
L = utils.get_inf_gen(OmegaConf.to_container(cfg.model.forward_kwargs, resolve=True), cfg.data.N)

fig, ax = plt.subplots(1, 4, figsize=[12, 3])
ax[0].semilogy(betas, label="Hazard", color='black')
ax[0].legend()

ax[1].semilogy(-mla, label="E[S]", color='black')
ax[1].legend()

ax[2].plot(alpha_bar, label="p(unmut)", color='black')
ax[2].legend()

ax[3].imshow(L, vmin=-0.1, vmax=0.1, cmap='bwr')

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

wandb_logger = WandbLogger(project="debugging")
lightning_model = model
torch.set_float32_matmul_precision('high')

# from pytorch_lightning.profilers import PyTorchProfiler
# profiler = PyTorchProfiler(
#     on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'),
#     schedule=torch.profiler.schedule(wait=1, warmup=1, active=10, repeat=2)
# )
trainer = Trainer(max_epochs=1, accelerator='auto',
                  devices=torch.cuda.device_count(),
                  # val_check_interval=1,
                  logger=wandb_logger)#, profiler=profiler)
trainer.fit(lightning_model, train_dataloader, test_dataloader)

In [None]:

import wandb
wandb.finish()