# MNIST Diffusion - Let's get to real images!

In [None]:
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

from src.diffusion_playground.diffusion.noise_schedule import LinearNoiseSchedule
from src.diffusion_playground.diffusion.training_utils import sample_xt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

## Load the MNIST-Dataset

In [None]:
# Load MNIST dataset directly as a tensor
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist_dataset = datasets.MNIST(root="data", train=True, transform=transform, download=True)

# Extract all images into a single tensor
mnist_data = torch.stack([mnist_dataset[i][0] for i in range(len(mnist_dataset))])
mnist_labels = torch.tensor([mnist_dataset[i][1] for i in range(len(mnist_dataset))])

print(f"Dataset shape: {mnist_data.shape}")
print(f"Labels shape: {mnist_labels.shape}")
print(f"Input shape per image: (1, 28, 28)")

## Explore the Dataset

In [None]:
# Pick a random sample to visualize
sample_idx = 5

print(f"Sample label: {mnist_labels[sample_idx].item()}")

plt.imshow(mnist_data[sample_idx][0], cmap="gray")
plt.title(f"Sample MNIST Digit: {mnist_labels[sample_idx].item()}")
plt.axis("off")
plt.show()

## Forward Diffusion on MNIST

In [None]:
# Create NoiseSchedule
schedule = LinearNoiseSchedule(time_steps=1_000)

# Use a batch from our data for visualization
x0_batch = mnist_data[:128]  # Take first 128 samples

time_steps = [0, 50, 200, 500, 999]

plt.figure(figsize=(15, 3))
for i, t in enumerate(time_steps):
    # Create a batch of time steps (one per sample in the batch)
    t_tensor = torch.full((x0_batch.shape[0],), t)
    xt, _, _ = sample_xt(x0_batch, schedule, t=t_tensor)

    plt.subplot(1, len(time_steps), i + 1)
    plt.imshow(xt[0, 0].cpu(), cmap="gray")
    plt.title(f"t={t}")
    plt.axis("off")

plt.show()

## Train the CNN Denoiser

Now let's train a UNet-style CNN to learn the reverse diffusion process on MNIST! Unlike the simple MLP we used for the toy moons dataset, this CNN architecture uses:

- **Convolutional layers** to preserve spatial structure
- **Encoder-decoder architecture** (UNet) with skip connections
- **Sinusoidal time embeddings** for better time conditioning

The beauty of our generic training pipeline is that we use the **exact same `train_denoiser` function** - only the model architecture changes!

**Note**: Training on images takes longer than the toy dataset. For a full training run, you might want to train for 50,000-100,000 epochs. For quick testing, we'll use fewer epochs here.

In [None]:
from src.diffusion_playground.models import CNNDenoiser
from src.diffusion_playground.training.denoiser_trainer import train_denoiser

# Data is already loaded as mnist_data tensor - no need to collect from DataLoader!
print(f"Training data shape: {mnist_data.shape}")

In [None]:
# Create the CNN denoiser model
model = CNNDenoiser(in_channels=1, base_channels=32, time_emb_dim=128)

# Print model summary
print(f"\nModel architecture:")
print(model)

In [None]:
# Train the model (using the same generic training function!)
# Note: Adjust epochs based on your needs (50k-100k for full training, 10k for quick test)
train_denoiser(
    model=model,
    data=mnist_data,
    noise_schedule=schedule,
    epochs=10,
    lr=1e-3,
    batch_size=128,
    checkpoint_dir="checkpoints/mnist_cnn",
    save_every=3
)

## Test the Reverse Diffusion Process

Let's test our trained model by performing the reverse diffusion process - generating images from pure noise! Even with minimal training, this will verify that all our interfaces work correctly together.

In [None]:
from src.diffusion_playground.training.denoiser_trainer import load_checkpoint

# Always load the best checkpoint for evaluation
# This allows running the evaluation independently without training first
cp_name = "checkpoint_epoch_1000.pt"
checkpoint_path = f"checkpoints/mnist_cnn/{cp_name}"

# Create a fresh model instance (in case we're running this cell standalone)
model = CNNDenoiser(in_channels=1, base_channels=32, time_emb_dim=128)
model.to(device)

# Load the best checkpoint
checkpoint_info = load_checkpoint(model, checkpoint_path, device=device)
print(f"Loaded model trained for {checkpoint_info['epoch']} epochs")
print(f"Best training loss: {checkpoint_info['loss']:.6f}")

In [None]:
# Setup for generation
model.eval()
num_samples = 16
time_steps = schedule.time_steps

# Start from pure noise
xt = torch.randn(num_samples, 1, 28, 28).to(device)

