# Experiment 00 Pure Diffusion
The first out of 3 related experiments. Here, I (pre-)train a diffusion model.

After this, I will use this model in experiment 3 as base-model to fine-tune with
a GAN-based optimization target.

In [None]:
# imports

# pretend we are in the root folder:
import os
import sys
sys.path.append("../")

from udl_2024_package.nn import unet_factory
from udl_2024_package.diffusion import DiffusionModel
from udl_2024_package.datasets import remove_dataset_labels, default_img_transforms

import torch
from torchvision import datasets
from torch.utils.data import DataLoader

import lightning as L
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
import wandb


Experiment configuration and hyper-parameters:

In [None]:
# WandB config:
project_name = "udl_2025_diffusion_gan"
group_name = "experiment_00_pure_diffusion"

config = {
    # Input data and training:
    "batch_size": 256,          # Warning: change to 256 for real run
    
    "dataset_cls": datasets.CIFAR10,
        "ds_name": "cifar10",   # Warning: change when changing dataset
        "num_channels": 3,      # Warning: change when changing dataset
        "img_size": 32,         # Warning: change when changing dataset
        "extra_transforms": [], # Warning: change when changing dataset

    # "dataset_cls": datasets.MNIST,
    #     "ds_name": "mnist",     # Warning: change when changing dataset
    #     "num_channels": 1,      # Warning: change when changing dataset
    #     "img_size": 32,         # Warning: change when changing dataset
    #     "extra_transforms": [   # Warning: change when changing dataset
    #         transforms.Resize(32)
    #     ],
    
    "max_epochs": 150,
    "dl_num_workers": 4,

    # Diffusion model (DDPM paper defaults):
    "optimizer_cls": torch.optim.Adam,
    "optimizer_args": {"lr": 2e-4},
    "ddpm_steps": 1000,
    "ddpm_beta_start": 0.0001,
    "ddpm_beta_end": 0.02,

    # U-Net config:
    "block_out_channels": [128, 256, 256, 256],
    "layers_per_block": 2
}

datasets_path = os.path.join(os.environ.get("TMPDIR", os.curdir), "datasets")

dataloader_kwargs = {
    "batch_size": config["batch_size"],
    "shuffle": True,
    "num_workers": config["dl_num_workers"],
    "pin_memory": True,
}

Getting the dataloaders

In [None]:
ds_transforms = default_img_transforms(config["num_channels"])

train_ds = config["dataset_cls"](datasets_path, transform=ds_transforms, download=True, train=True)
val_ds   = config["dataset_cls"](datasets_path, transform=ds_transforms, download=True, train=False)

train_ds = remove_dataset_labels(train_ds)
val_ds   = remove_dataset_labels(val_ds)

train_dl = DataLoader(train_ds, **dataloader_kwargs)
val_dl   = DataLoader(val_ds, **dataloader_kwargs)

Setting up WandB

In [None]:
run = wandb.init(
    project=project_name,
    group=group_name,
    config=config,
)

wandb_logger = WandbLogger(
    experiment=run,
)

Lightning Callback functions

In [None]:
checkpoint_callback = ModelCheckpoint(
    dirpath=f"{group_name}_models",
    filename=f"{config['ds_name']}_{run.name}_epoch{{epoch}}",
    every_n_epochs=5,
    save_top_k=-1
)

class LogImageSample(L.Callback):
    """Logs a batch of 4 images to WandB."""

    def __init__(self, logger: WandbLogger, config: dict[str, any]):
        super().__init__()
        self.logger = logger
        self.channels = config["num_channels"]
        self.img_size = config["img_size"]
    
    def on_validation_end(self, trainer, pl_module: DiffusionModel):
        sample = pl_module.sample_img(
            (4, self.channels, self.img_size, self.img_size))
        sample = sample.detach().cpu() * 0.5 + 0.5
        self.logger.log_image(
            key="generated images",
            images=[img for img in sample]
        )

log_img_callback = LogImageSample(wandb_logger, config)

Setting up the model and the trainer

In [None]:
unet = unet_factory(
    img_size=config["img_size"],
    img_channels=config["num_channels"],
    block_out_channels=config["block_out_channels"],
    layers_per_block=config["layers_per_block"]
)

model = DiffusionModel(
    unet,
    optimizer_cls=config["optimizer_cls"],
    optimizer_args=config["optimizer_args"],
    steps=config["ddpm_steps"],
    beta_start=config["ddpm_beta_start"],
    beta_end=config["ddpm_beta_end"],
)

trainer = L.Trainer(
    max_epochs=config["max_epochs"],
    logger=wandb_logger,
    callbacks=[
        checkpoint_callback,
        log_img_callback
    ],
)

Doing the actual training run

In [None]:
trainer.fit(model, train_dl, val_dl)