# Metrics 3: FID Evaluation using pytorch-fid

This notebook computes Fréchet Inception Distance (FID) scores using the standard pytorch-fid library.

**Prerequisites:**
- Run `inference1_t2i_mnist_cfg.ipynb` to pre-generate images for each guidance scale
- Images should be saved in `outputs/experiment_1/generated/guidance_X/digit_Y/`
- Real MNIST images should be in `outputs/experiment_1/dataset/digit_X/`

**Note on FID for MNIST:**
The Inception-V3 based FID was designed for ImageNet (299×299 RGB). Using it for MNIST (28×28 grayscale)
may not give meaningful results because:
1. MNIST images must be upscaled ~10x and converted to RGB
2. Inception features are tuned for natural images, not handwritten digits

However, this provides a standardized comparison baseline.

In [18]:
# Setup paths and imports
from pathlib import Path
import sys

PROJECT_ROOT = Path("/home/doshlom4/work/final_project")
sys.path.insert(0, str(PROJECT_ROOT))

print(f"Project root: {PROJECT_ROOT}")

Project root: /home/doshlom4/work/final_project


In [19]:
# Load configuration
from config import (
    EXPERIMENT_1_CONFIG,
    OUTPUTS_DIR,
    EXPERIMENT_1_DIR,
    EXPERIMENT_1_DATASET_DIR,
    EXPERIMENT_1_METRICS_DIR,
    get_generated_images_dir,
    get_digit_dir,
)

# Unpack experiment configuration
GUIDANCE_SCALES = EXPERIMENT_1_CONFIG["guidance_scales"]
IMAGES_PER_DIGIT = EXPERIMENT_1_CONFIG["images_per_digit"]
DIGITS = EXPERIMENT_1_CONFIG["digits"]
EXPERIMENT_OUTPUT_DIR = EXPERIMENT_1_DIR
DATASET_DIR = EXPERIMENT_1_DATASET_DIR

print("Configuration loaded:")
print(f"  Guidance scales: {GUIDANCE_SCALES}")
print(f"  Images per digit: {IMAGES_PER_DIGIT}")
print(f"  Dataset directory: {DATASET_DIR}")
print(f"  Experiment output: {EXPERIMENT_OUTPUT_DIR}")

Configuration loaded:
  Guidance scales: [0, 5, 10, 15, 20, 30, 40, 50, 100]
  Images per digit: 100
  Dataset directory: /home/doshlom4/work/final_project/outputs/experiment_1/dataset
  Experiment output: /home/doshlom4/work/final_project/outputs/experiment_1


In [20]:
# Import standard libraries
import os
import subprocess
import json
from datetime import datetime
from typing import Dict, List, Tuple

import numpy as np
import matplotlib.pyplot as plt

from utils.image_utils import count_images_in_folder

print("Libraries imported successfully")

Libraries imported successfully


## 1. Verify Pre-generated Images

Check that all required images have been generated before computing FID.

In [21]:
def check_image_availability() -> Dict:
    """
    Check which guidance scales have complete image sets.
    
    Returns:
        Dict with status for each guidance scale
    """
    status = {}
    
    # Check real images
    real_count = 0
    for digit in DIGITS:
        digit_dir = DATASET_DIR / f"digit_{digit}"
        real_count += count_images_in_folder(digit_dir)
    
    expected_real = IMAGES_PER_DIGIT * len(DIGITS)
    print(f"Real MNIST images: {real_count}/{expected_real}")
    
    if real_count < expected_real:
        print("  ⚠️  Missing real images! Run inference notebook first.")
    else:
        print("  ✓ Complete")
    
    print()
    print("Generated images by guidance scale:")
    
    for guidance_scale in GUIDANCE_SCALES:
        guidance_dir = get_generated_images_dir(guidance_scale)
        gen_count = 0
        
        for digit in DIGITS:
            digit_dir = get_digit_dir(guidance_dir, digit)
            gen_count += count_images_in_folder(digit_dir)
        
        expected = IMAGES_PER_DIGIT * len(DIGITS)
        is_complete = gen_count >= expected
        
        status[guidance_scale] = {
            "count": gen_count,
            "expected": expected,
            "complete": is_complete,
        }
        
        symbol = "✓" if is_complete else "⚠️"
        print(f"  w={guidance_scale:3d}: {gen_count:4d}/{expected} {symbol}")
    
    return status

