# CelebA-HQ Latent Diffusion Training with Classifier-Free Guidance

This notebook trains a text-conditioned diffusion model on CelebA-HQ using **latent diffusion** (operates in VAE latent space instead of pixel space for faster training).

## Key Differences from Pixel-Space Training:
1. **VAE Encoding**: Images are encoded to 32√ó32√ó4 latents before training
2. **Faster Training**: 8x smaller spatial dimensions = much faster
3. **Latent U-Net**: 4-channel input/output (latent channels)
4. **Same CFG**: Classifier-free guidance works the same way

## Training Configuration:
- **Image size**: 256√ó256 ‚Üí 32√ó32 latents (8x compression)
- **Model**: Custom UNet2DConditionModel (latent space)
- **VAE**: Pretrained SD VAE (frozen, `stabilityai/sd-vae-ft-mse`)
- **Text encoder**: CLIP ViT-B/32 (512-dim embeddings)
- **CFG**: 10% unconditional dropout
- **Batch size**: 32
- **Epochs**: 100

In [1]:
# Setup: Define project root and add to path
from pathlib import Path
import sys

PROJECT_ROOT = Path("/home/doshlom4/work/final_project")
sys.path.insert(0, str(PROJECT_ROOT))
print(f"Project root: {PROJECT_ROOT}")

Project root: /home/doshlom4/work/final_project


In [2]:
# Imports
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from datasets import load_dataset
import torchvision.transforms as transforms
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

# Import custom modules
from config import (
    TRAIN_CELEBA_LDM_CONFIG,
    EXPERIMENT_4_CONFIG,
    DATASET_CACHE_DIR,
    CHECKPOINTS_DIR,
    UNET_CELEBA_LDM_CHECKPOINT_PREFIX,
    CELEBA_DATASET_NAME,
    CLIP_MODEL_NAME,
    ensure_experiment_4_dirs,
)
from models.custom_unet_celeba_ldm import CustomUNet2DConditionModelCelebaLDM
from models.vae_wrapper import VAEWrapper
from custom_datasets.celeba_hq_dataset import CelebAHQWithCaptions

print("‚úì Imports successful")

  from .autonotebook import tqdm as notebook_tqdm


‚úì Imports successful


In [3]:
# Configuration
NUM_EPOCHS = TRAIN_CELEBA_LDM_CONFIG["num_epochs"]
LEARNING_RATE = TRAIN_CELEBA_LDM_CONFIG["learning_rate"]
BATCH_SIZE = TRAIN_CELEBA_LDM_CONFIG["batch_size"]
NUM_TRAIN_TIMESTEPS = TRAIN_CELEBA_LDM_CONFIG["num_train_timesteps"]
BETA_SCHEDULE = TRAIN_CELEBA_LDM_CONFIG["beta_schedule"]
CHECKPOINT_EVERY_N_EPOCHS = TRAIN_CELEBA_LDM_CONFIG["checkpoint_every_n_epochs"]
CFG_DROPOUT_PROB = TRAIN_CELEBA_LDM_CONFIG["cfg_dropout_prob"]

IMAGE_SIZE = EXPERIMENT_4_CONFIG["image_size"]
LATENT_SIZE = EXPERIMENT_4_CONFIG["latent_size"]

print("Training configuration:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Image size: {IMAGE_SIZE}√ó{IMAGE_SIZE}")
print(f"  Latent size: {LATENT_SIZE}√ó{LATENT_SIZE}")
print(f"  CFG dropout: {CFG_DROPOUT_PROB*100:.0f}%")

Training configuration:
  Epochs: 100
  Learning rate: 1e-05
  Batch size: 32
  Image size: 256√ó256
  Latent size: 32√ó32
  CFG dropout: 10%


In [4]:
# Setup device and create directories
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

# Create experiment directories
ensure_experiment_4_dirs()
print(f"\n‚úì Directories created")

Using device: cuda
GPU: NVIDIA A100 80GB PCIe
CUDA version: 11.8

‚úì Directories created


## Step 1: Load CelebA Dataset

Load the CelebA dataset with attributes from local cache. The dataset contains 200K celebrity face images at 178√ó218 resolution with 40 binary attributes per image (gender, age, hair color, facial hair, accessories, etc.).

In [5]:
# Load CelebA dataset with attributes
print("Loading CelebA dataset...")
print(f"Dataset: {CELEBA_DATASET_NAME}")
print(f"Cache directory: {DATASET_CACHE_DIR}")
print()

