# Experiment 3: Evaluate WikiArt Generation Quality

This notebook evaluates the quality of generated WikiArt images using:

1. **FID (Fréchet Inception Distance)** - Measures similarity between real and generated image distributions
2. **Classification Accuracy** - How well generated images match the prompted art style

**Evaluation Structure:**
- Compare generated images at different guidance scales
- Compute FID per guidance scale
- Compute per-style classification accuracy
- Generate evaluation report

## 1. Setup and Configuration

In [None]:
# Project configuration - use absolute paths
from pathlib import Path
import sys

PROJECT_ROOT = Path("/home/doshlom4/work/final_project")
sys.path.insert(0, str(PROJECT_ROOT))

print(f"Project root: {PROJECT_ROOT}")

In [None]:
# Import configuration
from config import (
    EXPERIMENT_3_CONFIG,
    WIKIART_STYLES,
    EXPERIMENT_3_DIR,
    EXPERIMENT_3_DATASET_DIR,
    EXPERIMENT_3_GENERATED_DIR,
    EXPERIMENT_3_METRICS_DIR,
    get_wikiart_generated_images_dir,
    get_style_dir,
)

from wikiart_classifier import (
    WikiArtClassifier,
    load_wikiart_classifier,
    get_wikiart_classifier_checkpoint_path,
    WikiArtDataset,
    get_wikiart_transforms,
)

# Deep learning frameworks
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

# Standard libraries
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import json
import subprocess
import os

print("Libraries imported successfully")

In [None]:
# Evaluation configuration
GUIDANCE_SCALES = EXPERIMENT_3_CONFIG["guidance_scales"]
IMAGES_PER_STYLE = EXPERIMENT_3_CONFIG["images_per_class"]

print("Evaluation Configuration:")
print(f"  Guidance scales: {GUIDANCE_SCALES}")
print(f"  Images per style: {IMAGES_PER_STYLE}")
print(f"  Number of styles: {len(WIKIART_STYLES)}")
print(f"\nDirectories:")
print(f"  Dataset: {EXPERIMENT_3_DATASET_DIR}")
print(f"  Generated: {EXPERIMENT_3_GENERATED_DIR}")
print(f"  Metrics: {EXPERIMENT_3_METRICS_DIR}")

## 2. Verify Generated Images

In [None]:
# Count images in each directory
print("Image counts per guidance scale:\n")

image_counts = {}

for guidance_scale in GUIDANCE_SCALES:
    guidance_dir = get_wikiart_generated_images_dir(guidance_scale)
    total = 0
    style_counts = {}
    
    for style_idx in range(len(WIKIART_STYLES)):
        style_dir = get_style_dir(guidance_dir, style_idx)
        if style_dir.exists():
            count = len(list(style_dir.glob("*.png")))
        else:
            count = 0
        style_counts[style_idx] = count
        total += count
    
    image_counts[guidance_scale] = {
        'total': total,
        'per_style': style_counts
    }
    
    expected = IMAGES_PER_STYLE * len(WIKIART_STYLES)
    status = "✓" if total == expected else f"✗ (expected {expected})"
    print(f"Guidance {guidance_scale:3d}: {total:5d} images {status}")

# Count real images
real_total = 0
for style_idx in range(len(WIKIART_STYLES)):
    style_dir = get_style_dir(EXPERIMENT_3_DATASET_DIR, style_idx)
    if style_dir.exists():
        real_total += len(list(style_dir.glob("*.png")))

expected_real = IMAGES_PER_STYLE * len(WIKIART_STYLES)
status = "✓" if real_total == expected_real else f"✗ (expected {expected_real})"
print(f"\nReal images: {real_total} {status}")

## 3. Compute FID Scores

Using pytorch-fid to compute FID between real and generated images.

In [None]:
# Prepare flat directories for FID computation
# pytorch-fid expects flat directories with all images

def create_flat_directory(source_base_dir: Path, flat_dir: Path, num_styles: int = 27):
    """
    Create a flat directory with symlinks to all images for FID computation.
    """
    flat_dir.mkdir(parents=True, exist_ok=True)
    
    # Clear existing symlinks
    for f in flat_dir.glob("*"):
        if f.is_symlink() or f.is_file():
            f.unlink()
    
    # Create symlinks for all images
    count = 0
    for style_idx in range(num_styles):
        style_dir = get_style_dir(source_base_dir, style_idx)
        if style_dir.exists():
            for img_path in style_dir.glob("*.png"):
                link_path = flat_dir / f"style{style_idx:02d}_{img_path.name}"
                if not link_path.exists():
                    link_path.symlink_to(img_path)
                count += 1
    
    return count

