# Lab 2.6.1: Diffusion Model Theory - From Noise to Art

**Module:** 2.6 - Diffusion Models  
**Time:** 2 hours  
**Difficulty:** ‚≠ê‚≠ê‚≠ê (Intermediate)

---

## üéØ Learning Objectives

By the end of this notebook, you will:
- [ ] Understand the mathematical foundation of diffusion models
- [ ] Implement forward diffusion (adding noise) from scratch
- [ ] Visualize noise schedules and their effects
- [ ] Build a simple U-Net for denoising
- [ ] Implement reverse diffusion to generate new images
- [ ] Train a DDPM on MNIST and generate digits!

---

## üìö Prerequisites

- Completed: Module 2.5 (Hugging Face Ecosystem)
- Knowledge of: PyTorch basics, neural network fundamentals
- Math comfort: Basic probability, normal distributions
- **Required packages:**
  - `torch>=2.0.0`
  - `torchvision>=0.15.0`
  - `matplotlib>=3.7.0`
  - `tqdm`

---

## üåç Real-World Context

**Diffusion models power the AI art revolution!**

- **Midjourney** uses diffusion to create stunning artwork
- **DALL-E 3** generates images from text descriptions
- **Stable Diffusion** is the open-source king of image generation
- **Sora** uses diffusion principles for video generation

Understanding how diffusion works gives you the foundation to use, customize, and even build these systems.

---

## üßí ELI5: What is Diffusion?

> **Imagine you have a beautiful photo of a cat.** üê±
>
> Now imagine slowly adding TV static to it - a tiny bit at first, then more and more.
> After 1000 steps of adding static, your cat photo is completely unrecognizable -
> it's just pure noise, like a TV with no signal.
>
> **That's the "forward" process - destroying the image with noise.**
>
> Now here's the magic: What if we could train a neural network to **reverse this process**?
> 
> If we show the AI many examples of "noisy image at step 500" ‚Üí "slightly less noisy image at step 499",
> it learns to remove noise one step at a time.
>
> **The "reverse" process - starting from pure noise and gradually revealing an image!**
>
> The amazing part? Once trained, we can start with *random* noise and the model will
> denoise it into a *new* image that looks like the training data - a brand new cat! üê±‚ú®

### The Key Insight

```
Forward:  Real Image  ‚îÄ‚îÄ‚îÄ[add noise]‚îÄ‚îÄ‚îÄ>  Pure Noise    (easy, just math)
                                              ‚îÇ
Reverse:  New Image   <‚îÄ‚îÄ‚îÄ[remove noise]‚îÄ‚îÄ‚îÄ‚îÄ  ‚îÇ          (learned by neural network)
```

---

## Part 1: Setting Up Our Environment

Let's import everything we need and check our DGX Spark GPU.

In [None]:
# Core imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import math

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    total_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"Total Memory: {total_mem:.1f} GB")
    print(f"\nDGX Spark's 128GB unified memory = room for all experiments! üöÄ")
    
    # Use bfloat16 for DGX Spark's Blackwell architecture
    dtype = torch.bfloat16
else:
    dtype = torch.float32

print(f"\nUsing dtype: {dtype}")

---

## Part 2: Understanding the Forward Process

### The Math Behind Adding Noise

In diffusion models, we add Gaussian noise to images according to a **schedule**.

At each timestep $t$, we have:

$$x_t = \sqrt{\bar{\alpha}_t} \cdot x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon$$

Where:
- $x_0$ = Original clean image
- $x_t$ = Noisy image at timestep $t$  
- $\epsilon$ = Random Gaussian noise (same shape as image)
- $\bar{\alpha}_t$ = Cumulative product of $(1 - \beta_i)$ from $i=1$ to $t$
- $\beta_t$ = Noise schedule (how much noise to add at step $t$)

### üßí ELI5: The Noise Schedule