try:
    print(f"üì• Loading '{CELEBA_DATASET_NAME}' from cache...")
    
    # Load dataset from cached location
    celeba_hq = load_dataset(
        CELEBA_DATASET_NAME,
        cache_dir=str(DATASET_CACHE_DIR / "huggingface"),
        split="train"
    )
    
    print(f"‚úì Loaded {len(celeba_hq)} images")
    print(f"   Image size: {celeba_hq[0]['image'].size}")
    print(f"   Dataset features: {list(celeba_hq.features.keys())}")
    print(f"   Attributes: {len([k for k in celeba_hq.features.keys() if k not in ['image', 'celeb_id']])} binary attributes")
    print(f"\n‚úì Dataset ready for training!")
    
except Exception as e:
    print(f"‚úó Failed to load dataset: {e}")
    celeba_hq = None

Loading CelebA dataset...
Dataset: flwrlabs/celeba
Cache directory: /home/doshlom4/work/final_project/dataset_cache

üì• Loading 'flwrlabs/celeba' from cache...
‚úì Loaded 162770 images
   Image size: (178, 218)
   Dataset features: ['image', 'celeb_id', '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young']
   Attributes: 40 binary attributes

‚úì Dataset ready for training!


## Step 2: Initialize Models

Initialize all models:
1. **VAE** (frozen, pretrained) - for encoding/decoding
2. **UNet** (trainable) - noise prediction in latent space
3. **CLIP Text Encoder** (frozen) - text embeddings
4. **Noise Scheduler** - DDPM with squared cosine schedule

In [6]:
# Initialize VAE (frozen, pretrained)
print("Loading VAE...")
vae = VAEWrapper().to(device)
print("‚úì VAE loaded and frozen")

Loading VAE...
Loading VAE from: stabilityai/sd-vae-ft-mse
VAE loaded and frozen:
  - Latent channels: 4
  - Downsample factor: 8x
  - Scale factor: 0.18215
‚úì VAE loaded and frozen


In [7]:
# Initialize UNet (trainable)
print("Initializing UNet...")
unet = CustomUNet2DConditionModelCelebaLDM().to(device)
unet.print_parameter_count()
print("‚úì UNet initialized")

Initializing UNet...
Number of trainable parameters: 109,370,116 (109.4M)
‚úì UNet initialized


In [8]:
# Initialize CLIP text encoder and tokenizer (frozen)
print("Loading CLIP...")
text_encoder = CLIPTextModel.from_pretrained(
    CLIP_MODEL_NAME,
    cache_dir=str(DATASET_CACHE_DIR / "huggingface"),
).to(device)
text_encoder.eval()
text_encoder.requires_grad_(False)

tokenizer = CLIPTokenizer.from_pretrained(
    CLIP_MODEL_NAME,
    cache_dir=str(DATASET_CACHE_DIR / "huggingface"),
)
print("‚úì CLIP loaded and frozen")

Loading CLIP...
‚úì CLIP loaded and frozen


In [9]:
# Initialize noise scheduler
noise_scheduler = DDPMScheduler(
    num_train_timesteps=NUM_TRAIN_TIMESTEPS,
    beta_schedule=BETA_SCHEDULE,
)
print(f"‚úì Noise scheduler initialized ({NUM_TRAIN_TIMESTEPS} timesteps)")

‚úì Noise scheduler initialized (1000 timesteps)


## Step 3: Prepare Dataset

Create transforms and dataset. Images are:
1. Resized to 256√ó256
2. Converted to tensor
3. Normalized to [-1, 1] (required for VAE)

In [10]:
# Define transforms
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),  # Normalize to [-1, 1]
])

# Create dataset (only if celeba_hq was loaded)
if celeba_hq is not None:
    dataset = CelebAHQWithCaptions(
        hf_dataset=celeba_hq,
        transform=transform,
    )
    
    # Create dataloader
    dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
    )
    
    print(f"‚úì Dataset ready: {len(dataset)} images")
    print(f"‚úì DataLoader ready: {len(dataloader)} batches")
else:
    print("‚ö† Skipping dataset creation (dataset not loaded)")
    dataloader = None

CelebA-HQ Dataset initialized:
  - Total images: 162770
  - Available attributes: 40
  - Attributes used for prompting: 16
‚úì Dataset ready: 162770 images
‚úì DataLoader ready: 5087 batches


## Step 4: Training Setup

Initialize optimizer and prepare for training.

In [11]:
# Initialize optimizer
optimizer = torch.optim.AdamW(unet.parameters(), lr=LEARNING_RATE)
print(f"‚úì Optimizer initialized (AdamW, lr={LEARNING_RATE})")

‚úì Optimizer initialized (AdamW, lr=1e-05)


## Step 5: Training Loop

Key steps in each training iteration:
1. **Encode images to latents** using frozen VAE
2. **Add noise** to latents (in latent space, not pixel space)
3. **Get text embeddings** from CLIP
4. **Apply CFG dropout**: Replace 10% of prompts with empty string
5. **Predict noise** using UNet
6. **Compute loss** and update weights