print("Flat directory function defined")

In [None]:
# Create flat directory for real images
real_flat_dir = EXPERIMENT_3_METRICS_DIR / "real_flat"
real_count = create_flat_directory(EXPERIMENT_3_DATASET_DIR, real_flat_dir)
print(f"Created flat directory for real images: {real_count} images")

In [None]:
# Compute FID for each guidance scale
fid_results = {}

print("Computing FID scores...\n")

for guidance_scale in GUIDANCE_SCALES:
    # Create flat directory for generated images
    gen_flat_dir = EXPERIMENT_3_METRICS_DIR / f"gen_flat_guidance_{guidance_scale}"
    guidance_dir = get_wikiart_generated_images_dir(guidance_scale)
    
    gen_count = create_flat_directory(guidance_dir, gen_flat_dir)
    
    if gen_count == 0:
        print(f"Guidance {guidance_scale}: No images found, skipping")
        continue
    
    # Run pytorch-fid
    try:
        result = subprocess.run(
            [
                "python", "-m", "pytorch_fid",
                str(real_flat_dir),
                str(gen_flat_dir),
                "--device", "cuda" if torch.cuda.is_available() else "cpu"
            ],
            capture_output=True,
            text=True,
            timeout=600  # 10 minute timeout
        )
        
        # Parse FID from output
        output = result.stdout + result.stderr
        
        # Extract FID value (format: "FID: XX.XX")
        import re
        fid_match = re.search(r'FID:\s*([\d.]+)', output)
        if fid_match:
            fid_value = float(fid_match.group(1))
        else:
            # Try alternative format
            fid_match = re.search(r'([\d.]+)\s*$', output.strip())
            if fid_match:
                fid_value = float(fid_match.group(1))
            else:
                print(f"Could not parse FID from output: {output}")
                fid_value = None
        
        fid_results[guidance_scale] = fid_value
        
        if fid_value is not None:
            print(f"Guidance {guidance_scale:3d}: FID = {fid_value:.2f}")
        else:
            print(f"Guidance {guidance_scale:3d}: FID computation failed")
            
    except subprocess.TimeoutExpired:
        print(f"Guidance {guidance_scale}: Timeout")
        fid_results[guidance_scale] = None
    except Exception as e:
        print(f"Guidance {guidance_scale}: Error - {e}")
        fid_results[guidance_scale] = None

print("\nFID computation complete!")

In [None]:
# Plot FID vs Guidance Scale
valid_guidance = [g for g in GUIDANCE_SCALES if fid_results.get(g) is not None]
valid_fid = [fid_results[g] for g in valid_guidance]

if valid_fid:
    plt.figure(figsize=(10, 6))
    plt.plot(valid_guidance, valid_fid, 'bo-', linewidth=2, markersize=8)
    plt.xlabel('Guidance Scale', fontsize=12)
    plt.ylabel('FID (lower is better)', fontsize=12)
    plt.title('WikiArt FID vs Guidance Scale', fontsize=14)
    plt.grid(True, alpha=0.3)
    
    # Mark best FID
    best_idx = np.argmin(valid_fid)
    plt.scatter([valid_guidance[best_idx]], [valid_fid[best_idx]], 
                color='green', s=200, zorder=5, label=f'Best: w={valid_guidance[best_idx]}')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(EXPERIMENT_3_METRICS_DIR / "fid_vs_guidance.png", dpi=150)
    plt.show()
    
    print(f"\nBest FID: {min(valid_fid):.2f} at guidance scale {valid_guidance[best_idx]}")
else:
    print("No valid FID results to plot")

## 4. Classification Accuracy

Evaluate how well generated images match the prompted art style.

In [None]:
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Load WikiArt classifier
classifier_path = get_wikiart_classifier_checkpoint_path()

if classifier_path.exists():
    classifier, checkpoint = load_wikiart_classifier(device)
    classifier.eval()
    print(f"✓ Loaded WikiArt classifier")
else:
    print(f"✗ Classifier not found at {classifier_path}")
    print("  Please run train2_train_wikiart_classifier.ipynb first")
    classifier = None

In [None]:
import torchvision.transforms as transforms

# Get transforms for classifier
_, test_transform = get_wikiart_transforms(image_size=128)

print("Transforms loaded")

