# Experiment 3: Generate WikiArt Images for Evaluation

This notebook generates bulk WikiArt images for quantitative evaluation:
- 100 images per art style × 27 styles × multiple guidance scales
- Images saved in organized directory structure
- Also exports real WikiArt images for FID comparison

**Output Structure:**
```
outputs/experiment_3/
├── dataset/
│   └── style_0_Abstract_Expressionism/
│   └── style_1_Action_painting/
│   └── ... (27 styles)
├── generated/
│   └── guidance_0/
│       └── style_0_Abstract_Expressionism/
│       └── ... (27 styles)
│   └── guidance_5/
│   └── ... (9 guidance scales)
└── metrics/
```

## 1. Setup and Configuration

In [None]:
# Project configuration - use absolute paths
from pathlib import Path
import sys

PROJECT_ROOT = Path("/home/doshlom4/work/final_project")
sys.path.insert(0, str(PROJECT_ROOT))

print(f"Project root: {PROJECT_ROOT}")

In [None]:
# Import configuration
from config import (
    EXPERIMENT_3_CONFIG,
    INFERENCE_CONFIG,
    TOKENIZER_MAX_LENGTH,
    CLIP_MODEL_NAME,
    WIKIART_STYLES,
    DATASET_CACHE_DIR,
    EXPERIMENT_3_DIR,
    EXPERIMENT_3_DATASET_DIR,
    EXPERIMENT_3_GENERATED_DIR,
    get_latest_wikiart_unet_checkpoint,
    get_wikiart_generated_images_dir,
    get_style_dir,
    ensure_experiment_3_dirs,
)

# Deep learning frameworks
import torch
from diffusers import DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from tqdm import tqdm

# Standard libraries
import numpy as np
from PIL import Image
import os

# HuggingFace datasets
from datasets import load_dataset
import torchvision.transforms as transforms

print("Libraries imported successfully")

In [None]:
# Generation configuration
IMAGES_PER_STYLE = EXPERIMENT_3_CONFIG["images_per_class"]  # 100
GUIDANCE_SCALES = EXPERIMENT_3_CONFIG["guidance_scales"]  # [0, 5, 10, 15, 20, 30, 40, 50, 100]
NUM_INFERENCE_STEPS = INFERENCE_CONFIG["num_inference_steps"]  # 50

# Smaller batch size for 128x128 images (memory constraints)
GENERATION_BATCH_SIZE = 4

print("Generation Configuration:")
print(f"  Images per style: {IMAGES_PER_STYLE}")
print(f"  Number of styles: {len(WIKIART_STYLES)}")
print(f"  Guidance scales: {GUIDANCE_SCALES}")
print(f"  Inference steps: {NUM_INFERENCE_STEPS}")
print(f"  Batch size: {GENERATION_BATCH_SIZE}")
print()
total_images = IMAGES_PER_STYLE * len(WIKIART_STYLES) * len(GUIDANCE_SCALES)
print(f"Total images to generate: {total_images:,}")

In [None]:
# Create directory structure
ensure_experiment_3_dirs()
print(f"Output directories created under: {EXPERIMENT_3_DIR}")

## 2. Load Models

In [None]:
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# Load trained UNet model
from models.custom_unet_wikiart import load_wikiart_unet_from_checkpoint

checkpoint_path = get_latest_wikiart_unet_checkpoint()
print(f"Loading checkpoint: {checkpoint_path}")

unet, checkpoint = load_wikiart_unet_from_checkpoint(str(checkpoint_path), device)
unet.eval()

print(f"\n✓ Loaded WikiArt UNet from epoch {checkpoint['epoch']}")

In [None]:
# Load CLIP text encoder and tokenizer
text_encoder = CLIPTextModel.from_pretrained(CLIP_MODEL_NAME).to(device)
tokenizer = CLIPTokenizer.from_pretrained(CLIP_MODEL_NAME)

text_encoder.eval()
text_encoder.requires_grad_(False)

print(f"✓ Loaded CLIP text encoder: {CLIP_MODEL_NAME}")

## 3. Generation Function

