# Alignment Model Evaluation and Explainability

This notebook provides comprehensive evaluation and explainability analysis for aligned vision-text models.

## Features:

### 1. Retrieval Evaluation
- Image-to-text and text-to-image retrieval
- Comprehensive metrics: R@K, mAP, NDCG, rank statistics
- Comparison between MLP and Perceiver architectures

### 2. Matryoshka Representation Learning (MRL)
- Performance across different embedding dimensions
- Efficiency vs. accuracy trade-offs
- MRL curves and analysis

### 3. Explainability Analysis
- Embedding space visualization (PCA, t-SNE)
- Dimension importance analysis
- Modality separation metrics
- Similarity distribution analysis

### 4. Benchmarking
- Standardized evaluation protocols
- Comparable metrics across models
- Performance reports

**Hardware Requirements:**
- 1 GPU (H200 recommended)
- ~20-30GB GPU memory
- ~32GB RAM for large-scale visualization

## 1. Setup and Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path

# Add src to path
sys.path.insert(0, str(Path.cwd().parent / "src"))
Path.cwd().parent / "src"

PosixPath('/storage/ice1/1/0/vchopra37/projects/edge_glass/edge_glass_modular/src')

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm
import wandb
from datetime import datetime
import json

In [4]:
# Import our modular evaluation components
from config import load_config
from models import MultimodalAlignmentModel
from data.dataset_builder import build_image_datasets_from_parquet
from data.transforms import get_image_transforms

# Evaluation modules
from evaluation import (
    RetrievalMetrics,
    compute_retrieval_metrics,
    compute_mrl_performance,
    AlignmentBenchmark,
    ExplainabilityAnalyzer,
)

# Visualization
from utils.visualization import TrainingVisualizer



In [5]:
# Set up matplotlib
%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print(f"Using device: {device}")

PyTorch version: 2.9.0+cu128
CUDA available: True
GPU: NVIDIA H200
GPU Memory: 150.11 GB
Using device: cuda


## 2. Configuration

In [6]:
# Load evaluation configuration
eval_config_path = "../configs/pixmo_alignment.yaml"
eval_config = load_config(eval_config_path)

print("="*60)
print("EVALUATION CONFIGURATION")
print("="*60)
print(f"\nExperiment: {eval_config.name}")
print(f"\nDataset:")
print(f"  Test parquet: {eval_config.dataset.test_parquet}")
print(f"  Batch size: {eval_config.dataset.batch_size}")
print(f"  Max samples: {eval_config.dataset.max_samples}")
print(f"\nMetrics:")
print(f"  Recall@K: {eval_config.metrics.recall_at_k}")
print(f"  MRL enabled: {eval_config.mrl.enabled}")
print(f"  MRL dimensions: {eval_config.mrl.dimensions}")
print(f"\nExplainability:")
print(f"  Enabled: {eval_config.explainability.enabled}")
print(f"  Embedding viz method: {eval_config.explainability.embedding_viz.method}")
print("="*60)

EVALUATION CONFIGURATION

Experiment: pixmo_vision_text_alignment

Dataset:
  Test parquet: /home/hice1/vchopra37/scratch/projects/edge_glass/dataset/final_dataset/pixmo/pixmo_test.parquet
  Batch size: 128


AttributeError: 'DatasetConfig' object has no attribute 'max_samples'

## 3. Select Model to Evaluate

You can evaluate either the MLP-based or Perceiver-based alignment model.

In [None]:
# Choose which model to evaluate
MODEL_TYPE = "perceiver_mrl"  # Options: "pixmo_mlp" or "perceiver_mrl"

if hasattr(eval_config, 'checkpoints') and MODEL_TYPE in eval_config.checkpoints:
    checkpoint_info = eval_config.checkpoints[MODEL_TYPE]
    model_config_path = checkpoint_info.config
    checkpoint_path = checkpoint_info.path
