<a href="https://colab.research.google.com/github/0jg/DDPM-Toy-Example/blob/main/ddpm_cifar10_cats.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DDPM Diffusion Models: CIFAR-10 Cats Example

This notebook demonstrates Denoising Diffusion Probabilistic Models (DDPM) applied to real images — specifically cat images from the CIFAR-10 dataset. Building on the concepts from the toy example, we'll see how these same principles scale to generate 32×32 RGB images.

## What We'll Cover

We'll implement the forward diffusion process for images, then train a U-Net style neural network to reverse this process. The architecture is more sophisticated than the simple MLP used for 2D points, but the core mathematics remain identical.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from tqdm import tqdm

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Device configuration
device = torch.device('mps' if torch.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Plotting utils
plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['STIXGeneral', 'DejaVu Serif'],
    'mathtext.fontset': 'stix',
    'font.size': 12,
    'figure.dpi': 150
})

## Step 1: Load CIFAR-10 Cat Images

CIFAR-10 contains 60,000 32×32 colour images in 10 classes. Class 3 is "cat", and we'll extract all cat images for our training set. We normalize pixel values to [-1, 1] which is standard for diffusion models.

In [None]:
def load_cifar10_cats():
    """
    Load cat images from CIFAR-10.
    
    CIFAR-10 classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
    Cat is class index 3.
    
    Returns:
        Dataset of cat images normalized to [-1, 1]
    """
    # Transform: convert to tensor and normalize to [-1, 1]
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Maps [0,1] to [-1,1]
    ])
    
    # Download and load training data
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    
    # Filter for cats (class 3)
    cat_class = 3
    train_cat_indices = [i for i, (_, label) in enumerate(train_dataset) if label == cat_class]
    test_cat_indices = [i for i, (_, label) in enumerate(test_dataset) if label == cat_class]
    
    train_cats = Subset(train_dataset, train_cat_indices)
    test_cats = Subset(test_dataset, test_cat_indices)
    
    print(f"Found {len(train_cats)} cat images in training set")
    print(f"Found {len(test_cats)} cat images in test set")
    print(f"Image shape: 3 x 32 x 32 (RGB)")
    
    return train_cats, test_cats

# Load the dataset
train_cats, test_cats = load_cifar10_cats()

# Display some cat images
fig, axes = plt.subplots(2, 8, figsize=(12, 3))
for i, ax in enumerate(axes.flat):
    img, _ = train_cats[i]
    # Convert from [-1,1] to [0,1] for display
    img = (img.permute(1, 2, 0) + 1) / 2
    ax.imshow(img.clip(0, 1))
    ax.axis('off')
plt.suptitle('Sample Cat Images from CIFAR-10', fontsize=14)
plt.tight_layout()
plt.show()

## Step 2: Define the Forward Diffusion Process

The forward process is identical to the toy example — we gradually add Gaussian noise over T timesteps. The only difference is that our data is now 3×32×32 images instead of 2D points. The mathematics remain exactly the same:

$$q(x_t | x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t} x_0, (1 - \bar{\alpha}_t) I)$$

In [None]:
class DDPMSchedule:
    """
    Manages the noise schedule for the diffusion process.
    
    Same as the toy example, with parameters tuned for images.
    """
    def __init__(self, num_timesteps=1000, beta_start=1e-4, beta_end=0.02):
        self.num_timesteps = num_timesteps
        
        # Linear beta schedule
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        
        # Pre-compute useful quantities
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        
        # For q(x_t | x_0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        
        # For posterior q(x_{t-1} | x_t, x_0)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
        self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        
    def to(self, device):
        """Move all tensors to specified device."""
        self.betas = self.betas.to(device)
        self.alphas = self.alphas.to(device)
        self.alphas_cumprod = self.alphas_cumprod.to(device)
        self.alphas_cumprod_prev = self.alphas_cumprod_prev.to(device)
        self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device)
        self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
        self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(device)
        self.posterior_variance = self.posterior_variance.to(device)
        return self
    
    def q_sample(self, x_0, t, noise=None):
        """
        Sample from q(x_t | x_0) - the forward diffusion process.
        """
        if noise is None:
            noise = torch.randn_like(x_0)
        
        sqrt_alpha_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        
        return sqrt_alpha_cumprod_t * x_0 + sqrt_one_minus_alpha_cumprod_t * noise