> Think of $\beta$ as a "noise dial" that goes from 0 to 1:
> - $\beta = 0$: No noise added (pure signal)
> - $\beta = 1$: All noise (no signal)
>
> The schedule controls how fast we turn this dial. A **linear schedule** turns it
> steadily. A **cosine schedule** turns it slowly at first, then faster in the middle,
> then slowly again at the end - like an S-curve.
>
> Cosine usually works better because it preserves image structure longer.

In [None]:
class NoiseScheduler:
    """
    Manages the noise schedule for diffusion models.
    
    This class handles:
    - Computing beta (noise variance) at each timestep
    - Computing alpha (signal preservation) at each timestep
    - Adding noise to images (forward process)
    - Computing loss weights for training
    """
    
    def __init__(
        self, 
        num_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        schedule_type: str = "cosine"
    ):
        self.num_timesteps = num_timesteps
        
        # Compute beta schedule
        if schedule_type == "linear":
            self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        elif schedule_type == "cosine":
            # Cosine schedule from "Improved DDPM" paper
            self.betas = self._cosine_schedule(num_timesteps)
        else:
            raise ValueError(f"Unknown schedule: {schedule_type}")
        
        # Compute derived values
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        
        # For sampling (reverse process)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        
        # Posterior variance for sampling
        self.posterior_variance = (
            self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
    
    def _cosine_schedule(self, num_timesteps: int, s: float = 0.008):
        """Cosine schedule as in 'Improved DDPM' paper."""
        steps = num_timesteps + 1
        x = torch.linspace(0, num_timesteps, steps)
        alphas_cumprod = torch.cos(((x / num_timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        return torch.clip(betas, 0.0001, 0.9999)
    
    def add_noise(self, x_0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor = None):
        """
        Add noise to images according to the forward process.
        
        x_t = sqrt(alpha_cumprod_t) * x_0 + sqrt(1 - alpha_cumprod_t) * noise
        
        Args:
            x_0: Clean images, shape (B, C, H, W)
            t: Timesteps, shape (B,)
            noise: Optional pre-generated noise
            
        Returns:
            Noisy images x_t and the noise that was added
        """
        if noise is None:
            noise = torch.randn_like(x_0)
        
        # Get coefficients for each sample in batch
        sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1).to(x_0.device)
        sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1).to(x_0.device)
        
        # Forward diffusion
        x_t = sqrt_alpha * x_0 + sqrt_one_minus_alpha * noise
        
        return x_t, noise
    
    def to(self, device):
        """Move all tensors to specified device."""
        self.betas = self.betas.to(device)
        self.alphas = self.alphas.to(device)
        self.alphas_cumprod = self.alphas_cumprod.to(device)
        self.alphas_cumprod_prev = self.alphas_cumprod_prev.to(device)
        self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device)
        self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
        self.posterior_variance = self.posterior_variance.to(device)
        return self


# Create scheduler
scheduler = NoiseScheduler(num_timesteps=1000, schedule_type="cosine")
print(f"Created noise scheduler with {scheduler.num_timesteps} timesteps")
print(f"Beta range: {scheduler.betas[0]:.6f} to {scheduler.betas[-1]:.6f}")
print(f"Alpha cumprod range: {scheduler.alphas_cumprod[-1]:.6f} to {scheduler.alphas_cumprod[0]:.6f}")

### Visualizing the Noise Schedules

Let's compare linear vs cosine schedules to understand why cosine is preferred.

In [None]:
# Create both schedulers for comparison
linear_scheduler = NoiseScheduler(num_timesteps=1000, schedule_type="linear")
cosine_scheduler = NoiseScheduler(num_timesteps=1000, schedule_type="cosine")

# Plot comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

timesteps = np.arange(1000)

# Plot 1: Beta values (noise added at each step)
axes[0].plot(timesteps, linear_scheduler.betas.numpy(), label='Linear', alpha=0.8)
axes[0].plot(timesteps, cosine_scheduler.betas.numpy(), label='Cosine', alpha=0.8)
axes[0].set_xlabel('Timestep')
axes[0].set_ylabel('Beta (noise variance)')
axes[0].set_title('Noise Added Per Step')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot 2: Alpha cumprod (signal remaining)
axes[1].plot(timesteps, linear_scheduler.alphas_cumprod.numpy(), label='Linear', alpha=0.8)
axes[1].plot(timesteps, cosine_scheduler.alphas_cumprod.numpy(), label='Cosine', alpha=0.8)
axes[1].set_xlabel('Timestep')
axes[1].set_ylabel('Alpha Cumprod (signal remaining)')
axes[1].set_title('Signal Preservation Over Time')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Plot 3: SNR (Signal-to-Noise Ratio)
linear_snr = linear_scheduler.alphas_cumprod / (1 - linear_scheduler.alphas_cumprod + 1e-8)
cosine_snr = cosine_scheduler.alphas_cumprod / (1 - cosine_scheduler.alphas_cumprod + 1e-8)
axes[2].semilogy(timesteps, linear_snr.numpy(), label='Linear', alpha=0.8)
axes[2].semilogy(timesteps, cosine_snr.numpy(), label='Cosine', alpha=0.8)
axes[2].set_xlabel('Timestep')
axes[2].set_ylabel('SNR (log scale)')
axes[2].set_title('Signal-to-Noise Ratio')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nüìä Key Observations:")
print("- Cosine schedule preserves signal longer at the start (better for learning)")
print("- Linear schedule destroys information too quickly in early steps")
print("- Both converge to pure noise by t=1000")

---

## Part 3: Visualizing the Forward Process

Let's load MNIST and watch an image get progressively noisier!

In [None]:
# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

train_dataset = datasets.MNIST(
    root='./data', 
    train=True, 
    download=True, 
    transform=transform
)

# Get a sample image
sample_image, label = train_dataset[0]
sample_image = sample_image.unsqueeze(0)  # Add batch dimension

print(f"Dataset size: {len(train_dataset)}")
print(f"Image shape: {sample_image.shape}")
print(f"Label: {label}")
print(f"Pixel range: [{sample_image.min():.2f}, {sample_image.max():.2f}]")

In [None]:
def visualize_forward_diffusion(image, scheduler, timesteps_to_show):
    """
    Visualize the forward diffusion process at multiple timesteps.
    
    This shows how an image progressively becomes noise.
    """
    n_steps = len(timesteps_to_show)
    fig, axes = plt.subplots(1, n_steps, figsize=(2.5 * n_steps, 3))
    
    # Use the same noise for all timesteps (to see progression clearly)
    noise = torch.randn_like(image)
    
    for idx, t in enumerate(timesteps_to_show):
        t_tensor = torch.tensor([t])
        noisy_image, _ = scheduler.add_noise(image, t_tensor, noise)
        
        # Convert for display
        img_display = noisy_image.squeeze().numpy()
        img_display = (img_display + 1) / 2  # [-1, 1] -> [0, 1]
        
        axes[idx].imshow(img_display, cmap='gray', vmin=0, vmax=1)
        axes[idx].set_title(f't = {t}')
        axes[idx].axis('off')
        
        # Calculate and show signal/noise ratio
        alpha = scheduler.alphas_cumprod[t].item()
        axes[idx].set_xlabel(f'Signal: {alpha*100:.1f}%', fontsize=9)
    
    plt.suptitle('Forward Diffusion: Adding Noise Over Time', fontsize=14, y=1.02)
    plt.tight_layout()
    plt.show()

# Visualize at different timesteps
timesteps_to_show = [0, 50, 100, 250, 500, 750, 900, 999]
visualize_forward_diffusion(sample_image, cosine_scheduler, timesteps_to_show)

print("\nüîç Notice how:")
print("  - t=0: Original image (100% signal)")
print("  - t=250: Still recognizable, but fuzzy")
print("  - t=500: Barely recognizable")
print("  - t=999: Pure noise (almost 0% signal)")

### ‚úã Try It Yourself: Experiment with Different Images

Try visualizing the forward process on different MNIST digits.

<details>
<summary>üí° Hint</summary>

```python
# Get a different digit
different_image, different_label = train_dataset[42]  # or any index
different_image = different_image.unsqueeze(0)
visualize_forward_diffusion(different_image, cosine_scheduler, timesteps_to_show)
```
</details>

In [None]:
# YOUR CODE HERE: Try different images from the dataset
# Experiment with different indices to see different digits



---

## Part 4: Building the U-Net Denoiser

### üßí ELI5: Why U-Net?

> The U-Net is like a smart photo editor:
> 
> 1. **Encoder** (going down): "Zoom out" to understand the big picture
>    - What kind of digit is this? Where are the main strokes?
> 
> 2. **Bottleneck** (the bottom): Process the high-level understanding
> 
> 3. **Decoder** (going up): "Zoom in" to fill in details
>    - Now that I know it's a "7", let me sharpen the edges!
> 
> 4. **Skip connections**: Let the decoder see the original noisy pixels
>    - "Here's exactly what you're working with at each scale"
> 
> The model also gets told "what timestep is this?" so it knows how much
> noise it's dealing with. Removing 1% noise is different from removing 50%!

### Architecture Overview

```
Input (noisy image) ‚îÄ‚îê
                     ‚îÇ
    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
    ‚îÇ   [Conv Block] ‚îÇ + Timestep     ‚îÇ
    ‚îÇ        ‚Üì       ‚îÇ   Embedding    ‚îÇ
    ‚îÇ   [Conv Block]‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚Üí‚îÇ Skip
    ‚îÇ        ‚Üì       ‚îÇ                ‚îÇ
    ‚îÇ   [Downsample] ‚îÇ                ‚îÇ
    ‚îÇ        ‚Üì       ‚îÇ                ‚îÇ
    ‚îÇ   [Conv Block]‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚Üí‚îÇ Skip
    ‚îÇ        ‚Üì       ‚îÇ                ‚îÇ
    ‚îÇ   [Bottleneck] ‚îÇ                ‚îÇ
    ‚îÇ        ‚Üì       ‚îÇ                ‚îÇ
    ‚îÇ   [Upsample]   ‚îÇ                ‚îÇ
    ‚îÇ        ‚Üì       ‚îÇ                ‚îÇ
    ‚îÇ   [Conv Block]‚Üê‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò Concat
    ‚îÇ        ‚Üì       ‚îÇ
    ‚îÇ   [Upsample]   ‚îÇ
    ‚îÇ        ‚Üì       ‚îÇ
    ‚îÇ   [Conv Block]‚Üê‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò Concat
    ‚îÇ        ‚Üì       ‚îÇ
    ‚îî‚îÄ‚îÄ‚îÄ>[Output]‚îÄ‚îÄ‚îÄ‚îÄ‚îò
            ‚îÇ
            ‚Üì
    Predicted Noise
```

In [None]:
def get_timestep_embedding(timesteps, embedding_dim):
    """
    Create sinusoidal timestep embeddings.
    
    This gives the model a unique "fingerprint" for each timestep,
    allowing it to know how noisy the input is.
    
    Similar to positional embeddings in Transformers!
    """
    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, device=timesteps.device) * -emb)
    emb = timesteps[:, None].float() * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:
        emb = F.pad(emb, (0, 1))
    return emb