else:
    print(f"⚠️  No checkpoint info for {MODEL_TYPE} in config. Using defaults/discovery.")
    model_config_path = eval_config_path
    
    # Try different possible checkpoint locations
    possible_dirs = [
        Path("checkpoints") / "pixmo_alignment",
        Path("checkpoints") / "perceiver_mrl_alignment",
        Path("outputs") / "pixmo_alignment",
    ]
    # Add explicit output dir from config if present
    if hasattr(eval_config, 'trainer') and eval_config.trainer.output_dir:
         possible_dirs.insert(0, Path(eval_config.trainer.output_dir))
    if hasattr(eval_config, 'trainer') and eval_config.trainer.ckpt_dir:
         possible_dirs.insert(0, Path(eval_config.trainer.ckpt_dir))
    
    checkpoint_path = None
    for ckpt_dir in possible_dirs:
        if ckpt_dir.exists():
            pt_files = list(ckpt_dir.glob("*.pt"))
            if pt_files:
                # Prefer 'best' in name, or recent ones
                best = next((f for f in pt_files if 'best' in f.name), pt_files[0])
                checkpoint_path = str(best)
                print(f"  Found checkpoint in {ckpt_dir}")
                break
    
    if checkpoint_path is None:
        checkpoint_path = "checkpoints/pixmo_alignment/best_model.pt"
        print(f"  Could not find checkpoint file, defaulting to: {checkpoint_path}")

print(f"Evaluating: {MODEL_TYPE}")
print(f"Config: {model_config_path}")
print(f"Checkpoint: {checkpoint_path}")


## 4. Load Model

In [None]:
# Load model configuration
model_config = load_config(model_config_path)

print("\nModel Configuration:")
print(f"  Vision Encoder: {model_config.vision_encoder.model_name}")
print(f"  Projection Dim: {model_config.vision_encoder.projection_dim}")
print(f"  MRL Dimensions: {model_config.vision_encoder.mrl_dimensions}")

if hasattr(model_config.vision_encoder, 'perceiver_num_latents'):
    print(f"  Perceiver Latents: {model_config.vision_encoder.perceiver_num_latents}")
    print(f"  Perceiver Layers: {model_config.vision_encoder.perceiver_num_layers}")

In [None]:
# Create and load model
print("\nLoading model...")
model = MultimodalAlignmentModel(model_config).to(device)

# Load checkpoint
if Path(checkpoint_path).exists():
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    print(f"✓ Loaded checkpoint from {checkpoint_path}")
    
    if 'epoch' in checkpoint:
        print(f"  Epoch: {checkpoint['epoch']}")
    if 'best_val_loss' in checkpoint:
        print(f"  Best val loss: {checkpoint['best_val_loss']:.4f}")
else:
    print(f"⚠️  Checkpoint not found: {checkpoint_path}")
    print("   Using randomly initialized model")

model.eval()

# Print parameter counts
print(f"\nModel Architecture:")
model.print_parameter_counts()

## 5. Load Evaluation Dataset

In [None]:
# Load evaluation configuration
eval_config_path = "../configs/pixmo_alignment.yaml"
eval_config = load_config(eval_config_path)

print("="*60)
print("EVALUATION CONFIGURATION")
print("="*60)
print(f"\nExperiment: {eval_config.name}")
print(f"\nDataset:")
print(f"  Test parquet: {eval_config.dataset.test_parquet}")
print(f"  Batch size: {eval_config.dataset.batch_size}")
# Safe access
max_samples = getattr(eval_config.dataset, 'max_samples', None)
print(f"  Max samples: {max_samples}")

print(f"\nMetrics:")
# Default metrics
metrics_cfg = getattr(eval_config, 'metrics', None)
recall_at_k = getattr(metrics_cfg, 'recall_at_k', [1, 5, 10]) if metrics_cfg else [1, 5, 10]
print(f"  Recall@K: {recall_at_k}")

# MRL
use_mrl = False
mrl_dims = []
if hasattr(eval_config, 'mrl'):
    use_mrl = eval_config.mrl.enabled
    mrl_dims = eval_config.mrl.dimensions
elif hasattr(eval_config, 'vision_encoder') and hasattr(eval_config.vision_encoder, 'use_mrl'):
    use_mrl = eval_config.vision_encoder.use_mrl
    mrl_dims = getattr(eval_config.vision_encoder, 'mrl_dimensions', [])

print(f"  MRL enabled: {use_mrl}")
print(f"  MRL dimensions: {mrl_dims}")

print(f"\nExplainability:")
explain_cfg = getattr(eval_config, 'explainability', None)
explain_enabled = getattr(explain_cfg, 'enabled', False) if explain_cfg else False
viz_method = "pca"
if explain_enabled and hasattr(explain_cfg, 'embedding_viz'):
    viz_method = getattr(explain_cfg.embedding_viz, 'method', 'pca')

print(f"  Enabled: {explain_enabled}")
print(f"  Embedding viz method: {viz_method}")
print("="*60)


In [None]:
# Create dataloader
test_loader = DataLoader(
    test_dataset,
    batch_size=eval_config.dataset.batch_size,
    shuffle=False,
    num_workers=eval_config.dataset.num_workers,
    pin_memory=eval_config.dataset.pin_memory,
)