In [None]:
@torch.no_grad()
def generate_batch(
    prompts: list[str],
    guidance_scale: float,
    num_inference_steps: int = 50
) -> list[Image.Image]:
    """
    Generate a batch of WikiArt images.
    
    Args:
        prompts: List of text prompts
        guidance_scale: CFG scale
        num_inference_steps: Number of denoising steps
    
    Returns:
        List of PIL Images
    """
    batch_size = len(prompts)
    
    # Encode text prompts
    text_input = tokenizer(
        prompts,
        padding="max_length",
        max_length=TOKENIZER_MAX_LENGTH,
        truncation=True,
        return_tensors="pt"
    )
    text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
    
    # Unconditional embeddings for CFG
    uncond_input = tokenizer(
        [""] * batch_size,
        padding="max_length",
        max_length=TOKENIZER_MAX_LENGTH,
        return_tensors="pt"
    )
    uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
    
    # Concatenate for CFG
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
    
    # Initialize noise
    latents = torch.randn((batch_size, 3, 128, 128), device=device)
    
    # Setup scheduler
    scheduler = DDPMScheduler(
        beta_schedule=INFERENCE_CONFIG["beta_schedule"],
        num_train_timesteps=INFERENCE_CONFIG["num_train_timesteps"]
    )
    scheduler.set_timesteps(num_inference_steps)
    
    # Denoising loop
    for t in scheduler.timesteps:
        latent_model_input = torch.cat([latents] * 2)
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)
        
        noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
        
        # CFG
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        
        latents = scheduler.step(noise_pred, t, latents).prev_sample
    
    # Convert to PIL images
    images = (latents / 2 + 0.5).clamp(0, 1)
    images = images.cpu().numpy()
    
    pil_images = []
    for img in images:
        img = (np.transpose(img, (1, 2, 0)) * 255).astype(np.uint8)
        pil_images.append(Image.fromarray(img))
    
    return pil_images


print("Generation function defined")

## 4. Export Real WikiArt Images

Export real WikiArt images for FID comparison.

In [None]:
# Load WikiArt dataset
print("Loading WikiArt dataset...")
wikiart_hf = load_dataset(
    "huggan/wikiart",
    split="train",
    cache_dir=str(DATASET_CACHE_DIR / "huggingface")
)

print(f"Loaded {len(wikiart_hf)} images")

# Detect style column
sample = wikiart_hf[0]
style_column = 'style' if 'style' in sample else 'label'
print(f"Style column: {style_column}")

In [None]:
# Count images per style in dataset
style_counts = {i: 0 for i in range(len(WIKIART_STYLES))}

for item in tqdm(wikiart_hf, desc="Counting styles"):
    style_idx = item[style_column]
    if style_idx < len(WIKIART_STYLES):
        style_counts[style_idx] += 1

print("\nImages per style:")
for style_idx, count in style_counts.items():
    print(f"  {WIKIART_STYLES[style_idx]}: {count}")

In [None]:
# Export real images (100 per style)
print(f"\nExporting {IMAGES_PER_STYLE} real images per style to {EXPERIMENT_3_DATASET_DIR}")

# Track how many we've saved per style
saved_per_style = {i: 0 for i in range(len(WIKIART_STYLES))}

for item in tqdm(wikiart_hf, desc="Exporting real images"):
    style_idx = item[style_column]
    
    # Skip if style index is out of range or we have enough
    if style_idx >= len(WIKIART_STYLES):
        continue
    if saved_per_style[style_idx] >= IMAGES_PER_STYLE:
        continue
    
    # Get image and resize to 128x128
    image = item['image']
    if image.mode != 'RGB':
        image = image.convert('RGB')
    image = image.resize((128, 128), Image.LANCZOS)
    
    # Save image
    style_dir = get_style_dir(EXPERIMENT_3_DATASET_DIR, style_idx)
    style_dir.mkdir(parents=True, exist_ok=True)
    
    image_path = style_dir / f"real_{saved_per_style[style_idx]:04d}.png"
    image.save(image_path)
    
    saved_per_style[style_idx] += 1
    
    # Check if we have enough for all styles
    if all(count >= IMAGES_PER_STYLE for count in saved_per_style.values()):
        break

print("\nExport complete!")
for style_idx, count in saved_per_style.items():
    print(f"  {WIKIART_STYLES[style_idx]}: {count} images")

## 5. Generate Images for All Guidance Scales

In [None]:
# Main generation loop
print(f"\n{'='*70}")
print(f"Starting WikiArt Image Generation")
print(f"{'='*70}")
print(f"Styles: {len(WIKIART_STYLES)}")
print(f"Images per style: {IMAGES_PER_STYLE}")
print(f"Guidance scales: {GUIDANCE_SCALES}")
print(f"Batch size: {GENERATION_BATCH_SIZE}")
print(f"{'='*70}\n")

total_generated = 0