class ResidualBlock(nn.Module):
    """
    Residual block with timestep conditioning.
    
    The timestep embedding is added to allow the network to behave
    differently depending on the noise level.
    """
    
    def __init__(self, in_channels, out_channels, time_emb_dim):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)
        
        # Project timestep embedding to channel dimension
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        
        # Skip connection (if channels change)
        if in_channels != out_channels:
            self.skip = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.skip = nn.Identity()
    
    def forward(self, x, t_emb):
        # First conv
        h = self.conv1(x)
        h = self.norm1(h)
        h = F.silu(h)  # SiLU activation (smoother than ReLU)
        
        # Add timestep embedding
        t = self.time_mlp(t_emb)[:, :, None, None]  # (B, C) -> (B, C, 1, 1)
        h = h + t
        
        # Second conv
        h = self.conv2(h)
        h = self.norm2(h)
        h = F.silu(h)
        
        # Skip connection
        return h + self.skip(x)


class SimpleUNet(nn.Module):
    """
    A simplified U-Net for MNIST diffusion.
    
    This is a minimal implementation for educational purposes.
    Production models (like Stable Diffusion) are much larger!
    """
    
    def __init__(
        self, 
        in_channels: int = 1,
        out_channels: int = 1,
        base_channels: int = 64,
        time_emb_dim: int = 128
    ):
        super().__init__()
        
        self.time_emb_dim = time_emb_dim
        
        # Time embedding MLP
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, time_emb_dim * 4),
            nn.SiLU(),
            nn.Linear(time_emb_dim * 4, time_emb_dim),
        )
        
        # Encoder (downsampling path)
        self.enc1 = ResidualBlock(in_channels, base_channels, time_emb_dim)
        self.enc2 = ResidualBlock(base_channels, base_channels * 2, time_emb_dim)
        self.enc3 = ResidualBlock(base_channels * 2, base_channels * 4, time_emb_dim)
        
        self.pool = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = ResidualBlock(base_channels * 4, base_channels * 4, time_emb_dim)
        
        # Decoder (upsampling path)
        self.up3 = nn.ConvTranspose2d(base_channels * 4, base_channels * 4, 2, stride=2)
        self.dec3 = ResidualBlock(base_channels * 8, base_channels * 2, time_emb_dim)  # *8 due to concat
        
        self.up2 = nn.ConvTranspose2d(base_channels * 2, base_channels * 2, 2, stride=2)
        self.dec2 = ResidualBlock(base_channels * 4, base_channels, time_emb_dim)
        
        self.up1 = nn.ConvTranspose2d(base_channels, base_channels, 2, stride=2)
        self.dec1 = ResidualBlock(base_channels * 2, base_channels, time_emb_dim)
        
        # Output projection
        self.out_conv = nn.Conv2d(base_channels, out_channels, 1)
    
    def forward(self, x, t):
        """
        Forward pass.
        
        Args:
            x: Noisy image, shape (B, C, H, W)
            t: Timesteps, shape (B,)
            
        Returns:
            Predicted noise, shape (B, C, H, W)
        """
        # Get timestep embedding
        t_emb = get_timestep_embedding(t, self.time_emb_dim)
        t_emb = self.time_mlp(t_emb)
        
        # Encoder
        e1 = self.enc1(x, t_emb)      # (B, 64, 28, 28)
        e2 = self.enc2(self.pool(e1), t_emb)  # (B, 128, 14, 14)
        e3 = self.enc3(self.pool(e2), t_emb)  # (B, 256, 7, 7)
        
        # Bottleneck
        b = self.bottleneck(self.pool(e3), t_emb)  # (B, 256, 3, 3)
        
        # Decoder with skip connections
        # Pad if needed to match encoder dimensions
        d3 = self.up3(b)  # (B, 256, 6, 6) -> need (B, 256, 7, 7)
        d3 = F.interpolate(d3, size=e3.shape[2:])  # Match encoder size
        d3 = self.dec3(torch.cat([d3, e3], dim=1), t_emb)  # (B, 128, 7, 7)
        
        d2 = self.up2(d3)
        d2 = F.interpolate(d2, size=e2.shape[2:])
        d2 = self.dec2(torch.cat([d2, e2], dim=1), t_emb)  # (B, 64, 14, 14)
        
        d1 = self.up1(d2)
        d1 = F.interpolate(d1, size=e1.shape[2:])
        d1 = self.dec1(torch.cat([d1, e1], dim=1), t_emb)  # (B, 64, 28, 28)
        
        # Output
        return self.out_conv(d1)  # (B, 1, 28, 28)