print(f"Test batches: {len(test_loader)}")

## 6. Initialize Evaluation Tools

In [None]:
# Initialize benchmark
benchmark = AlignmentBenchmark(
    model=model,
    device=device,
    mrl_dims=eval_config.mrl.dimensions if eval_config.mrl.enabled else None,
)

# Initialize explainability analyzer
explainability = ExplainabilityAnalyzer(
    model=model,
    device=device,
)

# Initialize visualizer
output_dir = Path(eval_config.visualization.output_dir) / MODEL_TYPE
output_dir.mkdir(parents=True, exist_ok=True)

visualizer = TrainingVisualizer(
    save_dir=output_dir,
    style=eval_config.visualization.style,
)

print(f"Output directory: {output_dir}")

## 7. Initialize Weights & Biases (Optional)

In [None]:
if eval_config.logging.use_wandb:
    wandb.init(
        project=eval_config.logging.wandb_project,
        entity=eval_config.logging.wandb_entity,
        name=f"{eval_config.logging.run_name}_{MODEL_TYPE}",
        config={
            "model_type": MODEL_TYPE,
            "checkpoint": checkpoint_path,
            "dataset": eval_config.dataset.test_parquet,
            "batch_size": eval_config.dataset.batch_size,
        },
    )
    print("✓ Weights & Biases initialized")
else:
    print("⊘ Weights & Biases logging disabled")

## 8. Run Full Evaluation

In [None]:
# Run comprehensive evaluation
results = benchmark.run_full_evaluation(
    dataloader=test_loader,
    max_batches=None,  # Use full dataset
)

## 9. Detailed Retrieval Analysis

In [None]:
# Extract metrics
i2t_metrics = results['retrieval']['i2t']
t2i_metrics = results['retrieval']['t2i']

print("="*60)
print("RETRIEVAL PERFORMANCE SUMMARY")
print("="*60)

print("\nImage-to-Text Retrieval:")
print(f"  R@1:   {i2t_metrics.r_at_1:.2f}%")
print(f"  R@5:   {i2t_metrics.r_at_5:.2f}%")
print(f"  R@10:  {i2t_metrics.r_at_10:.2f}%")
print(f"  R@50:  {i2t_metrics.r_at_50:.2f}%")
print(f"  Mean Rank: {i2t_metrics.mean_rank:.2f}")
print(f"  Median Rank: {i2t_metrics.median_rank:.0f}")
print(f"  mAP@10: {i2t_metrics.map_at_10:.4f}")
print(f"  NDCG@10: {i2t_metrics.ndcg_at_10:.4f}")

print("\nText-to-Image Retrieval:")
print(f"  R@1:   {t2i_metrics.r_at_1:.2f}%")
print(f"  R@5:   {t2i_metrics.r_at_5:.2f}%")
print(f"  R@10:  {t2i_metrics.r_at_10:.2f}%")
print(f"  R@50:  {t2i_metrics.r_at_50:.2f}%")
print(f"  Mean Rank: {t2i_metrics.mean_rank:.2f}")
print(f"  Median Rank: {t2i_metrics.median_rank:.0f}")
print(f"  mAP@10: {t2i_metrics.map_at_10:.4f}")
print(f"  NDCG@10: {t2i_metrics.ndcg_at_10:.4f}")

print("="*60)

In [None]:
# Log to wandb
if eval_config.logging.use_wandb:
    wandb.log({
        "i2t/R@1": i2t_metrics.r_at_1,
        "i2t/R@5": i2t_metrics.r_at_5,
        "i2t/R@10": i2t_metrics.r_at_10,
        "i2t/mean_rank": i2t_metrics.mean_rank,
        "t2i/R@1": t2i_metrics.r_at_1,
        "t2i/R@5": t2i_metrics.r_at_5,
        "t2i/R@10": t2i_metrics.r_at_10,
        "t2i/mean_rank": t2i_metrics.mean_rank,
    })

## 10. MRL Performance Analysis

In [None]:
if results['mrl']:
    print("="*60)
    print("MRL PERFORMANCE ANALYSIS")
    print("="*60)
    
    print("\nPerformance vs. Dimension (Image→Text):")
    print(f"{'Dim':>6} | {'R@1':>7} | {'R@5':>7} | {'R@10':>7} | {'Mean Rank':>10}")
    print("-" * 50)
    
    for dim in sorted(results['mrl'].keys()):
        metrics = results['mrl'][dim]['i2t']
        print(f"{dim:6d} | {metrics.r_at_1:6.2f}% | {metrics.r_at_5:6.2f}% | {metrics.r_at_10:6.2f}% | {metrics.mean_rank:10.2f}")
    
    print("="*60)
