# CNNDenoiserLarge - Let's get bigger! ðŸ’ª

This notebook implements the training procedure for the CNNDenoiserLarge - a larger version of the standard CNNDenoiser as shown (and trained) in the notebook `cnn_denoiser.ipynb`.

In [None]:
import torch
import matplotlib.pyplot as plt

from src.diffusion_playground.diffusion.backward import generate_samples
from src.diffusion_playground.diffusion.noise_schedule import LinearNoiseSchedule
from src.diffusion_playground.evaluation.image_generation_results import generate_samples_from_checkpoints
from src.diffusion_playground.data_loader.cifar_10_dataset import load_cifar_10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

## Load the Dataset

Load the dataset and apply transformation ToTensor and Normalize

In [None]:
cifar_data, cifar_labels = load_cifar_10(download=False)

## Create the CNNDenoiserLarge Model and Noise Schedule

Create the larger model class and a linear noise schedule.

In [None]:
from src.diffusion_playground.models import CNNDenoiserLarge

# Create the CNN denoiser model for RGB images
model = CNNDenoiserLarge(
    in_channels=3,
    base_channels=128,
    time_emb_dim=128
)
model.to(device)

# Create the noise schedule
schedule = LinearNoiseSchedule(time_steps=1_000)

# Print model summary
print(f"\nModel architecture:")
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

## Train the Model

Train the model on the CIFAR dataset.

In [None]:
from src.diffusion_playground.training.denoiser_trainer import train_denoiser

# Train the model
train_denoiser(
    model=model,
    data=cifar_data,
    noise_schedule=schedule,
    epochs=100_000,
    lr=1e-3,
    batch_size=128,
    checkpoint_dir="checkpoints/cnn_denoiser_large",
    save_every=1_000,
    resume=True,
)

## Reverse Diffusion

### Load a checkpoint

Load an arbitrary checkpoint, .pt file must be present locally.

In [None]:
from src.diffusion_playground.training.denoiser_trainer import load_checkpoint

# Load checkpoint for testing
cp_name = "checkpoint_epoch_30000.pt"
checkpoint_path = f"checkpoints/cnn_denoiser_large/{cp_name}"

# Load the checkpoint
checkpoint_info = load_checkpoint(model, checkpoint_path, device=device)
print(f"Loaded model trained for {checkpoint_info['epoch']} epochs")
print(f"Training loss: {checkpoint_info['loss']:.6f}")

### Generate Samples in-line

Show a few generated samples in this cell.

In [None]:
# Setup for generation
model.eval()
num_samples = 9

# Generate samples
images = generate_samples(
    model=model,
    noise_schedule=schedule,
    image_shape=(3, 32, 32),
    num_samples=num_samples,
    device=device
)

# Visualize
fig, axes = plt.subplots(3, 3, figsize=(8, 8))
for idx, ax in enumerate(axes.flat):
    ax.imshow(images[idx].cpu())
    ax.axis("off")

# Title, Layout, Show
title = f"Generated CIFAR-10 Images - {cp_name}\nEpoch: {checkpoint_info['epoch']} | Loss: {checkpoint_info['loss']:.6f}"
plt.suptitle(title, fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Generate the samples
generate_samples_from_checkpoints(
    model=model,
    model_name="53M",
    device="cpu",
    checkpoint_epochs=[1000, 25000, 50000, 75000, 100000],
    checkpoint_dir="./checkpoints/cnn_denoiser_large",
    output_dir="../../../docs/cifar-10-cnn/cnn-denoiser-large",
    noise_schedule=schedule,
    image_shape=(3, 32, 32),
)