# Test the model
model = SimpleUNet(in_channels=1, out_channels=1, base_channels=64).to(device)

# Count parameters
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,} ({n_params/1e6:.2f}M)")

# Test forward pass
test_x = torch.randn(4, 1, 28, 28).to(device)
test_t = torch.randint(0, 1000, (4,)).to(device)
test_out = model(test_x, test_t)
print(f"Input shape: {test_x.shape}")
print(f"Output shape: {test_out.shape}")
print("Model forward pass successful!")

---

## Part 5: Training the Diffusion Model

### The Training Objective

We train the model to **predict the noise** that was added to an image:

$$\mathcal{L} = \mathbb{E}_{x_0, \epsilon, t} \left[ \|\epsilon - \epsilon_\theta(x_t, t)\|^2 \right]$$

Where:
- $\epsilon$ = The actual noise we added
- $\epsilon_\theta(x_t, t)$ = The noise our model predicts

### üßí ELI5: Why Predict Noise?

> It's like a game of "spot the difference":
> 
> 1. We take a clean picture
> 2. We add some static (noise) to it - we know exactly what static we added
> 3. We show the noisy picture to the model and ask: "What static do you see?"
> 4. The model guesses the static
> 5. We compare its guess to the real static and train it to be more accurate
>
> After training, the model becomes really good at seeing "what doesn't belong" in an image!