else:
    print("MRL evaluation not enabled")

## 11. Visualization: Rank Analysis

In [None]:
# Plot rank histograms
visualizer.plot_rank_histogram(
    ranks=i2t_metrics.ranks,
    title="Image→Text Retrieval Ranks",
    save_name="i2t_rank_histogram.png",
)

visualizer.plot_rank_histogram(
    ranks=t2i_metrics.ranks,
    title="Text→Image Retrieval Ranks",
    save_name="t2i_rank_histogram.png",
)

print(f"✓ Rank histograms saved to {output_dir}")

In [None]:
# Plot rank CDFs
visualizer.plot_rank_cdf(
    ranks=i2t_metrics.ranks,
    title="Image→Text Rank CDF",
    save_name="i2t_rank_cdf.png",
)

visualizer.plot_rank_cdf(
    ranks=t2i_metrics.ranks,
    title="Text→Image Rank CDF",
    save_name="t2i_rank_cdf.png",
)

print(f"✓ Rank CDFs saved to {output_dir}")

In [None]:
# Display one of the plots
from IPython.display import Image as IPImage, display
display(IPImage(filename=str(output_dir / 'i2t_rank_histogram.png')))

## 12. Visualization: Similarity Analysis

In [None]:
# Get embeddings
vision_embs = results['embeddings']['vision']
text_embs = results['embeddings']['text']

# Compute similarity distributions
vision_norm = F.normalize(vision_embs, p=2, dim=-1)
text_norm = F.normalize(text_embs, p=2, dim=-1)

sims = torch.matmul(vision_norm, text_norm.t())
N = sims.size(0)

# Positive pairs (diagonal)
pos_sims = sims.diag().cpu().numpy()

# Negative pairs (off-diagonal, sample subset)
mask = torch.eye(N, dtype=torch.bool, device=sims.device)
neg_sims_all = sims[~mask].cpu().numpy()
neg_sims = np.random.choice(neg_sims_all, size=min(10000, len(neg_sims_all)), replace=False)

# Plot
visualizer.plot_similarity_distributions(
    positive_sims=pos_sims,
    negative_sims=neg_sims,
    title="Positive vs. Negative Pair Similarities",
    save_name="similarity_distributions.png",
)

print(f"✓ Similarity distributions saved to {output_dir}")

# Display
display(IPImage(filename=str(output_dir / 'similarity_distributions.png')))

In [None]:
# Plot similarity matrix (subset)
visualizer.plot_similarity_matrix(
    vision_embs=vision_embs.numpy(),
    text_embs=text_embs.numpy(),
    n_samples=50,
    save_name="similarity_matrix.png",
)

print(f"✓ Similarity matrix saved to {output_dir}")
display(IPImage(filename=str(output_dir / 'similarity_matrix.png')))

## 13. Visualization: MRL Curves

In [None]:
if results['mrl']:
    # Plot MRL performance curves for different metrics
    for metric in ['r_at_1', 'r_at_5', 'r_at_10']:
        visualizer.plot_mrl_curves(
            mrl_results=results['mrl'],
            metric=metric,
            title=f"MRL Performance: {metric.upper()}",
            save_name=f"mrl_{metric}.png",
        )
    
    print(f"✓ MRL curves saved to {output_dir}")
    
    # Display R@1 curve
    display(IPImage(filename=str(output_dir / 'mrl_r_at_1.png')))
else:
    print("MRL evaluation not enabled")

## 14. Visualization: Embedding Space

In [None]:
# Visualize embedding space
visualizer.plot_embedding_space(
    vision_embs=vision_embs.numpy(),
    text_embs=text_embs.numpy(),
    method=eval_config.explainability.embedding_viz.method,
    n_samples=eval_config.explainability.embedding_viz.n_samples,
    save_name="embedding_space.png",
)

print(f"✓ Embedding space visualization saved to {output_dir}")
display(IPImage(filename=str(output_dir / 'embedding_space.png')))

## 15. Explainability Analysis

In [None]:
if eval_config.explainability.enabled:
    # Generate comprehensive explainability report
    explainability_report = explainability.generate_explainability_report(
        vision_embs=vision_embs,
        text_embs=text_embs,
        save_dir=output_dir,
    )
else:
    print("Explainability analysis not enabled")

## 16. Retrieval Examples

In [None]:
# Show some retrieval examples
num_examples = 5

