# Tiny-ImageNet - Let's step up the Game! üèÉ‚Äç‚û°Ô∏è

This is an implementation of an experiment using the same technique (same time- and class-conditional wrapper) as for the CIFAR-10 dataset, but this time... for tiny-imagenet (64x64)!

## 1. Prerequisites

- Imports, data, etc.

The usual thing üòâ

In [None]:
import torch
import matplotlib.pyplot as plt

from src.diffusion_playground.diffusion.noise_schedule import LinearNoiseSchedule
from src.diffusion_playground.data_loader.imagenet import load_imagenet

from src.diffusion_playground.models.backbones.cnn_denoiser_xl_attention import CNNDenoiserXLAttention

from src.diffusion_playground.models.time_and_class_conditioned_model import TimeAndClassConditionedModel
from src.diffusion_playground.training.denoiser_trainer import train_conditioned_denoiser

# Set global constants
PROJECT_DIR = "."
TIME_EMB_DIM = 128
BASE_CHANNELS = 128
CHECKPOINTS_DIR = f"{PROJECT_DIR}/checkpoints/cnn_denoiser_conditioned_xl_attention"

# Specific to the dataset
INPUT_CHANNELS = 3

# Load data
data, labels, class_idx_to_name = load_imagenet(
    split="train",
    path_data="../../../data/tiny-imagenet",
)

# Determine the number of channels
NUM_CLASSES = len(list(class_idx_to_name.keys()))
print(f"Number of classes: {NUM_CLASSES}")

# Determine device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Define noise schedule
schedule = LinearNoiseSchedule(time_steps=1_000)

## 2. Backbone ü¶¥

Create the backbone of the process

In [None]:
# Create the CNN denoiser model for RGB images
backbone = CNNDenoiserXLAttention(
    in_channels=INPUT_CHANNELS,
    base_channels=BASE_CHANNELS,
    time_emb_dim=TIME_EMB_DIM,
)
backbone.to(device)

## 3. Embed the model into Conditional Wrapper üéÅ

Condition the model on time and class

In [None]:
model = TimeAndClassConditionedModel(
    backbone_model=backbone,
    num_classes=NUM_CLASSES,
    time_emb_dim=TIME_EMB_DIM,
)
print(f"Number of parameters: {sum([p.numel() for p in model.parameters()]):,}")

## 4. Train the Model üí™

Let's send that guy to the gym!

In [None]:
# Train the model
train_conditioned_denoiser(
    model=model,
    data=data,
    labels=labels,
    noise_schedule=schedule,
    epochs=1_000,
    lr=1e-3,
    batch_size=128,
    checkpoint_dir=CHECKPOINTS_DIR,
    save_every=10,
    resume=True,
)

## 5. Inference

Load a checkpoint! üèÅ

In [None]:
from src.diffusion_playground.training.denoiser_trainer import load_checkpoint

# Load checkpoint for testing
cp_name = "checkpoint_epoch_5.pt"
checkpoint_path = f"{CHECKPOINTS_DIR}/{cp_name}"

# Load the checkpoint
checkpoint_info = load_checkpoint(model, checkpoint_path, device=device)
print(f"Loaded model trained for {checkpoint_info['epoch']} epochs")
print(f"Training loss: {checkpoint_info['loss']:.6f}")

## 6. Generate some Samples üè≠

Let's see, how good we perform...

In [None]:
from src.diffusion_playground.diffusion.backward import generate_samples_conditioned

# Setup for generation
model.eval()
num_samples = 4
class_idx = 7
class_idxes = [class_idx] * num_samples
class_name = class_idx_to_name[class_idx]

# Generate samples
images = generate_samples_conditioned(
    model=model,
    noise_schedule=schedule,
    image_shape=(3, 64, 64),
    class_labels=torch.tensor(class_idxes),
    device=device
)

# Visualize
fig, axes = plt.subplots(2, 2, figsize=(8, 8))
for idx, ax in enumerate(axes.flat):
    ax.imshow(images[idx].cpu())
    ax.axis("off")

# Title, Layout, Show
title = f"Generated Trained on Tiny-ImageNet - {cp_name}\nClass: {class_name} | Epoch: {checkpoint_info['epoch']} | Loss: {checkpoint_info['loss']:.6f}"
plt.suptitle(title, fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()