# Denoising Diffusion Probabilistic Models (DDPM) from Scratch

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/diffusion_ddpm.ipynb)

Diffusion models generate data by reversing a gradual noising process.

Key Concepts:
1. **Forward Process ($q$):** Gradually add Gaussian noise to an image until it becomes pure noise.
2. **Reverse Process ($p_\theta$):** Learn a neural network (U-Net) to predict the noise added at each step, effectively denoising the image.

$$ x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right) + \sigma_t z $$

In [None]:
!pip install torch torchvision matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Noise Schedule & Forward Process

In [None]:
class Diffusion:
    def __init__(self, T=1000, beta_start=1e-4, beta_end=0.02, img_size=28, device=device):
        self.T = T
        self.img_size = img_size
        self.device = device
        
        self.beta = torch.linspace(beta_start, beta_end, T).to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        epsilon = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * epsilon, epsilon

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.T, size=(n,), device=self.device)

## 2. Simple U-Net (Noise Predictor)

In [None]:
class SimpleUNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Simplified U-Net for MNIST
        self.down1 = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2))
        self.down2 = nn.Sequential(nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2))
        
        self.time_embed = nn.Sequential(nn.Linear(1, 64), nn.ReLU(), nn.Linear(64, 64))
        
        self.up1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.up_conv1 = nn.Sequential(nn.Conv2d(64 + 64, 32, 3, padding=1), nn.ReLU())
        self.up2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.up_conv2 = nn.Sequential(nn.Conv2d(32, 32, 3, padding=1), nn.ReLU())
        
        self.out = nn.Conv2d(32, 1, 3, padding=1)

    def forward(self, x, t):
        # Time embedding
        t = t.float().view(-1, 1)
        t_emb = self.time_embed(t)[:, :, None, None]
        
        # Down
        x1 = self.down1(x)
        x2 = self.down2(x1)
        
        # Inject time
        x2 = x2 + t_emb
        
        # Up
        x = self.up1(x2)
        # Concatenate skip connection (resize due to pooling arithmetic)
        # Here we just pad for simplicity or ensure dimensions match. 
        # For 28x28: down1->14x14, down2->7x7. up1->14x14 (matches x1). up2->28x28.
        x = torch.cat([x, x1], dim=1)
        x = self.up_conv1(x)
        
        x = self.up2(x)
        x = self.up_conv2(x)
        
        return self.out(x)

## 3. Training Loop

In [None]:
dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

model = SimpleUNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
curr_diffusion = Diffusion(device=device)

loss_fn = nn.MSELoss()

for epoch in range(3): # Reduced epochs for demo
    epoch_loss = 0
    for images, _ in dataloader:
        images = images.to(device)
        t = curr_diffusion.sample_timesteps(images.shape[0])
        x_t, noise = curr_diffusion.noise_images(images, t)
        
        predicted_noise = model(x_t, t)
        loss = loss_fn(noise, predicted_noise)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    
    print(f"Epoch {epoch}: Loss {epoch_loss/len(dataloader):.4f}")

## 4. Sampling (Generation)

In [None]:
def sample(model, n, diffusion):
    model.eval()
    with torch.no_grad():
        x = torch.randn((n, 1, 28, 28)).to(device)
        for i in reversed(range(1, diffusion.T)):
            t = (torch.ones(n) * i).long().to(device)
            predicted_noise = model(x, t)
            
            alpha = diffusion.alpha[t][:, None, None, None]
            alpha_hat = diffusion.alpha_hat[t][:, None, None, None]
            beta = diffusion.beta[t][:, None, None, None]
            
            if i > 1:
                noise = torch.randn_like(x)
            else:
                noise = torch.zeros_like(x)
                
            x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
            
    model.train()
    return x.clamp(0, 1)

generated_images = sample(model, 8, curr_diffusion)

fig, ax = plt.subplots(1, 8, figsize=(12, 2))
for i in range(8):
    ax[i].imshow(generated_images[i].cpu().squeeze(), cmap='gray')
    ax[i].axis('off')
plt.show()