# Train12: Flickr8k Text-to-Image Diffusion (Conditional UNet)

This notebook trains a text-conditional diffusion model on the Flickr8k dataset using:
- Pretrained VAE (AutoencoderKL) to work in latent space (32x32x4 for 256x256)
- CLIP tokenizer + text encoder for caption embeddings
- UNet2DConditionModel with cross-attention
- DDPM scheduler and EMA for stable training

It follows the style of prior notebooks (train9, train11) and saves checkpoints and samples periodically.

## 1. Set Up Environment and Install Dependencies

If needed, install/update required libraries. (Skip if already installed on HPC environment.)

```bash
# OPTIONAL: Uncomment to install dependencies
# pip install --upgrade diffusers transformers datasets accelerate huggingface_hub safetensors
# pip install xformers==0.0.27.post2  # if CUDA compatible
# pip install clean-fid torchmetrics   # for later FID computation
```

We rely on: `diffusers`, `transformers`, `datasets`, `accelerate`, and optionally `xformers` for memory-efficient attention.

## 2. Runtime and Accelerator Configuration
We use `accelerate` to support multi-GPU / distributed training transparently and mixed precision for speed & memory efficiency.

In [None]:
from __future__ import annotations

import os
import math
import random
from dataclasses import dataclass
from typing import List, Optional, Dict, Any

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt

from datasets import load_dataset, Dataset as HFDataset
from accelerate import Accelerator

from diffusers import DDPMScheduler, DDIMScheduler
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers.training_utils import EMAModel
from transformers import CLIPTextModel, CLIPTokenizer


def set_seed(seed: int = 42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


accelerator = Accelerator(mixed_precision="fp16")
device = accelerator.device

if accelerator.is_main_process:
    print(f"Accelerator initialized | device={device} | num_processes={accelerator.num_processes} | mixed_precision={accelerator.mixed_precision}")
    if torch.cuda.is_available():
        print(f"CUDA devices: {torch.cuda.device_count()} | Current: {torch.cuda.get_device_name(0)}")


## 3. Experiment Config: Paths and Hyperparameters

In [None]:
@dataclass
class TrainConfig:
    run_name: str = "train12_flickr8k_text2img"
    output_dir: str = "./outputs/train12_flickr8k_text2img"
    cache_dir: str = os.path.abspath("../../dataset_cache")
    dataset_name_candidates: List[str] = None
    image_size: int = 256
    train_batch_size: int = 16
    eval_batch_size: int = 16
    lr: float = 1e-4
    weight_decay: float = 1e-2
    num_epochs: int = 50
    gradient_accumulation_steps: int = 1
    num_workers: int = 2
    max_caption_length: int = 77
    save_every: int = 2000
    sample_every: int = 2000
    mixed_precision: str = "fp16"
    ema_decay: float = 0.9999
    classifier_free_guidance_prob: float = 0.1
    seed: int = 42
    prompts: List[str] = None


config = TrainConfig(
    dataset_name_candidates=[
        "nlphuji/flickr8k",
        "flickr8k",
        "yashkant/Flickr8k",
        "conceptofmind/flickr8k",
    ],
    prompts=[
        "A child playing with a dog",
        "A man riding a bicycle",
        "Two people sitting on a bench",
        "A group of hikers on a mountain",
        "A dog catching a frisbee in a park",
        "A girl holding a red balloon",
        "A person surfing a big wave",
        "A cat sleeping on a couch",
    ],
)

os.makedirs(config.output_dir, exist_ok=True)
set_seed(config.seed)

if accelerator.is_main_process:
    print(f"Output dir: {os.path.abspath(config.output_dir)}")
    print(f"Dataset cache dir: {config.cache_dir}")


## 4–6. Flickr8k Dataset: Load, Preprocess, Build DataLoader

In [None]:
def build_image_transforms(image_size: int):
    return transforms.Compose([
        transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),  # to [-1,1]
    ])