In [None]:
def compute_classification_accuracy(classifier, gen_dir, device, transform):
    """
    Compute classification accuracy for generated images.
    
    Returns:
        overall_acc: Overall accuracy
        per_style_acc: Dict of style_name -> accuracy
    """
    if classifier is None:
        return None, None
    
    classifier.eval()
    
    correct_per_style = {i: 0 for i in range(len(WIKIART_STYLES))}
    total_per_style = {i: 0 for i in range(len(WIKIART_STYLES))}
    
    for style_idx in range(len(WIKIART_STYLES)):
        style_dir = get_style_dir(gen_dir, style_idx)
        
        if not style_dir.exists():
            continue
        
        for img_path in style_dir.glob("*.png"):
            # Load and transform image
            image = Image.open(img_path).convert('RGB')
            image_tensor = transform(image).unsqueeze(0).to(device)
            
            # Predict
            with torch.no_grad():
                output = classifier(image_tensor)
                _, pred = output.max(1)
            
            total_per_style[style_idx] += 1
            if pred.item() == style_idx:
                correct_per_style[style_idx] += 1
    
    # Compute accuracies
    overall_correct = sum(correct_per_style.values())
    overall_total = sum(total_per_style.values())
    overall_acc = 100.0 * overall_correct / overall_total if overall_total > 0 else 0
    
    per_style_acc = {}
    for style_idx in range(len(WIKIART_STYLES)):
        if total_per_style[style_idx] > 0:
            acc = 100.0 * correct_per_style[style_idx] / total_per_style[style_idx]
        else:
            acc = 0
        per_style_acc[WIKIART_STYLES[style_idx]] = acc
    
    return overall_acc, per_style_acc

print("Classification accuracy function defined")

In [None]:
# Compute classification accuracy for each guidance scale
accuracy_results = {}

if classifier is not None:
    print("Computing classification accuracy...\n")
    
    for guidance_scale in tqdm(GUIDANCE_SCALES, desc="Guidance scales"):
        guidance_dir = get_wikiart_generated_images_dir(guidance_scale)
        
        if not guidance_dir.exists():
            print(f"Guidance {guidance_scale}: Directory not found")
            continue
        
        overall_acc, per_style_acc = compute_classification_accuracy(
            classifier, guidance_dir, device, test_transform
        )
        
        accuracy_results[guidance_scale] = {
            'overall': overall_acc,
            'per_style': per_style_acc
        }
        
        print(f"Guidance {guidance_scale:3d}: Accuracy = {overall_acc:.2f}%")
    
    print("\nClassification accuracy computation complete!")
else:
    print("Skipping classification accuracy (classifier not loaded)")

In [None]:
# Plot Accuracy vs Guidance Scale
if accuracy_results:
    guidance_list = sorted(accuracy_results.keys())
    accuracy_list = [accuracy_results[g]['overall'] for g in guidance_list]
    
    plt.figure(figsize=(10, 6))
    plt.plot(guidance_list, accuracy_list, 'ro-', linewidth=2, markersize=8)
    plt.xlabel('Guidance Scale', fontsize=12)
    plt.ylabel('Classification Accuracy (%)', fontsize=12)
    plt.title('WikiArt Classification Accuracy vs Guidance Scale', fontsize=14)
    plt.grid(True, alpha=0.3)
    
    # Mark best accuracy
    best_idx = np.argmax(accuracy_list)
    plt.scatter([guidance_list[best_idx]], [accuracy_list[best_idx]], 
                color='green', s=200, zorder=5, label=f'Best: w={guidance_list[best_idx]}')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(EXPERIMENT_3_METRICS_DIR / "accuracy_vs_guidance.png", dpi=150)
    plt.show()
    
    print(f"\nBest accuracy: {max(accuracy_list):.2f}% at guidance scale {guidance_list[best_idx]}")

## 5. Combined Metrics Visualization

