# üî¨ Comprehensive Multimodal Alignment Evaluation

**A thorough evaluation suite for multimodal alignment models**

This notebook evaluates your trained model against standard benchmarks and compares with published baselines.

---

## üìä Benchmarks Covered

| Benchmark | Task | Metrics | Baselines |
|-----------|------|---------|----------|
| **COCO** | Image-Text Retrieval | R@1, R@5, R@10 | CLIP, BLIP, ImageBind |
| **Flickr30K** | Image-Text Retrieval | R@1, R@5, R@10 | CLIP, BLIP |
| **ESC-50** | Audio Zero-Shot Classification | Accuracy | AudioCLIP, ImageBind, CLAP |
| **AudioCaps** | Audio-Text Retrieval | R@1, R@5, R@10 | CLAP, ImageBind |
| **LibriSpeech** | ASR (via LLM) | WER | Whisper |
| **Matryoshka** | Embedding Compression | R@1 vs Dim | - |

---

## üìö Reference Papers

1. **CLIP** (Radford et al., 2021) - Learning Transferable Visual Models From Natural Language Supervision
2. **ImageBind** (Girdhar et al., 2023) - One Embedding Space To Bind Them All
3. **BLIP** (Li et al., 2022) - Bootstrapping Language-Image Pre-training
4. **CLAP** (Elizalde et al., 2023) - CLAP: Learning Audio Concepts from Natural Language Supervision
5. **Matryoshka** (Kusupati et al., 2022) - Matryoshka Representation Learning

---

## 0. Setup & Imports

In [None]:
# ============================================================
# IMPORTS
# ============================================================

import os
import io
import math
import random
import json
import warnings
import requests
from pathlib import Path
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any, Union
from collections import defaultdict
from io import BytesIO

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

# Transformers & Datasets
from transformers import (
    AutoTokenizer,
    AutoModel,
    AutoModelForCausalLM,
    CLIPVisionModel,
    CLIPImageProcessor,
    CLIPModel,
    CLIPProcessor,
    WhisperModel,
    WhisperProcessor,
)
from datasets import load_dataset

# Audio processing
try:
    import librosa
    HAS_LIBROSA = True
except ImportError:
    HAS_LIBROSA = False
    print("‚ö†Ô∏è librosa not installed. Run: pip install librosa")

# Metrics
try:
    from torchmetrics.text import WordErrorRate
    HAS_WER = True
except ImportError:
    HAS_WER = False
    print("‚ö†Ô∏è torchmetrics not installed. Run: pip install torchmetrics")

try:
    from sklearn.manifold import TSNE
    from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
    HAS_SKLEARN = True
except ImportError:
    HAS_SKLEARN = False
    print("‚ö†Ô∏è sklearn not installed. Run: pip install scikit-learn")

# Suppress warnings
warnings.filterwarnings('ignore')

# Set style
sns.set_theme(style="whitegrid")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.titlesize'] = 14

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nüñ•Ô∏è  Device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 1. Configuration

In [None]:
# ============================================================
# EVALUATION CONFIGURATION
# ============================================================

@dataclass
class EvalConfig:
    """Configuration for comprehensive evaluation."""
    
    # === Model Paths ===
    checkpoint_path: str = "./checkpoints/multimodal_adapter_poc_1.pt"
    
    # === Encoder Models (must match training!) ===
    vision_model_name: str = "openai/clip-vit-base-patch32"
    text_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
    audio_model_name: str = "openai/whisper-base"
    llm_model_name: str = "Qwen/Qwen2.5-1.5B-Instruct"  # For generation tasks
    
    # === Architecture Dimensions ===
    d_vision: int = 768      # CLIP ViT-B/32 output
    d_text: int = 384        # MiniLM output
    d_audio: int = 512       # Whisper-base output
    d_align: int = 512       # Alignment embedding dimension
    
    # Perceiver (if used)
    use_perceiver: bool = True
    perceiver_dim: int = 512
    num_latents: int = 64
    num_perceiver_layers: int = 2
    num_attn_heads: int = 8
    
    # LLM projection
    llm_hidden_size: int = 1536  # Qwen2.5-1.5B hidden size
    
    # === Evaluation Settings ===
    batch_size: int = 32
    num_workers: int = 0
    use_fp16: bool = False
    
    # Sample sizes (set to None for full evaluation)
    coco_samples: int = 1000
    flickr_samples: int = 1000
    audiocaps_samples: int = 500
    esc50_samples: int = 400  # Full ESC-50 = 2000
    librispeech_samples: int = 100
    
    # === Matryoshka Evaluation ===
    mrl_dims: Tuple[int, ...] = (32, 64, 128, 256, 512)
    
    # === Output ===
    seed: int = 42
    save_results: bool = True
    results_dir: str = "./eval_results"
    

cfg = EvalConfig()

# Reproducibility
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(cfg.seed)

# Create results directory
if cfg.save_results:
    Path(cfg.results_dir).mkdir(parents=True, exist_ok=True)

print(f"\nüìã Configuration:")
print(f"   Checkpoint: {cfg.checkpoint_path}")
print(f"   d_align: {cfg.d_align}")
print(f"   Perceiver: {cfg.use_perceiver}")
print(f"   MRL dims: {cfg.mrl_dims}")

## 2. Published Baselines

Reference numbers from CLIP, ImageBind, BLIP, CLAP papers for comparison.

In [None]:
# ============================================================
# PUBLISHED BASELINE RESULTS
# ============================================================

