# Diffusion Model from Scratch in PyTorch

Based on [this](https://towardsdatascience.com/diffusion-model-from-scratch-in-pytorch-ddpm-9d9760528946) TowardsDataScience article

## Packages and Presets

In [4]:
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange 
from typing import List
import random
import math
from torchvision import datasets, transforms
from torch.utils.data import DataLoader 
from timm.utils import ModelEmaV3 
from tqdm import tqdm
import matplotlib.pyplot as plt 
import torch.optim as optim
import numpy as np
%load_ext blackcellmagic

## About the Model

### Diffusion Model Training
**Input:** Training data x

**Output:** Model parameters $\phi_t$

**repeat**

> for $i\in \mathcal{B}$ do # for every training index in batch
>> $t \sim \mathcal{U}(\{1, ..., T\})$ # sample a time step

>> $\varepsilon \sim \mathcal{N}(0, I)$ # sample noise  

>> $\ell_i = \|g_t(\sqrt{\alpha_t} x_i + \sqrt{1-\alpha_t}\varepsilon\phi_t) - \varepsilon \|^2$ # individual loss

> Acucmulate loses for batch and take gradient step

**until converged**
 

### Sampling from Diffusion Model
**Input:** Model $g_t(\cdot, \phi_t)$

**Output:** Sample $x$

$z_T \sim \mathcal{N}(0, I)$ # sample last latent variable

for $t = T, ..., 2$ do
> $\hat{z}_{t-1} = \dfrac{1}{\sqrt{1-\beta_t}}z_t - \dfrac{\beta_t}{\sqrt{1-\alpha_t}\sqrt{1-\beta_t}}g_t(z_t, \phi_t)$ # predict previous latent variable

> $\varepsilon \sim \mathcal{N}(0, I)$

> $z_{t-1} = \hat{z}_{t-1} + \sqrt{\sigma_t}\varepsilon$ # add noise to previous altent variable

$x = \dfrac{1}{\sqrt{1-\beta_1}}z_1 - \dfrac{\beta_1}{\sqrt{1-\alpha_1}\sqrt{1-\alpha_1}}g_1(z_1, \phi_1)$ # sample from z1 without noise

Usually, we add back a small amount of noise to the final sample to prevent mode collapse/ keep the process stable after having subtracted the estimated noise previously.

### UNET

For predicting the noise for the diffusion reverse process, a special UNET is used that features attention in the 16x16 resolution and sinusoidal transformer embeddings in every residual block. The sinusoidal embeddings are used to encode for which time step the model is trying to predict the noise.

## UNET Implementation

### Sinusoidal Positional Encoding

In [2]:
class SinusoidalEmbeddings(nn.Module):
    def __init__(self, time_steps:int, embed_dim: int):
        super().__init__()
        position = torch.arange(time_steps).unsqueeze(1).float()
        # scaling factor for the sinusoidal embeddings
        div = torch.exp(torch.arange(0, embed_dim, 2).float() * -(math.log(10000.0) / embed_dim))
        embeddings = torch.zeros(time_steps, embed_dim, requires_grad=False)
        # sin for evend indices
        embeddings[:, 0::2] = torch.sin(position * div)
        # cos for odd indices
        embeddings[:, 1::2] = torch.cos(position * div)
        self.embeddings = embeddings

    def forward(self, x, t):
        embeds = self.embeddings[t].to(x.device)
        return embeds[:, :, None, None]

### Residual Block

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, C: int, num_groups: int, dropout_prob: float) -> None:
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.gnorm1 = nn.GroupNorm(num_groups=num_groups, num_channels=C)
        self.gnorm2 = nn.GroupNorm(num_groups=num_groups, num_channels=C)
        self.conv1 = nn.Conv2d(C, C, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(C, C, kernel_size=3, padding=1)
        self.dropout = nn.Dropout(p = dropout_prob, inplace=True)
    
    def forward(self, x: torch.Tensor, embeddings: torch.Tensor)->torch.Tensor:
        # add positional embeddings to the input
        x = x + embeddings[:, :x.shape[1], :, :]
        # output of the residual block
        r = self.conv1(self.relu(self.gnorm1(x)))
        r = self.dropout(r)
        r = self.conv2(self.relu(self.gnorm2(r)))
        return x + r # add the residual to the input

### Attention Mechanism

In [6]:
class Attention(nn.Module):
    def __init__(self, C: int, num_heads: int, dropout_prob: float):
        super().__init__()
        self.proj1 = nn.Linear(C, C * 3)
        self.proj2 = nn.Linear(C, C)
        self.num_heads = num_heads
        self.dropout_prob = dropout_prob

    def forward(self, x):
        h, w = x.shape[2:]
        x = rearrange(x, "b c h w -> b (h w ) c")
        x = self.proj1(x)
        x = rearrange(x, "b L (C H K) -> K b H L C", K=3, H=self.num_heads)
        q, k, v = x[0], x[1], x[2]
        x = F.scaled_dot_product_attention(
            q, k, v, dropout_prob=self.dropout_prob, is_causal=False
        )
        x = rearrange(x, "b H (h w) C -> b h w (C H)", h=h, w=w)
        x = self.proj2(x)
        return rearrange(x, "b h w C -> b C h w")

### UNET Layer

In [None]:
class UNETLayer(nn.Module):
    def __init__(self, 
            upscale: bool, 
            attention: bool, 
            num_groups: int, 
            dropout_prob: float,
            num_heads: int,
            C: int):
        super().__init__()
        self.resblock1 = ResidualBlock(C, num_groups, dropout_prob)
        self.resblock2 = ResidualBlock(C, num_groups, dropout_prob)
        
        if upscale:
            self.conv = nn.Conv2dTranspose2d(C, C//2, kernel_size=4, stride=2, padding=1)
        else:
            self.conv = nn.Conv2d(C, C*2, kernel_size=3, stride = 2, padding=1)
        if attention:
            self.attention = Attention(C, num_heads, dropout_prob)
    
    def forward(self, x: torch.Tensor, embeddings: torch.Tensor)->torch.Tensor:
        x = self.resblock1(x, embeddings)
        if hasattr(self, "attention"):
            x = self.attention(x)
        x = self.resblock2(x, embeddings)
        x = self.conv(x)
        return x