# PyTorch Tutorial: Generative Diffusion Models

Diffusion models (like Stable Diffusion and DALL-E) generate images by learning to reverse a gradual noise process. They start with pure noise and slowly refine it into an image.

In this notebook, we will implement the core mathematics of **DDPM (Denoising Diffusion Probabilistic Models)**.

## Learning Objectives
- Understand the Forward Diffusion Process (Adding Noise)
- Understand the Reverse Diffusion Process (Denoising)
- Implement the Noise Schedule
- Build a simplified U-Net for noise prediction


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

torch.manual_seed(42)

## 1. The Forward Process (Adding Noise)

We take an image $x_0$ and add Gaussian noise over $T$ steps until it becomes pure noise $x_T$.

Formula:
$$ q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} x_{t-1}, \beta_t I) $$

We can jump directly to any step $t$:
$$ x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon $$
where $\epsilon \sim \mathcal{N}(0, I)$.

In [None]:
def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

# Define schedule
T = 200
betas = linear_beta_schedule(timesteps=T)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)

def get_index_from_list(vals, t, x_shape):
    """Helper to get value at index t and reshape to match x"""
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t, device="cpu"):
    """Takes an image and a timestep t and returns the noisy image at t"""
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(torch.sqrt(alphas_cumprod), t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(torch.sqrt(1. - alphas_cumprod), t, x_0.shape)
    
    # Mean + Variance
    return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise

# Visualize
image = torch.zeros((1, 3, 64, 64)) # Dummy black image
image[:, :, 16:48, 16:48] = 1.0 # White square in middle

plt.figure(figsize=(10, 3))
for idx, t_val in enumerate([0, 50, 100, 199]):
    t = torch.tensor([t_val])
    noisy_image, _ = forward_diffusion_sample(image, t)
    
    plt.subplot(1, 4, idx+1)
    plt.imshow(noisy_image[0].permute(1, 2, 0).clamp(0, 1))
    plt.title(f"t={t_val}")
    plt.axis('off')
plt.show()

## 2. The Reverse Process (The Model)

We need a neural network that takes a noisy image $x_t$ and the timestep $t$, and predicts the noise $\epsilon$ that was added.

We typically use a **U-Net**.

In [None]:
class SimpleUNet(nn.Module):
    """A very simplified U-Net for demonstration"""
    def __init__(self):
        super().__init__()
        # Downsample
        self.down1 = nn.Conv2d(3, 64, 3, padding=1)
        self.down2 = nn.Conv2d(64, 128, 3, padding=1)
        
        # Time embedding (simplified)
        self.time_mlp = nn.Linear(1, 128)
        
        # Upsample
        self.up1 = nn.ConvTranspose2d(128, 64, 3, padding=1)
        self.up2 = nn.ConvTranspose2d(64, 3, 3, padding=1)

    def forward(self, x, t):
        # Embed time
        t = t.float().view(-1, 1)
        t_emb = self.time_mlp(t).view(-1, 128, 1, 1)
        
        # Down
        x1 = F.relu(self.down1(x))
        x2 = F.relu(self.down2(x1))
        
        # Add time info
        x2 = x2 + t_emb
        
        # Up
        x = F.relu(self.up1(x2))
        x = self.up2(x)
        return x

model = SimpleUNet()
print("Model created!")

## 3. Training Loop

1. Sample a random image $x_0$.
2. Sample a random timestep $t$.
3. Add noise to get $x_t$.
4. Model predicts the noise: $\hat{\epsilon} = \text{Model}(x_t, t)$.
5. Loss is MSE between real noise $\epsilon$ and predicted noise $\hat{\epsilon}$.

In [None]:
# Dummy Training Step
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# 1. Get data
x_0 = torch.randn(4, 3, 64, 64) # Batch of 4 images
t = torch.randint(0, T, (4,))

# 2. Forward diffusion
x_t, noise = forward_diffusion_sample(x_0, t)

# 3. Predict noise
noise_pred = model(x_t, t)

# 4. Loss
loss = F.mse_loss(noise, noise_pred)
print(f"Loss: {loss.item()}")

## 4. Sampling (Generation)

To generate an image:
1. Start with pure noise $x_T$.
2. Loop backwards from $T$ to $0$.
3. At each step, remove a bit of noise using the model's prediction.

*(Code omitted for brevity, but involves subtracting the predicted noise and adding a small amount of random noise back for stability)*

## Key Takeaways

1. **Diffusion**: A process of slowly destroying data with noise, then learning to reverse it.
2. **Forward Process**: Fixed mathematical formula (no learning).
3. **Reverse Process**: Learned by a U-Net.
4. **Objective**: Predict the noise that was added.