# Create schedule
schedule = DDPMSchedule(num_timesteps=1000).to(device)
print(f"Created noise schedule with {schedule.num_timesteps} timesteps")

## Visualize the Forward Diffusion Process

Let's see what happens to a cat image as we progressively add noise. Watch how the recognizable cat gradually dissolves into pure Gaussian noise.

In [None]:
# Get a sample cat image
sample_img, _ = train_cats[0]
x_0 = sample_img.unsqueeze(0).to(device)  # Add batch dimension

# Visualize diffusion at different timesteps
timesteps_to_show = [0, 100, 250, 500, 750, 999]
fig, axes = plt.subplots(1, len(timesteps_to_show), figsize=(12, 2.5))

for ax, t in zip(axes, timesteps_to_show):
    if t == 0:
        x_t = x_0
    else:
        t_tensor = torch.tensor([t], device=device)
        x_t = schedule.q_sample(x_0, t_tensor)
    
    # Convert to displayable format
    img = x_t[0].cpu().permute(1, 2, 0)
    img = (img + 1) / 2  # [-1,1] to [0,1]
    ax.imshow(img.clip(0, 1))
    ax.set_title(f't = {t}')
    ax.axis('off')

plt.suptitle('Forward Diffusion Process: Cat → Noise', fontsize=14)
plt.tight_layout()
plt.show()

print("As t increases, the cat image gradually becomes indistinguishable from random noise.")

## Step 3: Define the U-Net Denoising Network

For images, we need a more sophisticated architecture than the simple MLP used for 2D points. The U-Net architecture is the standard choice for diffusion models because:

1. **Encoder-Decoder structure**: Captures both local details and global context
2. **Skip connections**: Preserve fine details during upsampling
3. **Multi-scale processing**: Efficiently handles spatial hierarchies in images

The network still predicts the noise $\varepsilon$ that was added, conditioned on the timestep $t$.