**Note**: This is much faster than pixel-space training because we operate on 32√ó32 latents instead of 256√ó256 pixels!

In [12]:
# Training loop
print("Starting training...\n")

global_step = 0
for epoch in range(NUM_EPOCHS):
    unet.train()
    epoch_loss = 0.0
    
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
    
    for batch_idx, (images, captions, attributes) in enumerate(progress_bar):
        images = images.to(device)  # (B, 3, 256, 256)
        
        # Step 1: Encode images to latents using VAE
        with torch.no_grad():
            latents = vae.encode(images)  # (B, 4, 32, 32)
        
        # Step 2: Sample noise to add to latents
        noise = torch.randn_like(latents)
        
        # Step 3: Sample random timesteps
        timesteps = torch.randint(
            0, NUM_TRAIN_TIMESTEPS, (latents.shape[0],),
            device=device,
        ).long()
        
        # Step 4: Add noise to latents according to timesteps
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
        
        # Step 5: Get text embeddings
        # Apply CFG dropout: randomly replace captions with empty string
        cfg_mask = torch.rand(len(captions)) < CFG_DROPOUT_PROB
        captions_cfg = [
            "" if cfg_mask[i] else captions[i]
            for i in range(len(captions))
        ]
        
        # Tokenize and encode
        text_inputs = tokenizer(
            captions_cfg,
            padding="max_length",
            max_length=tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        
        with torch.no_grad():
            encoder_hidden_states = text_encoder(
                text_inputs.input_ids.to(device)
            )[0]  # (B, 77, 512)
        
        # Step 6: Predict noise using UNet
        noise_pred = unet(
            noisy_latents,
            timesteps,
            encoder_hidden_states,
        ).sample
        
        # Step 7: Compute loss
        loss = F.mse_loss(noise_pred, noise)
        
        # Step 8: Backprop and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Track loss
        epoch_loss += loss.item()
        global_step += 1
        
        # Update progress bar
        progress_bar.set_postfix({
            "loss": f"{loss.item():.4f}",
            "avg_loss": f"{epoch_loss/(batch_idx+1):.4f}",
        })
    
    # End of epoch
    avg_epoch_loss = epoch_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Average loss: {avg_epoch_loss:.4f}")
    
    # Save checkpoint
    if (epoch + 1) % CHECKPOINT_EVERY_N_EPOCHS == 0:
        checkpoint_path = CHECKPOINTS_DIR / f"{UNET_CELEBA_LDM_CHECKPOINT_PREFIX}{epoch}.pt"
        
        torch.save({
            "epoch": epoch,
            "global_step": global_step,
            "unet_state_dict": unet.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "loss": avg_epoch_loss,
            "config": {
                "image_size": IMAGE_SIZE,
                "latent_size": LATENT_SIZE,
                "vae_model": "stabilityai/sd-vae-ft-mse",
                "vae_scale_factor": vae.scale_factor,
            },
        }, checkpoint_path)
        
        print(f"‚úì Saved checkpoint: {checkpoint_path.name}")

print("\n‚úì Training complete!")

Starting training...



Epoch 1/100:  15%|‚ñà‚ñç        | 743/5087 [03:33<20:48,  3.48it/s, loss=0.4304, avg_loss=0.4162]


KeyboardInterrupt: 

## Step 6: Save Final Model

Save the final trained model checkpoint.

In [None]:
# Save final checkpoint
final_checkpoint_path = CHECKPOINTS_DIR / f"{UNET_CELEBA_LDM_CHECKPOINT_PREFIX}final.pt"

torch.save({
    "epoch": NUM_EPOCHS - 1,
    "global_step": global_step,
    "unet_state_dict": unet.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "config": {
        "image_size": IMAGE_SIZE,
        "latent_size": LATENT_SIZE,
        "vae_model": "stabilityai/sd-vae-ft-mse",
        "vae_scale_factor": vae.scale_factor,
    },
}, final_checkpoint_path)

print(f"‚úì Saved final checkpoint: {final_checkpoint_path}")

## Next Steps

1. **Inference**: Use `inference1_t2i_celeba_hq_ldm_cfg.ipynb` to generate images
2. **Metrics**: Use `metrics1_evaluate_celeba_hq.ipynb` to compute FID and CLIP scores
3. **Classifier**: Train attribute classifier in `train2_train_celeba_attribute_classifier.ipynb`

The trained model can generate 256√ó256 face images from text prompts like:
- "A photo of a young woman with blond hair, smiling"
- "A portrait of an older man with eyeglasses"
- "A young person with black hair and no accessories"