# HSIVI Model Evaluation on Colored MNIST

This notebook evaluates a trained HSIVI model using FID (Fr√©chet Inception Distance) as the main metric.


In [None]:
import os
import sys
import math
from pathlib import Path

import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision.utils import save_image, make_grid
from torchmetrics.image.fid import FrechetInceptionDistance
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

# Add project root to path
sys.path.insert(0, '.')

from hsivi_train.config import HSIVIConfig
from hsivi_train.hsivi_trainer import HSIVITrainer
from utils.dataset_h5 import H5ImagesDataset

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")


## Configuration


In [None]:
# Paths - Update these to match your setup
CHECKPOINT_PATH = "./work_dir/hsivi_colored_mnist/checkpoints/latest.pt"  # Path to trained HSIVI checkpoint
CONFIG_PATH = "./work_dir/hsivi_colored_mnist/config.json"  # Path to config (optional)
DATA_DIR = "./data"  # Path to colored MNIST data

# Evaluation settings
NUM_FID_SAMPLES = 10000  # Number of samples for FID calculation
BATCH_SIZE = 64  # Batch size for generation and FID computation
SEED = 42  # Random seed for reproducibility

# Set random seed
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)


## Load HSIVI Model


In [None]:
# Load configuration
if os.path.exists(CONFIG_PATH):
    config = HSIVIConfig.load(CONFIG_PATH)
    print(f"Loaded config from {CONFIG_PATH}")
else:
    print("Config not found, using default configuration")
    config = HSIVIConfig()

print(f"\nConfiguration:")
print(f"  Image size: {config.image_size}x{config.image_size}")
print(f"  Channels: {config.channels}")
print(f"  Discrete steps (NFE+1): {config.n_discrete_steps}")
print(f"  Phi base dim: {config.phi_base_dim}")
print(f"  F base dim: {config.f_base_dim}")


In [None]:
# Create trainer and load checkpoint
trainer = HSIVITrainer(
    config=config,
    pretrained_epsilon=None,  # Not needed for sampling
    device=device
)

# Load checkpoint
if os.path.exists(CHECKPOINT_PATH):
    trainer.load_checkpoint(CHECKPOINT_PATH)
    print(f"\nLoaded checkpoint from {CHECKPOINT_PATH}")
    print(f"Checkpoint trained for {trainer.step} steps")
else:
    raise FileNotFoundError(f"Checkpoint not found at {CHECKPOINT_PATH}")


## Load Real Data for FID Reference


In [None]:
# Load real dataset
dataset = H5ImagesDataset(DATA_DIR)
print(f"Dataset size: {len(dataset)} images")

dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)


## Generate Samples


In [None]:
# Generate samples from HSIVI model
@torch.no_grad()
def generate_samples(trainer, num_samples, batch_size=64):
    """Generate samples from trained HSIVI model."""
    trainer.phi_net.eval()
    
    all_samples = []
    num_batches = math.ceil(num_samples / batch_size)
    
    pbar = tqdm(range(num_batches), desc="Generating samples")
    for i in pbar:
        curr_batch_size = min(batch_size, num_samples - len(all_samples))
        samples = trainer.sample(batch_size=curr_batch_size)
        all_samples.append(samples.cpu())
    
    all_samples = torch.cat(all_samples, dim=0)[:num_samples]
    trainer.phi_net.train()
    
    return all_samples

print(f"Generating {NUM_FID_SAMPLES} samples...")
generated_samples = generate_samples(trainer, NUM_FID_SAMPLES, BATCH_SIZE)
print(f"Generated {len(generated_samples)} samples")
print(f"Sample shape: {generated_samples.shape}")
print(f"Sample range: [{generated_samples.min():.3f}, {generated_samples.max():.3f}]")


## Visualize Generated Samples


In [None]:
# Visualize some generated samples
def show_samples(samples, title="Generated Samples", nrow=8):
    """Display a grid of samples."""
    n_samples = min(64, len(samples))
    grid = make_grid(samples[:n_samples], nrow=nrow, padding=2, normalize=False)
    
    plt.figure(figsize=(12, 12))
    plt.imshow(grid.permute(1, 2, 0).numpy())
    plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

show_samples(generated_samples, f"HSIVI Generated Samples (NFE={config.n_discrete_steps - 1})")