for guidance_scale in GUIDANCE_SCALES:
    print(f"\n=== Guidance Scale: {guidance_scale} ===")
    
    for style_idx, style_name in enumerate(WIKIART_STYLES):
        style_display = style_name.replace('_', ' ')
        prompt = EXPERIMENT_3_CONFIG["prompt_template"].format(style_name=style_display)
        
        # Get output directory
        guidance_dir = get_wikiart_generated_images_dir(guidance_scale)
        style_dir = get_style_dir(guidance_dir, style_idx)
        style_dir.mkdir(parents=True, exist_ok=True)
        
        # Check how many already exist
        existing = len(list(style_dir.glob("*.png")))
        if existing >= IMAGES_PER_STYLE:
            print(f"  [{style_idx+1:2d}/{len(WIKIART_STYLES)}] {style_display}: Already complete ({existing} images)")
            continue
        
        # Generate remaining images
        remaining = IMAGES_PER_STYLE - existing
        num_batches = (remaining + GENERATION_BATCH_SIZE - 1) // GENERATION_BATCH_SIZE
        
        generated_count = existing
        
        for batch_idx in tqdm(range(num_batches), 
                               desc=f"  [{style_idx+1:2d}/{len(WIKIART_STYLES)}] {style_display}", 
                               leave=False):
            batch_size = min(GENERATION_BATCH_SIZE, IMAGES_PER_STYLE - generated_count)
            prompts = [prompt] * batch_size
            
            # Generate batch
            images = generate_batch(prompts, guidance_scale, NUM_INFERENCE_STEPS)
            
            # Save images
            for img in images:
                image_path = style_dir / f"gen_{generated_count:04d}.png"
                img.save(image_path)
                generated_count += 1
                total_generated += 1
        
        print(f"  [{style_idx+1:2d}/{len(WIKIART_STYLES)}] {style_display}: Generated {generated_count} images")

print(f"\n{'='*70}")
print(f"Generation Complete!")
print(f"{'='*70}")
print(f"Total images generated: {total_generated:,}")

## 6. Verify Generated Images

In [None]:
# Verify image counts
print("Verification: Image counts per guidance scale\n")

for guidance_scale in GUIDANCE_SCALES:
    guidance_dir = get_wikiart_generated_images_dir(guidance_scale)
    total = 0
    style_counts = {}
    
    for style_idx in range(len(WIKIART_STYLES)):
        style_dir = get_style_dir(guidance_dir, style_idx)
        count = len(list(style_dir.glob("*.png")))
        style_counts[style_idx] = count
        total += count
    
    expected = IMAGES_PER_STYLE * len(WIKIART_STYLES)
    status = "✓" if total == expected else "✗"
    print(f"Guidance {guidance_scale:3d}: {total:5d} images (expected: {expected}) {status}")

In [None]:
# Verify real dataset images
print("\nReal dataset images:")
total_real = 0

for style_idx in range(len(WIKIART_STYLES)):
    style_dir = get_style_dir(EXPERIMENT_3_DATASET_DIR, style_idx)
    count = len(list(style_dir.glob("*.png")))
    total_real += count

expected_real = IMAGES_PER_STYLE * len(WIKIART_STYLES)
status = "✓" if total_real == expected_real else "✗"
print(f"Total: {total_real} images (expected: {expected_real}) {status}")

In [None]:
# Display sample generated images
import matplotlib.pyplot as plt

sample_styles = [0, 12, 19]  # Abstract_Expressionism, Impressionism, Pop_Art
sample_guidance = [0, 5, 10]

fig, axes = plt.subplots(len(sample_styles), len(sample_guidance), figsize=(10, 10))

for row, style_idx in enumerate(sample_styles):
    for col, guidance in enumerate(sample_guidance):
        guidance_dir = get_wikiart_generated_images_dir(guidance)
        style_dir = get_style_dir(guidance_dir, style_idx)
        
        # Load first image
        img_path = list(style_dir.glob("*.png"))[0]
        img = Image.open(img_path)
        
        axes[row, col].imshow(img)
        if row == 0:
            axes[row, col].set_title(f'w={guidance}')
        if col == 0:
            axes[row, col].set_ylabel(WIKIART_STYLES[style_idx].replace('_', '\n'), fontsize=8)
        axes[row, col].set_xticks([])
        axes[row, col].set_yticks([])

plt.suptitle('Sample Generated Images', fontsize=14)
plt.tight_layout()
plt.show()

## Summary

This notebook generated bulk WikiArt images for evaluation:

**Generated:**
- 100 images × 27 styles × 9 guidance scales = 24,300 generated images
- 100 images × 27 styles = 2,700 real images for FID comparison

**Output Locations:**
- Real images: `outputs/experiment_3/dataset/`
- Generated images: `outputs/experiment_3/generated/guidance_X/`

**Next steps:**
1. `train2_train_wikiart_classifier.ipynb` - Train art style classifier
2. `metrics1_evaluate_wikiart.ipynb` - Compute FID and classification accuracy