In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

import numpy as np

from tqdm import tqdm
from torch.optim import Adam

import math

In [None]:
DEVICE = "mps"

# Intro:

In this jupyter notebook, we are going to code some crucial parts of the diffusion process and train a diffusion model on MNIST dataset.

We will mainly base our code on the [original paper](https://arxiv.org/abs/2006.11239).


# Load Dataset

In [None]:
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import torch.utils.data as torch_data

transform_to_tensor = transforms.Compose([transforms.ToTensor()])

train_dataset = MNIST("./MNIST_dataset/", transform=transform_to_tensor, train= True, download=True)
val_dataset = MNIST("./MNIST_dataset/", transform=transform_to_tensor, train= False, download=True)


In [None]:
train_dataloader = torch_data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
val_dataloader = torch_data.DataLoader(dataset=val_dataset, batch_size=32, shuffle=True)

# Check MNIST data

What are the data ? Size, shape, type, etc.
Please plot some nice figures to compare the soon-to-be generated images.

# Model (not the most important part)

I shamelessly copied the code from the Internet. It is a simple Unet.

## UNET BLOCK

Not really conventional block. Usually, we should only get 1 convolution ber block

In [None]:
class UNetBlock(nn.Module):
    def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, padding=1, activation=None, normalize=True):
        super(UNetBlock, self).__init__()
        self.ln = nn.LayerNorm(shape)
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size, stride, padding)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size, stride, padding)
        self.activation = nn.SiLU() if activation is None else activation
        self.normalize = normalize

    def forward(self, x):
        out = self.ln(x) if self.normalize else x
        out = self.conv1(out)
        out = self.activation(out)
        out = self.conv2(out)
        out = self.activation(out)
        return out
    


## Positional Embedding of time step

Exactly as in transformer, we encode the time step into an embedding. This is to help the model to learn the temporal dependency during the training and the generation step. We use the usual sinusoidal positional embedding but we can use any other positional embedding. 

In [None]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

## Model definition

This is the final model. We use the UNET block and the positional embedding to define the model. Please note that in diffusion, the main interest is not the model architecture but the diffusion process. The model is just a tool to help us to learn the diffusion process. Any other model should work ok.

