# Experiment 3: Train WikiArt Text-to-Image Diffusion Model

This notebook trains a text-to-image diffusion model on the WikiArt dataset using classifier-free guidance (CFG).

**Key Differences from Experiment 2 (CIFAR-10):**
- **Image size:** 128×128 (vs 32×32)
- **Classes:** 27 art styles (vs 10 object categories)
- **Dataset:** HuggingFace WikiArt dataset
- **Model capacity:** Larger UNet (5 blocks vs 4) for complex artistic images
- **Training:** More epochs (100 vs 50), smaller batch size (16 vs 128)

**Model Architecture:**
- Custom UNet2DConditionModel for 128×128 RGB
- CLIP text encoder for text conditioning
- Classifier-free guidance during training (10% dropout)

**Training Configuration:**
- 100 epochs
- Batch size: 16 (memory constraints for larger images)
- Learning rate: 1e-5
- Checkpoints every 10 epochs

## 1. Setup and Configuration

In [None]:
# Project configuration - use absolute paths
from pathlib import Path
import sys

PROJECT_ROOT = Path("/home/doshlom4/work/final_project")
sys.path.insert(0, str(PROJECT_ROOT))

print(f"Project root: {PROJECT_ROOT}")

In [None]:
# Import configuration
from config import (
    EXPERIMENT_3_CONFIG,
    TRAIN_WIKIART_CONFIG,
    INFERENCE_CONFIG,
    UNET_WIKIART_CONFIG,
    TOKENIZER_MAX_LENGTH,
    CLIP_MODEL_NAME,
    WIKIART_STYLES,
    CHECKPOINTS_DIR,
    DATASET_CACHE_DIR,
    UNET_WIKIART_CHECKPOINT_PREFIX,
    get_wikiart_unet_checkpoint_path,
    get_latest_wikiart_unet_checkpoint,
    ensure_experiment_3_dirs,
)

# Deep learning frameworks
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from diffusers import DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from tqdm import tqdm

# Standard libraries
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import random

# HuggingFace datasets
from datasets import load_dataset
import torchvision.transforms as transforms
from PIL import Image

print("Libraries imported successfully")

In [None]:
# Print configuration
print("WikiArt Training Configuration:")
print(f"  Styles: {len(WIKIART_STYLES)} art styles")
print(f"  Epochs: {TRAIN_WIKIART_CONFIG['num_epochs']}")
print(f"  Batch size: {TRAIN_WIKIART_CONFIG['batch_size']}")
print(f"  Learning rate: {TRAIN_WIKIART_CONFIG['learning_rate']}")
print(f"  Checkpoint every: {TRAIN_WIKIART_CONFIG['checkpoint_every_n_epochs']} epochs")
print()
print("UNet Configuration:")
print(f"  Sample size: {UNET_WIKIART_CONFIG['sample_size']}")
print(f"  Channels: {UNET_WIKIART_CONFIG['in_channels']}")
print(f"  Block channels: {UNET_WIKIART_CONFIG['block_out_channels']}")
print()
print("Art Styles:")
for i, style in enumerate(WIKIART_STYLES):
    print(f"  {i}: {style}")

In [None]:
# Create directories
ensure_experiment_3_dirs()
CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)
print("Directory structure created")

## 2. Load WikiArt Dataset from HuggingFace

In [None]:
# Load WikiArt dataset from HuggingFace
# This may take a while on first download
print("Loading WikiArt dataset from HuggingFace...")
wikiart_hf = load_dataset(
    "huggan/wikiart",
    split="train",
    cache_dir=str(DATASET_CACHE_DIR / "huggingface")
)

print(f"\nWikiArt dataset loaded: {len(wikiart_hf)} images")
print(f"Features: {wikiart_hf.features}")

In [None]:
# Explore dataset structure
sample = wikiart_hf[0]
print("Sample item keys:", sample.keys())
print(f"Image type: {type(sample['image'])}")
print(f"Image size: {sample['image'].size}")

# Check which columns contain style information
for key in sample.keys():
    if key != 'image':
        print(f"{key}: {sample[key]}")

In [None]:
# Visualize some samples from the dataset
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

for i, ax in enumerate(axes.flat):
    sample = wikiart_hf[i * 1000]
    ax.imshow(sample['image'])
    style_idx = sample.get('style', sample.get('label', i))
    if isinstance(style_idx, int) and style_idx < len(WIKIART_STYLES):
        title = WIKIART_STYLES[style_idx][:20]
    else:
        title = f"Sample {i}"
    ax.set_title(title, fontsize=8)
    ax.axis('off')

