In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot as plt

from d3pm_sc.ct_sched_cond import ScheduleCondition
from d3pm_sc.masking_diffusion import MaskingDiffusion
from d3pm_sc.d3pm_classic import D3PMClassic
from d3pm_sc.unet import KingmaUNet, UNet, SimpleUNet
from d3pm_sc.dit import DiT_Llama
from d3pm_sc import utils
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

import torch
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10

import wandb
wandb.login(key="6a47f093d2a55e4f4e85b33767423f2db66355b8")


In [None]:
N = 256  # number of classes for discretized state per pixel
n_channel = 3
gamma = 0
hybrid_loss_coeff = 0.01
logistic_pars = False
fix_x_t_bias = False
lr = 2e-4
grad_clip_val = 1

s_dim = 4
conditional = False
forward_kwargs = {"type":"gaussian",
                  "normalized": True,
                  "bandwidth":1 / 7}

batch_size = 16
n_epoch = 14 * torch.cuda.device_count()

nn_params = {"n_channel": n_channel, 
             "N": N,
             "n_T": 500,
             "schedule_conditioning": True,
             "s_dim":4,
             "num_classes":10 if conditional else 1,
             "inc_attn": False,
             "time_embed_dim": 128,
             # "n_transformers": 12,
             # "n_heads": 12,
             # "ch": 256,
            }
x0_model_class = KingmaUNet

# x0_model_class = DiT_Llama
# nn_params['dim'] = 1024

##### Pick model
# Schedule conditioning
model = ScheduleCondition(x0_model_class, nn_params, num_classes=N, hybrid_loss_coeff=hybrid_loss_coeff, gamma=gamma,
                          forward_kwargs=forward_kwargs, logistic_pars=logistic_pars, fix_x_t_bias=fix_x_t_bias, lr=lr, grad_clip_val=grad_clip_val)

# # # Masking
# nn_params["N"] += 1
# nn_params["schedule_conditioning"] = False
# model = MaskingDiffusion(x0_model_class, nn_params, num_classes=N, hybrid_loss_coeff=0.01).cuda()


##### Load data
dataset = CIFAR10(
    "./data",
    train=True,
    download=True,
    transform=transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    ),
)
def collate_fn(batch):
    x, cond = zip(*batch)
    x = torch.stack(x)
    cond = torch.tensor(cond)
    cond = (cond * conditional)
    x = (x * (N - 1)).round().long().clamp(0, N - 1)
    return x, cond
train_size = int(len(dataset) * 0.9)
dataset, test_dataset = random_split(dataset, [train_size, len(dataset) - train_size])
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=15, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=15, collate_fn=collate_fn)


In [None]:
steps = torch.arange(1000 + 1, dtype=torch.float64) / 1000

fig, ax = plt.subplots(1, 3, figsize=[9, 3])
ax[0].semilogy(model.beta(steps), label="Hazard", color='black')
ax[0].legend()

alpha_bar = torch.exp(model.log_alpha(steps))
ax[1].plot(alpha_bar, label="p(unmut)", color='black')
ax[1].legend()

L = utils.get_inf_gens(forward_kwargs, N)
ax[2].imshow(L, vmin=-0.1, vmax=0.1, cmap='bwr')

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

wandb.init()
wandb_logger = WandbLogger(project="debugging")
lightning_model = model
torch.set_float32_matmul_precision('high')

# from pytorch_lightning.profilers import PyTorchProfiler
# profiler = PyTorchProfiler(
#     on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'),
#     schedule=torch.profiler.schedule(wait=1, warmup=1, active=10, repeat=2)
# )
trainer = Trainer(max_epochs=n_epoch, accelerator='auto',
                  devices=torch.cuda.device_count(), logger=wandb_logger)#, profiler=profiler)
trainer.fit(lightning_model, dataloader, test_dataloader)

In [None]:

import wandb
wandb.finish()

In [None]:
a = [np.array(p.shape) for p in model.x0_model.down_blocks[0][0].parameters()]

In [None]:
print(sum([np.prod(p) for p in a]))
print(a)

In [None]:
31470288 - (17884416 + 11563008) * (1-48/64)