BASELINES = {
    # === COCO Image-Text Retrieval (5K test set) ===
    # From original CLIP, BLIP, ImageBind papers
    "coco_retrieval": {
        "CLIP-ViT-B/32": {
            "T2I_R@1": 30.4, "T2I_R@5": 56.0, "T2I_R@10": 67.0,
            "I2T_R@1": 50.1, "I2T_R@5": 75.3, "I2T_R@10": 84.0,
            "source": "CLIP (Radford et al., 2021)"
        },
        "CLIP-ViT-L/14": {
            "T2I_R@1": 36.5, "T2I_R@5": 61.1, "T2I_R@10": 71.6,
            "I2T_R@1": 56.3, "I2T_R@5": 79.4, "I2T_R@10": 86.7,
            "source": "CLIP (Radford et al., 2021)"
        },
        "BLIP-ViT-B": {
            "T2I_R@1": 39.7, "T2I_R@5": 63.8, "T2I_R@10": 74.0,
            "I2T_R@1": 59.2, "I2T_R@5": 82.4, "I2T_R@10": 89.6,
            "source": "BLIP (Li et al., 2022)"
        },
        "ImageBind": {
            "T2I_R@1": 34.8, "T2I_R@5": 60.2, "T2I_R@10": 70.9,
            "I2T_R@1": 53.2, "I2T_R@5": 77.8, "I2T_R@10": 85.6,
            "source": "ImageBind (Girdhar et al., 2023)"
        },
    },
    
    # === Flickr30K Image-Text Retrieval (1K test set) ===
    "flickr_retrieval": {
        "CLIP-ViT-B/32": {
            "T2I_R@1": 68.7, "T2I_R@5": 90.6, "T2I_R@10": 95.2,
            "I2T_R@1": 88.0, "I2T_R@5": 98.7, "I2T_R@10": 99.4,
            "source": "CLIP (Radford et al., 2021)"
        },
        "CLIP-ViT-L/14": {
            "T2I_R@1": 75.6, "T2I_R@5": 93.2, "T2I_R@10": 96.5,
            "I2T_R@1": 92.4, "I2T_R@5": 99.3, "I2T_R@10": 99.8,
            "source": "CLIP (Radford et al., 2021)"
        },
        "BLIP-ViT-B": {
            "T2I_R@1": 80.6, "T2I_R@5": 95.2, "T2I_R@10": 97.8,
            "I2T_R@1": 94.8, "I2T_R@5": 99.7, "I2T_R@10": 99.9,
            "source": "BLIP (Li et al., 2022)"
        },
    },
    
    # === ESC-50 Zero-Shot Audio Classification ===
    "esc50_zeroshot": {
        "Random": {"accuracy": 2.0, "source": "1/50 classes"},
        "Wav2CLIP": {"accuracy": 41.4, "source": "Wav2CLIP (Wu et al., 2022)"},
        "AudioCLIP": {"accuracy": 69.4, "source": "AudioCLIP (Guzhov et al., 2022)"},
        "ImageBind": {"accuracy": 66.9, "source": "ImageBind (Girdhar et al., 2023)"},
        "CLAP": {"accuracy": 82.6, "source": "CLAP (Elizalde et al., 2023)"},
    },
    
    # === AudioCaps Text-Audio Retrieval ===
    "audiocaps_retrieval": {
        "AudioCLIP": {
            "T2A_R@1": 6.8, "T2A_R@5": 19.2, "T2A_R@10": 28.8,
            "A2T_R@1": 9.1, "A2T_R@5": 24.5, "A2T_R@10": 35.1,
            "source": "AudioCLIP (Guzhov et al., 2022)"
        },
        "ImageBind": {
            "T2A_R@1": 8.3, "T2A_R@5": 23.4, "T2A_R@10": 33.5,
            "A2T_R@1": 11.2, "A2T_R@5": 28.7, "A2T_R@10": 40.2,
            "source": "ImageBind (Girdhar et al., 2023)"
        },
        "CLAP": {
            "T2A_R@1": 26.7, "T2A_R@5": 54.4, "T2A_R@10": 67.1,
            "A2T_R@1": 33.9, "A2T_R@5": 61.8, "A2T_R@10": 73.8,
            "source": "CLAP (Elizalde et al., 2023)"
        },
    },
    
    # === ImageNet Zero-Shot ===
    "imagenet_zeroshot": {
        "CLIP-ViT-B/32": {"top1": 63.2, "top5": 87.8, "source": "CLIP"},
        "CLIP-ViT-L/14": {"top1": 75.5, "top5": 93.0, "source": "CLIP"},
        "ImageBind": {"top1": 77.7, "top5": 94.3, "source": "ImageBind"},
    },
    
    # === LibriSpeech ASR (WER - lower is better) ===
    "librispeech_asr": {
        "Whisper-tiny": {"wer_clean": 5.6, "source": "Whisper (Radford et al., 2022)"},
        "Whisper-base": {"wer_clean": 4.2, "source": "Whisper (Radford et al., 2022)"},
        "Whisper-small": {"wer_clean": 3.4, "source": "Whisper (Radford et al., 2022)"},
    },
}

print("üìä Loaded baselines from published papers:")
for benchmark, models in BASELINES.items():
    print(f"   ‚Ä¢ {benchmark}: {len(models)} models")

## 3. Model Architecture

Define the model architecture to match training configuration.

In [None]:
# ============================================================
# MODEL ARCHITECTURE DEFINITIONS
# ============================================================

class ModalityAdapter(nn.Module):
    """Simple linear projection adapter."""
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.proj = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        return self.proj(x)


class MLPAdapter(nn.Module):
    """2-layer MLP adapter with LayerNorm."""
    def __init__(self, d_in: int, d_out: int, hidden_factor: float = 2.0, dropout: float = 0.1):
        super().__init__()
        hidden_dim = int(d_in * hidden_factor)
        self.net = nn.Sequential(
            nn.LayerNorm(d_in),
            nn.Linear(d_in, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, d_out),
            nn.Dropout(dropout),
        )
    
    def forward(self, x):
        return self.net(x)


class FeedForward(nn.Module):
    """FFN block for Perceiver."""
    def __init__(self, dim: int, mlp_ratio: float = 4.0):
        super().__init__()
        hidden_dim = int(dim * mlp_ratio)
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x): 
        return self.net(x)


class PerceiverLayer(nn.Module):
    """Single Perceiver layer with cross-attention and self-attention."""
    def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.ln_latents_1 = nn.LayerNorm(dim)
        self.ln_tokens = nn.LayerNorm(dim)
        self.ln_latents_2 = nn.LayerNorm(dim)
        self.ln_latents_3 = nn.LayerNorm(dim)
        self.mlp = FeedForward(dim, mlp_ratio)

    def forward(self, latents, tokens, token_mask=None):
        # Cross-Attention: latents attend to tokens
        q = self.ln_latents_1(latents)
        kv = self.ln_tokens(tokens)
        key_padding_mask = ~token_mask.bool() if token_mask is not None else None
        attn_out, _ = self.cross_attn(q, kv, kv, key_padding_mask=key_padding_mask, need_weights=False)
        latents = latents + attn_out
        
        # Self-Attention on latents
        q2 = self.ln_latents_2(latents)
        self_out, _ = self.self_attn(q2, q2, q2, need_weights=False)
        latents = latents + self_out
        
        # FFN
        latents = latents + self.mlp(self.ln_latents_3(latents))
        return latents