plt.suptitle('WikiArt Dataset Samples', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Define transforms - resize to 128x128 and normalize to [-1, 1]
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # RGB normalization to [-1, 1]
])

print("Transform defined: Resize(128x128) -> ToTensor -> Normalize to [-1, 1]")

In [None]:
# Create custom PyTorch dataset with text captions
class WikiArtWithCaptions(Dataset):
    """
    WikiArt dataset with text captions for each image.
    Maps style labels to text prompts.
    """
    def __init__(self, hf_dataset, transform=None, prompt_template="A painting in the style of {style_name}"):
        self.dataset = hf_dataset
        self.transform = transform
        self.prompt_template = prompt_template
        self.style_names = WIKIART_STYLES
        
        # Determine the style column name (varies in different dataset versions)
        sample = hf_dataset[0]
        if 'style' in sample:
            self.style_column = 'style'
        elif 'label' in sample:
            self.style_column = 'label'
        else:
            # If no style column, we'll use random styles (fallback)
            self.style_column = None
            print("Warning: No style column found, using random styles")
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item['image']
        
        # Ensure image is RGB
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Get style label
        if self.style_column:
            style_idx = item[self.style_column]
            # Ensure style_idx is within bounds
            if style_idx >= len(self.style_names):
                style_idx = style_idx % len(self.style_names)
        else:
            style_idx = idx % len(self.style_names)
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        # Create caption
        style_name = self.style_names[style_idx].replace('_', ' ')
        caption = self.prompt_template.format(style_name=style_name)
        
        return image, caption, style_idx


# Create dataset with captions
train_dataset = WikiArtWithCaptions(
    wikiart_hf,
    transform=transform,
    prompt_template=EXPERIMENT_3_CONFIG["prompt_template"]
)

# Test
img, caption, label = train_dataset[0]
print(f"Sample image shape: {img.shape}")
print(f"Sample caption: '{caption}'")
print(f"Sample label: {label} ({WIKIART_STYLES[label]})")

