## Text-to-Image Conditional Diffusion with Classifier-Free Guidance

In this notebook, we will train a **text-to-image conditional diffusion model** on a cropped version of the **TF Flowers** dataset. The goal is to generate flower images conditioned on semantic information while retaining the ability to generate unconditioned samples.

Our model combines:
- A **CLIP text encoder** for textual conditioning;
- A **U-Net–based DDPM** for image generation;
- **Classifier-Free Diffusion Guidance (CFG)** to strengthen conditioning at sampling time without requiring a separate classifier.

To this end, we train a Conditional Diffusion Model, where the image denoising network is conditioned on both the diffusion timestep $t$, embedded using a sinusoidal embedding, and a context embedding $c$ derived from text. The overall architecture follows the **Denoising Diffusion Probabilistic Models (DDPM)** framework ([Ho et al., 2020](https://arxiv.org/pdf/2006.11239)), with a modified U-Net that supports conditional inputs. As in standard diffusion models, timestep embeddings provide the model with information about the current noise level during denoising.

To condition the model, we inject text-derived embeddings, represented by the context vector $c$, into the U-Net. Sample quality is further improved via CFG, which allows us to control how strongly the conditioning influences generation at inference time. To enable this mechanism, the model is trained to operate both _with conditioning_ and _without conditioning_. This is achieved by randomly masking the context input during training using a Bernoulli distribution. During sampling, noise predictions with and without conditioning are linearly combined using a _guidance weight_, and this amplified noise estimate is then used in the reverse diffusion process.

In the following cell, we import the modules and functions defined in our `src` package, set random seeds for reproducibility, and define the parameters and file paths used throughout the notebook.

In [None]:
%load_ext autoreload
%autoreload 2

import torch
import clip
import wandb
import torchvision.transforms as transforms
import os
import random

from torch.utils.data import DataLoader, Subset

# On a multi-GPU system, this hides all GPUs except the first 
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 

# Custom modules
from handsoncv.datasets import generate_clip_metadata, TFflowersCLIPDataset
from handsoncv.models import UNet 
from handsoncv.utils import DDPM, set_seed, seed_worker
from handsoncv.training import train_diffusion

# Hardware & Paths
NOTEBOOK_DIR = os.getcwd()
PROJECT_ROOT = os.path.abspath(os.path.join(NOTEBOOK_DIR, "..", ".."))
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

# Folders we frequently use across the experiments' notebooks
ROOT_PATH = os.path.join(PROJECT_ROOT, "Assignment-3")
ROOT_DATA = os.path.join(ROOT_PATH, "data")
DATA_DIR = f"{ROOT_DATA}/cropped_flowers"
SAMPLE_DIR = f"{ROOT_DATA}/05_flowers_images"
CSV_PATH = f"{ROOT_DATA}/clip_embeddings_metadata.csv"

CHECKPOINTS_DIR = os.path.join(ROOT_PATH, "checkpoints")
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)

# Numpy and Torch Reproducibility
SEED=42
set_seed(42)

# Base Configuration Parameters
BATCH_SIZE = 128

cuda
Seeds set to 42 for reproducibility.


### Data Preparation 

In this cell, we load a pretrained `CLIP ViT-B/32` model and use it to precompute image embeddings for the entire training dataset. For each image, we extract its CLIP image embedding and store it, together with the image path, in a `.csv` file located in the data directory.

This preprocessing step is performed once and avoids repeatedly encoding images with CLIP during training, which would be computationally expensive. If the metadata file already exists, this step is skipped to ensure reproducibility and reduce unnecessary computation.

In [2]:
# Prepare Metadata (Originate clip.csv)
clip_model, clip_preprocess = clip.load("ViT-B/32", device=DEVICE)

if not os.path.exists(CSV_PATH):
    print("Generating CLIP metadata...")
    generate_clip_metadata(DATA_DIR, CSV_PATH, clip_model, clip_preprocess, DEVICE)

We train on a modified version of the **TF Flowers datase**t consisting of cropped color images focused on the flower itself. These crops reduce background noise and encourage the model to focus on shape, color and texture of flowers. The dataset was provided as part of NVIDIA’s course [Generative AI with Diffusion Models](https://learn.nvidia.com/courses/course-detail?course_id=course-v1:DLI+C-FX-08+V1) and was downloaded as a zipped archive via a shared Google Drive link.

The reduced version of the dataset contains only three classes (daisies, roses, and sunflowers) and includes roughly 1100 images in total. Given this limited dataset size, the model cannot fully capture the true generative distribution of natural flower images. To mitigate overfitting and improve generalization, we apply horizontal flipping as data augmentation technique during training.

---

All images are resized to 32×32 pixels, following NVIDIA’s reference implementation and model constraints. While this enables efficient training, it also leads to a significant loss of fine-grained spatial detail, which further limits the model’s ability to learn high-frequency structures and detailed textures.

The dataset is split into 95% training and 5% validation subsets using a fixed random seed to ensure reproducibility. The split is performed once by shuffling dataset indices and creating separate Subset instances for training and validation. 

Images are normalized to the range [-1, 1], which matches the input assumptions of the diffusion model. Training data uses both base preprocessing and augmentation, while validation data uses only deterministic base transforms. The custom `TFflowersCLIPDataset` returns a transformed RGB image tensor of shape $(3, 32, 32)$ and the corresponding CLIP embedding vector.

In [3]:
# Base transforms used by both training and validation data
base_t = [
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Lambda(lambda t: (t * 2) - 1)
]

# Training: Base + Augmentation
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5), # Augmentation present in Nvidia's notebook
    *base_t 
])

