# CNNDenoiser on CIFAR-10 ðŸš€

This notebook implements the training and inference (generation) process using the CNN-based U-Net style architecture CNNDenoiser.

## Prerequisites

- Imports
- Set global constants
- Load data
- Set device
- Define noise schedule
- Create noise schedule

In [None]:
import torch
import matplotlib.pyplot as plt

from src.diffusion_playground.diffusion.noise_schedule import LinearNoiseSchedule
from src.diffusion_playground.data_loader.cifar_10_dataset import load_cifar_10

from src.diffusion_playground.models.time_and_class_conditioned_model import TimeAndClassConditionedModel
from src.diffusion_playground.training.denoiser_trainer import train_conditioned_denoiser

# Set global constants
TIME_EMB_DIM = 128
BASE_CHANNELS = 128
PROJECT_DIR = "."
CHECKPOINTS_DIR = f"{PROJECT_DIR}/checkpoints/cnn_denoiser"

# Specific to CIFAR-10 dataset
NUM_CLASSES = 10
INPUT_CHANNELS = 3

# Load data
cifar_data, cifar_labels, class_idx_to_name = load_cifar_10()

# Determine device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Define noise schedule
schedule = LinearNoiseSchedule(time_steps=1_000)

## Create the Model Backbone

Define which model to use as the backbone of the diffusion training / process. The backbone is embedded into a wrapper to perform the time- as well as potential other conditioning to use (e.g. label conditioning for controlled generation).

***Note: The backbone model can be adapted here!***

In [None]:
from src.diffusion_playground.models.backbones.cnn_denoiser import CNNDenoiser
from src.diffusion_playground.models.backbones.cnn_denoiser_large import CNNDenoiserLarge
from src.diffusion_playground.models.backbones.cnn_denoiser_large_attention import CNNDenoiserLargeAttention

# Create the CNN denoiser model for RGB images
backbone = CNNDenoiser(
    in_channels=INPUT_CHANNELS,
    base_channels=BASE_CHANNELS,
    time_emb_dim=TIME_EMB_DIM,
)
backbone.to(device)

## Create the Diffusion Wrapper

This wrapper defines the conditioning of the model, e.g. only time- or also label-conditioning for controlled generation.

In [None]:
model = TimeAndClassConditionedModel(
    backbone_model=backbone,
    num_classes=NUM_CLASSES,
    time_emb_dim=TIME_EMB_DIM,
)

# Print model parameters
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

## Create the MetricsTracker Instance

Create the class instance to periodically save metrics and sample results.

In [None]:
from src.diffusion_playground.evaluation.metrics_tracker import MetricsTracker

IMAGE_SHAPE = (32, 32, 3)
EVAL_CLASSES = [i for i in range(10)]

metrics_tracker = MetricsTracker(
    checkpoint_dir=CHECKPOINTS_DIR,
    noise_schedule=schedule,
    image_shape=IMAGE_SHAPE,
    eval_classes=EVAL_CLASSES,
    real_images=cifar_data,
    num_samples_per_class=4,
    num_fid_samples=256,
    device=device,
    class_names=class_idx_to_name,
)

## Train the Model

Train the model on the CIFAR dataset.

In [None]:
# Train the model
train_conditioned_denoiser(
    model=model,
    data=cifar_data,
    labels=cifar_labels,
    noise_schedule=schedule,
    epochs=1_000,
    lr=1e-3,
    batch_size=128,
    checkpoint_dir=CHECKPOINTS_DIR,
    save_every=1,
    resume=True,
    metrics_tracker=metrics_tracker,
)

## Inference

### Load a checkpoint

Load an arbitrary checkpoint, .pt file must be present locally.

In [None]:
from src.diffusion_playground.training.denoiser_trainer import load_checkpoint

# Load checkpoint for testing
cp_name = "checkpoint_epoch_400.pt"
checkpoint_path = f"{CHECKPOINTS_DIR}/{cp_name}"

# Load the checkpoint
checkpoint_info = load_checkpoint(model, checkpoint_path, device=device)
print(f"Loaded model trained for {checkpoint_info['epoch']} epochs")
print(f"Training loss: {checkpoint_info['loss']:.6f}")

### Generate Samples in-line

Show a few generated samples in this cell.

CIFAR-10 class names:

0: Airplane

1: Automobile

2: Bird

3: Cat

4: Deer

5: Dog

6: Frog

7: Horse

8: Ship

9: Truck

In [None]:
from src.diffusion_playground.diffusion.backward import generate_samples_conditioned

# Setup for generation
model.eval()
num_samples = 4
class_idx = 7
class_idxes = [class_idx] * num_samples
class_name = class_idx_to_name[class_idx]

# Generate samples
images = generate_samples_conditioned(
    model=model,
    noise_schedule=schedule,
    image_shape=(3, 32, 32),
    class_labels=torch.tensor(class_idxes),
    device=device
)

# Visualize
fig, axes = plt.subplots(2, 2, figsize=(8, 8))
for idx, ax in enumerate(axes.flat):
    ax.imshow(images[idx].cpu())
    ax.axis("off")

# Title, Layout, Show
title = f"Generated CIFAR-10 Images - {cp_name}\nClass: {class_name} | Epoch: {checkpoint_info['epoch']} | Loss: {checkpoint_info['loss']:.6f}"
plt.suptitle(title, fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()