In [None]:
# Create dataloader
train_loader = DataLoader(
    train_dataset,
    batch_size=TRAIN_WIKIART_CONFIG["batch_size"],
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

print(f"Dataloader created: {len(train_loader)} batches per epoch")
print(f"Total images: {len(train_dataset)}")
print(f"Batch size: {TRAIN_WIKIART_CONFIG['batch_size']}")

## 3. Load Models

In [None]:
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# Import and create custom UNet model
from models.custom_unet_wikiart import CustomUNet2DConditionModelWikiArt

unet = CustomUNet2DConditionModelWikiArt().to(device)

# Print model info
num_params = unet.get_num_parameters()
print(f"✓ Created WikiArt UNet model")
print(f"  Parameters: {num_params:,}")
print(f"  Sample size: {unet.config.sample_size}")
print(f"  In channels: {unet.config.in_channels}")
print(f"  Out channels: {unet.config.out_channels}")

In [None]:
# Load CLIP text encoder and tokenizer
text_encoder = CLIPTextModel.from_pretrained(CLIP_MODEL_NAME).to(device)
tokenizer = CLIPTokenizer.from_pretrained(CLIP_MODEL_NAME)

# Freeze text encoder
text_encoder.requires_grad_(False)
text_encoder.eval()

print(f"✓ Loaded CLIP text encoder: {CLIP_MODEL_NAME}")
print(f"  Tokenizer max length: {TOKENIZER_MAX_LENGTH}")

In [None]:
# Create noise scheduler
noise_scheduler = DDPMScheduler(
    beta_schedule=TRAIN_WIKIART_CONFIG["beta_schedule"],
    num_train_timesteps=TRAIN_WIKIART_CONFIG["num_train_timesteps"],
)

print(f"✓ Created DDPM scheduler")
print(f"  Beta schedule: {TRAIN_WIKIART_CONFIG['beta_schedule']}")
print(f"  Timesteps: {TRAIN_WIKIART_CONFIG['num_train_timesteps']}")

## 4. Training Setup with Checkpoint Resume

In [None]:
# Create optimizer
optimizer = torch.optim.AdamW(
    unet.parameters(),
    lr=TRAIN_WIKIART_CONFIG["learning_rate"],
)

print(f"✓ Created AdamW optimizer")
print(f"  Learning rate: {TRAIN_WIKIART_CONFIG['learning_rate']}")

In [None]:
# Training configuration
CFG_DROPOUT_PROB = 0.1  # Probability of dropping text conditioning (for CFG training)
NUM_EPOCHS = TRAIN_WIKIART_CONFIG["num_epochs"]
CHECKPOINT_EVERY = TRAIN_WIKIART_CONFIG["checkpoint_every_n_epochs"]

# Resume from checkpoint settings
RESUME_TRAINING = True  # Set to True to resume from latest checkpoint

print(f"Training configuration:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  CFG dropout probability: {CFG_DROPOUT_PROB}")
print(f"  Checkpoint every: {CHECKPOINT_EVERY} epochs")
print(f"  Resume training: {RESUME_TRAINING}")

In [None]:
def save_checkpoint(epoch, unet, optimizer, loss_history):
    """Save training checkpoint."""
    checkpoint_path = get_wikiart_unet_checkpoint_path(epoch)
    
    torch.save({
        "epoch": epoch,
        "unet_state_dict": unet.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss_history": loss_history,
    }, checkpoint_path)
    
    print(f"  ✓ Saved checkpoint: {checkpoint_path}")


def load_checkpoint(unet, optimizer):
    """Load latest checkpoint if available."""
    try:
        checkpoint_path = get_latest_wikiart_unet_checkpoint()
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        unet.load_state_dict(checkpoint["unet_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        
        start_epoch = checkpoint["epoch"] + 1
        loss_history = checkpoint.get("loss_history", [])
        
        print(f"✓ Resumed from checkpoint: {checkpoint_path}")
        print(f"  Starting from epoch {start_epoch}")
        print(f"  Previous loss history: {len(loss_history)} epochs")
        
        return start_epoch, loss_history
    
    except FileNotFoundError:
        print("No checkpoint found, starting from scratch")
        return 1, []


print("Checkpoint functions defined")

In [None]:
# Try to resume from checkpoint
if RESUME_TRAINING:
    start_epoch, loss_history = load_checkpoint(unet, optimizer)
else:
    start_epoch = 1
    loss_history = []
    print("Starting fresh training (not resuming)")

## 5. Training Loop

In [None]:
def train_one_epoch(epoch, unet, train_loader, optimizer, noise_scheduler, text_encoder, tokenizer, device):
    """
    Train for one epoch with classifier-free guidance.
    """
    unet.train()
    total_loss = 0
    num_batches = 0
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}")
    
    for images, captions, labels in progress_bar:
        images = images.to(device)  # (B, 3, 128, 128)
        batch_size = images.shape[0]
        
        # Randomly drop captions for CFG training
        drop_mask = torch.rand(batch_size) < CFG_DROPOUT_PROB
        captions_with_dropout = [
            "" if drop_mask[i] else captions[i] 
            for i in range(batch_size)
        ]
        
        # Encode text
        text_input = tokenizer(
            captions_with_dropout,
            padding="max_length",
            max_length=TOKENIZER_MAX_LENGTH,
            truncation=True,
            return_tensors="pt"
        )
        text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
        
        # Sample noise
        noise = torch.randn_like(images)
        
        # Sample random timesteps
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps,
            (batch_size,), device=device
        ).long()
        
        # Add noise to images
        noisy_images = noise_scheduler.add_noise(images, noise, timesteps)
        
        # Predict noise
        noise_pred = unet(noisy_images, timesteps, encoder_hidden_states=text_embeddings).sample
        
        # Compute loss
        loss = F.mse_loss(noise_pred, noise)
        
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
        
        # Update progress bar
        progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
    
    avg_loss = total_loss / num_batches
    return avg_loss

print("Training function defined")

In [None]:
# Main training loop
print(f"\n{'='*70}")
print(f"Starting WikiArt Text-to-Image Training")
print(f"{'='*70}")
print(f"Starting epoch: {start_epoch}")
print(f"Total epochs: {NUM_EPOCHS}")
print(f"Batches per epoch: {len(train_loader)}")
print(f"Total training images: {len(train_dataset)}")
print(f"{'='*70}\n")

training_start_time = datetime.now()

for epoch in range(start_epoch, NUM_EPOCHS + 1):
    epoch_start = datetime.now()
    
    # Train one epoch
    avg_loss = train_one_epoch(
        epoch, unet, train_loader, optimizer, 
        noise_scheduler, text_encoder, tokenizer, device
    )
    
    loss_history.append(avg_loss)
    epoch_time = (datetime.now() - epoch_start).total_seconds()
    
    print(f"Epoch {epoch}/{NUM_EPOCHS} - Loss: {avg_loss:.4f} - Time: {epoch_time:.1f}s")
    
    # Save checkpoint
    if epoch % CHECKPOINT_EVERY == 0 or epoch == NUM_EPOCHS:
        save_checkpoint(epoch, unet, optimizer, loss_history)

total_time = (datetime.now() - training_start_time).total_seconds()

print(f"\n{'='*70}")
print(f"Training Complete!")
print(f"{'='*70}")
print(f"Total time: {total_time/60:.1f} minutes ({total_time/3600:.2f} hours)")
print(f"Final loss: {loss_history[-1]:.4f}")

## 6. Visualize Training Progress

In [None]:
# Plot loss curve
plt.figure(figsize=(12, 6))
plt.plot(range(1, len(loss_history) + 1), loss_history, 'b-', linewidth=2)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('WikiArt Training Loss', fontsize=14)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Minimum loss: {min(loss_history):.4f} at epoch {loss_history.index(min(loss_history)) + 1}")

## 7. Test Generation

In [None]:
@torch.no_grad()
def generate_image(prompt: str, guidance_scale: float, num_images: int = 1) -> torch.Tensor:
    """
    Generate WikiArt images using classifier-free guidance.
    """
    unet.eval()
    
    # Encode text prompt
    text_input = tokenizer(
        prompt,
        padding="max_length",
        max_length=TOKENIZER_MAX_LENGTH,
        truncation=True,
        return_tensors="pt"
    )
    text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
    
    # Unconditional embeddings
    uncond_input = tokenizer(
        [""] * num_images,
        padding="max_length",
        max_length=TOKENIZER_MAX_LENGTH,
        return_tensors="pt"
    )
    uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
    
    # Concatenate for CFG
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
    
    # Initialize noise at 128x128
    latents = torch.randn((num_images, 3, 128, 128), device=device)
    
    # Setup scheduler
    scheduler = DDPMScheduler(
        beta_schedule=INFERENCE_CONFIG["beta_schedule"],
        num_train_timesteps=INFERENCE_CONFIG["num_train_timesteps"]
    )
    scheduler.set_timesteps(INFERENCE_CONFIG["num_inference_steps"])
    
    # Denoising loop
    for t in tqdm(scheduler.timesteps, desc="Generating", leave=False):
        latent_model_input = torch.cat([latents] * 2)
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)
        
        noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
        
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        
        latents = scheduler.step(noise_pred, t, latents).prev_sample
    
    # Denormalize from [-1, 1] to [0, 1]
    image = (latents / 2 + 0.5).clamp(0, 1)
    
    return image