In [None]:
def train_one_epoch(model, dataloader, optimizer, scheduler, device):
    """
    Train for one epoch.
    
    For each batch:
    1. Sample random timesteps
    2. Add noise to images
    3. Predict the noise
    4. Compute MSE loss
    5. Backpropagate
    """
    model.train()
    total_loss = 0
    
    pbar = tqdm(dataloader, desc="Training")
    for images, _ in pbar:
        images = images.to(device)
        batch_size = images.shape[0]
        
        # Sample random timesteps for each image
        t = torch.randint(0, scheduler.num_timesteps, (batch_size,), device=device)
        
        # Add noise
        noisy_images, noise = scheduler.add_noise(images, t)
        
        # Predict the noise
        noise_pred = model(noisy_images, t)
        
        # Compute loss (simple MSE)
        loss = F.mse_loss(noise_pred, noise)
        
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / len(dataloader)


# Setup training
batch_size = 128
learning_rate = 3e-4
num_epochs = 10  # Increase for better results!

train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

model = SimpleUNet(in_channels=1, out_channels=1, base_channels=64).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = NoiseScheduler(num_timesteps=1000, schedule_type="cosine").to(device)

print(f"Training configuration:")
print(f"  Batch size: {batch_size}")
print(f"  Learning rate: {learning_rate}")
print(f"  Epochs: {num_epochs}")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Batches per epoch: {len(train_loader)}")

