In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot as plt

from d3pm_sc.ct_sched_cond import ScheduleCondition
from d3pm_sc.ct_sched_cond_sparse_k import ScheduleConditionSparseK
from d3pm_sc.masking_diffusion import MaskingDiffusion
from d3pm_sc.d3pm_classic import D3PMClassic
from d3pm_sc.unet import KingmaUNet, UNet, SimpleUNet
# from d3pm_sc.dit import DiT_Llama
from d3pm_sc import utils
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

import torch
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10

import wandb
# wandb.login(key="6a47f093d2a55e4f4e85b33767423f2db66355b8")


In [None]:
from nets import get_model_setup
from data import get_dataloaders
from omegaconf import OmegaConf
from evodiff.utils import Tokenizer
import data

cfg = OmegaConf.load('configs/basic_protein.yaml')
cfg.train.batch_size = 128

num_classes = cfg.data.N

##### Load data
train_dataloader, test_dataloader = data.get_dataloaders(cfg)
tokenizer = train_dataloader.tokenizer if hasattr(train_dataloader, "tokenizer") else None

##### Setup x0_model
x0_model_class, nn_params = get_model_setup(cfg, tokenizer) 

##### Pick model
forward_kwargs = OmegaConf.to_container(cfg.model.forward_kwargs, resolve=True)
model = ScheduleCondition(        
    x0_model_class,
    nn_params,
    num_classes=num_classes,
    hybrid_loss_coeff=cfg.model.hybrid_loss_coeff,
    gamma=cfg.model.gamma,
    forward_kwargs=forward_kwargs,
    schedule_type=cfg.model.schedule_type,
    logistic_pars=cfg.model.logistic_pars,
    fix_x_t_bias=cfg.model.fix_x_t_bias,
    n_T=cfg.model.n_T,
    t_max=cfg.model.t_max,
    seed=cfg.model.seed,
    input_logits=cfg.model.input_logits,
    sedd_param=cfg.model.sedd_param,
    eff_num_classes=cfg.model.eff_num_classes,
    n_stat_samples=2e5,
    **OmegaConf.to_container(cfg.train, resolve=True),)


In [None]:
model.pre_configure_model(train_dataloader)

In [None]:
steps = torch.arange(1000 + 1, dtype=torch.float32) / 1000

fig, ax = plt.subplots(1, 4, figsize=[12, 3])
ax[0].semilogy(model.beta(steps), label="Hazard", color='black')
ax[0].legend()

alpha_bar = torch.exp(model.log_alpha(steps))
ax[1].plot(alpha_bar, label="p(unmut)", color='black')
ax[1].legend()

alpha_bar = model.log_alpha(steps)
ax[2].plot(-alpha_bar, label="E[S]", color='black')
ax[2].legend()

L = utils.get_inf_gen(forward_kwargs, num_classes)
ax[3].imshow(L, vmin=-0.1, vmax=0.1, cmap='bwr')

In [None]:

import gc

torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
gc.collect()

from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(
    dirpath='checkpoints',
    filename='latest-model',
    save_last=True,
    save_top_k=1,
    every_n_train_steps=None,  # Disable saving based on training steps
    save_on_train_epoch_end=False,  # Disable saving at the end of training epoch
)

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

wandb.init()
wandb_logger = WandbLogger(project="debugging", log_model='all')
lightning_model = model
torch.set_float32_matmul_precision('high')
model.to(torch.float32)

# from pytorch_lightning.profilers import PyTorchProfiler
# profiler = PyTorchProfiler(
#     on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'),
#     schedule=torch.profiler.schedule(wait=1, warmup=1, active=10, repeat=2)
# )
limit_val_batches = 210000//cfg.train.batch_size
val_check_interval = 10 * limit_val_batches

trainer = Trainer(max_epochs=1, accelerator='auto',
                  devices=torch.cuda.device_count(), logger=wandb_logger,
                  limit_val_batches=limit_val_batches,
                  callbacks=[checkpoint_callback],
                  val_check_interval=val_check_interval)#, profiler=profiler)
trainer.fit(lightning_model, train_dataloader, test_dataloader)

In [None]:
p = check_unused_parameters(model.cuda(), train_dataloader)

In [None]:
p

In [None]:
def check_unused_parameters(model, dataloader):
    # Set all gradients to None
    for param in model.parameters():
        param.grad = None
    
    # Get a batch of data
    batch = next(iter(dataloader))
    
    # Perform a forward and backward pass
    print((b.cuda() for b in batch))
    loss = model.training_step((batch[0].cuda(), batch[1].cuda()), 0)
    loss.backward()
    
    # Check which parameters have gradients
    unused_params = []
    for name, param in model.named_parameters():
        if param.grad is None:
            unused_params.append(name)
    
    return unused_params