# Experiment 3: WikiArt Inference with Classifier-Free Guidance

This notebook explores the trained WikiArt text-to-image model, demonstrating:
- Generation across all 27 art styles
- Effect of different guidance scales
- Comparison with real WikiArt images
- Multiple samples per style to show diversity
- Custom prompt variations

## 1. Setup and Configuration

In [None]:
# Project configuration - use absolute paths
from pathlib import Path
import sys

PROJECT_ROOT = Path("/home/doshlom4/work/final_project")
sys.path.insert(0, str(PROJECT_ROOT))

print(f"Project root: {PROJECT_ROOT}")

In [None]:
# Import configuration
from config import (
    EXPERIMENT_3_CONFIG,
    INFERENCE_CONFIG,
    TOKENIZER_MAX_LENGTH,
    CLIP_MODEL_NAME,
    WIKIART_STYLES,
    DATASET_CACHE_DIR,
    get_latest_wikiart_unet_checkpoint,
    get_wikiart_unet_checkpoint_path,
)

# Deep learning frameworks
import torch
from diffusers import DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from tqdm import tqdm

# Standard libraries
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# HuggingFace datasets
from datasets import load_dataset
import torchvision.transforms as transforms

print("Libraries imported successfully")

In [None]:
# Print configuration
print("WikiArt Configuration:")
print(f"  Number of styles: {len(WIKIART_STYLES)}")
print(f"  Inference steps: {INFERENCE_CONFIG['num_inference_steps']}")
print(f"  Prompt template: {EXPERIMENT_3_CONFIG['prompt_template']}")
print()
print("Art Styles:")
for i, style in enumerate(WIKIART_STYLES):
    print(f"  {i:2d}: {style}")

## 2. Load Models

In [None]:
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# Load trained UNet model from checkpoint
from models.custom_unet_wikiart import CustomUNet2DConditionModelWikiArt, load_wikiart_unet_from_checkpoint

# Get latest checkpoint
checkpoint_path = get_latest_wikiart_unet_checkpoint()
print(f"Loading checkpoint: {checkpoint_path}")

# Load model
unet, checkpoint = load_wikiart_unet_from_checkpoint(str(checkpoint_path), device)
unet.eval()

print(f"\n✓ Loaded WikiArt UNet from epoch {checkpoint['epoch']}")
print(f"  Parameters: {unet.get_num_parameters():,}")

In [None]:
# Load CLIP text encoder and tokenizer
text_encoder = CLIPTextModel.from_pretrained(CLIP_MODEL_NAME).to(device)
tokenizer = CLIPTokenizer.from_pretrained(CLIP_MODEL_NAME)

text_encoder.eval()
text_encoder.requires_grad_(False)

print(f"✓ Loaded CLIP text encoder: {CLIP_MODEL_NAME}")

## 3. Generation Functions

In [None]:
@torch.no_grad()
def generate_images(
    prompts: list[str],
    guidance_scale: float = 7.5,
    num_inference_steps: int = 50,
    show_progress: bool = True
) -> torch.Tensor:
    """
    Generate WikiArt images using classifier-free guidance.
    
    Args:
        prompts: List of text prompts
        guidance_scale: CFG scale (0 = unconditional, higher = more conditioned)
        num_inference_steps: Number of denoising steps
        show_progress: Whether to show progress bar
    
    Returns:
        Tensor of images (B, 3, 128, 128) in [0, 1] range
    """
    batch_size = len(prompts)
    
    # Encode text prompts
    text_input = tokenizer(
        prompts,
        padding="max_length",
        max_length=TOKENIZER_MAX_LENGTH,
        truncation=True,
        return_tensors="pt"
    )
    text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
    
    # Unconditional embeddings for CFG
    uncond_input = tokenizer(
        [""] * batch_size,
        padding="max_length",
        max_length=TOKENIZER_MAX_LENGTH,
        return_tensors="pt"
    )
    uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
    
    # Concatenate for CFG
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
    
    # Initialize noise
    latents = torch.randn((batch_size, 3, 128, 128), device=device)
    
    # Setup scheduler
    scheduler = DDPMScheduler(
        beta_schedule=INFERENCE_CONFIG["beta_schedule"],
        num_train_timesteps=INFERENCE_CONFIG["num_train_timesteps"]
    )
    scheduler.set_timesteps(num_inference_steps)
    
    # Denoising loop
    timesteps = tqdm(scheduler.timesteps, desc="Generating") if show_progress else scheduler.timesteps
    
    for t in timesteps:
        latent_model_input = torch.cat([latents] * 2)
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)
        
        noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
        
        # CFG
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        
        latents = scheduler.step(noise_pred, t, latents).prev_sample
    
    # Denormalize from [-1, 1] to [0, 1]
    images = (latents / 2 + 0.5).clamp(0, 1)
    
    return images


def tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
    """Convert tensor (C, H, W) in [0, 1] to PIL Image."""
    array = (tensor.cpu().numpy() * 255).astype(np.uint8)
    array = np.transpose(array, (1, 2, 0))  # CHW -> HWC
    return Image.fromarray(array)


print("Generation functions defined")

## 4. Generate All Art Styles

Generate one sample for each of the 27 art styles.

In [None]:
# Generate one image per style
guidance_scale = 7.5
generated_images = {}

print(f"Generating images for all {len(WIKIART_STYLES)} art styles (guidance_scale={guidance_scale})...\n")

for style_idx, style_name in enumerate(WIKIART_STYLES):
    style_display = style_name.replace('_', ' ')
    prompt = EXPERIMENT_3_CONFIG["prompt_template"].format(style_name=style_display)
    
    images = generate_images([prompt], guidance_scale=guidance_scale, show_progress=False)
    generated_images[style_name] = images[0]
    
    print(f"  [{style_idx+1:2d}/{len(WIKIART_STYLES)}] {style_display}")

print("\nGeneration complete!")

In [None]:
# Display all generated styles in a grid
n_cols = 6
n_rows = (len(WIKIART_STYLES) + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 4 * n_rows))
axes = axes.flatten()

for i, (style_name, img_tensor) in enumerate(generated_images.items()):
    img = img_tensor.permute(1, 2, 0).cpu().numpy()
    axes[i].imshow(img)
    axes[i].set_title(style_name.replace('_', '\n'), fontsize=8)
    axes[i].axis('off')

# Hide empty subplots
for i in range(len(WIKIART_STYLES), len(axes)):
    axes[i].axis('off')

plt.suptitle(f'Generated WikiArt Samples - All 27 Styles (guidance={guidance_scale})', fontsize=16)
plt.tight_layout()
plt.show()

## 5. Compare with Real WikiArt Images

In [None]:
# Load WikiArt dataset for comparison
print("Loading WikiArt dataset for comparison...")
wikiart_hf = load_dataset(
    "huggan/wikiart",
    split="train",
    cache_dir=str(DATASET_CACHE_DIR / "huggingface")
)

print(f"Loaded {len(wikiart_hf)} images")

In [None]:
# Get sample index for style column
sample = wikiart_hf[0]
style_column = 'style' if 'style' in sample else 'label'
print(f"Style column: {style_column}")

# Get one real image per style
real_images_per_style = {}

for item in wikiart_hf:
    style_idx = item[style_column]
    if style_idx < len(WIKIART_STYLES):
        style_name = WIKIART_STYLES[style_idx]
        if style_name not in real_images_per_style:
            real_images_per_style[style_name] = item['image']
    
    # Stop when we have all styles
    if len(real_images_per_style) == len(WIKIART_STYLES):
        break

print(f"Found real images for {len(real_images_per_style)} styles")

In [None]:
# Side-by-side comparison: Real vs Generated
selected_styles = [
    "Impressionism", "Baroque", "Cubism", "Pop_Art",
    "Renaissance" if "Renaissance" in WIKIART_STYLES else "High_Renaissance",
    "Expressionism"
]
# Filter to available styles
selected_styles = [s for s in selected_styles if s in generated_images]

fig, axes = plt.subplots(len(selected_styles), 2, figsize=(8, 4 * len(selected_styles)))

for i, style_name in enumerate(selected_styles):
    # Real image
    if style_name in real_images_per_style:
        real_img = real_images_per_style[style_name]
        if real_img.mode != 'RGB':
            real_img = real_img.convert('RGB')
        real_img_resized = real_img.resize((128, 128))
        axes[i, 0].imshow(real_img_resized)
    else:
        axes[i, 0].text(0.5, 0.5, 'N/A', ha='center', va='center')
    axes[i, 0].set_title(f'Real: {style_name.replace("_", " ")}')
    axes[i, 0].axis('off')
    
    # Generated image
    gen_img = generated_images[style_name].permute(1, 2, 0).cpu().numpy()
    axes[i, 1].imshow(gen_img)
    axes[i, 1].set_title(f'Generated: {style_name.replace("_", " ")}')
    axes[i, 1].axis('off')

plt.suptitle('Real vs Generated WikiArt Images', fontsize=14)
plt.tight_layout()
plt.show()

## 6. Guidance Scale Ablation

Explore the effect of different guidance scales on generation quality.

In [None]:
# Test different guidance scales
test_styles = ["Impressionism", "Pop_Art", "Baroque"]
guidance_scales = [0, 2, 5, 7.5, 10, 15, 20]

fig, axes = plt.subplots(len(test_styles), len(guidance_scales), 
                          figsize=(3 * len(guidance_scales), 3 * len(test_styles)))

for row, style_name in enumerate(test_styles):
    style_display = style_name.replace('_', ' ')
    prompt = EXPERIMENT_3_CONFIG["prompt_template"].format(style_name=style_display)
    
    for col, guidance_scale in enumerate(guidance_scales):
        images = generate_images([prompt], guidance_scale=guidance_scale, show_progress=False)
        img = images[0].permute(1, 2, 0).cpu().numpy()
        
        axes[row, col].imshow(img)
        if row == 0:
            axes[row, col].set_title(f'w={guidance_scale}')
        if col == 0:
            axes[row, col].set_ylabel(style_display, fontsize=10)
        axes[row, col].set_xticks([])
        axes[row, col].set_yticks([])