In [None]:
# Train the model!
print("Starting training...")
print("(This will take a few minutes on DGX Spark)\n")

losses = []
for epoch in range(num_epochs):
    avg_loss = train_one_epoch(model, train_loader, optimizer, scheduler, device)
    losses.append(avg_loss)
    print(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {avg_loss:.4f}")

# Plot training curve
plt.figure(figsize=(10, 4))
plt.plot(losses, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Time')
plt.grid(True, alpha=0.3)
plt.show()

print(f"\n Training complete! Final loss: {losses[-1]:.4f}")

---

## Part 6: Sampling (The Reverse Process)

Now for the magic - generating NEW digits from pure noise!

### The Sampling Algorithm (DDPM)

Starting from pure noise $x_T \sim \mathcal{N}(0, I)$:

For $t = T, T-1, ..., 1$:
1. Predict the noise: $\hat{\epsilon} = \epsilon_\theta(x_t, t)$
2. Estimate the clean image: $\hat{x}_0 = \frac{x_t - \sqrt{1-\bar{\alpha}_t} \cdot \hat{\epsilon}}{\sqrt{\bar{\alpha}_t}}$
3. Compute the previous step: $x_{t-1} = \mu_\theta(x_t, \hat{x}_0, t) + \sigma_t \cdot z$ where $z \sim \mathcal{N}(0, I)$

### üßí ELI5: The Reverse Process

> Imagine you're an artist who's really good at "cleaning up" blurry photos.
> 
> 1. Someone gives you a picture of pure static (random noise)
> 2. You squint at it and think "hmm, this COULD be a '7' if I clean it up"
> 3. You remove a tiny bit of static - now it's slightly less noisy
> 4. You look again: "yes, definitely looking more like a '7'!"
> 5. You repeat 999 more times, each time removing a little noise
> 6. At the end: a crisp, clear digit '7'!
>
> The model learned to "imagine" what's under the static.

In [None]:
@torch.no_grad()
def sample(model, scheduler, num_samples=16, image_size=28, num_channels=1, device='cuda'):
    """
    Generate samples using the trained model.
    
    This implements the DDPM sampling algorithm.
    """
    model.eval()
    
    # Start from pure noise
    x = torch.randn(num_samples, num_channels, image_size, image_size, device=device)
    
    # Store intermediate steps for visualization
    intermediates = [x.cpu().clone()]
    
    # Reverse diffusion
    for t in tqdm(reversed(range(scheduler.num_timesteps)), desc="Sampling", total=scheduler.num_timesteps):
        t_batch = torch.full((num_samples,), t, device=device, dtype=torch.long)
        
        # Predict noise
        noise_pred = model(x, t_batch)
        
        # Get coefficients
        alpha = scheduler.alphas[t]
        alpha_cumprod = scheduler.alphas_cumprod[t]
        alpha_cumprod_prev = scheduler.alphas_cumprod_prev[t]
        beta = scheduler.betas[t]
        
        # Predict x_0
        x0_pred = (x - torch.sqrt(1 - alpha_cumprod) * noise_pred) / torch.sqrt(alpha_cumprod)
        x0_pred = torch.clamp(x0_pred, -1, 1)  # Clip for stability
        
        # Compute posterior mean
        posterior_mean = (
            beta * torch.sqrt(alpha_cumprod_prev) / (1 - alpha_cumprod) * x0_pred +
            (1 - alpha_cumprod_prev) * torch.sqrt(alpha) / (1 - alpha_cumprod) * x
        )
        
        # Add noise (except for t=0)
        if t > 0:
            noise = torch.randn_like(x)
            posterior_variance = scheduler.posterior_variance[t]
            x = posterior_mean + torch.sqrt(posterior_variance) * noise
        else:
            x = posterior_mean
        
        # Store intermediate (every 100 steps)
        if t % 100 == 0:
            intermediates.append(x.cpu().clone())
    
    return x, intermediates


# Generate samples!
print("Generating new digits from noise...")
samples, intermediates = sample(model, scheduler, num_samples=16, device=device)
print(f"Generated {samples.shape[0]} samples!")

In [None]:
# Display generated samples
def show_samples(samples, title="Generated Samples"):
    """Display a grid of samples."""
    samples = samples.cpu()
    # Denormalize from [-1, 1] to [0, 1]
    samples = (samples + 1) / 2
    samples = torch.clamp(samples, 0, 1)
    
    grid = make_grid(samples, nrow=4, padding=2, normalize=False)
    
    plt.figure(figsize=(8, 8))
    plt.imshow(grid.permute(1, 2, 0).numpy(), cmap='gray')
    plt.title(title, fontsize=14)
    plt.axis('off')
    plt.show()

show_samples(samples, "Generated MNIST Digits")

print("\nüéâ Congratulations! You just generated images using a diffusion model you trained!")
print("\nTips for better results:")
print("  - Train for more epochs (20-50)")
print("  - Use a larger model (increase base_channels)")
print("  - Try different learning rates")

In [None]:
# Visualize the reverse process step by step
def visualize_reverse_process(intermediates):
    """Show how an image emerges from noise."""
    n_steps = len(intermediates)
    fig, axes = plt.subplots(1, n_steps, figsize=(2 * n_steps, 2.5))
    
    # Take the first sample from each intermediate
    for i, inter in enumerate(intermediates):
        img = inter[0].squeeze().numpy()
        img = (img + 1) / 2  # Denormalize
        img = np.clip(img, 0, 1)
        
        t = 1000 - (i * 100)
        axes[i].imshow(img, cmap='gray')
        axes[i].set_title(f't={max(t, 0)}')
        axes[i].axis('off')
    
    plt.suptitle('Reverse Diffusion: From Noise to Image', fontsize=14, y=1.02)
    plt.tight_layout()
    plt.show()

visualize_reverse_process(intermediates)

print("\nüîç Watch how the digit emerges from pure noise!")
print("   This is the reverse of the forward process we saw earlier.")

---

## ‚ö†Ô∏è Common Mistakes

### Mistake 1: Wrong Normalization Range

```python
# ‚ùå Wrong: Images in [0, 1]
transform = transforms.Compose([
    transforms.ToTensor(),  # Gives [0, 1]
])

# ‚úÖ Right: Images in [-1, 1]
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Centers at 0
])
```
**Why:** Noise is centered at 0, so images should be too!

---

### Mistake 2: Forgetting to Clip During Sampling

```python
# ‚ùå Wrong: No clipping
x0_pred = (x - sqrt_1m_alpha * noise_pred) / sqrt_alpha

# ‚úÖ Right: Clip to valid range
x0_pred = (x - sqrt_1m_alpha * noise_pred) / sqrt_alpha
x0_pred = torch.clamp(x0_pred, -1, 1)
```
**Why:** Without clipping, predictions can explode and ruin generation!

---

### Mistake 3: Training with Very High Loss

If your loss stays above 0.5:
- Check your timestep embedding (is it being used?)
- Verify the noise schedule (alphas_cumprod should go from 1‚Üí0)
- Make sure images and noise have the same shape

---

### Mistake 4: Generating Blurry Images

- Train for more epochs!
- Increase model capacity (more channels)
- Check if using the correct number of sampling steps

---

## üéâ Checkpoint

You've learned:
- ‚úÖ How forward diffusion adds noise to images
- ‚úÖ Different noise schedules (linear vs cosine)
- ‚úÖ The U-Net architecture for denoising
- ‚úÖ How to train a model to predict noise
- ‚úÖ The reverse diffusion sampling process
- ‚úÖ How to generate new images from random noise!

---

## üöÄ Challenge (Optional)

### Challenge 1: Class-Conditional Generation
Modify the model to generate specific digits by adding class conditioning!

<details>
<summary>üí° Hint</summary>

Add an embedding layer for class labels, similar to the timestep embedding:
```python
self.class_emb = nn.Embedding(10, time_emb_dim)  # 10 digits

def forward(self, x, t, class_labels):
    t_emb = get_timestep_embedding(t, self.time_emb_dim)
    t_emb = self.time_mlp(t_emb)
    c_emb = self.class_emb(class_labels)
    emb = t_emb + c_emb  # Combine!
    ...
```
</details>

### Challenge 2: Faster Sampling with DDIM
Implement DDIM sampling to generate images in 50 steps instead of 1000!

### Challenge 3: Try Fashion-MNIST
Adapt your model to generate clothing items from Fashion-MNIST!

---

## üìñ Further Reading

- [DDPM Paper](https://arxiv.org/abs/2006.11239) - "Denoising Diffusion Probabilistic Models"
- [Improved DDPM](https://arxiv.org/abs/2102.09672) - Cosine schedule and other improvements
- [DDIM Paper](https://arxiv.org/abs/2010.02502) - Faster sampling
- [Lilian Weng's Blog](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/) - Excellent explanation
- [The Annotated Diffusion Model](https://huggingface.co/blog/annotated-diffusion) - Code walkthrough

---

## üßπ Cleanup

In [None]:
# Save the trained model (optional)
from pathlib import Path

# Create output directory if it doesn't exist
output_dir = Path("./model_checkpoints")
output_dir.mkdir(parents=True, exist_ok=True)

model_path = output_dir / "mnist_diffusion_model.pt"
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'losses': losses,
}, model_path)
print(f"Model saved to {model_path}")

# Clear GPU memory
import gc
torch.cuda.empty_cache()
gc.collect()
print("GPU memory cleared")

---

## Next Steps

Now that you understand diffusion theory, proceed to:

**Lab 2.6.2: Stable Diffusion Generation** - Learn to use production-grade diffusion models for text-to-image generation!

You'll learn:
- Loading SDXL on DGX Spark
- Prompt engineering techniques
- Guidance scale and its effects
- Generating stunning images!