In [None]:
# Plot FID and Accuracy together
if fid_results and accuracy_results:
    fig, ax1 = plt.subplots(figsize=(12, 6))
    
    # Common guidance scales
    common_guidance = [g for g in GUIDANCE_SCALES 
                       if g in fid_results and fid_results[g] is not None 
                       and g in accuracy_results]
    
    if common_guidance:
        fid_vals = [fid_results[g] for g in common_guidance]
        acc_vals = [accuracy_results[g]['overall'] for g in common_guidance]
        
        # FID on left axis
        color1 = 'tab:blue'
        ax1.set_xlabel('Guidance Scale', fontsize=12)
        ax1.set_ylabel('FID (lower is better)', color=color1, fontsize=12)
        ax1.plot(common_guidance, fid_vals, 'o-', color=color1, linewidth=2, markersize=8, label='FID')
        ax1.tick_params(axis='y', labelcolor=color1)
        
        # Accuracy on right axis
        ax2 = ax1.twinx()
        color2 = 'tab:red'
        ax2.set_ylabel('Classification Accuracy (%)', color=color2, fontsize=12)
        ax2.plot(common_guidance, acc_vals, 's-', color=color2, linewidth=2, markersize=8, label='Accuracy')
        ax2.tick_params(axis='y', labelcolor=color2)
        
        plt.title('WikiArt: FID and Classification Accuracy vs Guidance Scale', fontsize=14)
        
        # Combined legend
        lines1, labels1 = ax1.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
        
        plt.tight_layout()
        plt.savefig(EXPERIMENT_3_METRICS_DIR / "combined_metrics.png", dpi=150)
        plt.show()

## 6. Per-Style Analysis

In [None]:
# Analyze per-style accuracy at best guidance scale
if accuracy_results:
    # Find best guidance scale by accuracy
    best_guidance = max(accuracy_results.keys(), 
                        key=lambda g: accuracy_results[g]['overall'])
    
    per_style_acc = accuracy_results[best_guidance]['per_style']
    
    # Plot
    plt.figure(figsize=(14, 6))
    style_names = [s.replace('_', ' ')[:15] for s in WIKIART_STYLES]
    accuracies = [per_style_acc[s] for s in WIKIART_STYLES]
    
    colors = ['green' if acc >= 50 else 'orange' if acc >= 30 else 'red' for acc in accuracies]
    
    plt.bar(range(len(WIKIART_STYLES)), accuracies, color=colors)
    plt.xticks(range(len(WIKIART_STYLES)), style_names, rotation=45, ha='right', fontsize=8)
    plt.axhline(y=50, color='gray', linestyle='--', alpha=0.5, label='50%')
    plt.xlabel('Art Style')
    plt.ylabel('Classification Accuracy (%)')
    plt.title(f'WikiArt Per-Style Accuracy (Guidance Scale = {best_guidance})', fontsize=14)
    plt.legend()
    plt.tight_layout()
    plt.savefig(EXPERIMENT_3_METRICS_DIR / "per_style_accuracy.png", dpi=150)
    plt.show()
    
    # Print top and bottom styles
    sorted_styles = sorted(per_style_acc.items(), key=lambda x: x[1], reverse=True)
    
    print(f"\nBest performing styles (guidance={best_guidance}):")
    for style, acc in sorted_styles[:5]:
        print(f"  {style}: {acc:.1f}%")
    
    print(f"\nWorst performing styles:")
    for style, acc in sorted_styles[-5:]:
        print(f"  {style}: {acc:.1f}%")

## 7. Save Results

In [None]:
# Compile all results
evaluation_results = {
    "experiment": "WikiArt Text-to-Image Generation (Experiment 3)",
    "num_styles": len(WIKIART_STYLES),
    "images_per_style": IMAGES_PER_STYLE,
    "guidance_scales": GUIDANCE_SCALES,
    "fid_scores": {str(k): v for k, v in fid_results.items()},
    "classification_accuracy": {
        str(k): {
            "overall": v['overall'],
            "per_style": v['per_style']
        } for k, v in accuracy_results.items()
    } if accuracy_results else {},
}

# Find best results
if fid_results:
    valid_fid = {k: v for k, v in fid_results.items() if v is not None}
    if valid_fid:
        best_fid_guidance = min(valid_fid.keys(), key=lambda k: valid_fid[k])
        evaluation_results["best_fid"] = {
            "guidance_scale": best_fid_guidance,
            "fid": valid_fid[best_fid_guidance]
        }

if accuracy_results:
    best_acc_guidance = max(accuracy_results.keys(), 
                           key=lambda k: accuracy_results[k]['overall'])
    evaluation_results["best_accuracy"] = {
        "guidance_scale": best_acc_guidance,
        "accuracy": accuracy_results[best_acc_guidance]['overall']
    }

# Save to JSON
results_path = EXPERIMENT_3_METRICS_DIR / "evaluation_results.json"
with open(results_path, 'w') as f:
    json.dump(evaluation_results, f, indent=2)

print(f"Results saved to: {results_path}")