# Validation: Base only
val_transform = transforms.Compose(base_t)

# Cropped TF Flowers Data Loading
# We use a dummy dataset just to get the total count
temp_ds = TFflowersCLIPDataset(CSV_PATH)
dataset_size = len(temp_ds)
indices = list(range(dataset_size))
split = int(0.95 * dataset_size)

# Shuffle indices once
random.shuffle(indices)
train_indices, val_indices = indices[:split], indices[split:]

# Create a Generator object to pass to the dataLoaders
g = torch.Generator()
g.manual_seed(SEED)

# Create two separate Dataset Instances
train_ds = Subset(TFflowersCLIPDataset(CSV_PATH, transform=train_transform), train_indices)
val_ds = Subset(TFflowersCLIPDataset(CSV_PATH, transform=val_transform), val_indices)

# Create dataloaders
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, worker_init_fn=seed_worker, generator=g, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, generator=g)

### Training the UNet–DDPM Model

In this section, we configure and train a **conditional UNet-based DDPM** for image generation. We define the training hyperparameters, initialize the diffusion process, construct the UNet architecture, and launch the training loop. The diffusion process uses 400 timesteps with a linear noise schedule, and conditioning is provided via CLIP image embeddings.

To ensure reproducibility, random seeds are reset prior to model initialization. Training is performed using the Adam optimizer with a learning rate scheduler (`ReduceLROnPlateau`) that adaptively reduces the learning rate when validation loss stagnates. During training, CFR is enabled by randomly dropping conditioning information with a fixed probability. . Periodically, the model generates samples from a small set of text prompts to qualitatively monitor progress. All training statistics and metrics are logged to [diffusion-model-assessment-v2](https://wandb.ai/handsoncv-research/diffusion-model-assessment-v2/workspace); for reference, see the runs labeled `*-training`, especially the pinned ones.

---

#### UNet-DDPM Architecture

The model follows a UNet architecture ([Ronneberger et al., 2015](https://arxiv.org/pdf/1505.04597)) tailored for diffusion-based generative modeling and conditional image synthesis.

1. Encoder (Downsampling Path):
- An initial residual convolution block projects the input image into feature space.
- Two downsampling stages composed of:
  - Convolutional blocks with **Group Normalization** and **GELU** activations.
  - Spatial downsampling via **rearrangement-based pooling**.

2. Bottleneck:
- Operates directly on a **low-resolution spatial feature map** (e.g., 8×8) rather than flattening features into a vector.
- Optionally includes a **self-attention block** to capture long-range spatial dependencies.

3. Conditioning Mechanism:
- Diffusion timesteps are embedded using **sinusoidal positional encodings**.
- CLIP conditioning vectors are projected into spatial embeddings using learned embedding blocks.
- Conditioning is injected via **scale-and-shift modulation** during the upsampling stages.
- A Bernoulli mask enables **classifier-free training** by randomly removing conditioning information.

4. Decoder (Upsampling Path):
- Upsampling is performed using **nearest-neighbor interpolation followed by convolution**, avoiding transposed convolutions.
- Skip connections from the encoder preserve fine-grained spatial details.
- A final convolution maps features back to RGB space, producing a **noise prediction** at each diffusion timestep.

> **Note.**
> Compared to the NVIDIA reference model, this implementation preserves the bottleneck as a 2D spatial feature map instead of flattening it and reprojecting with an MLP, which helps maintain spatial coherence. An optional multi-head self-attention block is added at the bottleneck to capture global spatial dependencies. Additionally, transposed convolutions are replaced with nearest-neighbor upsampling followed by convolution to improve training stability and reduce checkerboard artifacts.


In [None]:
# Training Configuration 
EPOCHS = 200
LEARNING_RATE = 1e-4
SUBSET_SIZE = len(train_ds) + len(val_ds) 

# Initialize Model & & DDPM recommended hyperparameters 
T = 400
IMG_CH = 3
IMG_SIZE = train_loader.dataset[0][0].shape[-1]
BETAS = torch.linspace(0.0001, 0.02, T).to(DEVICE)
# For OpenAI's CLIP, c_embed_dim is stored in model.visual.output_dim
CLIP_EMBED_DIM = clip_model.visual.output_dim 

# Set Seed again for Ensuring Same Model Initialization at Every Run
set_seed(SEED)

ddpm = DDPM(BETAS, DEVICE)
model = UNet(
    T, 
    IMG_CH, 
    IMG_SIZE, 
    down_chs=(256, 256, 512), 
    t_embed_dim=8, 
    c_embed_dim=CLIP_EMBED_DIM
).to(DEVICE)
print("Num params: ", sum(p.numel() for p in model.parameters()))

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    'min', 
    patience=15, # We wait 15 epochs before cutting LR
    factor=0.5,   # We don'tcut it so aggressively
    min_lr=5e-5 # We stop the LR from dropping below 5e-5
)
BOTTLE_EMB_CHANNELS = model.down2.model[-2].model[0].out_channels

# Define list of text prompts to generate images for 
text_list = [
    "A round white daisy with a yellow center",
    "An orange sunflower with a big brown center",
    "A deep red rose flower"
]

# Initialize W&B Run
run = wandb.init(
    project="diffusion-model-assessment-v2", 
    name="ddpm_unet_training",
    config={
        "architecture": "ddpm_unet",
        "strategy": "generative_modeling_without_ema_without_selfatt_without_aug",
        "downsample_mode": "maxpool",
        "embedding_size": BOTTLE_EMB_CHANNELS,
        "learning_rate": LEARNING_RATE,
        "batch_size": BATCH_SIZE,
        "epochs": EPOCHS,
        "optimizer_type": "Adam",
        "subset_size": SUBSET_SIZE,
        "seed": SEED,
    }
)

# Execute Training
train_diffusion(
    model=model,
    ddpm=ddpm,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    epochs=EPOCHS,
    device=DEVICE,
    drop_prob=0.1,
    save_dir=CHECKPOINTS_DIR,
    sample_save_dir=SAMPLE_DIR,
    clip_model=clip_model,   # Pass the clip model for evaluation
    clip_preprocess=clip_preprocess,  # Pass the clip preprocess for evaluation
    cond_list=text_list,   # Pass the text prompts list for evaluation
    scheduler=scheduler
)

wandb.finish()

Seeds set to 42 for reproducibility.
Num params:  34122243


[34m[1mwandb[0m: Currently logged in as: [33mguarino-vanessa-emanuela[0m ([33mhandsoncv-research[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 0: Train Loss: 0.9390 | Val Loss: 0.4279
Saved samples to /home/vanessa/Documents/repos/Applied-Hands-On-Computer-Vision/Assignment-3/data/05_images/sample_ep00.png
Epoch 0: Val Loss: 0.4279 | CLIP Score: 0.1886
Saved and logged samples for epoch 0
--- Saved new best Val model to /home/vanessa/Documents/repos/Applied-Hands-On-Computer-Vision/Assignment-3/checkpoints ---
--- Saved new best CLIP model to /home/vanessa/Documents/repos/Applied-Hands-On-Computer-Vision/Assignment-3/checkpoints ---
Epoch 1: Train Loss: 0.2770 | Val Loss: 0.1984
--- Saved new best Val model to /home/vanessa/Documents/repos/Applied-Hands-On-Computer-Vision/Assignment-3/checkpoints ---
Epoch 2: Train Loss: 0.2042 | Val Loss: 0.1700
--- Saved new best Val model to /home/vanessa/Documents/repos/Applied-Hands-On-Computer-Vision/Assignment-3/checkpoints ---
Epoch 3: Train Loss: 0.1730 | Val Loss: 0.1813
Epoch 4: Train Loss: 0.1629 | Val Loss: 0.1491
--- Saved new best Val model to /home/vanessa/Documents/repo

0,1
clip_score,▁▂▃▁▃▃▃▅▅▅▆▆▇▇▆▇▆▇▆▇▇▇▇▇▇▇▆▇█▇▇▇▇█▇▆▇▆▆█
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▇▇███
epoch_time_sec,▁▁█▁▁▁▁▁▁▁▁▁█▁▁▁▁█▁█▁▁█▁▁▁█▁▁█▁▁█▁█▁█▁██
learning_rate,████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
peak_gpu_mem_mb,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▆▄▃▂▂▂▂▂▂▂▂▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▇▆▅▅▅▅▃▃▄▄▃▃▃▄▃▃▄▃▄▃▃▃▃▃▃▃▄▃▄▁▃▃▂▂▂▃▁▃▃

0,1
clip_score,0.28703
epoch,199.0
epoch_time_sec,18.99681
learning_rate,5e-05
peak_gpu_mem_mb,6310.4668
train_loss,0.06653
val_loss,0.0888