In [None]:
class MyUNet(nn.Module):
    def __init__(self, time_emb_dim=100):
        super(MyUNet, self).__init__()

        # Sinusoidal embedding
        self.time_embed = SinusoidalPosEmb(time_emb_dim)
        
        # First half
        self.te1 = mlp_time_embedding(time_emb_dim, 1)
        self.b1 = nn.Sequential(
            UNetBlock((1, 28, 28), 1, 10),
            UNetBlock((10, 28, 28), 10, 10),
            UNetBlock((10, 28, 28), 10, 10)
        )
        self.down1 = nn.Conv2d(10, 10, 4, 2, 1)

        self.te2 = mlp_time_embedding(time_emb_dim, 10)
        self.b2 = nn.Sequential(
            UNetBlock((10, 14, 14), 10, 20),
            UNetBlock((20, 14, 14), 20, 20),
            UNetBlock((20, 14, 14), 20, 20)
        )
        self.down2 = nn.Conv2d(20, 20, 4, 2, 1)

        self.te3 = mlp_time_embedding(time_emb_dim, 20)
        self.b3 = nn.Sequential(
            UNetBlock((20, 7, 7), 20, 40),
            UNetBlock((40, 7, 7), 40, 40),
            UNetBlock((40, 7, 7), 40, 40)
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(40, 40, 2, 1),
            nn.SiLU(),
            nn.Conv2d(40, 40, 4, 2, 1)
        )

        # Bottleneck
        self.te_mid = mlp_time_embedding(time_emb_dim, 40)
        self.b_mid = nn.Sequential(
            UNetBlock((40, 3, 3), 40, 20),
            UNetBlock((20, 3, 3), 20, 20),
            UNetBlock((20, 3, 3), 20, 40)
        )

        # Second half
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(40, 40, 4, 2, 1),
            nn.SiLU(),
            nn.ConvTranspose2d(40, 40, 2, 1)
        )

        self.te4 = mlp_time_embedding(time_emb_dim, 80)
        self.b4 = nn.Sequential(
            UNetBlock((80, 7, 7), 80, 40),
            UNetBlock((40, 7, 7), 40, 20),
            UNetBlock((20, 7, 7), 20, 20)
        )

        self.up2 = nn.ConvTranspose2d(20, 20, 4, 2, 1)
        self.te5 = mlp_time_embedding(time_emb_dim, 40)
        self.b5 = nn.Sequential(
            UNetBlock((40, 14, 14), 40, 20),
            UNetBlock((20, 14, 14), 20, 10),
            UNetBlock((10, 14, 14), 10, 10)
        )

        self.up3 = nn.ConvTranspose2d(10, 10, 4, 2, 1)
        self.te_out = mlp_time_embedding(time_emb_dim, 20)
        self.b_out = nn.Sequential(
            UNetBlock((20, 28, 28), 20, 10),
            UNetBlock((10, 28, 28), 10, 10),
            UNetBlock((10, 28, 28), 10, 10, normalize=False)
        )

        self.conv_out = nn.Conv2d(10, 1, 3, 1, 1)

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        # x is (N, 1, 28, 28) (image with positional embedding stacked on channel dimension)
        t = self.time_embed(t)
        n = len(x)
        out1 = self.b1(x + self.te1(t).reshape(n, -1, 1, 1))  # (N, 10, 28, 28)
        out2 = self.b2(self.down1(out1) + self.te2(t).reshape(n, -1, 1, 1))  # (N, 20, 14, 14)
        out3 = self.b3(self.down2(out2) + self.te3(t).reshape(n, -1, 1, 1))  # (N, 40, 7, 7)

        out_mid = self.b_mid(self.down3(out3) + self.te_mid(t).reshape(n, -1, 1, 1))  # (N, 40, 3, 3)

        out4 = torch.cat((out3, self.up1(out_mid)), dim=1)  # (N, 80, 7, 7)
        out4 = self.b4(out4 + self.te4(t).reshape(n, -1, 1, 1))  # (N, 20, 7, 7)

        out5 = torch.cat((out2, self.up2(out4)), dim=1)  # (N, 40, 14, 14)
        out5 = self.b5(out5 + self.te5(t).reshape(n, -1, 1, 1))  # (N, 10, 14, 14)

        out = torch.cat((out1, self.up3(out5)), dim=1)  # (N, 20, 28, 28)
        out = self.b_out(out + self.te_out(t).reshape(n, -1, 1, 1))  # (N, 1, 28, 28)

        out = self.conv_out(out)

        return out

def mlp_time_embedding(dim_in: int, dim_out:int)-> nn.Sequential:
    return nn.Sequential(
        nn.Linear(dim_in, dim_out),
        nn.SiLU(),
        nn.Linear(dim_out, dim_out)
    )


How many trainable parameters do we have ?

# Diffusion module