class PerceiverResampler(nn.Module):
    """Perceiver Resampler for compressing variable-length sequences."""
    def __init__(self, dim: int, num_latents: int, num_layers: int, num_heads: int, mlp_ratio: float = 4.0):
        super().__init__()
        self.dim = dim
        self.latents = nn.Parameter(torch.randn(num_latents, dim) / math.sqrt(dim))
        self.layers = nn.ModuleList([
            PerceiverLayer(dim, num_heads, mlp_ratio) for _ in range(num_layers)
        ])

    def forward(self, tokens, token_mask=None):
        B = tokens.shape[0]
        latents = self.latents.unsqueeze(0).expand(B, -1, -1)
        for layer in self.layers:
            latents = layer(latents, tokens, token_mask)
        return latents

In [None]:
# ============================================================
# EVALUATION MODEL WRAPPER
# ============================================================

class AlignedModelForEval(nn.Module):
    """
    Complete aligned model wrapper for evaluation.
    
    Provides unified interface:
        - embed_image(images) -> (B, D)
        - embed_text(texts) -> (B, D)
        - embed_audio(waveforms) -> (B, D)
        - generate(modality_feats, modality_type, prompt) -> str
    """
    
    def __init__(
        self,
        vision_adapter: nn.Module,
        audio_adapter: nn.Module,
        perceiver: nn.Module,
        projector: nn.Module,
        cfg: EvalConfig,
    ):
        super().__init__()
        self.cfg = cfg
        self.device = device
        
        # Adapters (loaded from checkpoint)
        self.vision_adapter = vision_adapter
        self.audio_adapter = audio_adapter
        self.perceiver = perceiver
        self.projector = projector
        
        # Load frozen encoders
        print("\nüì¶ Loading frozen encoders...")
        
        # Vision (CLIP)
        self.vision_processor = CLIPImageProcessor.from_pretrained(cfg.vision_model_name)
        self.vision_encoder = CLIPVisionModel.from_pretrained(cfg.vision_model_name)
        self.vision_encoder.to(device).eval()
        for p in self.vision_encoder.parameters():
            p.requires_grad = False
        print(f"   ‚úì Vision: {cfg.vision_model_name}")
        
        # Audio (Whisper)
        self.audio_processor = WhisperProcessor.from_pretrained(cfg.audio_model_name)
        self.audio_encoder = WhisperModel.from_pretrained(cfg.audio_model_name).encoder
        self.audio_encoder.to(device).eval()
        for p in self.audio_encoder.parameters():
            p.requires_grad = False
        print(f"   ‚úì Audio: {cfg.audio_model_name}")
        
        # Text encoder for retrieval
        self.text_tokenizer = AutoTokenizer.from_pretrained(cfg.text_model_name)
        self.text_encoder = AutoModel.from_pretrained(cfg.text_model_name)
        self.text_encoder.to(device).eval()
        for p in self.text_encoder.parameters():
            p.requires_grad = False
        print(f"   ‚úì Text (retrieval): {cfg.text_model_name}")
        
        # LLM for generation (loaded on demand)
        self.llm_model = None
        self.llm_tokenizer = None
    
    def _load_llm_if_needed(self):
        """Lazy load LLM for generation tasks."""
        if self.llm_model is None:
            print(f"   Loading LLM: {self.cfg.llm_model_name}...")
            self.llm_tokenizer = AutoTokenizer.from_pretrained(self.cfg.llm_model_name)
            self.llm_model = AutoModelForCausalLM.from_pretrained(
                self.cfg.llm_model_name,
                torch_dtype=torch.bfloat16,
                device_map="auto",
            )
            self.llm_model.eval()
            print(f"   ‚úì LLM loaded")
    
    @torch.no_grad()
    def encode_image_features(self, features: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """Project pre-extracted image features to alignment space."""
        features = features.to(self.device)
        mask = mask.to(self.device)
        
        tokens = self.vision_adapter(features)
        if self.perceiver is not None:
            latents = self.perceiver(tokens, mask)
            z = self.projector(latents)
            return z.mean(dim=1)  # Pool latents
        else:
            return tokens.mean(dim=1)
    
    @torch.no_grad()
    def encode_audio_features(self, features: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """Project pre-extracted audio features to alignment space."""
        features = features.to(self.device)
        mask = mask.to(self.device)
        
        tokens = self.audio_adapter(features)
        if self.perceiver is not None:
            latents = self.perceiver(tokens, mask)
            z = self.projector(latents)
            return z.mean(dim=1)
        else:
            return tokens.mean(dim=1)
    
    @torch.no_grad()
    def embed_image(self, images: Union[List[Image.Image], torch.Tensor]) -> torch.Tensor:
        """Encode images to aligned embeddings."""
        if isinstance(images, list):
            inputs = self.vision_processor(images=images, return_tensors="pt")
            pixel_values = inputs["pixel_values"]
        else:
            pixel_values = images
        
        pixel_values = pixel_values.to(self.device)
        
        # Get features from CLIP
        outputs = self.vision_encoder(pixel_values=pixel_values)
        features = outputs.last_hidden_state  # (B, T, D)
        
        # Project through adapter + perceiver
        mask = torch.ones(features.shape[0], features.shape[1], dtype=torch.bool, device=self.device)
        z = self.encode_image_features(features, mask)
        
        return z
    
    @torch.no_grad()
    def embed_audio(self, waveforms: np.ndarray, sr: int = 16000) -> torch.Tensor:
        """Encode audio waveforms to aligned embeddings."""
        # Process with Whisper
        inputs = self.audio_processor(
            waveforms,
            sampling_rate=sr,
            return_tensors="pt",
        )
        input_features = inputs["input_features"].to(self.device)
        
        # Encode
        outputs = self.audio_encoder(input_features)
        features = outputs.last_hidden_state  # (B, T, D)
        
        # Project through adapter + perceiver
        mask = torch.ones(features.shape[0], features.shape[1], dtype=torch.bool, device=self.device)
        z = self.encode_audio_features(features, mask)
        
        return z
    
    @torch.no_grad()
    def embed_text(self, texts: List[str]) -> torch.Tensor:
        """
        Encode texts using the text encoder.
        
        For retrieval, we use the sentence-transformer encoder.
        For comparing with LLM embeddings, we use LLM embeddings.
        """
        tokens = self.text_tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=128,
            return_tensors="pt",
        ).to(self.device)
        
        outputs = self.text_encoder(**tokens)
        
        # Mean pooling
        hidden = outputs.last_hidden_state
        mask = tokens["attention_mask"].unsqueeze(-1).float()
        pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-6)
        
        return pooled
    
    @torch.no_grad()
    def embed_text_llm(self, texts: List[str]) -> torch.Tensor:
        """Encode texts using LLM embeddings (for comparison with projected features)."""
        self._load_llm_if_needed()
        
        tokens = self.llm_tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=64,
            return_tensors="pt",
        ).to(self.llm_model.device)
        
        token_embs = self.llm_model.get_input_embeddings()(tokens.input_ids)
        mask = tokens.attention_mask.unsqueeze(-1)
        pooled = (token_embs * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
        
        return pooled.float()
    
    @torch.no_grad()
    def generate(self, modality_feats: torch.Tensor, modality_type: str, prompt: str, max_new_tokens: int = 50) -> str:
        """Generate text from modality features."""
        self._load_llm_if_needed()
        
        # Select adapter
        if modality_type == "vision":
            adapter = self.vision_adapter
        elif modality_type == "audio":
            adapter = self.audio_adapter
        else:
            raise ValueError(f"Unknown modality: {modality_type}")
        
        modality_feats = modality_feats.to(self.device)
        mask = torch.ones(1, modality_feats.shape[1], dtype=torch.bool, device=self.device)
        
        # Project modality
        tokens = adapter(modality_feats)
        if self.perceiver is not None:
            latents = self.perceiver(tokens, mask)
            modality_embeds = self.projector(latents)
        else:
            modality_embeds = tokens
        
        # Move to LLM device/dtype
        llm_dev = self.llm_model.model.embed_tokens.weight.device
        llm_dtype = self.llm_model.model.embed_tokens.weight.dtype
        modality_embeds = modality_embeds.to(device=llm_dev, dtype=llm_dtype)
        
        # Embed text prompt
        text_inputs = self.llm_tokenizer(prompt, return_tensors="pt").to(llm_dev)
        text_embeds = self.llm_model.get_input_embeddings()(text_inputs.input_ids)
        
        # Concatenate
        combined = torch.cat([modality_embeds, text_embeds], dim=1)
        
        # Generate
        out = self.llm_model.generate(
            inputs_embeds=combined,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            pad_token_id=self.llm_tokenizer.eos_token_id,
        )
        
        return self.llm_tokenizer.batch_decode(out, skip_special_tokens=True)[0]

## 4. Load Trained Model

In [None]:
# ============================================================
# LOAD CHECKPOINT
# ============================================================

def load_model(cfg: EvalConfig) -> AlignedModelForEval:
    """
    Load trained model from checkpoint.
    
    Expected checkpoint structure:
        - 'vision_adapter': state_dict
        - 'audio_adapter': state_dict  
        - 'perceiver': state_dict
        - 'projector': state_dict
    """
    print(f"\nüì¶ Loading checkpoint: {cfg.checkpoint_path}")
    
    # Initialize architectures
    vision_adapter = ModalityAdapter(cfg.d_vision, cfg.perceiver_dim)
    audio_adapter = ModalityAdapter(cfg.d_audio, cfg.perceiver_dim)
    
    if cfg.use_perceiver:
        perceiver = PerceiverResampler(
            dim=cfg.perceiver_dim,
            num_latents=cfg.num_latents,
            num_layers=cfg.num_perceiver_layers,
            num_heads=cfg.num_attn_heads,
        )
        projector = nn.Linear(cfg.perceiver_dim, cfg.llm_hidden_size)
    else:
        perceiver = None
        projector = nn.Linear(cfg.perceiver_dim, cfg.d_align)
    
    # Load weights if checkpoint exists
    if Path(cfg.checkpoint_path).exists():
        checkpoint = torch.load(cfg.checkpoint_path, map_location="cpu")
        
        # Handle different checkpoint formats
        if 'model' in checkpoint:
            state = checkpoint['model']
        else:
            state = checkpoint
        
        # Load each component
        if 'vision_adapter' in state:
            vision_adapter.load_state_dict(state['vision_adapter'])
            print("   ‚úì Loaded vision_adapter")
        
        if 'audio_adapter' in state:
            audio_adapter.load_state_dict(state['audio_adapter'])
            print("   ‚úì Loaded audio_adapter")
        
        if 'perceiver' in state and perceiver is not None:
            perceiver.load_state_dict(state['perceiver'])
            print("   ‚úì Loaded perceiver")
        
        if 'projector' in state:
            projector.load_state_dict(state['projector'])
            print("   ‚úì Loaded projector")
        
        if 'epoch' in checkpoint:
            print(f"   Checkpoint from epoch: {checkpoint['epoch']}")
    else:
        print(f"   ‚ö†Ô∏è Checkpoint not found! Using random weights.")
        print(f"   This is for testing the evaluation pipeline only.")
    
    # Move to device
    vision_adapter.to(device).eval()
    audio_adapter.to(device).eval()
    if perceiver is not None:
        perceiver.to(device).eval()
    projector.to(device).eval()
    
    # Create wrapper
    model = AlignedModelForEval(
        vision_adapter=vision_adapter,
        audio_adapter=audio_adapter,
        perceiver=perceiver,
        projector=projector,
        cfg=cfg,
    )
    
    return model


# Load model
model = load_model(cfg)
model.eval()
print("\n‚úÖ Model ready for evaluation")

## 5. Core Evaluation Functions

In [None]:
# ============================================================
# RETRIEVAL METRICS
# ============================================================

def l2_normalize(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
    """L2 normalize along specified dimension."""
    return F.normalize(x, p=2, dim=dim)


def recall_at_k(sim_matrix: torch.Tensor, k: int = 1) -> float:
    """
    Compute Recall@K from similarity matrix.
    Assumes diagonal entries are ground truth matches.
    """
    N = sim_matrix.size(0)
    k = min(k, N)
    targets = torch.arange(N, device=sim_matrix.device)
    topk_indices = sim_matrix.topk(k, dim=-1).indices
    hits = (topk_indices == targets.unsqueeze(-1)).any(dim=-1)
    return hits.float().mean().item() * 100


def compute_retrieval_metrics(
    query_embs: torch.Tensor,
    target_embs: torch.Tensor,
    ks: List[int] = [1, 5, 10],
) -> Dict[str, float]:
    """Compute R@K, MeanRank, MedianRank."""
    q = l2_normalize(query_embs.float())
    t = l2_normalize(target_embs.float())
    sim = q @ t.T
    N = sim.size(0)
    targets = torch.arange(N, device=sim.device)
    
    metrics = {}
    for k in ks:
        metrics[f"R@{k}"] = recall_at_k(sim, k)
    
    # Rank statistics
    sorted_indices = sim.argsort(dim=-1, descending=True)
    ranks = (sorted_indices == targets.unsqueeze(-1)).nonzero(as_tuple=True)[1]
    metrics["MeanRank"] = ranks.float().mean().item()
    metrics["MedianRank"] = ranks.median().item()
    
    return metrics, sim


def compute_bidirectional_retrieval(
    embs_a: torch.Tensor,
    embs_b: torch.Tensor,
    name_a: str = "A",
    name_b: str = "B",
) -> Dict[str, float]:
    """Compute A->B and B->A retrieval."""
    a2b, _ = compute_retrieval_metrics(embs_a, embs_b)
    a2b_results = {f"{name_a}2{name_b}_{k}": v for k, v in a2b.items()}
    
    b2a, _ = compute_retrieval_metrics(embs_b, embs_a)
    b2a_results = {f"{name_b}2{name_a}_{k}": v for k, v in b2a.items()}
    
    return {**a2b_results, **b2a_results}

In [None]:
# ============================================================
# EMBEDDING QUALITY METRICS
# ============================================================

def compute_alignment_uniformity(
    z_a: torch.Tensor,
    z_b: torch.Tensor,
    t: float = 2.0,
) -> Dict[str, float]:
    """
    Alignment & Uniformity metrics (Wang & Isola, 2020).
    
    - Alignment: avg distance between positive pairs (lower = better)
    - Uniformity: log of avg pairwise Gaussian (lower = better)
    """
    z_a = l2_normalize(z_a.float())
    z_b = l2_normalize(z_b.float())
    
    # Alignment
    alignment = (z_a - z_b).pow(2).sum(dim=-1).mean().item()
    
    # Uniformity
    sq_pdist_a = torch.pdist(z_a, p=2).pow(2)
    uniformity_a = sq_pdist_a.mul(-t).exp().mean().log().item()
    
    sq_pdist_b = torch.pdist(z_b, p=2).pow(2)
    uniformity_b = sq_pdist_b.mul(-t).exp().mean().log().item()
    
    return {
        "alignment": alignment,
        "uniformity_a": uniformity_a,
        "uniformity_b": uniformity_b,
        "uniformity_avg": (uniformity_a + uniformity_b) / 2,
    }


def compute_gramian_volume(
    z_a: torch.Tensor,
    z_b: torch.Tensor,
    z_c: Optional[torch.Tensor] = None,
    n_samples: int = 256,
) -> Dict[str, float]:
    """
    Gramian volume for multimodal alignment quality.
    Lower volume = better alignment.
    """
    N = min(n_samples, z_a.size(0))
    volumes = []
    
    for i in range(N):
        a = l2_normalize(z_a[i:i+1]).squeeze()
        b = l2_normalize(z_b[i:i+1]).squeeze()
        
        if z_c is not None:
            c = l2_normalize(z_c[i:i+1]).squeeze()
            G = torch.stack([
                torch.stack([a @ a, a @ b, a @ c]),
                torch.stack([b @ a, b @ b, b @ c]),
                torch.stack([c @ a, c @ b, c @ c]),
            ])
        else:
            G = torch.stack([
                torch.stack([a @ a, a @ b]),
                torch.stack([b @ a, b @ b]),
            ])
        
        vol = torch.det(G).abs().item()
        volumes.append(vol)
    
    return {
        "gramian_mean": np.mean(volumes),
        "gramian_std": np.std(volumes),
        "gramian_median": np.median(volumes),
    }

## 6. Image-Text Retrieval Evaluation

In [None]:
# ============================================================
# IMAGE-TEXT RETRIEVAL (COCO / Flickr30K)
# ============================================================

@torch.no_grad()
def evaluate_image_text_retrieval(
    model: AlignedModelForEval,
    dataset_name: str = "coco",
    max_samples: int = 1000,
    batch_size: int = 32,
) -> Dict[str, Any]:
    """
    Evaluate image-text retrieval on COCO or Flickr30K.
    """
    print(f"\nüñºÔ∏è  Evaluating Image-Text Retrieval on {dataset_name.upper()}...")
    
    # Load dataset
    if dataset_name.lower() == "coco":
        try:
            ds = load_dataset("HuggingFaceM4/COCO", split="validation", streaming=True)
            img_col, txt_col = "image", "sentences"
        except:
            print("   Using PixMo-Cap as fallback...")
            ds = load_dataset("allenai/pixmo-cap", split="train", streaming=True)
            img_col, txt_col = "image_url", "caption"
    elif dataset_name.lower() == "flickr":
        ds = load_dataset("nlphuji/flickr30k", split="test", streaming=True)
        img_col, txt_col = "image", "caption"
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    # Collect embeddings
    all_img_embs, all_txt_embs = [], []
    batch_imgs, batch_txts = [], []
    count = 0
    
    pbar = tqdm(ds, total=max_samples, desc=f"Encoding {dataset_name}")
    
    for ex in pbar:
        if count >= max_samples:
            break
        
        try:
            # Get image
            if img_col == "image_url":
                resp = requests.get(ex[img_col], timeout=5)
                img = Image.open(BytesIO(resp.content)).convert("RGB")
            else:
                img = ex[img_col]
                if not isinstance(img, Image.Image):
                    continue
                img = img.convert("RGB")
            
            # Get caption
            caption = ex[txt_col]
            if isinstance(caption, list):
                caption = caption[0] if caption else ""
            if isinstance(caption, dict):
                caption = caption.get("raw", str(caption))
            
            batch_imgs.append(img)
            batch_txts.append(str(caption))
            count += 1
            
            # Process batch
            if len(batch_imgs) >= batch_size:
                img_emb = model.embed_image(batch_imgs)
                txt_emb = model.embed_text(batch_txts)
                all_img_embs.append(img_emb.cpu())
                all_txt_embs.append(txt_emb.cpu())
                batch_imgs, batch_txts = [], []
        except Exception as e:
            continue
    
    # Process remaining
    if batch_imgs:
        img_emb = model.embed_image(batch_imgs)
        txt_emb = model.embed_text(batch_txts)
        all_img_embs.append(img_emb.cpu())
        all_txt_embs.append(txt_emb.cpu())
    
    # Concatenate
    all_img_embs = torch.cat(all_img_embs, dim=0)
    all_txt_embs = torch.cat(all_txt_embs, dim=0)
    
    print(f"   Collected {all_img_embs.size(0)} pairs")
    
    # Compute metrics
    metrics = compute_bidirectional_retrieval(all_txt_embs, all_img_embs, "T", "I")
    
    # Embedding quality
    quality = compute_alignment_uniformity(all_img_embs[:500], all_txt_embs[:500])
    metrics.update({f"quality_{k}": v for k, v in quality.items()})
    
    # Print results
    print(f"\n   üìä {dataset_name.upper()} Results:")
    print(f"      Text‚ÜíImage: R@1={metrics['T2I_R@1']:.1f}%, R@5={metrics['T2I_R@5']:.1f}%, R@10={metrics['T2I_R@10']:.1f}%")
    print(f"      Image‚ÜíText: R@1={metrics['I2T_R@1']:.1f}%, R@5={metrics['I2T_R@5']:.1f}%, R@10={metrics['I2T_R@10']:.1f}%")
    
    return {
        "metrics": metrics,
        "img_embs": all_img_embs,
        "txt_embs": all_txt_embs,
    }

In [None]:
# Run COCO evaluation
coco_results = evaluate_image_text_retrieval(
    model,
    dataset_name="coco",
    max_samples=cfg.coco_samples,
    batch_size=cfg.batch_size,
)

## 7. ESC-50 Zero-Shot Audio Classification

In [None]:
# ============================================================
# ESC-50 ZERO-SHOT CLASSIFICATION
# ============================================================

ESC50_CLASSES = [
    'dog', 'rooster', 'pig', 'cow', 'frog', 'cat', 'hen', 'insects', 'sheep', 'crow',
    'rain', 'sea_waves', 'crackling_fire', 'crickets', 'chirping_birds', 
    'water_drops', 'wind', 'pouring_water', 'toilet_flush', 'thunderstorm',
    'crying_baby', 'sneezing', 'clapping', 'breathing', 'coughing', 
    'footsteps', 'laughing', 'brushing_teeth', 'snoring', 'drinking_sipping',
    'door_wood_knock', 'mouse_click', 'keyboard_typing', 'door_wood_creaks', 'can_opening', 
    'washing_machine', 'vacuum_cleaner', 'clock_alarm', 'clock_tick', 'glass_breaking',
    'helicopter', 'chainsaw', 'siren', 'car_horn', 'engine', 
    'train', 'church_bells', 'airplane', 'fireworks', 'hand_saw'
]

@torch.no_grad()
def evaluate_esc50_zeroshot(
    model: AlignedModelForEval,
    max_samples: int = 400,
) -> Dict[str, Any]:
    """
    Zero-shot audio classification on ESC-50.
    """
    print(f"\nüîä Evaluating ESC-50 Zero-Shot Classification...")
    
    if not HAS_LIBROSA:
        print("   ‚ö†Ô∏è librosa not installed. Skipping.")
        return {"accuracy": 0.0, "num_samples": 0}
    
    # Create text embeddings for each class
    prompts = [f"The sound of {c.replace('_', ' ')}." for c in ESC50_CLASSES]
    class_embs = model.embed_text(prompts)
    class_embs = l2_normalize(class_embs)
    
    # Load ESC-50
    try:
        ds = load_dataset("ashraq/esc50", split="train", streaming=True)
        ds = ds.cast_column("audio", load_dataset.features.Audio(decode=False))
    except:
        try:
            ds = load_dataset("ashraq/esc50", split="train", streaming=True).decode(False)
        except Exception as e:
            print(f"   ‚ö†Ô∏è Could not load ESC-50: {e}")
            return {"accuracy": 0.0, "num_samples": 0}
    
    predictions, ground_truth = [], []
    count = 0
    
    for ex in tqdm(ds, total=max_samples, desc="ESC-50"):
        if count >= max_samples:
            break
        
        try:
            # Decode audio
            audio_bytes = ex["audio"]["bytes"]
            wav, sr = librosa.load(io.BytesIO(audio_bytes), sr=16000)
            category = ex["category"]
            
            # Embed audio
            audio_emb = model.embed_audio(wav, sr=16000)
            audio_emb = l2_normalize(audio_emb)
            
            # Classify
            sims = audio_emb @ class_embs.T
            pred_idx = sims.argmax(dim=-1).item()
            pred_class = ESC50_CLASSES[pred_idx]
            
            predictions.append(pred_class)
            ground_truth.append(category)
            count += 1
        except Exception as e:
            continue
    
    # Compute accuracy
    accuracy = accuracy_score(ground_truth, predictions) * 100 if predictions else 0.0
    
    print(f"\n   üìä ESC-50 Results:")
    print(f"      Zero-Shot Accuracy: {accuracy:.2f}%")
    print(f"      Samples evaluated: {len(predictions)}")
    
    return {
        "accuracy": accuracy,
        "num_samples": len(predictions),
        "predictions": predictions,
        "ground_truth": ground_truth,
    }

In [None]:
# Run ESC-50 evaluation
esc50_results = evaluate_esc50_zeroshot(model, max_samples=cfg.esc50_samples)

## 8. Matryoshka Representation Learning Evaluation

In [None]:
# ============================================================
# MATRYOSHKA EVALUATION
# ============================================================

def evaluate_matryoshka(
    img_embs: torch.Tensor,
    txt_embs: torch.Tensor,
    dims: Tuple[int, ...] = (32, 64, 128, 256, 512),
) -> Dict[int, Dict[str, float]]:
    """
    Evaluate retrieval at different embedding truncations (Matryoshka).
    """
    print(f"\nüìê Evaluating Matryoshka Representations...")
    
    full_dim = img_embs.size(-1)
    valid_dims = [d for d in dims if d <= full_dim]
    
    results = {}
    
    for dim in valid_dims:
        img_trunc = img_embs[:, :dim]
        txt_trunc = txt_embs[:, :dim]
        
        metrics = compute_bidirectional_retrieval(txt_trunc, img_trunc, "T", "I")
        
        results[dim] = {
            "T2I_R@1": metrics["T2I_R@1"],
            "T2I_R@5": metrics["T2I_R@5"],
            "I2T_R@1": metrics["I2T_R@1"],
        }
        
        print(f"   Dim={dim:4d}: T2I R@1={metrics['T2I_R@1']:5.1f}%, I2T R@1={metrics['I2T_R@1']:5.1f}%")
    
    return results


def plot_matryoshka_curve(mrl_results: Dict, save_path: Optional[str] = None):
    """Plot MRL accuracy vs dimension curve."""
    dims = sorted(mrl_results.keys())
    t2i_r1 = [mrl_results[d]["T2I_R@1"] for d in dims]
    i2t_r1 = [mrl_results[d]["I2T_R@1"] for d in dims]
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    ax.plot(dims, t2i_r1, 'b-o', label='Text‚ÜíImage R@1', linewidth=2, markersize=8)
    ax.plot(dims, i2t_r1, 'r-s', label='Image‚ÜíText R@1', linewidth=2, markersize=8)
    
    ax.set_xlabel('Embedding Dimension', fontsize=12)
    ax.set_ylabel('Recall@1 (%)', fontsize=12)
    ax.set_title('Matryoshka Representation Learning\nAccuracy vs Embedding Dimension', fontsize=14)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    ax.set_xscale('log', base=2)
    ax.set_xticks(dims)
    ax.set_xticklabels(dims)
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"   Saved: {save_path}")
    plt.show()

In [None]:
# Run Matryoshka evaluation
if coco_results:
    mrl_results = evaluate_matryoshka(
        coco_results["img_embs"],
        coco_results["txt_embs"],
        dims=cfg.mrl_dims,
    )
    
    plot_matryoshka_curve(
        mrl_results,
        save_path=f"{cfg.results_dir}/matryoshka_curve.png" if cfg.save_results else None
    )

## 9. Compare Against Baselines

In [None]:
# ============================================================
# BASELINE COMPARISON
# ============================================================

def create_comparison_table(
    our_results: Dict[str, float],
    baselines: Dict[str, Dict[str, float]],
    metrics: List[str],
    our_name: str = "Ours",
) -> pd.DataFrame:
    """Create comparison DataFrame."""
    rows = []
    
    for model_name, model_results in baselines.items():
        row = {"Model": model_name}
        for metric in metrics:
            row[metric] = model_results.get(metric, "-")
        rows.append(row)
    
    our_row = {"Model": f"üî• {our_name}"}
    for metric in metrics:
        our_row[metric] = our_results.get(metric, "-")
    rows.append(our_row)
    
    return pd.DataFrame(rows)


def plot_comparison_chart(
    our_results: Dict[str, float],
    baselines: Dict[str, Dict[str, float]],
    metric: str,
    title: str,
    save_path: Optional[str] = None,
    higher_is_better: bool = True,
):
    """Create bar chart comparison."""
    models = list(baselines.keys()) + ["Ours"]
    values = [baselines[m].get(metric, 0) for m in baselines.keys()] + [our_results.get(metric, 0)]
    
    colors = ['steelblue'] * len(baselines) + ['coral']
    
    fig, ax = plt.subplots(figsize=(12, 6))
    bars = ax.bar(models, values, color=colors, edgecolor='black', linewidth=1.2)
    
    for bar, val in zip(bars, values):
        height = bar.get_height()
        ax.annotate(f'{val:.1f}',
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3),
                    textcoords="offset points",
                    ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    ax.set_ylabel(metric, fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.set_ylim(0, max(values) * 1.15)
    plt.xticks(rotation=45, ha='right')
    
    # Add "better" indicator
    direction = "‚Üë Higher is better" if higher_is_better else "‚Üì Lower is better"
    ax.text(0.02, 0.98, direction, transform=ax.transAxes, fontsize=10,
            verticalalignment='top', style='italic', color='gray')
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
# ============================================================
# CREATE COMPARISON TABLES & CHARTS
# ============================================================

print("\n" + "="*70)
print("üìä COMPARISON WITH PUBLISHED BASELINES")
print("="*70)

# --- COCO Retrieval ---
if coco_results:
    our_coco = {
        "T2I_R@1": coco_results["metrics"]["T2I_R@1"],
        "T2I_R@5": coco_results["metrics"]["T2I_R@5"],
        "T2I_R@10": coco_results["metrics"]["T2I_R@10"],
        "I2T_R@1": coco_results["metrics"]["I2T_R@1"],
        "I2T_R@5": coco_results["metrics"]["I2T_R@5"],
        "I2T_R@10": coco_results["metrics"]["I2T_R@10"],
    }
    
    print("\nüì∑ COCO Image-Text Retrieval:")
    coco_table = create_comparison_table(
        our_coco,
        BASELINES["coco_retrieval"],
        ["T2I_R@1", "T2I_R@5", "I2T_R@1", "I2T_R@5"],
    )
    print(coco_table.to_string(index=False))
    
    plot_comparison_chart(
        our_coco,
        BASELINES["coco_retrieval"],
        "T2I_R@1",
        "COCO Text‚ÜíImage Retrieval R@1 Comparison",
        save_path=f"{cfg.results_dir}/coco_comparison.png" if cfg.save_results else None,
    )

# --- ESC-50 ---
if esc50_results and esc50_results.get("accuracy", 0) > 0:
    our_esc50 = {"accuracy": esc50_results["accuracy"]}
    
    print("\nüîä ESC-50 Zero-Shot Audio Classification:")
    esc50_table = create_comparison_table(
        our_esc50,
        BASELINES["esc50_zeroshot"],
        ["accuracy"],
    )
    print(esc50_table.to_string(index=False))
    
    plot_comparison_chart(
        our_esc50,
        BASELINES["esc50_zeroshot"],
        "accuracy",
        "ESC-50 Zero-Shot Accuracy Comparison",
        save_path=f"{cfg.results_dir}/esc50_comparison.png" if cfg.save_results else None,
    )

## 10. Visualizations

In [None]:
# ============================================================
# VISUALIZATION FUNCTIONS
# ============================================================

def plot_similarity_heatmap(
    img_embs: torch.Tensor,
    txt_embs: torch.Tensor,
    n_samples: int = 32,
    save_path: Optional[str] = None,
):
    """Plot similarity matrix heatmap."""
    img_sub = l2_normalize(img_embs[:n_samples].float())
    txt_sub = l2_normalize(txt_embs[:n_samples].float())
    sim_matrix = (txt_sub @ img_sub.T).cpu().numpy()
    
    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(
        sim_matrix, ax=ax, cmap='RdBu_r', center=0,
        vmin=-1, vmax=1, square=True, cbar_kws={'shrink': 0.8},
    )
    ax.set_xlabel('Image Index', fontsize=12)
    ax.set_ylabel('Text Index', fontsize=12)
    ax.set_title('Text-Image Similarity Matrix\n(Diagonal should be highest)', fontsize=14)
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()


def plot_tsne_embeddings(
    img_embs: torch.Tensor,
    txt_embs: torch.Tensor,
    n_samples: int = 200,
    save_path: Optional[str] = None,
):
    """Plot t-SNE visualization."""
    if not HAS_SKLEARN:
        print("   sklearn not available for t-SNE")
        return
    
    n = min(n_samples, img_embs.size(0))
    img_sub = img_embs[:n].numpy()
    txt_sub = txt_embs[:n].numpy()
    combined = np.vstack([img_sub, txt_sub])
    
    print("   Running t-SNE...")
    tsne = TSNE(n_components=2, perplexity=30, random_state=cfg.seed)
    coords = tsne.fit_transform(combined)
    
    fig, ax = plt.subplots(figsize=(10, 8))
    img_coords = coords[:n]
    txt_coords = coords[n:]
    
    ax.scatter(img_coords[:, 0], img_coords[:, 1], c='coral', label='Image', alpha=0.6, s=50)
    ax.scatter(txt_coords[:, 0], txt_coords[:, 1], c='steelblue', label='Text', alpha=0.6, s=50)
    
    # Connect paired samples
    for i in range(min(30, n)):
        ax.plot([img_coords[i, 0], txt_coords[i, 0]],
                [img_coords[i, 1], txt_coords[i, 1]],
                'gray', alpha=0.3, linewidth=0.5)
    
    ax.set_xlabel('t-SNE Dim 1', fontsize=12)
    ax.set_ylabel('t-SNE Dim 2', fontsize=12)
    ax.set_title('t-SNE of Image & Text Embeddings\n(Lines connect paired samples)', fontsize=14)
    ax.legend(fontsize=11)
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
# Create visualizations
if coco_results:
    print("\nüìà Creating Visualizations...")
    
    plot_similarity_heatmap(
        coco_results["img_embs"],
        coco_results["txt_embs"],
        n_samples=32,
        save_path=f"{cfg.results_dir}/similarity_heatmap.png" if cfg.save_results else None,
    )
    
    plot_tsne_embeddings(
        coco_results["img_embs"],
        coco_results["txt_embs"],
        n_samples=200,
        save_path=f"{cfg.results_dir}/tsne_embeddings.png" if cfg.save_results else None,
    )

## 11. Final Summary Report

In [None]:
# ============================================================
# FINAL SUMMARY
# ============================================================

def generate_summary(coco_results, mrl_results, esc50_results, cfg):
    """Generate final summary report."""
    summary = {
        "config": {
            "checkpoint": cfg.checkpoint_path,
            "d_align": cfg.d_align,
            "mrl_dims": cfg.mrl_dims,
        },
        "results": {},
    }
    
    if coco_results:
        summary["results"]["coco"] = {
            k: v for k, v in coco_results["metrics"].items()
            if not k.startswith("quality")
        }
    
    if mrl_results:
        summary["results"]["matryoshka"] = mrl_results
    
    if esc50_results:
        summary["results"]["esc50"] = {
            "accuracy": esc50_results.get("accuracy", 0),
            "num_samples": esc50_results.get("num_samples", 0),
        }
    
    return summary


# Generate and print summary
summary = generate_summary(
    coco_results if 'coco_results' in dir() else None,
    mrl_results if 'mrl_results' in dir() else None,
    esc50_results if 'esc50_results' in dir() else None,
    cfg,
)

print("\n" + "="*70)
print("üìã FINAL EVALUATION SUMMARY")
print("="*70)

# Pretty print
if 'coco_results' in dir() and coco_results:
    m = coco_results["metrics"]
    print(f"\nüì∑ Image-Text Retrieval (COCO):")
    print(f"   Text‚ÜíImage: R@1={m['T2I_R@1']:.1f}% | R@5={m['T2I_R@5']:.1f}% | R@10={m['T2I_R@10']:.1f}%")
    print(f"   Image‚ÜíText: R@1={m['I2T_R@1']:.1f}% | R@5={m['I2T_R@5']:.1f}% | R@10={m['I2T_R@10']:.1f}%")

if 'mrl_results' in dir() and mrl_results:
    print(f"\nüìê Matryoshka (MRL):")
    for dim in sorted(mrl_results.keys()):
        r = mrl_results[dim]
        print(f"   Dim {dim:4d}: T2I R@1={r['T2I_R@1']:5.1f}%")

if 'esc50_results' in dir() and esc50_results.get("accuracy", 0) > 0:
    print(f"\nüîä Audio Classification (ESC-50):")
    print(f"   Zero-Shot Accuracy: {esc50_results['accuracy']:.2f}%")

# Save report
if cfg.save_results:
    report_path = f"{cfg.results_dir}/evaluation_report.json"
    with open(report_path, "w") as f:
        json.dump(summary, f, indent=2, default=str)
    print(f"\n‚úÖ Report saved: {report_path}")

print("\n" + "="*70)
print("‚úÖ Evaluation Complete!")
print("="*70)

---

## üìö Benchmark Reference Table

| Paper | Model | COCO T2I R@1 | COCO I2T R@1 | ESC-50 Acc | Year |
|-------|-------|--------------|--------------|------------|------|
| CLIP | ViT-B/32 | 30.4% | 50.1% | - | 2021 |
| CLIP | ViT-L/14 | 36.5% | 56.3% | - | 2021 |
| BLIP | ViT-B | 39.7% | 59.2% | - | 2022 |
| ImageBind | ViT-H | 34.8% | 53.2% | 66.9% | 2023 |
| CLAP | - | - | - | 82.6% | 2023 |
| **Ours** | - | **?%** | **?%** | **?%** | 2024 |

---