In [None]:
# Visualize some real samples for comparison
real_batch = next(iter(dataloader))
show_samples(real_batch, "Real Colored MNIST Samples")


## Compute FID Score


In [None]:
def compute_fid(
    real_dataloader,
    fake_samples,
    num_real_samples=10000,
    device='cuda'
):
    """
    Compute FID between real and generated samples.
    
    Args:
        real_dataloader: DataLoader for real images
        fake_samples: Generated samples tensor [N, C, H, W] in range [0, 1]
        num_real_samples: Number of real samples to use
        device: Device to use
    
    Returns:
        FID score
    """
    # Initialize FID metric
    fid = FrechetInceptionDistance(feature=2048, normalize=True).to(device)
    
    # Process real images
    print("Processing real images...")
    num_processed = 0
    pbar = tqdm(real_dataloader, desc="Real images")
    
    for batch in pbar:
        if isinstance(batch, (list, tuple)):
            batch = batch[0]
        
        # Ensure proper format for FID (uint8, 0-255)
        batch = batch.to(device)
        if batch.max() <= 1.0:
            batch = (batch * 255).clamp(0, 255).to(torch.uint8)
        else:
            batch = batch.clamp(0, 255).to(torch.uint8)
        
        fid.update(batch, real=True)
        num_processed += batch.shape[0]
        
        if num_processed >= num_real_samples:
            break
    
    print(f"Processed {num_processed} real images")
    
    # Process fake images
    print("Processing generated images...")
    fake_samples = fake_samples.to(device)
    
    # Convert to uint8
    if fake_samples.max() <= 1.0:
        fake_samples_uint8 = (fake_samples * 255).clamp(0, 255).to(torch.uint8)
    else:
        fake_samples_uint8 = fake_samples.clamp(0, 255).to(torch.uint8)
    
    # Process in batches
    batch_size = 64
    for i in tqdm(range(0, len(fake_samples_uint8), batch_size), desc="Generated images"):
        batch = fake_samples_uint8[i:i+batch_size]
        fid.update(batch, real=False)
    
    print(f"Processed {len(fake_samples)} generated images")
    
    # Compute FID
    fid_score = fid.compute().item()
    
    return fid_score


In [None]:
# Compute FID
print(f"\nComputing FID with {NUM_FID_SAMPLES} samples...")
print("="*50)

fid_score = compute_fid(
    real_dataloader=dataloader,
    fake_samples=generated_samples,
    num_real_samples=NUM_FID_SAMPLES,
    device=device
)

print("="*50)
print(f"\nüéØ FID Score: {fid_score:.2f}")
print(f"   Number of function evaluations (NFE): {config.n_discrete_steps - 1}")


## Compare with Baseline DDPM (Optional)

If you have a trained DDPM baseline, compare sampling quality at different NFE.


In [None]:
# Optional: Load and compare with baseline DDPM model
DDPM_CHECKPOINT = "./ckpts/model-5.pt"  # Update path

compare_with_baseline = os.path.exists(DDPM_CHECKPOINT)

if compare_with_baseline:
    from diffusion.ddpm import Unet, GaussianDiffusion
    from ema_pytorch import EMA
    
    # Load DDPM model
    ddpm_model = Unet(
        dim=64,
        dim_mults=(1, 2, 4),
        channels=config.channels,
        flash_attn=False
    )
    
    ddpm_diffusion = GaussianDiffusion(
        ddpm_model,
        image_size=config.image_size,
        timesteps=1000,
        sampling_timesteps=config.n_discrete_steps - 1  # Same NFE as HSIVI
    )
    
    # Load checkpoint
    ckpt = torch.load(DDPM_CHECKPOINT, map_location=device, weights_only=True)
    
    # Load into EMA
    ema = EMA(ddpm_diffusion, beta=0.995, update_every=10)
    ema.load_state_dict(ckpt['ema'])
    ema_model = ema.ema_model.to(device)
    ema_model.eval()
    
    print(f"Loaded DDPM baseline from {DDPM_CHECKPOINT}")
else:
    print(f"DDPM checkpoint not found at {DDPM_CHECKPOINT}")
    print("Skipping baseline comparison")


