# Denoising Diffusion Probabilistic Models (DDPM) from Scratch

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/diffusion_ddpm.ipynb)

## 1. Mathematical Foundations

Diffusion models are latent variable models of the form $p_\theta(x_0) := \int p_\theta(x_{0:T}) dx_{1:T}$, where $x_1, \dots, x_T$ are latents of the same dimensionality as the data $x_0$.

### Forward Process ($q$)
The forward process (diffusion) is a fixed Markov chain that gradually adds Gaussian noise to the data according to a variance schedule $\beta_1, \dots, \beta_T$:

$$ q(x_{1:T}|x_0) := \prod_{t=1}^T q(x_t|x_{t-1}) $$

$$ q(x_t|x_{t-1}) := \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t \mathbf{I}) $$

A nice property is that we can sample $x_t$ at any arbitrary time step $t$ in closed form. Let $\alpha_t = 1 - \beta_t$ and $\bar{\alpha}_t = \prod_{s=1}^t \alpha_s$:

$$ q(x_t|x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t} x_0, (1 - \bar{\alpha}_t)\mathbf{I}) $$

### Reverse Process ($p_\theta$)
The reverse process is a learned Markov chain defined as:

$$ p_\theta(x_{0:T}) := p(x_T) \prod_{t=1}^T p_\theta(x_{t-1}|x_t) $$

$$ p_\theta(x_{t-1}|x_t) := \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t)) $$

We train a model $\mu_\theta$ to approximate the posterior mean. In practice, we predict the noise $\epsilon_\theta(x_t, t)$.

### Training Objective (Simplified ELBO)
Ho et al. (2020) showed that the variational lower bound (ELBO) can be simplified to a weighted squared error between the actual noise $\epsilon$ and the predicted noise $\epsilon_\theta$:

$$ L_{\text{simple}}(\theta) = \mathbb{E}_{t, x_0, \epsilon} \left[ \| \epsilon - \epsilon_\theta(\sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon, t) \|^2 \right] $$

In [None]:
!pip install torch torchvision matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 2. Noise Schedule & Forward Process Implementation

In [None]:
class Diffusion:
    def __init__(self, T=1000, beta_start=1e-4, beta_end=0.02, img_size=28, device=device):
        self.T = T
        self.img_size = img_size
        self.device = device
        
        self.beta = torch.linspace(beta_start, beta_end, T).to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        epsilon = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * epsilon, epsilon

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.T, size=(n,), device=self.device)

### Visualization: Forward Process

In [None]:
def plot_forward_process(diffusion):
    # Load one image
    dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())
    x0, _ = dataset[0]
    x0 = x0.unsqueeze(0).to(device)
    
    plt.figure(figsize=(15, 3))
    num_steps = 10
    step_size = diffusion.T // num_steps
    
    for i in range(num_steps):
        t = torch.tensor([i * step_size]).to(device)
        xt, _ = diffusion.noise_images(x0, t)
        
        plt.subplot(1, num_steps, i + 1)
        plt.imshow(xt.cpu().squeeze(), cmap='gray')
        plt.title(f"t={t.item()}")
        plt.axis('off')
    plt.suptitle("Forward Diffusion Process: $q(x_t|x_0)$", fontsize=16)
    plt.show()

curr_diffusion = Diffusion(device=device)
plot_forward_process(curr_diffusion)

## 3. Simple U-Net (Noise Predictor)

In [None]:
class SimpleUNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Simplified U-Net for MNIST
        self.down1 = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2))
        self.down2 = nn.Sequential(nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2))
        
        self.time_embed = nn.Sequential(nn.Linear(1, 64), nn.ReLU(), nn.Linear(64, 64))
        
        self.up1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.up_conv1 = nn.Sequential(nn.Conv2d(64 + 64, 32, 3, padding=1), nn.ReLU())
        self.up2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.up_conv2 = nn.Sequential(nn.Conv2d(32, 32, 3, padding=1), nn.ReLU())
        
        self.out = nn.Conv2d(32, 1, 3, padding=1)

    def forward(self, x, t):
        # Time embedding
        t = t.float().view(-1, 1)
        t_emb = self.time_embed(t)[:, :, None, None]
        
        # Down
        x1 = self.down1(x)
        x2 = self.down2(x1)
        
        # Inject time
        x2 = x2 + t_emb
        
        # Up
        x = self.up1(x2)
        x = torch.cat([x, x1], dim=1)
        x = self.up_conv1(x)
        
        x = self.up2(x)
        x = self.up_conv2(x)
        
        return self.out(x)

## 4. Training with Loss Visualization

In [None]:
dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

model = SimpleUNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

loss_history = []

 epochs = 3
print("Starting Training...")
for epoch in range(epochs):
    epoch_loss = 0
    for images, _ in dataloader:
        images = images.to(device)
        t = curr_diffusion.sample_timesteps(images.shape[0])
        x_t, noise = curr_diffusion.noise_images(images, t)
        
        predicted_noise = model(x_t, t)
        loss = loss_fn(noise, predicted_noise)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss/len(dataloader)
    loss_history.append(avg_loss)
    print(f"Epoch {epoch}: Loss {avg_loss:.4f}")

# Plot Loss
plt.figure(figsize=(8, 4))
plt.plot(loss_history, marker='o')
plt.title("Training Loss (MSE between $\epsilon$ and $\epsilon_\theta$)")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.show()

## 5. Sampling (Reverse Process)

In [None]:
def sample(model, n, diffusion):
    model.eval()
    with torch.no_grad():
        x = torch.randn((n, 1, 28, 28)).to(device)
        for i in reversed(range(1, diffusion.T)):
            t = (torch.ones(n) * i).long().to(device)
            predicted_noise = model(x, t)
            
            alpha = diffusion.alpha[t][:, None, None, None]
            alpha_hat = diffusion.alpha_hat[t][:, None, None, None]
            beta = diffusion.beta[t][:, None, None, None]
            
            if i > 1:
                noise = torch.randn_like(x)
            else:
                noise = torch.zeros_like(x)
                
            x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
            
    model.train()
    return x.clamp(0, 1)

generated_images = sample(model, 8, curr_diffusion)

plt.figure(figsize=(12, 2))
for i in range(8):
    plt.subplot(1, 8, i+1)
    plt.imshow(generated_images[i].cpu().squeeze(), cmap='gray')
    plt.axis('off')
plt.suptitle("Generated Samples via Reverse Diffusion", fontsize=16)
plt.show()