availability = check_image_availability()

Real MNIST images: 1000/1000
  ✓ Complete

Generated images by guidance scale:
  w=  0:    0/1000 ⚠️
  w=  5:    0/1000 ⚠️
  w= 10:    0/1000 ⚠️
  w= 15:    0/1000 ⚠️
  w= 20:    0/1000 ⚠️
  w= 30:    0/1000 ⚠️
  w= 40:    0/1000 ⚠️
  w= 50:    0/1000 ⚠️
  w=100:    0/1000 ⚠️


## 2. Compute FID using pytorch-fid

The pytorch-fid library computes FID between two folders of images.
We'll compute FID for each guidance scale vs the real MNIST test images.

In [22]:
def compute_fid_pytorch(real_dir: Path, generated_dir: Path, device: str = "cuda") -> float:
    """
    Compute FID between two directories using pytorch-fid.
    
    Args:
        real_dir: Directory containing real images
        generated_dir: Directory containing generated images
        device: Device to use (cuda or cpu)
    
    Returns:
        FID score
    """
    cmd = [
        "python", "-m", "pytorch_fid",
        str(real_dir),
        str(generated_dir),
        "--device", device,
    ]
    
    result = subprocess.run(cmd, capture_output=True, text=True)
    
    if result.returncode != 0:
        print(f"Error: {result.stderr}")
        return float('nan')
    
    # Parse FID from output
    # Output format: "FID:  123.456"
    output = result.stdout.strip()
    try:
        fid_value = float(output.split()[-1])
        return fid_value
    except (ValueError, IndexError):
        print(f"Could not parse FID from output: {output}")
        return float('nan')

print("FID computation function defined")

FID computation function defined


In [23]:
# Test that pytorch-fid is available
try:
    result = subprocess.run(["python", "-m", "pytorch_fid", "--help"], 
                          capture_output=True, text=True)
    if result.returncode == 0:
        print("✓ pytorch-fid is available")
    else:
        print("⚠️  pytorch-fid not found. Install with: pip install pytorch-fid")
except Exception as e:
    print(f"Error checking pytorch-fid: {e}")

⚠️  pytorch-fid not found. Install with: pip install pytorch-fid


In [24]:
def compute_all_fid_scores(guidance_scales: List[int] = None) -> Dict[int, float]:
    """
    Compute FID for all guidance scales.
    
    Args:
        guidance_scales: List of guidance scales to compute. If None, use all.
    
    Returns:
        Dict mapping guidance scale to FID score
    """
    if guidance_scales is None:
        guidance_scales = GUIDANCE_SCALES
    
    fid_scores = {}
    
    print(f"Computing FID for {len(guidance_scales)} guidance scales...")
    print(f"Real images directory: {DATASET_DIR}")
    print()
    
    for i, guidance_scale in enumerate(guidance_scales):
        generated_dir = get_generated_images_dir(guidance_scale)
        
        # Check if complete
        if guidance_scale in availability and not availability[guidance_scale]["complete"]:
            print(f"[{i+1}/{len(guidance_scales)}] w={guidance_scale}: Skipping (incomplete)")
            fid_scores[guidance_scale] = float('nan')
            continue
        
        print(f"[{i+1}/{len(guidance_scales)}] w={guidance_scale}: Computing FID...")
        
        fid = compute_fid_pytorch(DATASET_DIR, generated_dir)
        fid_scores[guidance_scale] = fid
        
        print(f"  FID = {fid:.2f}")
    
    return fid_scores

print("FID computation function defined")

FID computation function defined


In [25]:
# Compute FID for all available guidance scales
# This may take a few minutes

fid_results = compute_all_fid_scores()

