# Experiment 02B WGAN fine-tuning of Diffusion U-Net
The last out of 3 related experiments. Here, I fine-tune the U-Net from
experiment 00 as if it is a generator in a GAN setup, to see if it converges
faster than the GAN in experiment 01.

In [None]:
# imports

# pretend we are in the root folder:
import os
import sys
sys.path.append("../")

from udl_2024_package.nn import unet_factory, SimpleDiscriminator
from udl_2024_package.diffusion import DiffusionModel
from udl_2024_package.wgan import WGANWithGradientPenalty
from udl_2024_package.datasets import remove_dataset_labels, default_img_transforms

import torch
from torchvision import datasets
from torch.utils.data import DataLoader

import lightning as L
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
import wandb


Experiment configuration and hyper-parameters:

In [None]:
# WandB config:
project_name = "udl_2025_diffusion_gan"
group_name = "experiment_02b_wgan_finetune"

config = {
    # Input data and training:
    "batch_size": 256,          # Warning: change to 256 for real run
    
    "dataset_cls": datasets.CIFAR10,
        "ds_name": "cifar10",   # Warning: change when changing dataset
        "num_channels": 3,      # Warning: change when changing dataset
        "img_size": 32,         # Warning: change when changing dataset
        "extra_transforms": [], # Warning: change when changing dataset

    # "dataset_cls": datasets.MNIST,
    #     "ds_name": "mnist",     # Warning: change when changing dataset
    #     "num_channels": 1,      # Warning: change when changing dataset
    #     "img_size": 32,         # Warning: change when changing dataset
    #     "extra_transforms": [   # Warning: change when changing dataset
    #         transforms.Resize(32)
    #     ],
    
    "max_epochs": 150,
    "dl_num_workers": 4,

    # WGAN model (WGAN-GP paper defaults):
    "optimizer_cls": torch.optim.Adam,
    "gen_optimizer_args": { "lr": 1e-5, "betas": (0.5, 0.99), "weight_decay": 1e-5 },
    "cri_optimizer_args": { "lr": 1e-4, "betas": (0.5, 0.99) },
    "gp_weight": 10,
    "critic_iterations": 5,
    "freeze_gen_upnet": True,

    # DDPM values (not really used, but for loading checkpoint):
    "ddpm_steps": 1000,
    "ddpm_beta_start": 0.0001,
    "ddpm_beta_end": 0.02,

    # U-Net config:
    "block_out_channels": [128, 256, 256, 256],
    "layers_per_block": 2,

    # Starting point for training:
    "unet_checkpoint": "./cifar10_peachy-totem-7_epochepoch=94.ckpt",

    # Critic config:
    "cri_channel_list": [128] * 4,
    "cri_kernel_list": [3] * 4,
    "cri_downsample_list": [True] * 2 + [False] * 2
}

datasets_path = os.path.join(os.environ.get("TMPDIR", os.curdir), "datasets")

dataloader_kwargs = {
    "batch_size": config["batch_size"],
    "shuffle": True,
    "num_workers": config["dl_num_workers"],
    "pin_memory": True,
}

Getting the dataloaders

In [None]:
ds_transforms = default_img_transforms(config["num_channels"])

train_ds = config["dataset_cls"](datasets_path, transform=ds_transforms, download=True, train=True)
val_ds   = config["dataset_cls"](datasets_path, transform=ds_transforms, download=True, train=False)

train_ds = remove_dataset_labels(train_ds)
val_ds   = remove_dataset_labels(val_ds)

train_dl = DataLoader(train_ds, **dataloader_kwargs)
val_dl   = DataLoader(val_ds, **dataloader_kwargs)

Function for generating samples with the WGAN generator. Note that:
- It needs to give a time value (`ts`) to the model, because it is actually
  meant for diffusion. I always give 999, corresponding to the pure
  noise timestep.
- In contrast to the variant in notebook 02A, this one is identical to the one
  in notebook 01: we treat the model output directly as the generator output.
  See 02A for what I tried first, and why it didn't work.

In [None]:
def generate_samples(generator, real_batch):
    noise = torch.randn_like(real_batch)
    ts = torch.full((len(real_batch),), 999, device=real_batch.device)
    return generator(noise, ts)

Setting up WandB

In [None]:
run = wandb.init(
    project=project_name,
    group=group_name,
    config=config,
    # mode="disabled" # TODO: remove this!
)

wandb_logger = WandbLogger(
    experiment=run,
)

Lightning Callback functions

In [None]:
checkpoint_callback = ModelCheckpoint(
    dirpath=f"{group_name}_models",
    filename=f"{config['ds_name']}_{run.name}_epoch{{epoch}}",
    every_n_epochs=10,
    save_top_k=-1
)

# Copied from 01, instead of 02A:
class LogImageSample(L.Callback):
    """Logs a batch of 4 images to WandB."""

    def __init__(self, logger: WandbLogger, config: dict[str, any]):
        super().__init__()
        batch_size = 6
        diffusion_step = 999    # typically the last step for diffusion model
        self.logger = logger
        self.noise_sample = torch.randn((
            batch_size,
            config["num_channels"],
            config["img_size"],
            config["img_size"])
        )
        self.ts = torch.full((batch_size,), diffusion_step)
    
    def on_validation_end(self, trainer, pl_module: WGANWithGradientPenalty):
        generator = pl_module.gen
        if self.noise_sample.device != pl_module.device:
            self.noise_sample = self.noise_sample.to(pl_module.device)
            self.ts = self.ts.to(pl_module.device)
        
        sample = generator(self.noise_sample, self.ts)
        sample = sample.detach().cpu() * 0.5 + 0.5
        self.logger.log_image(
            key="generated images",
            images=[img for img in sample]
        )

log_img_callback = LogImageSample(wandb_logger, config)

Setting up the models and the trainer

In [None]:
unet = unet_factory(
    img_size=config["img_size"],
    img_channels=config["num_channels"],
    block_out_channels=config["block_out_channels"],
    layers_per_block=config["layers_per_block"]
)
unet = DiffusionModel.load_from_checkpoint(
    config["unet_checkpoint"],
    model=unet,
    optimizer_cls=config["optimizer_cls"],
    optimizer_args=config["cri_optimizer_args"],
    steps=config["ddpm_steps"],
    beta_start=config["ddpm_beta_start"],
    beta_end=config["ddpm_beta_end"],
).model

# Optionally freeze part of the network
if config["freeze_gen_upnet"]:
    for param in unet.down_blocks.parameters():
        param.requires_grad = False
    for param in unet.mid_block.parameters():
        param.requires_grad = False

critic = SimpleDiscriminator(
    in_channels=config["num_channels"],
    channel_list=config["cri_channel_list"],
    kernel_list=config["cri_kernel_list"],
    downsample_list=config["cri_downsample_list"]
)

model = WGANWithGradientPenalty(
    generator=unet,
    critic=critic,
    generator_func=generate_samples,
    optimizer_cls=config["optimizer_cls"],
    gen_optimizer_args=config["gen_optimizer_args"],
    cri_optimizer_args=config["cri_optimizer_args"],
    gp_weight=config["gp_weight"],
    critic_iterations=config["critic_iterations"],
)

wandb_logger.watch(model, log="all", log_freq=500)

In [None]:
trainer = L.Trainer(
    max_epochs=config["max_epochs"],
    logger=wandb_logger,
    callbacks=[
        checkpoint_callback,
        log_img_callback
    ],
    enable_progress_bar=False
)

Doing the actual training run

In [None]:
trainer.fit(model, train_dl, val_dl)
wandb.finish(0)