print("Generation function defined")

In [None]:
# Test generation for a selection of art styles
# Show 9 different styles (3x3 grid)
selected_styles = [0, 4, 9, 12, 15, 19, 21, 23, 26]  # Representative styles

fig, axes = plt.subplots(3, 3, figsize=(12, 12))

for i, style_idx in enumerate(selected_styles):
    style_name = WIKIART_STYLES[style_idx].replace('_', ' ')
    prompt = EXPERIMENT_3_CONFIG["prompt_template"].format(style_name=style_name)
    
    image = generate_image(prompt, guidance_scale=7.5, num_images=1)
    img = image[0].permute(1, 2, 0).cpu().numpy()
    
    ax = axes[i // 3, i % 3]
    ax.imshow(img)
    ax.set_title(style_name, fontsize=10)
    ax.axis('off')

plt.suptitle('Generated WikiArt Samples (guidance scale = 7.5)', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Test different guidance scales on one style
test_style = "Impressionism"
guidance_scales = [0, 2, 5, 7.5, 10, 15]

fig, axes = plt.subplots(1, len(guidance_scales), figsize=(18, 3))

prompt = f"A painting in the style of {test_style}"

for i, guidance_scale in enumerate(guidance_scales):
    image = generate_image(prompt, guidance_scale=guidance_scale, num_images=1)
    img = image[0].permute(1, 2, 0).cpu().numpy()
    
    axes[i].imshow(img)
    axes[i].set_title(f'w={guidance_scale}')
    axes[i].axis('off')

plt.suptitle(f'Generated "{test_style}" at different guidance scales', fontsize=14)
plt.tight_layout()
plt.show()

## Summary

This notebook trained a WikiArt text-to-image diffusion model with classifier-free guidance.

**What was accomplished:**
- Loaded WikiArt dataset from HuggingFace (81K+ images, 27 art styles)
- Created custom UNet model for 128×128 RGB images
- Trained with classifier-free guidance (10% caption dropout)
- Saved checkpoints every 10 epochs for resume capability
- Tested generation with different guidance scales

**Next steps:**
1. `inference1_t2i_wikiart_cfg.ipynb` - Detailed inference exploration
2. `generate_images.ipynb` - Bulk image generation for evaluation
3. `train2_train_wikiart_classifier.ipynb` - Train art style classifier
4. `metrics1_evaluate_wikiart.ipynb` - FID and classification accuracy