Computing FID for 9 guidance scales...
Real images directory: /home/doshlom4/work/final_project/outputs/experiment_1/dataset

[1/9] w=0: Skipping (incomplete)
[2/9] w=5: Skipping (incomplete)
[3/9] w=10: Skipping (incomplete)
[4/9] w=15: Skipping (incomplete)
[5/9] w=20: Skipping (incomplete)
[6/9] w=30: Skipping (incomplete)
[7/9] w=40: Skipping (incomplete)
[8/9] w=50: Skipping (incomplete)
[9/9] w=100: Skipping (incomplete)


## 3. Visualize Results

In [26]:
def plot_fid_vs_guidance(fid_scores: Dict[int, float], title: str = "FID vs Guidance Scale (pytorch-fid)"):
    """
    Plot FID scores against guidance scale.
    
    Args:
        fid_scores: Dict mapping guidance scale to FID
        title: Plot title
    """
    # Filter out NaN values
    valid_scales = [k for k, v in fid_scores.items() if not np.isnan(v)]
    valid_fids = [fid_scores[k] for k in valid_scales]
    
    if not valid_scales:
        print("No valid FID scores to plot")
        return
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    ax.plot(valid_scales, valid_fids, 'b-o', linewidth=2, markersize=8)
    
    # Mark the minimum
    min_idx = np.argmin(valid_fids)
    min_scale = valid_scales[min_idx]
    min_fid = valid_fids[min_idx]
    ax.scatter([min_scale], [min_fid], color='red', s=200, zorder=5, 
               label=f'Best: w={min_scale}, FID={min_fid:.2f}')
    
    ax.set_xlabel('Guidance Scale (w)', fontsize=12)
    ax.set_ylabel('FID Score (lower is better)', fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save figure
    output_path = EXPERIMENT_OUTPUT_DIR / "fid_vs_guidance_pytorch_fid.png"
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    print(f"Saved plot to: {output_path}")
    
    plt.show()

plot_fid_vs_guidance(fid_results)

No valid FID scores to plot


In [27]:
def print_fid_summary(fid_scores: Dict[int, float]):
    """
    Print a summary table of FID scores.
    """
    print("\n" + "="*50)
    print("FID Score Summary (pytorch-fid with Inception-V3)")
    print("="*50)
    print(f"{'Guidance Scale':<15} {'FID Score':<15} {'Rank':<10}")
    print("-"*50)
    
    # Sort by FID
    valid_items = [(k, v) for k, v in fid_scores.items() if not np.isnan(v)]
    sorted_items = sorted(valid_items, key=lambda x: x[1])
    
    for rank, (scale, fid) in enumerate(sorted_items, 1):
        marker = " ← Best" if rank == 1 else ""
        print(f"w={scale:<12} {fid:<15.2f} {rank:<10}{marker}")
    
    print("="*50)
    
    if sorted_items:
        best_scale, best_fid = sorted_items[0]
        print(f"\nBest guidance scale: w={best_scale} with FID={best_fid:.2f}")

print_fid_summary(fid_results)


FID Score Summary (pytorch-fid with Inception-V3)
Guidance Scale  FID Score       Rank      
--------------------------------------------------


## 4. Save Results

In [28]:
def save_results(fid_scores: Dict[int, float]):
    """
    Save FID results to JSON and text files.
    """
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Prepare results dict
    results = {
        "timestamp": timestamp,
        "method": "pytorch-fid (Inception-V3)",
        "images_per_digit": IMAGES_PER_DIGIT,
        "total_images_per_guidance": IMAGES_PER_DIGIT * len(DIGITS),
        "fid_scores": {str(k): v for k, v in fid_scores.items() if not np.isnan(v)},
    }
    
    # Find best
    valid_items = [(k, v) for k, v in fid_scores.items() if not np.isnan(v)]
    if valid_items:
        best_scale, best_fid = min(valid_items, key=lambda x: x[1])
        results["best_guidance_scale"] = best_scale
        results["best_fid"] = best_fid
    
    # Save JSON
    json_path = EXPERIMENT_OUTPUT_DIR / "fid_results_pytorch_fid.json"
    with open(json_path, 'w') as f:
        json.dump(results, f, indent=2)
    print(f"Saved JSON results to: {json_path}")
    
    # Save text report
    txt_path = EXPERIMENT_OUTPUT_DIR / "fid_results_pytorch_fid.txt"
    with open(txt_path, 'w') as f:
        f.write("FID Evaluation Report (pytorch-fid)\n")
        f.write("="*50 + "\n")
        f.write(f"Timestamp: {timestamp}\n")
        f.write(f"Method: pytorch-fid (Inception-V3)\n")
        f.write(f"Images per digit: {IMAGES_PER_DIGIT}\n")
        f.write(f"Total images per guidance scale: {IMAGES_PER_DIGIT * len(DIGITS)}\n")
        f.write("\n")
        f.write("Results:\n")
        f.write("-"*50 + "\n")
        
        for scale, fid in sorted(fid_scores.items()):
            if not np.isnan(fid):
                f.write(f"  w={scale:<3}: FID = {fid:.2f}\n")
        
        if valid_items:
            f.write("\n")
            f.write(f"Best: w={best_scale} with FID={best_fid:.2f}\n")
    
    print(f"Saved text report to: {txt_path}")
    
    return results

saved_results = save_results(fid_results)

Saved JSON results to: /home/doshlom4/work/final_project/outputs/experiment_1/fid_results_pytorch_fid.json
Saved text report to: /home/doshlom4/work/final_project/outputs/experiment_1/fid_results_pytorch_fid.txt


## 5. Compare with TorchMetrics FID (Optional)

If you have results from `metrics2_evaluate_t2i_mnist.ipynb`, you can compare them here.

In [29]:
# Load TorchMetrics results if available
torchmetrics_results_path = OUTPUTS_DIR / "mnist_evaluation_report.txt"

if torchmetrics_results_path.exists():
    print("TorchMetrics results found:")
    with open(torchmetrics_results_path, 'r') as f:
        print(f.read())
else:
    print(f"No TorchMetrics results at: {torchmetrics_results_path}")
    print("Run metrics2_evaluate_t2i_mnist.ipynb to generate them.")

TorchMetrics results found:
MNIST Text-to-Image Diffusion Model Evaluation Report
Number of samples: 100
Real-vs-Real Baseline FID: 111.93

----------------------------------------------------------------------
MNIST-FID Results (lower is better)
----------------------------------------------------------------------
Guidance Scale w=  0: FID = 2340.00
Guidance Scale w=  1: FID = 2287.32
Guidance Scale w=  5: FID = 2189.91
Guidance Scale w= 10: FID = 1842.67
Guidance Scale w= 20: FID = 1261.67

----------------------------------------------------------------------
Conditional Accuracy Results (higher is better)
----------------------------------------------------------------------
Guidance Scale w=  0: Accuracy =   8.00%
  Per-digit accuracy:
    Digit 0:   0.00%
    Digit 1:   0.00%
    Digit 2:  10.00%
    Digit 3:  20.00%
    Digit 4:   0.00%
    Digit 5:   0.00%
    Digit 6:   0.00%
    Digit 7:   0.00%
    Digit 8:  10.00%
    Digit 9:  40.00%

Guidance Scale w=  1: Accuracy =  70.

## Summary

This notebook computed FID scores using the standard pytorch-fid library with Inception-V3 features.

**Important Notes:**
1. Inception-V3 FID is the standard metric for image generation quality
2. However, it was designed for ImageNet (299×299 RGB natural images)
3. For MNIST (28×28 grayscale), the metric may not be optimal
4. Consider using MNIST-specific metrics (e.g., classifier-based) for more meaningful comparisons

**Expected Behavior:**
- FID should generally decrease as guidance scale increases (up to a point)
- Too high guidance scale may lead to mode collapse and increased FID
- The optimal guidance scale is typically between 5-15 for CFG