In [None]:
!nvidia-smi

In [None]:
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
import torchvision.transforms as T

from simple_ae import SimpleAE


In [None]:
device = 'cuda:0'

In [None]:
def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch

    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed = 42
seed_everything(seed)

In [None]:
data_dir = Path('/app/data/mnist')
data_dir.mkdir(exist_ok=True, parents=True)

In [None]:
to_tensor = T.ToTensor()

dataset = MNIST(root=data_dir, train=True, download=True, transform=to_tensor)

In [None]:
train_dataloader = DataLoader(dataset, batch_size=32, num_workers=2)

In [None]:
class DDPM(nn.Module):
    def __init__(self, T = 1000):
        super(DDPM, self).__init__()
        
        self.diffuser = SimpleAE(1, 32)

        self.T = T
        self.betas = self.var_scale(torch.linspace(0, 0.999, self.T)).to(device)
    
    def noise(self, x_t, b_t):
        res = []
        for i_b_t in b_t:
            i_noise = torch.empty(x_t[:1].shape).normal_(mean=torch.sqrt(1.-i_b_t), std=i_b_t).to(device)
            res.append(i_noise)
        noise = torch.cat(res, dim=0)
        return x_t, noise
    
    def var_scale(self, t):
        s = 0.008
        return torch.cos((t+s)/(1+s)*(torch.pi/2))**2
    
    def forward(self, x0):
        t = torch.randint(self.T, (x0.size(0),)).to(x0.device)
        b_t = self.betas[t]
        
        x0, noise = self.noise(x0, b_t)
        pred_noise = self.diffuser(x0+noise, b_t.unsqueeze(1).float())
        losses_dict = self.diffuser.loss_function(pred_noise, noise)
        return losses_dict

In [None]:
model = DDPM().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=1e-5)

In [None]:
import wandb
wandb.init(project='mnist_diffusion', name=f'simple_ae_diffusion={seed}')

In [None]:
num_epoch = 10000
loss_vis_freq = 200

i = 0
for epoch in range(num_epoch):
    for batch in train_dataloader:
        batch = batch[0].to(device) #only images
        losses_dict = model(batch)
        
        loss = losses_dict['loss']

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if wandb.run is not None:
            wandb.log({'loss': loss.item(),
                       'i': i})
        
        if i % loss_vis_freq == 0:
                print(loss.item())
        
        i += 1