In [None]:
if compare_with_baseline:
    # Generate DDPM samples with same NFE
    print(f"\nGenerating {NUM_FID_SAMPLES} DDPM samples with NFE={config.n_discrete_steps - 1}...")
    
    @torch.no_grad()
    def generate_ddpm_samples(model, num_samples, batch_size=64):
        model.eval()
        all_samples = []
        num_batches = math.ceil(num_samples / batch_size)
        
        for i in tqdm(range(num_batches), desc="Generating DDPM samples"):
            curr_batch_size = min(batch_size, num_samples - len(all_samples))
            samples = model.sample(batch_size=curr_batch_size)
            all_samples.append(samples.cpu())
        
        return torch.cat(all_samples, dim=0)[:num_samples]
    
    ddpm_samples = generate_ddpm_samples(ema_model, NUM_FID_SAMPLES, BATCH_SIZE)
    print(f"Generated {len(ddpm_samples)} DDPM samples")
    
    # Visualize DDPM samples
    show_samples(ddpm_samples, f"DDPM (DDIM) Samples (NFE={config.n_discrete_steps - 1})")


In [None]:
if compare_with_baseline:
    # Compute DDPM FID
    print(f"\nComputing DDPM FID...")
    
    ddpm_fid_score = compute_fid(
        real_dataloader=dataloader,
        fake_samples=ddpm_samples,
        num_real_samples=NUM_FID_SAMPLES,
        device=device
    )
    
    print(f"\n" + "="*50)
    print(f"COMPARISON RESULTS (NFE = {config.n_discrete_steps - 1})")
    print("="*50)
    print(f"  HSIVI FID: {fid_score:.2f}")
    print(f"  DDPM (DDIM) FID: {ddpm_fid_score:.2f}")
    print("="*50)
    
    if fid_score < ddpm_fid_score:
        print(f"  ‚úÖ HSIVI is better by {ddpm_fid_score - fid_score:.2f} FID points")
    else:
        print(f"  DDPM is better by {fid_score - ddpm_fid_score:.2f} FID points")


## Save Results


In [None]:
# Save generated samples
output_dir = Path("./hsivi_evaluation")
output_dir.mkdir(exist_ok=True)

# Save sample grid
nrow = 8
grid = make_grid(generated_samples[:64], nrow=nrow, padding=2)
save_image(grid, output_dir / "hsivi_samples_grid.png")
print(f"Saved sample grid to {output_dir / 'hsivi_samples_grid.png'}")

# Save individual samples
samples_dir = output_dir / "samples"
samples_dir.mkdir(exist_ok=True)
for i, sample in enumerate(generated_samples[:100]):
    save_image(sample, samples_dir / f"sample_{i:04d}.png")
print(f"Saved 100 individual samples to {samples_dir}")


In [None]:
# Save evaluation results
import json

results = {
    "model": "HSIVI",
    "checkpoint": CHECKPOINT_PATH,
    "n_discrete_steps": config.n_discrete_steps,
    "nfe": config.n_discrete_steps - 1,
    "num_samples": NUM_FID_SAMPLES,
    "fid_score": fid_score,
    "config": {
        "image_size": config.image_size,
        "channels": config.channels,
        "phi_base_dim": config.phi_base_dim,
        "f_base_dim": config.f_base_dim,
        "skip_type": config.skip_type,
        "image_gamma": config.image_gamma,
        "independent_log_gamma": config.independent_log_gamma,
    }
}

if compare_with_baseline:
    results["ddpm_fid_score"] = ddpm_fid_score

with open(output_dir / "evaluation_results.json", "w") as f:
    json.dump(results, f, indent=2)

print(f"\nSaved evaluation results to {output_dir / 'evaluation_results.json'}")


## Summary


In [None]:
print("\n" + "="*60)
print("EVALUATION SUMMARY")
print("="*60)
print(f"Model: HSIVI")
print(f"Checkpoint: {CHECKPOINT_PATH}")
print(f"Training steps: {trainer.step}")
print(f"Number of Function Evaluations (NFE): {config.n_discrete_steps - 1}")
print(f"Number of samples for FID: {NUM_FID_SAMPLES}")
print(f"\nüéØ FID Score: {fid_score:.2f}")
if compare_with_baseline:
    print(f"üìä DDPM (DDIM) FID at same NFE: {ddpm_fid_score:.2f}")
print("="*60)