In [None]:
# Generate text report
report_lines = [
    "=" * 70,
    "WikiArt Text-to-Image Evaluation Report",
    "=" * 70,
    "",
    "Configuration:",
    f"  Number of art styles: {len(WIKIART_STYLES)}",
    f"  Images per style: {IMAGES_PER_STYLE}",
    f"  Guidance scales evaluated: {GUIDANCE_SCALES}",
    "",
    "FID Scores (lower is better):",
]

for g in GUIDANCE_SCALES:
    fid = fid_results.get(g)
    if fid is not None:
        report_lines.append(f"  Guidance {g:3d}: {fid:.2f}")
    else:
        report_lines.append(f"  Guidance {g:3d}: N/A")

if 'best_fid' in evaluation_results:
    report_lines.append(f"")
    report_lines.append(f"  Best FID: {evaluation_results['best_fid']['fid']:.2f} "
                       f"(guidance={evaluation_results['best_fid']['guidance_scale']})")

report_lines.append("")
report_lines.append("Classification Accuracy (higher is better):")

for g in GUIDANCE_SCALES:
    if g in accuracy_results:
        acc = accuracy_results[g]['overall']
        report_lines.append(f"  Guidance {g:3d}: {acc:.2f}%")

if 'best_accuracy' in evaluation_results:
    report_lines.append(f"")
    report_lines.append(f"  Best Accuracy: {evaluation_results['best_accuracy']['accuracy']:.2f}% "
                       f"(guidance={evaluation_results['best_accuracy']['guidance_scale']})")

report_lines.append("")
report_lines.append("=" * 70)

report_text = "\n".join(report_lines)
print(report_text)

# Save report
report_path = EXPERIMENT_3_METRICS_DIR / "evaluation_report.txt"
with open(report_path, 'w') as f:
    f.write(report_text)

print(f"\nReport saved to: {report_path}")

## 8. Visual Comparison

In [None]:
# Show sample images at different guidance scales
sample_styles = [0, 12, 19]  # Abstract_Expressionism, Impressionism, Pop_Art
sample_guidance = [0, 5, 10, 20]

fig, axes = plt.subplots(len(sample_styles), len(sample_guidance) + 1, 
                          figsize=(3 * (len(sample_guidance) + 1), 3 * len(sample_styles)))

for row, style_idx in enumerate(sample_styles):
    # Real image
    style_dir = get_style_dir(EXPERIMENT_3_DATASET_DIR, style_idx)
    if style_dir.exists():
        real_imgs = list(style_dir.glob("*.png"))
        if real_imgs:
            real_img = Image.open(real_imgs[0])
            axes[row, 0].imshow(real_img)
    axes[row, 0].set_title('Real' if row == 0 else '')
    axes[row, 0].set_ylabel(WIKIART_STYLES[style_idx].replace('_', '\n'), fontsize=8)
    axes[row, 0].set_xticks([])
    axes[row, 0].set_yticks([])
    
    # Generated images at different guidance scales
    for col, guidance in enumerate(sample_guidance):
        guidance_dir = get_wikiart_generated_images_dir(guidance)
        gen_style_dir = get_style_dir(guidance_dir, style_idx)
        
        if gen_style_dir.exists():
            gen_imgs = list(gen_style_dir.glob("*.png"))
            if gen_imgs:
                gen_img = Image.open(gen_imgs[0])
                axes[row, col + 1].imshow(gen_img)
        
        if row == 0:
            axes[row, col + 1].set_title(f'w={guidance}')
        axes[row, col + 1].set_xticks([])
        axes[row, col + 1].set_yticks([])

plt.suptitle('WikiArt: Real vs Generated at Different Guidance Scales', fontsize=14)
plt.tight_layout()
plt.savefig(EXPERIMENT_3_METRICS_DIR / "visual_comparison.png", dpi=150)
plt.show()

## Summary

This notebook evaluated WikiArt text-to-image generation quality:

**Metrics computed:**
1. FID (Fréchet Inception Distance) for each guidance scale
2. Classification accuracy (prompt adherence) for each guidance scale
3. Per-style accuracy breakdown

**Outputs saved:**
- `evaluation_results.json` - All numerical results
- `evaluation_report.txt` - Human-readable summary
- `fid_vs_guidance.png` - FID plot
- `accuracy_vs_guidance.png` - Accuracy plot
- `combined_metrics.png` - Combined FID and accuracy
- `per_style_accuracy.png` - Per-style breakdown
- `visual_comparison.png` - Sample image comparison

**Key findings:**
- Best FID at guidance scale: {best_fid_guidance}
- Best accuracy at guidance scale: {best_acc_guidance}
- Trade-off between image quality (FID) and prompt adherence (accuracy)