# Get a batch for visualization
sample_batch = next(iter(test_loader))
images_sample = sample_batch['image'][:num_examples]
texts_sample = sample_batch['text'][:num_examples]

# Get embeddings for this batch
with torch.no_grad():
    outputs = model(images=images_sample.to(device), texts=texts_sample, return_embeddings=True)
    query_embs = outputs.vision_emb

print("="*60)
print("IMAGE→TEXT RETRIEVAL EXAMPLES")
print("="*60)

for i in range(min(num_examples, len(texts_sample))):
    if not texts_sample[i]:  # Skip dropped texts
        continue
        
    print(f"\nExample {i+1}:")
    print(f"Ground Truth: {texts_sample[i][:100]}...")
    
    # Retrieve top-5 matches
    indices, scores = explainability.compute_retrieval_attention(
        query_emb=query_embs[i],
        key_embs=text_embs,
        top_k=5,
    )
    
    print("Top-5 Retrieved:")
    for rank, (idx, score) in enumerate(zip(indices, scores), 1):
        # Get text from dataset (note: might need to handle indexing)
        print(f"  {rank}. [{score:.3f}] Sample {idx}")

print("="*60)

## 17. Save Results

In [None]:
# Save evaluation results
results_dir = Path(eval_config.output.save_dir) / MODEL_TYPE
results_dir.mkdir(parents=True, exist_ok=True)

benchmark.save_results(results, results_dir)

print(f"\n✓ Results saved to {results_dir}")

## 18. Generate Summary Report

In [None]:
# Create summary report
summary = {
    'model_type': MODEL_TYPE,
    'checkpoint': str(checkpoint_path),
    'dataset': str(eval_config.dataset.test_parquet),
    'num_samples': len(test_dataset),
    'embedding_dim': vision_embs.shape[1],
    'retrieval': {
        'i2t': i2t_metrics.to_dict(),
        't2i': t2i_metrics.to_dict(),
    },
    'similarity': results['similarity'],
    'timestamp': datetime.now().isoformat(),
}

# Add MRL summary if available
if results['mrl']:
    summary['mrl'] = {
        dim: {
            'i2t': metrics['i2t'].to_dict(),
            't2i': metrics['t2i'].to_dict(),
        }
        for dim, metrics in results['mrl'].items()
    }

# Save summary
with open(results_dir / 'summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print(f"✓ Summary report saved to {results_dir / 'summary.json'}")

# Print summary
print("\n" + "="*60)
print("EVALUATION SUMMARY")
print("="*60)
print(f"\nModel: {MODEL_TYPE}")
print(f"Samples: {len(test_dataset)}")
print(f"Embedding dim: {vision_embs.shape[1]}")
print(f"\nBest I2T R@1: {i2t_metrics.r_at_1:.2f}%")
print(f"Best T2I R@1: {t2i_metrics.r_at_1:.2f}%")
print(f"\nResults saved to: {results_dir}")
print(f"Visualizations saved to: {output_dir}")
print("="*60)

## 19. Compare Models (Optional)

If you've evaluated both models, you can compare them here.

In [None]:
# Load results from both models for comparison
# This cell is optional and requires running the evaluation for both models

def load_model_results(model_type):
    """Load saved results for a model."""
    results_path = Path(eval_config.output.save_dir) / model_type / 'summary.json'
    if results_path.exists():
        with open(results_path, 'r') as f:
            return json.load(f)
    return None

# Try to load both model results
mlp_results = load_model_results('pixmo_mlp')
perceiver_results = load_model_results('perceiver_mrl')

if mlp_results and perceiver_results:
    print("="*60)
    print("MODEL COMPARISON")
    print("="*60)
    print(f"\n{'Metric':<20} | {'MLP':>10} | {'Perceiver':>10}")
    print("-" * 45)
    
    metrics_to_compare = ['R@1', 'R@5', 'R@10', 'mean_rank']
    
    for metric in metrics_to_compare:
        mlp_val = mlp_results['retrieval']['i2t'][metric]
        perc_val = perceiver_results['retrieval']['i2t'][metric]
        print(f"{metric:<20} | {mlp_val:10.2f} | {perc_val:10.2f}")
    
    print("="*60)
else:
    print("Run evaluation for both models to enable comparison")

## 20. Finish

In [None]:
# Finish wandb run
if eval_config.logging.use_wandb:
    wandb.finish()

print("\n✓ Evaluation complete!")
print(f"\nNext steps:")
print(f"1. Review visualizations in {output_dir}")
print(f"2. Check detailed results in {results_dir}")
print(f"3. Compare with other models or baselines")
print(f"4. Use insights to improve model architecture or training")