We choosed to use a nn.Module object for the diffusion process. Here is the main important challenge. 
In a diffusion process, we have to take care on:
- beta variance schedule (we will go to the simple case where $\beta_t$ is linear in $t$ like in the original paper but you cantry others schedules (for instance this [a cosine scheduler](https://arxiv.org/abs/2102.09672))) 
- code $\alpha_t := 1-\beta_t$ and $\bar{\alpha}_t := \prod_{i=1}^t \alpha_i$

<div>
<img src="images/parametrization_trick.png" width="700"/>
</div>

- The make_noisy function that given images at time 0 (no noise) add noise until the step $t$.
- The forward function (to train the model)
- The sampling function, that will be used to generate images


<div>
<img src="images/training_sampling.png" width="700"/>
</div>

In [None]:
class Diffusion(nn.Module):
    def __init__(self, model: nn.Module, n_times: int=1000, beta_minmax:tuple[float, float]=[1e-4, 2e-2], device: str='mps'):
    
        super(Diffusion, self).__init__()
    
        self.n_times = n_times

        self.model = model
        
        # define betas variance schedule

        # define alpha for forward diffusion kernel
        
        self.device = device
    
    def extract(self, a: torch.Tensor, t:torch.Tensor, x_shape:tuple[int, ...])-> torch.Tensor:
        """
            This function will be used to extract alphas and betas at time t.
            Basically, it gets the value of t, select the corresponding value of the tensor a
            at the index t, and reshape it to match the value of x.
            This tensor's broadcasting helps for hadamard product.
            from lucidrains' implementation
                https://github.com/lucidrains/denoising-diffusion-pytorch/blob/beb2f2d8dd9b4f2bd5be4719f37082fe061ee450/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L376        """
        b, *_ = t.shape
        out = a.gather(-1, t)
        return out.reshape(b, *((1,) * (len(x_shape) - 1)))
    
    def scale_to_minus_one_to_one(self, x: torch.Tensor)->torch.Tensor:
        # according to the DDPMs paper, normalization seems to be crucial to train reverse process network
        return x * 2 - 1
    
    def reverse_scale_to_zero_to_one(self, x:torch.Tensor)->torch.Tensor:
        # reverse the normalization
        return (x + 1) * 0.5
    
    def make_noisy(self, x_zeros: torch.Tensor, t: torch.Tensor)->tuple[torch.Tensor, torch.Tensor]:
        # In this function, you should, given $x_{t=0}$ generate $x_t$, i.e. generate a noisy
        # sample. This function will be used to create inputs during the training so no need to pass
        # any gradients here.
        # Please return $x_t$ and $\epsilon$.
        pass
    
    
    def forward(self, x_zeros:torch.Tensor)-> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # Forward function. Given a not noisy sample $x_{t=0}$, generate a noisy sample $x_t$.
        # t should be sampled uniformly from $[1, T]$.
        # \epsilon is given by the function make_noisy.
        # return the noisy samples $x_t$, epsilon and the predicted noise $\epsilon$.
        pass
    
    def denoise_at_t(self, x_t:torch.Tensor, timestep:torch.Tensor, t:int)-> torch.Tensor:
        # Denoise function. should be used to denoise a sample $x_t$ to time $t-1$.
        pass

    def sample(self, batch_size:int)-> torch.Tensor:
        # start from random noise vector, x_0
        pass


In [None]:
model = MyUNet(time_emb_dim=256)

diffusion = Diffusion(model, n_times=1000, device=DEVICE).to(DEVICE)

optimizer = Adam(diffusion.parameters(), lr=3e-4)
denoising_loss = nn.MSELoss()

In [None]:
model.eval()
for batch_idx, (x, _) in enumerate(val_dataloader):
    x = x.to(DEVICE)
    perturbed_images, epsilon, pred_epsilon = diffusion(x)
    perturbed_images = diffusion.reverse_scale_to_zero_to_one(perturbed_images)
    break

In [None]:
plt.imshow(perturbed_images.cpu()[3][0])

In [None]:
print("Start training DDPMs...")
model.train()
epochs = 10
for epoch in range(epochs):
    noise_prediction_loss = 0
    for batch_idx, (x, _) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        optimizer.zero_grad()

        x = x.to(DEVICE)
        
        noisy_input, epsilon, pred_epsilon = diffusion(x)
        loss = denoising_loss(pred_epsilon, epsilon)
        
        noise_prediction_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        
    print("\tEpoch", epoch + 1, "complete!", "\tDenoising Loss: ", noise_prediction_loss / batch_idx)
    
print("Finish!!")

In [None]:
model.eval()

with torch.no_grad():
    generated_images = diffusion.sample(N=32)


In [None]:
plt.imshow(generated_images[9][0].cpu())