## DDPM on CIFAR10

This is just me trying to learn how to implement simple diffusion models

paper|official code
-|-
[Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) | [code](https://github.com/hojonathanho/diffusion)
[Improved Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2102.09672) | [code](https://github.com/openai/improved-diffusion)
[Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/abs/2105.05233) | [code](https://github.com/openai/guided-diffusion)

helpful resources

- [lucidrains full implementation in one file](https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py)
- [unet parts](https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py)


## What's next?

- maybe text diffusion? https://github.com/madaan/minimal-text-diffusiony

In [None]:
# pip install wandb easydict tqdm torch==2 torchvision -qqq

In [None]:
def prefix_dict(prefix, dic):
    return {f"{prefix}.{key}": value for key, value in dic.items()}
def flatten_dict(d, parent_key='', sep='.'):
    items = {}
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.update(flatten_dict(v, new_key, sep=sep))
        else:
            items[new_key] = v
    return items


class Timer:
    def __init__(self, message=''):
        self.message = message
    def __enter__(self):
        self.start_time = time.time()
        return self
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.end_time = time.time()
        elapsed_time = self.end_time - self.start_time
        # print(f"{self.message} time: {elapsed_time:.2f}sec")


In [None]:
from tqdm import tqdm
import os
import torch
import torchvision
import torchvision.transforms as transforms

device = 'cuda' if torch.cuda.is_available() else 'cpu'
assert device == 'cuda'
# torch.set_default_device(device)

# optimizer = torch.optim.Adam(model.parameters())
# criterion = torch.nn.MSELoss().cuda()

transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
    # transforms.ColorJitter(brightness=0.5,contrast=0.5),
    transforms.Lambda(lambda t: (t * 2) - 1),
    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ##TODO: should I use this?
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.Lambda(lambda t: (t * 2) - 1),
])

batch_size = 512

train_dataset = torchvision.datasets.CIFAR10(os.path.expanduser("~/data/"), download=True, train=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(os.path.expanduser("~/data/"), download=True, train=False, transform=transform_test)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=19, prefetch_factor=2,)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=19, prefetch_factor=2,)

import time
tic = time.time()

for i, x in enumerate(tqdm(train_loader)):
    if i > 50:
        break

duration = time.time() - tic

print(f'dataloader completed 10 batches in {duration:.2f}')
assert duration < 3


In [None]:

import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
    img = (img + 1) / 2     # unnormalize
    npimg = img.detach().cpu().numpy().astype(float)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

imshow(train_dataset[0][0])

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
import torch.nn.functional as F
from typing import *


print(torch.__version__)

assert torch.cuda.is_available()

class DoubleConv(nn.Module):
    def __init__(
            self, 
            in_channels: int,
            out_channels: int,
            kernel_size: Union[int, Tuple[int, int]]=3,
        ):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Conv2d(in_channels, out_channels//2, kernel_size, padding=1, bias=False),
            nn.BatchNorm2d(out_channels//2),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels//2, out_channels, kernel_size, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        ])
    def forward(x):
        return self.convs(x)

class Down(nn.Module):
    def __init__(
        self, 
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]]=3,
    ):
        super().__init__()
        self.conv = nn.Conv2d(dim, dim, 3, 2, 1)

