In [1]:
!nvidia-smi

Fri Nov 11 12:45:39 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   68C    P0    31W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
import torchvision.transforms as T
from torchvision.utils import save_image, make_grid

from simple_ae import SimpleAE
from tqdm import tqdm
import math

In [3]:
device = 'cuda:0'

In [5]:
data_dir = Path('/app/data/mnist')
data_dir.mkdir(exist_ok=True, parents=True)

In [6]:
tf = T.Compose(
        [T.ToTensor(), T.Normalize((0.0,), (1.0))]
    )

dataset = MNIST(root=data_dir, train=True, download=True, transform=tf)

In [7]:
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4, drop_last=True)

In [8]:
class DDPM(nn.Module):
    def __init__(self, T = 1000):
        super(DDPM, self).__init__()
        
        self.diffuser = SimpleAE(in_channels=1, filters=128)

        self.T = T
        self.betas = torch.linspace(1e-4, 0.02, T).to(device)
        self.a = 1. - self.betas
        self.a_hat = torch.cumprod(self.a, dim=0)

    def forward(self, x0):
        t = torch.randint(self.T, (x0.size(0),)).to(device)
        eps = torch.randn_like(x0).to(device)
        
        a_hat_t = self.a_hat[t][:, None, None, None].to(device)
        
        pred_eps = self.diffuser(torch.sqrt(a_hat_t) * x0 +
                                 torch.sqrt(1 - a_hat_t) * eps, t.unsqueeze(1).float()/self.T)
        losses_dict = self.diffuser.loss_function(pred_eps, eps)
        return losses_dict
    
    @torch.no_grad()
    def sample(self, n_samples, data_size):
        xT = torch.randn([n_samples, *data_size]).to(device)
        for t in tqdm(range(self.T, 0, -1)):
            z = torch.randn_like(xT) if t > 1 else 0
            k = ((1 - self.a[t-1])/math.sqrt(1 - self.a_hat[t-1])).to(xT.device)
            t_tensor = t*torch.ones((n_samples, 1)).to(xT.device)/self.T
            sigma = torch.sqrt(1 - self.a[t-1]).to(xT.device)
            xT = 1/math.sqrt(self.a[t-1]) * (xT - k*self.diffuser(xT, t_tensor)) + z * sigma
        return xT
    
def log_images(t, i):
    img_grid = make_grid(samples, nrow=4, value_range=(-1, 1))
    save_image(img_grid, f"log_images/{i}_sample.png")
    if wandb.run is not None:
        images = wandb.Image(img_grid, caption="Random sampled images")

        wandb.log({"Samples": images, 'i': i})

In [9]:
import wandb
wandb.init(project='mnist_diffusion', name=f'simple_ae_diffusion')

[34m[1mwandb[0m: Currently logged in as: [33mnerlfield[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
model = DDPM().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-5)

In [11]:
num_epoch = 500
loss_vis_freq = 200

i = 0
epoch = 0

for epoch in range(epoch, num_epoch, 1):
    model.train()
    for batch in train_dataloader:
        batch = batch[0].to(device) #only images
        losses_dict = model(batch)
        
        loss = losses_dict['loss']

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if wandb.run is not None:
            wandb.log({'loss': loss.item(),
                       'i': i})
        elif i % loss_vis_freq == 0:
            print(loss.item())
        
        i += 1
        # break
    
    model.eval()
    samples = model.sample(32, (1, 28, 28))
    log_images(samples, i)


100%|███████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:12<00:00, 77.75it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:12<00:00, 77.45it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:12<00:00, 77.35it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:12<00:00, 77.70it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:12<00:00, 77.43it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:12<00:00, 77.50it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:12<00:00, 77.23it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 100

KeyboardInterrupt: 

Error in callback <function _WandbInit._pause_backend at 0x7f0cbb728820> (for post_run_cell):


Exception: The wandb backend process has shutdown

# overfit

In [None]:
# for batch in train_dataloader:
#     batch = batch[0].to(device)
#     break

In [None]:
# num_epoch = 1000
# loss_vis_freq = 200

# i = 0
# for epoch in range(num_epoch):
#     model.train()
#     for wsefd in range(1000):
#         losses_dict = model(batch)

#         loss = losses_dict['loss']

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         if wandb.run is not None:
#             wandb.log({'loss': loss.item(),
#                        'i': i})
#         elif i % loss_vis_freq == 0:
#             print(loss.item())

#         i += 1
        
#     model.eval()
#     samples = model.sample(32, (1, 28, 28))
#     log_images(samples, i)