In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot as plt

from d3pm_sc.ct_sched_cond import ScheduleCondition
from d3pm_sc.sedd import SEDD
from d3pm_sc.masking_diffusion import MaskingDiffusion
from d3pm_sc.d3pm_classic import D3PMClassic
from d3pm_sc.unet import KingmaUNet, UNet, SimpleUNet, GigaUNet
from d3pm_sc import utils
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

import torch
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10

import wandb
wandb.login(key="6a47f093d2a55e4f4e85b33767423f2db66355b8")


In [None]:
from nets import get_model_setup
from data import get_dataloaders
from omegaconf import OmegaConf
import data

cfg = OmegaConf.load('configs/basic.yaml')
cfg.train.batch_size = 8#cfg.train.batch_size
cfg.architecture.nn_params.n_layers = 32
cfg.architecture.nn_params.time_embed_dim = 512
cfg.data.N = 256
cfg.model.model = 'ScheduleCondition'
cfg.model.forward_kwargs.type = 'uniform'
# cfg.model.forward_kwargs.bandwidth = 0.05
# cfg.model.schedule_type = 'linear'

##### Load data
torch.manual_seed(cfg.model.seed)
train_dataloader, test_dataloader = data.get_dataloaders(cfg)
tokenizer = train_dataloader.tokenizer if hasattr(train_dataloader, "tokenizer") else None

##### Setup x0_model
# x0_model_class, nn_params = get_model_setup(cfg, tokenizer) 

### Pick model
# model = SEDD(
#     x0_model_class,
#     nn_params,
#     num_classes=len(tokenizer) if tokenizer else cfg.data.N,
#     hybrid_loss_coeff=cfg.model.hybrid_loss_coeff,
#     gamma=cfg.model.gamma,
#     forward_kwargs=OmegaConf.to_container(cfg.model.forward_kwargs, resolve=True),
#     schedule_type=cfg.model.schedule_type,
#     logistic_pars=cfg.model.logistic_pars,
#     fix_x_t_bias=cfg.model.fix_x_t_bias,
#     n_T=cfg.model.n_T,
#     t_max=cfg.model.t_max,
#     seed=cfg.model.seed,
#     sedd_param=cfg.model.sedd_param,
#     eff_num_classes=cfg.model.eff_num_classes,
#     input_logits=cfg.model.input_logits,
#     tokenizer=tokenizer if cfg.data.data != 'uniref50' else Tokenizer(),
#     **OmegaConf.to_container(cfg.train, resolve=True),
# )
use_ema = True
ckpt_path =  '/scratch/nvg7279/scud_scum/epoch=1386-step=542317.ckpt'
#
# '/scratch/nvg7279/scud_scum/epoch=1386-step=542317.ckpt'
# 
# '/home/nvg7279/d3pm/checkpoints/woven-dawn-229/epoch=775-step=303416.ckpt'
model = ScheduleCondition.load_from_checkpoint(ckpt_path)

In [None]:
model.gen_trans_step = 2**10

In [None]:
model.pre_configure_model(train_dataloader)

In [None]:
plt.plot(model.p0)

In [None]:
OmegaConf.to_container(cfg.model.forward_kwargs, resolve=True)

In [None]:
steps = (torch.arange(1000, dtype=torch.float32) / 1000)
betas = model.beta(steps)
mla = model.log_alpha(steps)
alpha_bar = torch.exp(mla)
L = utils.get_inf_gen(OmegaConf.to_container(cfg.model.forward_kwargs, resolve=True), cfg.data.N)

fig, ax = plt.subplots(1, 4, figsize=[12, 3])
ax[0].semilogy(betas, label="Hazard", color='black')
ax[0].legend()

ax[1].semilogy(-mla, label="E[S]", color='black')
ax[1].legend()

ax[2].plot(alpha_bar, label="p(unmut)", color='black')
ax[2].legend()

ax[3].imshow(L, vmin=-0.1, vmax=0.1, cmap='bwr')

In [None]:
model.eps = 1e-7

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from ema import EMA

# wandb.init()
wandb_logger = WandbLogger(project="debugging")
lightning_model = model
torch.set_float32_matmul_precision('high')

# ckpt_path = '/home/nvg7279/d3pm/checkpoints/woven-dawn-229/epoch=775-step=303416.ckpt'

# from pytorch_lightning.profilers import PyTorchProfiler
# profiler = PyTorchProfiler(
#     on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'),
#     schedule=torch.profiler.schedule(wait=1, warmup=1, active=10, repeat=2)
# )

trainer = Trainer(max_epochs=10000, accelerator='auto',
                  devices=1, # torch.cuda.device_count(),
                  val_check_interval=1,
                  limit_val_batches=1,
                  logger=wandb_logger,
                 callbacks=[EMA(1)]*use_ema)#, profiler=profiler)
# lightning_model.eval()  # Set model to evaluation mode
# with torch.no_grad():  # Disable gradient computation
trainer.fit(lightning_model, train_dataloader, test_dataloader, ckpt_path=ckpt_path)

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from ema import EMA

trainer = Trainer(max_epochs=10000, accelerator='auto',
                  devices=1, # torch.cuda.device_count(),
                  val_check_interval=1,
                  limit_val_batches=1,
                  logger=wandb_logger,
                 callbacks=[EMA(1)]*use_ema)#, profiler=profiler)
trainer.lightning_module = trainer.lightning_module.load_from_checkpoint(ckpt_path)


In [None]:
batch_size = 32
gen_trans_step = 1000

import torch, gc
torch.cuda.empty_cache(); gc.collect()

batch = next(iter(train_dataloader))
if isinstance(batch, tuple): #image datasets
    sample_x, cond = batch
    sample_x = sample_x.cuda()
    attn_mask = None
    if cond is not None:
        if cond.dim() == sample_x.dim(): #protein datasets
            attn_mask = cond.cuda()
            cond = None
        else:
            cond = cond.cuda()
elif isinstance(batch, dict): #text datasets
    sample_x, attn_mask = batch['input_ids'].cuda(), batch['attention_mask'].cuda()
    cond = batch['cond'].cuda() if 'cond' in batch else None

p = model.get_stationary().cuda()
samples = torch.multinomial(p, num_samples=batch_size*sample_x.shape[1:].numel(), replacement=True)
init_noise = samples.reshape((batch_size,)+sample_x.shape[1:]).to(sample_x.device)
model.eval().cuda()
images = model.sample_sequence(
    init_noise.cuda(), cond.cuda(), attn_mask, stride=3, n_T=gen_trans_step,
)

In [None]:

import wandb
wandb.finish()