# CIFAR-10 Diffusion - RGB Images Here We Come! üöÄ

Time to level up from grayscale MNIST to full-color CIFAR-10! This notebook demonstrates our generic diffusion pipeline working on **32√ó32 RGB images** with minimal changes to the architecture.

## Load the CIFAR-10 Dataset

CIFAR-10 contains 50,000 training images across 10 classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck. Each image is **32√ó32 pixels with 3 color channels (RGB)**.

In [None]:
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

from src.diffusion_playground.diffusion.backward import generate_samples
from src.diffusion_playground.diffusion.noise_schedule import LinearNoiseSchedule
from src.diffusion_playground.diffusion.training_utils import sample_xt
from src.diffusion_playground.evaluation.image_generation_results import generate_samples_from_checkpoints

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

## Explore the Dataset

In [None]:
# Load CIFAR-10 dataset directly as a tensor
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize RGB channels to [-1, 1]
])

cifar_dataset = datasets.CIFAR10(root="data", train=True, transform=transform, download=False)

# Extract all images into a single tensor
cifar_data = torch.stack([cifar_dataset[i][0] for i in range(len(cifar_dataset))])
cifar_labels = torch.tensor([cifar_dataset[i][1] for i in range(len(cifar_dataset))])

print(f"Dataset shape: {cifar_data.shape}")
print(f"Labels shape: {cifar_labels.shape}")
print(f"Input shape per image: (3, 32, 32)")
print(f"\nClasses: {cifar_dataset.classes}")

In [None]:
# Visualize a few random samples
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
for idx, ax in enumerate(axes.flat):
    sample_idx = torch.randint(0, len(cifar_data), (1,)).item()

    # Convert from [-1, 1] to [0, 1] for visualization
    img = (cifar_data[sample_idx].permute(1, 2, 0) + 1) / 2
    img = torch.clamp(img, 0, 1)  # Ensure values are in valid range

    ax.imshow(img.cpu())
    ax.set_title(f"{cifar_dataset.classes[cifar_labels[sample_idx]]}")
    ax.axis("off")

plt.suptitle("Sample CIFAR-10 Images", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Forward Diffusion on CIFAR-10

Let's visualize how images progressively become noise during the forward diffusion process. We'll see the same image at different time steps from clean (t=0) to pure noise (t=999).

In [None]:
# Create NoiseSchedule
schedule = LinearNoiseSchedule(time_steps=1_000)

# Use a batch from our data for visualization
x0_batch = cifar_data[:128]  # Take first 128 samples

time_steps = [0, 50, 200, 500, 999]

plt.figure(figsize=(15, 3))
for i, t in enumerate(time_steps):
    # Create a batch of time steps (one per sample in the batch)
    t_tensor = torch.full((x0_batch.shape[0],), t)
    xt, _, _ = sample_xt(x0_batch, schedule, t=t_tensor)

    # Convert from [-1, 1] to [0, 1] for visualization
    img = (xt[0].permute(1, 2, 0) + 1) / 2
    img = torch.clamp(img, 0, 1)

    plt.subplot(1, len(time_steps), i + 1)
    plt.imshow(img.cpu())
    plt.title(f"t={t}")
    plt.axis("off")

plt.suptitle("Forward Diffusion Process (Clean ‚Üí Noise)", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Train the CNN Denoiser

Now let's train our UNet-style CNN to learn the reverse diffusion process on CIFAR-10! 

**Key differences from MNIST:**
- **Input channels**: 3 (RGB) instead of 1 (grayscale)
- **Base channels**: 64 instead of 32 (more capacity for complex color images)
- **Image size**: 32√ó32 instead of 28√ó28

Everything else stays the same - same training pipeline, same loss function, same noise schedule! üéØ

In [None]:
from src.diffusion_playground.models import CNNDenoiser
from src.diffusion_playground.training.denoiser_trainer import train_denoiser

# Data is already loaded as cifar_data tensor!
print(f"Training data shape: {cifar_data.shape}")

In [None]:
# Create the CNN denoiser model for RGB images
model = CNNDenoiser(
    in_channels=3,  # RGB (3 channels) instead of grayscale (1 channel)
    base_channels=64,  # More capacity than MNIST (was 32)
    time_emb_dim=128  # Same time embedding dimension
)

# Print model summary
print(f"\nModel architecture:")
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Train the model (using the same generic training function!)
# Note: CIFAR-10 is more complex than MNIST, may need longer training
# The training will automatically RESUME from the latest checkpoint if interrupted!
train_denoiser(
    model=model,
    data=cifar_data,
    noise_schedule=schedule,
    epochs=100_000,
    lr=1e-3,
    batch_size=128,
    checkpoint_dir="checkpoints/cifar10_cnn",
    save_every=1_000,
    resume=True,
)

### üîÑ Auto-Resume Feature

The training will **automatically resume** from the latest checkpoint if interrupted! This is perfect for:
- ‚è∞ Google Colab sessions that time out
- üîå Unexpected disconnections
- üõë Manual interruptions

Just re-run the training cell and it will pick up where it left off. No manual checkpoint management needed!

## Test the Reverse Diffusion Process

Let's test our trained model by generating RGB images from pure noise! This will verify that all interfaces work correctly for CIFAR-10 before we move to long training on Google Colab.

In [None]:
from src.diffusion_playground.training.denoiser_trainer import load_checkpoint

# Load checkpoint for testing
cp_name = "checkpoint_epoch_90000.pt"
checkpoint_path = f"checkpoints/cifar10_cnn/{cp_name}"

# Create a fresh model instance
model = CNNDenoiser(in_channels=3, base_channels=64, time_emb_dim=128)
model.to(device)

# Load the checkpoint
checkpoint_info = load_checkpoint(model, checkpoint_path, device=device)
print(f"Loaded model trained for {checkpoint_info['epoch']} epochs")
print(f"Training loss: {checkpoint_info['loss']:.6f}")

In [None]:
# Setup for generation
model.eval()
num_samples = 9

images = generate_samples(
    model=model,
    noise_schedule=schedule,
    image_shape=(3, 32, 32),
    num_samples=num_samples,
    device=device
)

# Visualize generated RGB samples
fig, axes = plt.subplots(3, 3, figsize=(8, 8))
for idx, ax in enumerate(axes.flat):
    # Convert from [-1, 1] to [0, 1] for visualization
    img = images[idx]

    ax.imshow(img.cpu())
    ax.axis("off")

# Create informative title with checkpoint details
title = f"Generated CIFAR-10 Images - {cp_name}\n"
title += f"Epoch: {checkpoint_info['epoch']} | Loss: {checkpoint_info['loss']:.6f}"
plt.suptitle(title, fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

## Generate Documentation Visualizations

Create and save visualizations for multiple checkpoints to use in the README documentation.

In [None]:
# Generate the samples
generate_samples_from_checkpoints(
    model=model,
    dataset_name_trained_on="CIFAR-10",
    device="cpu",
    checkpoint_epochs=[1000, 25000, 50000, 75000],
    checkpoint_dir="./checkpoints/cifar10_cnn",
    output_dir="../../../docs/cifar-10-cnn",
    noise_schedule=schedule,
    image_shape=(3, 32, 32),
)