In [None]:
class SinusoidalPositionEmbeddings(nn.Module):
    """
    Embed the timestep using sinusoidal functions.
    Same as the toy example - allows the network to understand "how noisy" the input is.
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    
    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        embeddings = np.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = t[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


class ResidualBlock(nn.Module):
    """
    Residual block with time embedding injection.
    """
    def __init__(self, in_channels, out_channels, time_emb_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        self.norm1 = nn.GroupNorm(8, out_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)
        
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.shortcut = nn.Identity()
    
    def forward(self, x, t_emb):
        h = self.conv1(x)
        h = self.norm1(h)
        h = F.silu(h)
        
        # Add time embedding
        t_emb = self.time_mlp(t_emb)[:, :, None, None]
        h = h + t_emb
        
        h = self.conv2(h)
        h = self.norm2(h)
        h = F.silu(h)
        
        return h + self.shortcut(x)


class SimpleUNet(nn.Module):
    """
    A simplified U-Net for 32x32 images.
    
    Architecture:
    - Encoder: 32x32 → 16x16 → 8x8
    - Bottleneck: 8x8
    - Decoder: 8x8 → 16x16 → 32x32
    """
    def __init__(self, in_channels=3, out_channels=3, time_emb_dim=128):
        super().__init__()
        
        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.SiLU(),
        )
        
        # Encoder
        self.enc1 = ResidualBlock(in_channels, 64, time_emb_dim)
        self.enc2 = ResidualBlock(64, 128, time_emb_dim)
        self.enc3 = ResidualBlock(128, 256, time_emb_dim)
        
        self.pool = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = ResidualBlock(256, 256, time_emb_dim)
        
        # Decoder
        self.up3 = nn.ConvTranspose2d(256, 256, 2, stride=2)
        self.dec3 = ResidualBlock(512, 128, time_emb_dim)  # 256 + 256 from skip
        
        self.up2 = nn.ConvTranspose2d(128, 128, 2, stride=2)
        self.dec2 = ResidualBlock(256, 64, time_emb_dim)   # 128 + 128 from skip
        
        self.up1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
        self.dec1 = ResidualBlock(128, 64, time_emb_dim)   # 64 + 64 from skip
        
        # Output
        self.out = nn.Conv2d(64, out_channels, 1)
    
    def forward(self, x, t):
        # Time embedding
        t_emb = self.time_mlp(t)
        
        # Encoder path
        e1 = self.enc1(x, t_emb)           # 32x32
        e2 = self.enc2(self.pool(e1), t_emb)  # 16x16
        e3 = self.enc3(self.pool(e2), t_emb)  # 8x8
        
        # Bottleneck
        b = self.bottleneck(self.pool(e3), t_emb)  # 4x4
        
        # Decoder path with skip connections
        d3 = self.up3(b)                          # 8x8
        d3 = self.dec3(torch.cat([d3, e3], dim=1), t_emb)
        
        d2 = self.up2(d3)                         # 16x16
        d2 = self.dec2(torch.cat([d2, e2], dim=1), t_emb)
        
        d1 = self.up1(d2)                         # 32x32
        d1 = self.dec1(torch.cat([d1, e1], dim=1), t_emb)
        
        return self.out(d1)


# Create model
model = SimpleUNet().to(device)
num_params = sum(p.numel() for p in model.parameters())
print(f"Created U-Net with {num_params:,} parameters")

## Step 4: Training the Denoising Network

The training process is identical to the toy example:

1. Take a batch of images $x_0$
2. Sample random timesteps $t$
3. Sample noise $\varepsilon$ and create noisy images $x_t$
4. Predict the noise with our U-Net
5. Minimize MSE between predicted and actual noise

$$\mathcal{L} = \mathbb{E}_{x_0, \varepsilon, t}\left[\|\varepsilon - \varepsilon_\theta(x_t, t)\|^2\right]$$

In [None]:
def train_ddpm(model, dataset, schedule, num_epochs=50, batch_size=64, lr=1e-3):
    """
    Train the denoising network on images.
    """
    # Prepare data loader - extract just images (not labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    model.train()
    losses = []
    
    for epoch in tqdm(range(num_epochs), desc="Training"):
        epoch_loss = 0.0
        num_batches = 0
        
        for batch_data in dataloader:
            # Handle both (image, label) tuples and just images
            if isinstance(batch_data, (list, tuple)):
                x_0 = batch_data[0].to(device)
            else:
                x_0 = batch_data.to(device)
            
            batch_size_actual = x_0.shape[0]
            
            # Sample random timesteps
            t = torch.randint(0, schedule.num_timesteps, (batch_size_actual,), device=device)
            
            # Sample noise
            noise = torch.randn_like(x_0)
            
            # Create noisy images
            x_t = schedule.q_sample(x_0, t, noise)
            
            # Predict noise
            predicted_noise = model(x_t, t.float())
            
            # Compute loss
            loss = criterion(predicted_noise, noise)
            
            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            num_batches += 1
        
        avg_loss = epoch_loss / num_batches
        losses.append(avg_loss)
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.6f}")
    
    return losses

# Train the model
print("Starting training...")
print("Note: For best results, train for 100+ epochs. Using 50 for demonstration.")
losses = train_ddpm(model, train_cats, schedule, num_epochs=50, batch_size=64, lr=1e-3)

In [None]:
# Plot training loss
plt.figure(figsize=(8, 4))
plt.plot(losses, color='k')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True, alpha=0.3)
plt.show()

print(f"\nTraining complete! Final loss: {losses[-1]:.6f}")

## Step 5: Sampling from the Learned Model

Now we generate new cat images! The sampling process is the same as the toy example:

1. Start with pure Gaussian noise
2. For each timestep from T-1 to 0, use the model to predict and remove noise
3. The final result should be a new cat image!

$$x_{t-1} = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\varepsilon_\theta(x_t, t)\right) + \sigma_t z$$

In [None]:
@torch.no_grad()
def sample_ddpm(model, schedule, num_samples=16, image_size=32, channels=3, save_trajectory=False):
    """
    Generate images by reversing the diffusion process.
    """
    model.eval()
    
    # Start from pure noise
    x = torch.randn(num_samples, channels, image_size, image_size, device=device)
    
    trajectory = [x.cpu()] if save_trajectory else None
    
    # Reverse diffusion
    for t in tqdm(reversed(range(schedule.num_timesteps)), desc="Sampling", total=schedule.num_timesteps):
        t_batch = torch.full((num_samples,), t, device=device, dtype=torch.float)
        
        # Predict noise
        predicted_noise = model(x, t_batch)
        
        # Compute x_{t-1}
        alpha_t = schedule.alphas[t]
        alpha_cumprod_t = schedule.alphas_cumprod[t]
        beta_t = schedule.betas[t]
        
        # Mean of p(x_{t-1} | x_t)
        mean = (1 / torch.sqrt(alpha_t)) * (
            x - (beta_t / torch.sqrt(1 - alpha_cumprod_t)) * predicted_noise
        )
        
        # Add noise (except for t=0)
        if t > 0:
            noise = torch.randn_like(x)
            sigma_t = torch.sqrt(schedule.posterior_variance[t])
            x = mean + sigma_t * noise
        else:
            x = mean
        
        # Save trajectory at specific steps
        if save_trajectory and t in [999, 900, 750, 500, 250, 100, 50, 0]:
            trajectory.append(x.cpu())
    
    if save_trajectory:
        return x, trajectory
    return x

# Generate samples
print("Generating cat images...")
generated_images, trajectory = sample_ddpm(model, schedule, num_samples=16, save_trajectory=True)

## Visualize the Reverse Process

Let's watch cat images emerge from noise! This is the reverse of what we saw in the forward process visualization.

In [None]:
# Show the reverse diffusion trajectory for one sample
fig, axes = plt.subplots(1, len(trajectory), figsize=(14, 2))
timesteps_labels = ['t=999', 't=900', 't=750', 't=500', 't=250', 't=100', 't=50', 't=0']

for ax, traj, label in zip(axes, trajectory, ['t=1000'] + timesteps_labels):
    img = traj[0].permute(1, 2, 0)  # First sample
    img = (img + 1) / 2  # [-1,1] to [0,1]
    ax.imshow(img.clip(0, 1))
    ax.set_title(label)
    ax.axis('off')

plt.suptitle('Reverse Diffusion: Noise → Cat', fontsize=14)
plt.tight_layout()
plt.show()

## Generated Samples

Let's look at a grid of generated cat images and compare them to real cats from CIFAR-10.

In [None]:
fig, axes = plt.subplots(2, 8, figsize=(12, 3))

# Generated samples
for i in range(8):
    img = generated_images[i].cpu().permute(1, 2, 0)
    img = (img + 1) / 2
    axes[0, i].imshow(img.clip(0, 1))
    axes[0, i].axis('off')
    if i == 0:
        axes[0, i].set_ylabel('Generated', fontsize=12)

# Real samples for comparison
for i in range(8):
    img, _ = train_cats[i + 100]  # Different samples than shown earlier
    img = img.permute(1, 2, 0)
    img = (img + 1) / 2
    axes[1, i].imshow(img.clip(0, 1))
    axes[1, i].axis('off')
    if i == 0:
        axes[1, i].set_ylabel('Real', fontsize=12)

plt.suptitle('Generated vs Real Cat Images', fontsize=14)
plt.tight_layout()
plt.show()

print("Top row: Generated samples from our trained model")
print("Bottom row: Real cat images from CIFAR-10")
print("\nNote: With more training epochs and a larger model, quality improves significantly.")

## Key Takeaways

This notebook demonstrated DDPM applied to real images:

**Same Mathematics**: The forward process, loss function, and sampling algorithm are identical to the toy example. Only the data and network architecture changed.

**U-Net Architecture**: For images, we use a U-Net instead of an MLP. The encoder-decoder structure with skip connections is essential for preserving spatial details.

**Scale Considerations**: CIFAR-10 is still a small dataset with low resolution. State-of-the-art models use:
- Much larger U-Nets with attention mechanisms
- Higher resolutions (256×256, 512×512, or more)
- More training steps and larger batch sizes
- Techniques like classifier-free guidance for better sample quality

**From Toy to Real**: The jump from 2D points to 32×32 images demonstrates that diffusion models scale naturally to higher dimensions. The same principles apply whether generating simple shapes or photorealistic images!