plt.suptitle('Effect of Guidance Scale on WikiArt Generation', fontsize=14)
plt.tight_layout()
plt.show()

## 7. Multiple Samples Per Style

Show diversity in generation for the same style.

In [None]:
# Generate multiple samples for a few styles
diversity_styles = ["Impressionism", "Cubism", "Pop_Art"]
num_samples = 5
guidance_scale = 7.5

fig, axes = plt.subplots(len(diversity_styles), num_samples, 
                          figsize=(3 * num_samples, 3 * len(diversity_styles)))

for row, style_name in enumerate(diversity_styles):
    style_display = style_name.replace('_', ' ')
    prompt = EXPERIMENT_3_CONFIG["prompt_template"].format(style_name=style_display)
    
    for col in range(num_samples):
        images = generate_images([prompt], guidance_scale=guidance_scale, show_progress=False)
        img = images[0].permute(1, 2, 0).cpu().numpy()
        
        axes[row, col].imshow(img)
        axes[row, col].axis('off')
        if col == 0:
            axes[row, col].set_title(style_display, fontsize=10, loc='left')

plt.suptitle(f'Generation Diversity (guidance={guidance_scale})', fontsize=14)
plt.tight_layout()
plt.show()

## 8. Inference Steps Ablation

Effect of number of denoising steps on quality.

In [None]:
# Test different numbers of inference steps
test_style = "Impressionism"
inference_steps_list = [10, 20, 30, 50, 75, 100]
guidance_scale = 7.5

style_display = test_style.replace('_', ' ')
prompt = EXPERIMENT_3_CONFIG["prompt_template"].format(style_name=style_display)

fig, axes = plt.subplots(1, len(inference_steps_list), figsize=(3 * len(inference_steps_list), 3))

for i, num_steps in enumerate(inference_steps_list):
    images = generate_images([prompt], guidance_scale=guidance_scale, 
                             num_inference_steps=num_steps, show_progress=False)
    img = images[0].permute(1, 2, 0).cpu().numpy()
    
    axes[i].imshow(img)
    axes[i].set_title(f'{num_steps} steps')
    axes[i].axis('off')

plt.suptitle(f'Effect of Inference Steps on "{style_display}" (guidance={guidance_scale})', fontsize=14)
plt.tight_layout()
plt.show()

## 9. Custom Prompt Variations

Test the model with different prompt formulations.

In [None]:
# Test different prompt formulations
base_style = "Impressionism"
prompt_variations = [
    f"A painting in the style of {base_style}",
    f"An {base_style} painting",
    f"A landscape in {base_style} style",
    f"A portrait in the {base_style} movement",
    f"{base_style} artwork",
    f"Beautiful {base_style} painting",
]

guidance_scale = 7.5

fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()

for i, prompt in enumerate(prompt_variations):
    images = generate_images([prompt], guidance_scale=guidance_scale, show_progress=False)
    img = images[0].permute(1, 2, 0).cpu().numpy()
    
    axes[i].imshow(img)
    axes[i].set_title(f'"{prompt}"', fontsize=8, wrap=True)
    axes[i].axis('off')

plt.suptitle('Prompt Variations for Same Style', fontsize=14)
plt.tight_layout()
plt.show()

## 10. Unconditional Generation

Generate images without any text conditioning (guidance_scale = 0).

In [None]:
# Unconditional generation (no text guidance)
num_unconditional = 6

fig, axes = plt.subplots(1, num_unconditional, figsize=(3 * num_unconditional, 3))

for i in range(num_unconditional):
    # Empty prompt with guidance_scale = 0 (pure unconditional)
    images = generate_images(["A painting"], guidance_scale=0, show_progress=False)
    img = images[0].permute(1, 2, 0).cpu().numpy()
    
    axes[i].imshow(img)
    axes[i].set_title(f'Sample {i+1}')
    axes[i].axis('off')

plt.suptitle('Unconditional WikiArt Generation (guidance_scale=0)', fontsize=14)
plt.tight_layout()
plt.show()

## Summary

This notebook explored the trained WikiArt text-to-image model:

**Demonstrations:**
1. Generated samples for all 27 art styles
2. Compared generated images with real WikiArt images
3. Showed effect of guidance scale (0 to 20)
4. Demonstrated generation diversity for same style
5. Tested different numbers of inference steps
6. Explored prompt variations
7. Showed unconditional generation

**Next steps:**
- `generate_images.ipynb` - Bulk generation for evaluation
- `train2_train_wikiart_classifier.ipynb` - Train art style classifier
- `metrics1_evaluate_wikiart.ipynb` - Compute FID and classification metrics