def try_load_flickr8k(config: TrainConfig) -> Dict[str, HFDataset]:
    """Attempt multiple dataset names until success. Returns split dict with train/validation/test."""
    last_err = None
    for name in config.dataset_name_candidates:
        try:
            ds = load_dataset(name, cache_dir=config.cache_dir)
            if accelerator.is_main_process:
                print(f"Loaded dataset '{name}' with splits: {list(ds.keys())}")
            return ds
        except Exception as e:
            last_err = e
            if accelerator.is_main_process:
                print(f"Failed loading '{name}': {e}")
    raise RuntimeError(f"Could not load Flickr8k from candidates {config.dataset_name_candidates}. Last error: {last_err}
    You may need to manually place images + captions and implement a local loader.")


# Load dataset (train/validation/test splits vary by dataset variant)
raw_ds = try_load_flickr8k(config)

# Heuristic: Use 'train' split; fall back to largest available
train_split_name = 'train' if 'train' in raw_ds else list(raw_ds.keys())[0]
val_split_name = 'validation' if 'validation' in raw_ds else ('val' if 'val' in raw_ds else train_split_name)

train_ds = raw_ds[train_split_name]
val_ds = raw_ds[val_split_name]

if accelerator.is_main_process:
    print(f"Train split size: {len(train_ds)} | Val split size: {len(val_ds)}")

image_tfms = build_image_transforms(config.image_size)


class Flickr8kCaptionDataset(Dataset):
    def __init__(self, hf_dataset: HFDataset, transform, max_caption_len: int, cfg_prob: float = 0.1):
        self.ds = hf_dataset
        self.transform = transform
        self.max_caption_len = max_caption_len
        self.cfg_prob = cfg_prob

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        ex = self.ds[idx]
        image = ex.get('image') or ex.get('img')
        if image.mode != 'RGB':
            image = image.convert('RGB')
        image = self.transform(image)

        # Captions field variants: 'caption', 'sentences', 'captions'
        caption = ''
        if 'caption' in ex:
            # may already be a string or list
            c = ex['caption']
            if isinstance(c, list):
                caption = random.choice(c)
            else:
                caption = str(c)
        elif 'captions' in ex:
            c = ex['captions']
            caption = random.choice(c) if isinstance(c, list) and c else str(c)
        elif 'sentences' in ex:
            c = ex['sentences']
            # some variants list of dicts with 'raw'
            if isinstance(c, list) and len(c) > 0:
                choice = random.choice(c)
                if isinstance(choice, dict):
                    caption = choice.get('raw', '')
                else:
                    caption = str(choice)

        # Classifier-free guidance dropout
        if random.random() < self.cfg_prob:
            caption = ''

        return {"pixel_values": image, "text": caption}


def collate_fn(batch: List[Dict[str, Any]], tokenizer: CLIPTokenizer, max_len: int):
    pixel_values = torch.stack([b['pixel_values'] for b in batch])
    captions = [b['text'] for b in batch]
    tokenized = tokenizer(
        captions,
        padding='max_length',
        max_length=max_len,
        truncation=True,
        return_tensors='pt'
    )
    return {
        'pixel_values': pixel_values,
        'input_ids': tokenized.input_ids,
        'attention_mask': tokenized.attention_mask,
        'captions': captions,
    }


# Tokenizer placeholder (will be loaded later). We'll create loaders after tokenizer init.
train_dataset = Flickr8kCaptionDataset(train_ds, image_tfms, config.max_caption_length, config.classifier_free_guidance_prob)
val_dataset = Flickr8kCaptionDataset(val_ds, image_tfms, config.max_caption_length, 0.0)

if accelerator.is_main_process:
    print(f"Wrapped train dataset length: {len(train_dataset)} | val dataset length: {len(val_dataset)}")


## 7–8. Initialize Tokenizer, Text Encoder, VAE, UNet & Noise Scheduler

In [None]:
print("Loading CLIP tokenizer & text encoder...")
clip_model_name = "openai/clip-vit-base-patch32"
TokenizerClass = CLIPTokenizer
TextEncoderClass = CLIPTextModel

# Load tokenizer first for DataLoader collate
tokenizer = TokenizerClass.from_pretrained(clip_model_name)

# Build DataLoaders now that we have tokenizer
train_loader = DataLoader(
    train_dataset,
    batch_size=config.train_batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=True,
    collate_fn=lambda b: collate_fn(b, tokenizer, config.max_caption_length),
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config.eval_batch_size,
    shuffle=False,
    num_workers=config.num_workers,
    pin_memory=True,
    collate_fn=lambda b: collate_fn(b, tokenizer, config.max_caption_length),
)

print("Loading CLIP text encoder...")
text_encoder = TextEncoderClass.from_pretrained(clip_model_name)
text_encoder.eval()
for p in text_encoder.parameters():
    p.requires_grad = False

print("Loading pretrained VAE (AutoencoderKL)...")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
vae.eval()
for p in vae.parameters():
    p.requires_grad = False

print("Creating conditional UNet...")
unet = UNet2DConditionModel(
    sample_size=config.image_size // 8,  # latent resolution
    in_channels=4,
    out_channels=4,
    layers_per_block=2,
    block_out_channels=(128, 256, 512, 512),
    down_block_types=(
        "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"
    ),
    up_block_types=(
        "UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"
    ),
    cross_attention_dim=text_encoder.config.hidden_size,
)

# Noise scheduler for training
noise_scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", num_train_timesteps=1000)

print("Models initialized.")


## 9–10. Optimizer, LR Scheduler, EMA, Accelerator Preparation

In [None]:
optimizer = torch.optim.AdamW(unet.parameters(), lr=config.lr, weight_decay=config.weight_decay)

# Set total training steps for scheduler (approximation)
steps_per_epoch = max(1, len(train_loader) // accelerator.num_processes)
t_total = steps_per_epoch * config.num_epochs // config.gradient_accumulation_steps
lr_scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=max(10, t_total // 100), num_training_steps=t_total)

ema_unet = EMAModel(unet.parameters(), decay=config.ema_decay)

# Enable gradient checkpointing and optional xFormers
try:
    unet.enable_gradient_checkpointing()
except Exception:
    pass

try:
    unet.enable_xformers_memory_efficient_attention()
except Exception:
    pass

# Prepare with accelerator
(unet, optimizer, train_loader, val_loader, lr_scheduler) = accelerator.prepare(
    unet, optimizer, train_loader, val_loader, lr_scheduler
)

def encode_text(input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.FloatTensor:
    with torch.no_grad():
        return text_encoder(input_ids.to(device), attention_mask=attention_mask.to(device))[0]

print("Optimization components ready.")


## 11. Training Step Definition (Noise Prediction Objective)
The loss is $\mathcal{L} = \mathbb{E}_{t,\epsilon}[\lVert \epsilon - \epsilon_\theta(x_t, t, c) \rVert^2]$.

In [None]:
def training_step(batch) -> torch.Tensor:
    pixel_values = batch['pixel_values'].to(device)
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']

    with torch.no_grad():
        latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215
    # Sample noise
    noise = torch.randn_like(latents)
    bsz = latents.shape[0]
    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=device).long()
    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

    encoder_hidden_states = encode_text(input_ids, attention_mask)

    with accelerator.autocast():
        noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
        loss = nn.functional.mse_loss(noise_pred, noise)

    return loss


## 12–15. Distributed Training Loop, Sampling, and Checkpointing

In [None]:
def save_checkpoint(step: int, is_epoch: bool = False):
    tag = f"epoch_{step}" if is_epoch else f"step_{step}"
    ckpt = {
        'unet': accelerator.get_state_dict(unet),
        'ema': ema_unet.state_dict(),
        'optimizer': optimizer.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict(),
        'step': step,
        'config': vars(config),
    }
    path = os.path.join(config.output_dir, f"unet_{tag}.pt")
    accelerator.save(ckpt, path)
    if accelerator.is_main_process:
        print(f"Saved checkpoint: {path}")


def sample_prompts(unet_for_eval: UNet2DConditionModel, prompts: List[str], num_inference_steps: int = 50, guidance_scale: float = 7.5, seed: Optional[int] = 42, save_path: Optional[str] = None):
    # Use EMA weights for sampling
    unet_for_eval.eval()

    if seed is not None:
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", num_train_timesteps=1000)
    scheduler.set_timesteps(num_inference_steps)

    images = []
    with torch.no_grad():
        for i, prompt in enumerate(prompts):
            tokens = tokenizer(prompt, padding='max_length', max_length=config.max_caption_length, truncation=True, return_tensors='pt')
            text_emb = text_encoder(tokens.input_ids.to(device), attention_mask=tokens.attention_mask.to(device))[0]

            if guidance_scale > 1.0:
                uncond = tokenizer([""], padding='max_length', max_length=config.max_caption_length, return_tensors='pt')
                uncond_emb = text_encoder(uncond.input_ids.to(device), attention_mask=uncond.attention_mask.to(device))[0]
                text_emb = torch.cat([uncond_emb, text_emb], dim=0)

            latents = torch.randn((1, 4, config.image_size // 8, config.image_size // 8), device=device)

            for t in scheduler.timesteps:
                latent_model_input = latents if guidance_scale <= 1.0 else torch.cat([latents] * 2)
                latent_model_input = scheduler.scale_model_input(latent_model_input, t)
                noise_pred = unet_for_eval(latent_model_input, t, encoder_hidden_states=text_emb).sample
                if guidance_scale > 1.0:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
                latents = scheduler.step(noise_pred, t, latents).prev_sample

            latents = latents / 0.18215
            image = vae.decode(latents).sample
            image = (image / 2 + 0.5).clamp(0, 1)
            images.append(image.cpu())

    grid = make_grid(torch.cat(images, dim=0), nrow=min(4, len(images)))
    if save_path and accelerator.is_main_process:
        save_image(grid, save_path)
        print(f"Saved samples to {save_path}")
    return grid


def train_loop():
    global_step = 0
    ema_unet.to(device)

    for epoch in range(config.num_epochs):
        unet.train()
        pbar = enumerate(train_loader)
        total_loss = 0.0
        for step, batch in pbar:
            with accelerator.accumulate(unet):
                loss = training_step(batch)
                accelerator.backward(loss)
                accelerator.clip_grad_norm_(unet.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad(set_to_none=True)
                ema_unet.step(unet.parameters())

            total_loss += loss.item()
            global_step += 1

            if accelerator.is_main_process and (global_step % 50 == 0):
                avg = total_loss / (step + 1)
                print(f"Epoch {epoch+1}/{config.num_epochs} | step {global_step} | loss {avg:.4f}")

            if accelerator.is_main_process and (global_step % config.sample_every == 0):
                # Copy EMA params to a temp UNet for eval
                ema_unet.store(unet.parameters())
                ema_unet.copy_to(unet.parameters())
                samples_path = os.path.join(config.output_dir, f"samples_step_{global_step}.png")
                _ = sample_prompts(unet, config.prompts[:8], num_inference_steps=50, guidance_scale=7.5, seed=42, save_path=samples_path)
                ema_unet.restore(unet.parameters())

            if accelerator.is_main_process and (global_step % config.save_every == 0):
                save_checkpoint(global_step, is_epoch=False)

        if accelerator.is_main_process:
            # End of epoch: save checkpoint and do val sampling
            save_checkpoint(epoch + 1, is_epoch=True)
            ema_unet.store(unet.parameters())
            ema_unet.copy_to(unet.parameters())
            val_path = os.path.join(config.output_dir, f"val_samples_epoch_{epoch+1}.png")
            _ = sample_prompts(unet, config.prompts[:8], num_inference_steps=50, guidance_scale=7.5, seed=123, save_path=val_path)
            ema_unet.restore(unet.parameters())

    if accelerator.is_main_process:
        print("Training complete.")


train_loop()


## 17. Inference: Text-to-Image with Trained Weights

In [None]:
def generate(prompt: str, num_inference_steps: int = 50, guidance_scale: float = 7.5, seed: int = 42):
    unet.eval()
    ema_unet.store(unet.parameters())
    ema_unet.copy_to(unet.parameters())

    if seed is not None:
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2", num_train_timesteps=1000)
    scheduler.set_timesteps(num_inference_steps)

    tokens = tokenizer([prompt], padding='max_length', max_length=config.max_caption_length, truncation=True, return_tensors='pt')
    text_emb = text_encoder(tokens.input_ids.to(device), attention_mask=tokens.attention_mask.to(device))[0]

    if guidance_scale > 1.0:
        uncond = tokenizer([""], padding='max_length', max_length=config.max_caption_length, return_tensors='pt')
        uncond_emb = text_encoder(uncond.input_ids.to(device), attention_mask=uncond.attention_mask.to(device))[0]
        text_emb = torch.cat([uncond_emb, text_emb], dim=0)

    latents = torch.randn((1, 4, config.image_size // 8, config.image_size // 8), device=device)

    with torch.no_grad():
        for t in scheduler.timesteps:
            latent_model_input = latents if guidance_scale <= 1.0 else torch.cat([latents] * 2)
            latent_model_input = scheduler.scale_model_input(latent_model_input, t)
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_emb).sample
            if guidance_scale > 1.0:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
            latents = scheduler.step(noise_pred, t, latents).prev_sample

        latents = latents / 0.18215
        image = vae.decode(latents).sample
        image = (image / 2 + 0.5).clamp(0, 1)

    ema_unet.restore(unet.parameters())
    return image.cpu()

# Example prompt generation (after training) - uncomment when ready
# out = generate("A dog running through a field", num_inference_steps=50, guidance_scale=7.5, seed=123)
# save_image(out, os.path.join(config.output_dir, "example_generation.png"))
# print("Saved example_generation.png")


## 6b. Quick Visual Check of a Batch (Optional)

In [None]:
# Visualize a sample batch (run before training if desired)
if accelerator.is_main_process:
    sample_batch = next(iter(train_loader))
    imgs = sample_batch['pixel_values'][:8]
    caps = sample_batch['captions'][:8]
    grid = make_grid(imgs, nrow=4)
    plt.figure(figsize=(10,5))
    plt.imshow(grid.permute(1,2,0).numpy() * 0.5 + 0.5)
    plt.axis('off')
    plt.title('Sample Flickr8k Images (first 8)')
    plt.show()
    print("Captions:")
    for i,c in enumerate(caps):
        print(f"{i+1}. {c}")