class Up(nn.Module):
    def __init__(
        self, 
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]]=3,
    ):
        super().__init__()
        self.convs = DoubleConv(in_channels, out_channels, kernel_size)
        self.time_mlp = nn.ModuleList([
            nn.Linear(1, kernel_size**2//2),
            nn.ReLU(inplace=True),
            nn.Linear(kernel_size**2//2, kernel_size**2),
            nn.ReLU(inplace=True),
        ])
    def forward(self, x, t=None):
        x = self.convs(x)
        if t is not None:
            x = torch.concat((x, self.time_mlp(t)), dim=0)
        return x


class Unet(nn.Module):
    def __init__(self, bias=True, kernel_size=3, stride=1, final_conv=True, **kwargs):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 128, kernel_size=kernel_size, stride=stride, padding=1, bias=bias)
        self.conv2 = nn.Conv2d(128, 512, kernel_size=kernel_size, stride=stride, padding=1, bias=bias)
        self.conv3 = nn.Conv2d(512, 2048, kernel_size=kernel_size, stride=stride, padding=1, bias=bias)
        self.deconv1 = nn.ConvTranspose2d(2048, 512, kernel_size=kernel_size, stride=stride, padding=1, bias=bias)
        self.deconv2 = nn.ConvTranspose2d(512, 128, kernel_size=kernel_size, stride=stride, padding=1, bias=bias)
        self.deconv3 = nn.ConvTranspose2d(128, 3, kernel_size=kernel_size, stride=stride, padding=1, bias=bias)

        self.final_conv = None
        if final_conv:
            self.final_conv = nn.Conv2d(3, 3, 1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(F.relu(x1))
        x3 = self.conv3(F.relu(x2))

        x4 = self.deconv1(F.relu(x3)) + x2
        x5 = self.deconv2(F.relu(x4)) + x1
        x6 = self.deconv3(F.relu(x5)) + x

        if self.final_conv is not None:
            x6 = self.final_conv(x6)
        return F.tanh(x6)


model_config = {
    "bias": True,
    "kernel_size": 3,
    "stride": 1,
    "final_conv": False,
}

# test
net = Unet(**model_config)

inp = torch.ones(1, 3, 10, 10)
output = net(inp)
assert output.shape == inp.shape
del net
print('network tested successfully')


"""
lessons:

padding will change the size

"""


In [None]:
import wandb
import pprint
from easydict import EasyDict

global_step = 0

config = EasyDict({
    "use_amp": True,
    "epochs": 5,
    "optimizer": "Adam",
    "optimizer_kwargs": {
        "lr": 3e-4,
    },
    "loss_fn": "MSELoss",
    "eval_step": 20,
    "batch_size": batch_size,
    # "eval_step": eval_step,
    "model": model_config,
    # any other hyperparameters or settings you want to save
})


net = Unet(**model_config).cuda() #make_model(in_size, out_size, num_layers)
opt = getattr(torch.optim, config.optimizer)(net.parameters(), **config.optimizer_kwargs)
scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp)
loss_fn = getattr(nn, config.loss_fn)().to(device)

pprint.pprint(config)
wandb.finish()
# Initialize wandb
wandb.init(project='DDPM from scratch', name='3layer unet identity function', config=flatten_dict(config), resume=False)


In [None]:

# wandb.init(project='DDPM from scratch', name='3layer unet identity function', config=flatten_dict(config), resume=True)


In [None]:

torch.backends.cudnn.benchmark = True


for epoch in range(config.epochs):
    # for input, target in zip(tqdm(data), targets):
    pbar = tqdm(train_loader)
    start_time = time.time()
    for b, (input, target) in enumerate(pbar):
        # print(f'dataloader time: {time.time()-start_time}sec')
        global_step += 1
        # target = input.clone().cuda()
        input = input.cuda()

        with torch.autocast(device_type=device, dtype=torch.float16, enabled=config.use_amp):
            with Timer("forward"):
                output = net(input)
            with Timer("loss_fn"):
                loss = loss_fn(output, input)
                wandb.log({'loss': loss.item()})
        pbar.set_postfix(loss=loss.item())
        if global_step % config.eval_step == config.eval_step-1:
            val_sample = test_dataset[0][0].unsqueeze(0).to(device)
            val_output = net(val_sample)
            img_sample = torch.concat((val_sample[0], val_output[0], val_sample[0] - val_output[0]), dim=2)
            imshow(img_sample)
            image = wandb.Image(
                img_sample, 
                caption="Left: original, Right: reconstructed"
            )
            wandb.log({"examples": image})
        with Timer("backward"):
            scaler.scale(loss).backward()
        with Timer("step"):
            scaler.step(opt)
        scaler.update()
        opt.zero_grad() # set_to_none=True here can modestly improve performance
        start_time = time.time()