print(f"Starting reverse diffusion from noise shape: {xt.shape}")

# Reverse diffusion loop with STOCHASTIC sampling (adds noise for diversity!)
with torch.no_grad():
    for t in reversed(range(1, time_steps + 1)):
        # Create time tensor for all samples
        t_tensor = torch.full((num_samples,), t, device=device, dtype=torch.long)

        # Predict the noise
        pred_noise = model(xt, t_tensor)

        # Get schedule parameters
        beta_t = schedule.betas[t - 1]
        alpha_t = schedule.alphas[t - 1]
        alpha_bar_t = schedule.alpha_bars[t - 1]

        # Compute the mean of the reverse distribution
        mean = (xt - beta_t / torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_t)

        # Add noise (STOCHASTIC sampling for diversity!)
        if t > 1:
            # Add Gaussian noise scaled by beta
            noise = torch.randn_like(xt)
            sigma_t = torch.sqrt(beta_t)
            x_prev = mean + sigma_t * noise
        else:
            # No noise at the final step
            x_prev = mean

        # Update xt
        xt = x_prev

        # Print progress every 100 steps
        if t % 100 == 0:
            print(f"  Step {time_steps - t + 1}/{time_steps} (t={t})")

print("Reverse diffusion complete!")

# Visualize generated samples with checkpoint information
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for idx, ax in enumerate(axes.flat):
    ax.imshow(xt[idx, 0].cpu(), cmap="gray")
    ax.axis("off")

# Create informative title with checkpoint details
title = f"Generated MNIST Digits - {cp_name}\n"
title += f"Epoch: {checkpoint_info['epoch']} | Loss: {checkpoint_info['loss']:.6f}"
plt.suptitle(title, fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

## Generate Documentation Visualizations

Create and save visualizations for multiple checkpoints to use in the README documentation.

In [None]:
import os
from pathlib import Path

# Create output directory if it doesn't exist
output_dir = Path("../../../docs/mnist-cnn")
output_dir.mkdir(parents=True, exist_ok=True)

# Checkpoints to visualize
checkpoint_epochs = [1000, 10000, 50000, 79000]

for epoch in checkpoint_epochs:
    print(f"\n{'=' * 60}")
    print(f"Processing checkpoint: epoch {epoch}")
    print(f"{'=' * 60}")

    # Define checkpoint path
    cp_name = f"checkpoint_epoch_{epoch}.pt"
    checkpoint_path = f"checkpoints/mnist_cnn/{cp_name}"

    # Check if checkpoint exists
    if not os.path.exists(checkpoint_path):
        print(f"⚠️  Checkpoint not found: {checkpoint_path}")
        print(f"   Skipping...")
        continue

    # Load checkpoint
    model = CNNDenoiser(in_channels=1, base_channels=32, time_emb_dim=128)
    model.to(device)
    checkpoint_info = load_checkpoint(model, checkpoint_path, device=device)
    model.eval()

    # Generate samples (3x3 = 9 samples for documentation)
    num_samples = 9
    xt = torch.randn(num_samples, 1, 28, 28).to(device)

    print(f"Generating {num_samples} samples...")

    with torch.no_grad():
        for t in reversed(range(1, schedule.time_steps + 1)):
            t_tensor = torch.full((num_samples,), t, device=device, dtype=torch.long)
            pred_noise = model(xt, t_tensor)

            beta_t = schedule.betas[t - 1]
            alpha_t = schedule.alphas[t - 1]
            alpha_bar_t = schedule.alpha_bars[t - 1]

            mean = (xt - beta_t / torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_t)

            if t > 1:
                noise = torch.randn_like(xt)
                sigma_t = torch.sqrt(beta_t)
                xt = mean + sigma_t * noise
            else:
                xt = mean

    # Create visualization (3x3 grid)
    fig, axes = plt.subplots(3, 3, figsize=(6, 6))
    for idx, ax in enumerate(axes.flat):
        ax.imshow(xt[idx, 0].cpu(), cmap="gray")
        ax.axis("off")

    # Add informative title
    title = f"Generated MNIST Digits - {cp_name}\n"
    title += f"Epoch: {checkpoint_info['epoch']} | Loss: {checkpoint_info['loss']:.6f}"
    plt.suptitle(title, fontsize=12, fontweight='bold')
    plt.tight_layout()

    # Save figure
    output_filename = f"generated_samples_epoch_{epoch}.png"
    output_path = output_dir / output_filename
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    print(f"✓ Saved visualization to: {output_path}")

    plt.show()
    plt.close()

print(f"\n{'=' * 60}")
print(f"✓ All visualizations completed!")
print(f"{'=' * 60}")