In [1]:
!nvidia-smi

Thu Nov 10 22:01:17 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   68C    P0    31W /  70W |   1450MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
import torchvision.transforms as T

from vq_vae import VQVAE


In [3]:
device = 'cuda:0'

In [4]:
def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch

    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed = 42
seed_everything(seed)

In [5]:
data_dir = Path('/app/data/mnist')
data_dir.mkdir(exist_ok=True, parents=True)

In [6]:
to_tensor = T.ToTensor()

dataset = MNIST(root=data_dir, train=True, download=True, transform=to_tensor)

In [7]:
train_dataloader = DataLoader(dataset, batch_size=32, num_workers=2)

In [8]:
class DDPM(nn.Module):
    def __init__(self, T = 1000):
        super(DDPM, self).__init__()
        
        self.diffuser = VQVAE(in_channels=1, out_channels=1, embedding_dim=16, num_embeddings=10, img_size=27)

        self.T = T
        self.betas = self.var_scale(torch.linspace(0, 0.999, self.T)).to(device)
    
    def noise(self, x_t, b_t):
        res = []
        for i_b_t in b_t:
            i_noise = torch.empty(x_t[:1].shape).normal_(mean=torch.sqrt(1.-i_b_t), std=i_b_t).to(device)
            res.append(i_noise)
        noise = torch.cat(res, dim=0)
        return x_t, noise
    
    def var_scale(self, t):
        s = 0.008
        return torch.cos((t+s)/(1+s)*(torch.pi/2))**2
    
    def forward(self, x0):
        t = torch.randint(self.T, (x0.size(0),)).to(x0.device)
        b_t = self.betas[t]
        
        x0, noise = self.noise(x0, b_t)
        pred_noise, x0, vq_loss = self.diffuser(x0+noise, b_t.unsqueeze(1).float())
        losses_dict = self.diffuser.loss_function(pred_noise, noise, vq_loss)
        return losses_dict

In [9]:
model = DDPM().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=1e-5)

In [10]:
import wandb
wandb.init(project='mnist_diffusion', name=f'vqvae_diffusion={seed}')

[34m[1mwandb[0m: Currently logged in as: [33mnerlfield[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
num_epoch = 10000
loss_vis_freq = 200

i = 0
for epoch in range(num_epoch):
    for batch in train_dataloader:
        batch = batch[0].to(device) #only images
        losses_dict = model(batch)
        
        loss = losses_dict['loss']

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if wandb.run is not None:
            wandb.log({'loss': loss.item(),
                       'recon_loss': losses_dict['Reconstruction_Loss'].item(),
                       'vq_loss': losses_dict['VQ_Loss'].item(),
                       'i': i})
        
        if i % loss_vis_freq == 0:
                print(loss.item())
        
        i += 1

1.1304610967636108
0.5282307267189026
0.546099841594696
0.599320113658905
0.40492522716522217
0.6079970598220825
0.8302903175354004
1.6795475482940674
1.807058334350586
1.4012625217437744
2.4030938148498535
3.387909412384033
2.9166598320007324
2.6926119327545166
3.0163447856903076
2.925408363342285
2.458179473876953
5.917365550994873
5.343830108642578
8.873188972473145
2.745845317840576
9.836091995239258
9.176648139953613
10.125917434692383
9.596650123596191
9.661373138427734
9.355484008789062
10.855280876159668
6.041369438171387
6.056549549102783
11.93562126159668
7.195795059204102
7.9415974617004395
9.18093490600586
7.571775913238525
8.079236030578613
6.742697238922119
8.51246452331543
5.114672660827637
6.351847171783447
6.649860382080078
9.15014362335205
10.007893562316895
10.135489463806152
8.225232124328613
8.1424560546875
7.328952789306641
5.532806873321533
6.407562255859375
7.756834983825684
5.8097991943359375
6.059876441955566
6.794969081878662
5.373486042022705
5.1550464630126