# 03a ‚Äì Multimodal Alignment Retrieval Evaluation

This notebook evaluates **Phase-1 multimodal alignment** checkpoints trained in:

- `02_alig_multi_mlp.ipynb`
- `02_alig_multi_perciever.ipynb`

It computes retrieval metrics for:

- **Vision ‚Üî Text** (PixMo-Cap)
- **Audio ‚Üî Text** (Clotho)

for both model variants:

- **MLP + MRL** (no Perceiver bottleneck)
- **Perceiver + MLP + MRL**

and supports:

- Matryoshka truncation curves over `mm_cfg.mrl_dims`
- Rich metrics (R@K, mean/median rank, mAP@K, NDCG@K)
- Optional logging to **Weights & Biases**.


In [1]:
# ============================================================
# Imports & basic setup
# ============================================================

import os
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Optional, Dict, Any, List, Tuple

import math
import random
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset

from datasets import load_dataset

# Transformers encoders
from transformers import (
    CLIPVisionModel,
    CLIPImageProcessor,
    AutoTokenizer,
    AutoModel,
    WhisperModel,
    WhisperProcessor,
)

import matplotlib.pyplot as plt

try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False

# Local project imports
from imports.core import (
    set_seed,
    l2_normalize,
)
from imports.multimodal_alignment_perceiver import (
    MultimodalAlignmentConfig,
    MultimodalAlignmentModel,
)
from imports.in_memory_datasets import (
    InMemoryImageTextDataset,
    collate_in_memory_images,
)

from imports.core import compute_retrieval_metrics
from zipfile import Path as ZipPath



In [None]:
# ============================================================
# Experiment config (mirrors 02_alig_multi_mlp.ipynb)
# ============================================================

@dataclass
class ExperimentConfig:
    # General
    seed: int = 42
    num_epochs_mlp: int = 3
    num_epochs_perceiver: int = 5
    log_every: int = 20

    # Data: if Parquet paths are None, we fall back to HF datasets
    pixmo_parquet_dir: Path = Path.cwd() / "data" / "alignment_offline"
    clotho_parquet_dir: Path = Path.cwd() / "data" / "alignment_offline"
    pixmo_parquet_glob: str = "pixmocap_offline_20000*.parquet"
    clotho_parquet_glob: str = "clotho_development.parquet"

    image_hf_dataset: str = "allenai/pixmo-cap"
    audio_hf_dataset: str = "clotho_v2"

    max_image_samples: Optional[int] = 100   # None for all
    max_audio_samples: Optional[int] = 1000    # None for all
    image_size: Tuple[int, int] = (224, 224)
    sample_val_fraction: float = 0.05  # fraction of samples for validation

    # Training
    base_batch_size: int = 32
    learning_rate: float = 1e-4
    weight_decay: float = 0.01
    max_grad_norm: float = 1.0
    use_amp: bool = True

    # WandB
    use_wandb: bool = True
    wandb_project: str = "edgeglass_phase1_alignment"
    wandb_entity: Optional[str] = None  # set if you use a team
    wandb_mode: str = "online"  # "offline" or "disabled" etc.
    wandb_run_name: Optional[str] = "phase1_multimodal_eval"

    # Checkpoints (Phase 1)
    root_dir: Path = Path(".").resolve()
    ckpt_root: str = "checkpoints/phase1_multimodal"  # relative to root_dir

    # Alignment / model (weights are mirrored into mm_cfg later)
    mrl_weight: float = 1.0
    clip_weight: float = 0.5


cfg = ExperimentConfig()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
set_seed(cfg.seed)


Device: cuda


In [3]:

ROOT_DIR = cfg.root_dir
CKPT_ROOT = ROOT_DIR / cfg.ckpt_root
MLP_DIR = CKPT_ROOT / "mlp_mrl"
PERCEIVER_DIR = CKPT_ROOT / "perceiver_mrl"

for d in [CKPT_ROOT, MLP_DIR, PERCEIVER_DIR, cfg.pixmo_parquet_dir, cfg.clotho_parquet_dir]:
    d.mkdir(parents=True, exist_ok=True)

print("ROOT_DIR:", ROOT_DIR)
print("Checkpoint root:", CKPT_ROOT)
print("MLP run dir:", MLP_DIR)
print("Perceiver run dir:", PERCEIVER_DIR)

# Matryoshka / alignment config used by the model
mm_cfg = MultimodalAlignmentConfig()
mm_cfg.device = str(device)
mm_cfg.mrl_weight = cfg.mrl_weight
mm_cfg.clip_weight = cfg.clip_weight
print(mm_cfg)


ROOT_DIR: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base
Checkpoint root: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/checkpoints/phase1_multimodal
MLP run dir: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/checkpoints/phase1_multimodal/mlp_mrl
Perceiver run dir: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/checkpoints/phase1_multimodal/perceiver_mrl
MultimodalAlignmentConfig(vision_model_name='openai/clip-vit-base-patch32', audio_model_name='openai/whisper-base', text_model_name='sentence-transformers/all-MiniLM-L6-v2', llm_model_name='Qwen/Qwen2.5-1.5B-Instruct', d_vision=768, d_audio=512, d_text=384, perceiver_dim=512, num_latents=64, num_perceiver_layers=4, num_attn_heads=8, perceiver_mlp_ratio=4.0, perceiver_dropout=0.1, d_align=512, mrl_dims=(64, 128, 256, 512), llm_hidden_size=1536, num_prefix_tokens=64, batch_size=32, learning_rate=0.0001, weight_decay=0.01, num_epochs=10, warm

In [4]:
# ============================================================
# Dataset helpers ‚Äì PixMo (vision-text) and Clotho (audio-text)
# ============================================================

import glob
import torchaudio


def get_pixmo_dataset():
    """Load PixMo-Cap 'train' split from Parquet if available, otherwise from HF."""
    parquet_pattern = cfg.pixmo_parquet_dir / cfg.pixmo_parquet_glob
    matches = sorted(glob.glob(str(parquet_pattern)))
    if matches:
        print(f"Loading PixMo-Cap from Parquet: {matches[-1]}")
        ds_dict = load_dataset("parquet", data_files={"train": matches[-1]})
        ds = ds_dict["train"]
    else:
        print(f"No PixMo Parquet found at pattern {parquet_pattern}, loading from HF: {cfg.image_hf_dataset}")
        ds = load_dataset(cfg.image_hf_dataset, split="train")
    return ds

def build_image_val_loader() -> DataLoader:
    """Create *validation* DataLoader for image-text data using PixMo-Cap."""
    print("\n=== Building Image-Text Datasets (PixMo-Cap) for EVAL ===")
    hf_ds = get_pixmo_dataset()

    if cfg.max_image_samples is not None:
        hf_ds = hf_ds.select(range(min(cfg.max_image_samples, len(hf_ds))))

    # Expect PixMo-Cap columns: image_url + caption
    colnames = hf_ds.column_names
    if "image_url" not in colnames:
        raise ValueError(f"Expected 'image_url' column in PixMo dataset, found: {colnames}")
    if "caption" not in colnames:
        raise ValueError(f"Expected 'caption' column in PixMo dataset, found: {colnames}")

    ds = InMemoryImageTextDataset(
        hf_dataset=hf_ds,
        img_col="image_url",
        txt_col="caption",
        max_samples=None,  # already limited above
        image_size=cfg.image_size,
    )

    # Train/val split via indices (mirror training)
    n = len(ds)
    idx = np.random.permutation(n)
    split = int(n * (1.0 - cfg.sample_val_fraction))
    train_idx, val_idx = idx[:split], idx[split:]

    val_ds = Subset(ds, val_idx)

    num_gpus = max(1, torch.cuda.device_count())
    batch_size = cfg.base_batch_size * num_gpus

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=min(8, os.cpu_count() or 4),
        collate_fn=collate_in_memory_images,
        pin_memory=torch.cuda.is_available(),
        drop_last=False,
    )

    print(f"Image val size: {len(val_ds)} | batch size: {batch_size}")
    return val_loader


image_val_loader = build_image_val_loader()



=== Building Image-Text Datasets (PixMo-Cap) for EVAL ===
Loading PixMo-Cap from Parquet: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/data/alignment_offline/pixmocap_offline_20000.parquet

üì• Pre-loading 100 images into memory...
   Image size: (224, 224)
   Using 32 parallel workers


Loading images: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [00:03<00:00, 25.48it/s]

‚úÖ Loaded 100 images into memory
Image val size: 5 | batch size: 64





In [14]:
# === Audio‚ÄìText Eval Loader (Clotho records, same as training) ===
from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path
import os
import random

import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader, Subset

# Path to your torch-saved Clotho records file
CLOTHO_RECORDS_PATH = ROOT_DIR / "data" / "alignment_offline" / "clotho_development.parquet"

# Cache resamplers per original sr
_resamplers: Dict[int, torchaudio.transforms.Resample] = {}


def collate_clotho_batch(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    """
    Collate function for ClothoAudioCaptionDataset (eval).

    - Takes variable-length waveforms at original sr (typically 44100)
    - Resamples each to 16 kHz for Whisper
    - Pads to max length in the batch
    """
    audios_16k = []
    lengths_16k = []
    captions = []
    file_names = []

    for b in batch:
        wav = b["audio"]        # Tensor [T_orig]
        sr_orig = int(b["sr"])  # e.g. 44100

        # --- resample to 16 kHz ---
        if sr_orig != 16000:
            if sr_orig not in _resamplers:
                _resamplers[sr_orig] = torchaudio.transforms.Resample(
                    orig_freq=sr_orig,
                    new_freq=16000,
                )
            wav = _resamplers[sr_orig](wav)  # [T_16k]

        audios_16k.append(wav)
        lengths_16k.append(wav.shape[0])
        captions.append(b["caption"])
        file_names.append(b["file_name"])

    # Pad to max length at 16 kHz
    max_len = max(lengths_16k)
    B = len(audios_16k)

    padded = audios_16k[0].new_zeros(B, max_len)  # (B, T_max_16k)
    for i, a in enumerate(audios_16k):
        padded[i, : a.shape[0]] = a

    return {
        "audio": padded,                                    # (B, T_max_16k)
        "audio_lengths": torch.tensor(lengths_16k, dtype=torch.long),
        "sr": 16000,                                        # now fixed to 16k
        "captions": captions,                               # list[str]
        "file_names": file_names,                           # list[str]
    }


class ClothoAudioCaptionDataset(Dataset):
    """
    Dataset reading the torch-saved Clotho records used in training.
    Each item returns a single (audio, caption) pair.
    """

    def __init__(
        self,
        records_path: str,
        pick_random_caption: bool = True,
    ) -> None:
        super().__init__()
        self.records: List[Dict[str, Any]] = torch.load(records_path)
        self.pick_random_caption = pick_random_caption

    def __len__(self) -> int:
        return len(self.records)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        item = self.records[idx]
        waveform = item["waveform"]       # torch.FloatTensor [T]
        sr = item["sr"]
        file_name = item["file_name"]
        captions = item["captions"]       # list[str]

        if self.pick_random_caption:
            caption = random.choice(captions)
        else:
            caption = captions[0]

        return {
            "audio": waveform,
            "sr": sr,
            "caption": caption,
            "file_name": file_name,
        }


def build_audio_eval_dataloader() -> Optional[DataLoader]:
    """
    Build evaluation DataLoader for audio‚Äìtext retrieval using the same
    Clotho records format + collate function as used in training.
    """
    if not CLOTHO_RECORDS_PATH.exists():
        print(f"[Eval] Clotho records file not found: {CLOTHO_RECORDS_PATH}")
        print("       Skipping audio alignment evaluation.")
        return None

    print(f"[Eval] Loading Clotho records from: {CLOTHO_RECORDS_PATH}")
    ds = ClothoAudioCaptionDataset(
        records_path=str(CLOTHO_RECORDS_PATH),
        pick_random_caption=False,  # deterministic captions for eval
    )

    n = len(ds)
    print(f"[Eval] Clotho total records: {n}")

    # Optional subsampling for speed
    if MAX_EVAL_SAMPLES is not None and n > MAX_EVAL_SAMPLES:
        idx = torch.arange(n)[:MAX_EVAL_SAMPLES]
        ds = Subset(ds, idx)
        print(f"[Eval] Using first {MAX_EVAL_SAMPLES} Clotho samples for eval")

    num_gpus = max(1, torch.cuda.device_count())
    batch_size = BATCH_SIZE * num_gpus  # match training scaling if desired

    loader = DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=False,                      # eval -> no shuffle
        num_workers=min(8, os.cpu_count() or 4),
        collate_fn=collate_clotho_batch,    # same collate as training
        pin_memory=torch.cuda.is_available(),
        drop_last=False,
    )

    print(f"[Eval] Clotho eval batches: {len(loader)} | batch size: {batch_size}")
    return loader


# Build the eval loader (or None if file missing)
audio_val_loader = build_audio_eval_dataloader()

[Eval] Loading Clotho records from: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/data/alignment_offline/clotho_development.parquet
[Eval] Clotho total records: 3839
[Eval] Using first 1000 Clotho samples for eval
[Eval] Clotho eval batches: 16 | batch size: 64


In [15]:
# ============================================================
# Frozen encoders: CLIP (vision), Whisper (audio), MiniLM (text)
# ============================================================

# Vision encoder (CLIP)
vision_model_name = mm_cfg.vision_model_name
print("\nLoading CLIP vision encoder:", vision_model_name)
clip_processor = CLIPImageProcessor.from_pretrained(vision_model_name)
clip_vision = CLIPVisionModel.from_pretrained(vision_model_name)
clip_vision.to(device)
clip_vision.eval()
for p in clip_vision.parameters():
    p.requires_grad = False

# Text encoder (MiniLM)
text_model_name = mm_cfg.text_model_name
print("Loading text encoder:", text_model_name)
text_tokenizer = AutoTokenizer.from_pretrained(text_model_name)
text_encoder = AutoModel.from_pretrained(text_model_name)
text_encoder.to(device)
text_encoder.eval()
for p in text_encoder.parameters():
    p.requires_grad = False

# Audio encoder (Whisper)
audio_model_name = mm_cfg.audio_model_name
print("Loading audio encoder:", audio_model_name)
whisper_processor = WhisperProcessor.from_pretrained(audio_model_name)
whisper_encoder = WhisperModel.from_pretrained(audio_model_name).get_encoder()
whisper_encoder.to(device)
whisper_encoder.eval()
for p in whisper_encoder.parameters():
    p.requires_grad = False


# ============================================================
# Feature encoding helpers (match 02_alig_multi_mlp.ipynb)
# ============================================================

def encode_images_to_features(images: List) -> torch.Tensor:
    """Encode a batch of PIL images to CLIP patch features (B, T, 768)."""
    with torch.no_grad():
        inputs = clip_processor(images=images, return_tensors="pt").to(device)
        outputs = clip_vision(**inputs)
        feats = outputs.last_hidden_state  # (B, T, 768)
    return feats


def encode_texts_to_features(texts: List[str]) -> torch.Tensor:
    """Encode a batch of texts to token features (B, L, 384)."""
    with torch.no_grad():
        tokens = text_tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=64,
            return_tensors="pt",
        ).to(device)
        outputs = text_encoder(**tokens)
        feats = outputs.last_hidden_state  # (B, L, 384)
    return feats


def encode_audio_to_features(audio_batch: torch.Tensor, sr: int) -> torch.Tensor:
    """Encode a batch of waveforms [B, T] at 16 kHz to Whisper features (B, T_feat, 512)."""
    assert isinstance(audio_batch, torch.Tensor), "audio_batch must be Tensor[B, T]"
    if sr != 16000:
        raise ValueError(f"Expected 16 kHz audio for Whisper, got sr={sr}")

    B, T_max = audio_batch.shape
    features: List[torch.Tensor] = []

    with torch.no_grad():
        for i in range(B):
            wav = audio_batch[i].detach().cpu().float().numpy()  # (T,)

            inputs = whisper_processor(
                wav,
                sampling_rate=16000,
                return_tensors="pt",
            ).to(device)

            out = whisper_encoder(inputs.input_features)          # (1, T_feat_i, 512)
            feat_i = out.last_hidden_state.squeeze(0)            # (T_feat_i, 512)
            features.append(feat_i)

    # Pad along time dimension to get (B, T_feat_max, 512)
    max_T_feat = max(f.shape[0] for f in features)
    hidden_dim = features[0].shape[1]

    feats_batch = torch.zeros(B, max_T_feat, hidden_dim, device=device)
    for i, f in enumerate(features):
        feats_batch[i, : f.shape[0]] = f

    return feats_batch



Loading CLIP vision encoder: openai/clip-vit-base-patch32
Loading text encoder: sentence-transformers/all-MiniLM-L6-v2
Loading audio encoder: openai/whisper-base


In [16]:
# ============================================================
# Checkpoint loading ‚Äì MLP + Perceiver variants
# ============================================================

def load_alignment_model_from_dir(ckpt_dir: Path, tag: str = "best") -> Tuple[MultimodalAlignmentModel, Dict[str, Any]]:
    """Load MultimodalAlignmentModel + config dicts from a checkpoint directory."""
    ckpt_path = ckpt_dir / f"{tag}.pt"
    if not ckpt_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")

    print(f"\nLoading checkpoint from: {ckpt_path}")
    state = torch.load(ckpt_path, map_location=device, weights_only=False)

    mm_cfg_dict = state.get("mm_config", None)
    exp_cfg_dict = state.get("exp_config", None)

    if mm_cfg_dict is not None:
        mm_cfg_local = MultimodalAlignmentConfig(**mm_cfg_dict)
    else:
        mm_cfg_local = MultimodalAlignmentConfig()

    mm_cfg_local.device = str(device)

    model = MultimodalAlignmentModel(mm_cfg_local).to(device)
    model.load_state_dict(state["model_state"], strict=True)
    model.eval()

    return model, {"mm_config": mm_cfg_local, "exp_config": exp_cfg_dict, "state": state}


# Try to load both variants if available
model_mlp, info_mlp = None, None
model_perceiver, info_perceiver = None, None

if (MLP_DIR / "best.pt").exists():
    model_mlp, info_mlp = load_alignment_model_from_dir(MLP_DIR, tag="best")
else:
    print("‚ö†Ô∏è No MLP checkpoint found (best.pt).")

if (PERCEIVER_DIR / "best.pt").exists():
    model_perceiver, info_perceiver = load_alignment_model_from_dir(PERCEIVER_DIR, tag="best")
else:
    print("‚ö†Ô∏è No Perceiver checkpoint found (best.pt).")

assert model_mlp is not None or model_perceiver is not None, "No checkpoints found ‚Äì cannot run eval."



Loading checkpoint from: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/checkpoints/phase1_multimodal/mlp_mrl/best.pt

Loading checkpoint from: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/checkpoints/phase1_multimodal/perceiver_mrl/best.pt


In [17]:
# ============================================================
# Retrieval metrics (extend core.compute_retrieval_metrics)
# ============================================================

def compute_full_retrieval_metrics(z_q: torch.Tensor, z_k: torch.Tensor) -> Dict[str, float]:
    """Compute recall@K plus rank-based metrics for 1-1 aligned pairs."""
    base = compute_retrieval_metrics(z_q, z_k)  # uses l2_normalize inside

    # Similarity matrix
    z_q_n = l2_normalize(z_q)
    z_k_n = l2_normalize(z_k)
    sims = z_q_n @ z_k_n.T  # (N, N)

    N = sims.size(0)
    ranks = []

    for i in range(N):
        scores = sims[i]
        # Sort in descending order; get rank of the *true* index i
        sorted_idx = torch.argsort(scores, descending=True)
        rank_pos = (sorted_idx == i).nonzero(as_tuple=False).item()  # 0-based
        ranks.append(rank_pos + 1)  # 1-based

    ranks = torch.tensor(ranks, dtype=torch.float32)
    mean_rank = ranks.mean().item()
    median_rank = ranks.median().item()
    mrr = (1.0 / ranks).mean().item()

    def ndcg_at_k(k: int) -> float:
        gains = []
        for r in ranks.tolist():
            if r <= k:
                gains.append(1.0 / math.log2(r + 1.0))
            else:
                gains.append(0.0)
        # IDCG = 1/log2(1+1) = 1
        return float(sum(gains) / len(gains))

    res = dict(base)
    res["mean_rank"] = mean_rank
    res["median_rank"] = median_rank
    res["MRR"] = mrr
    res["NDCG@10"] = ndcg_at_k(10)
    res["NDCG@50"] = ndcg_at_k(50)
    return res


In [18]:
# ============================================================
# Embedding collection helpers
# ============================================================

def get_model_module(model: nn.Module) -> nn.Module:
    """Unwrap DataParallel if needed."""
    return model.module if isinstance(model, nn.DataParallel) else model


def encode_modality(
    model: MultimodalAlignmentModel,
    feats: torch.Tensor,
    mask: Optional[torch.Tensor],
    modality: str,
    use_perceiver: bool,
) -> torch.Tensor:
    """Encode features into aligned space (mirrors training helper)."""
    m = get_model_module(model)

    if modality == "vision":
        adapter = m.vision_adapter
    elif modality == "audio":
        adapter = m.audio_adapter
    elif modality == "text":
        adapter = m.text_adapter
    else:
        raise ValueError(f"Unknown modality: {modality}")

    tokens = adapter(feats)  # (B, T, perceiver_dim)

    if use_perceiver and hasattr(m, "perceiver") and m.perceiver is not None:
        latents = m.perceiver(tokens, mask)  # (B, K, perceiver_dim)
    else:
        # MLP-only: treat tokens as latents and pool across sequence
        latents = tokens  # (B, T, perceiver_dim)

    z = m.alignment_projector(latents)  # (B, d_align)
    return z


@torch.no_grad()
def collect_pair_embeddings(
    model: MultimodalAlignmentModel,
    loader: DataLoader,
    modality_a: str,
    modality_b: str,
    use_perceiver: bool,
    max_batches: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Collect aligned embeddings for a chosen pair: {vision,audio,text}."""
    model.eval()
    all_a, all_b = [], []

    for b_idx, batch in enumerate(tqdm(loader, desc=f"Embeddings {modality_a}-{modality_b}")):
        # Move tensors to device
        batch_t = {}
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch_t[k] = v.to(device)
            else:
                batch_t[k] = v

        if modality_a == "vision":
            feats_a = encode_images_to_features(batch_t["images"])
            mask_a = None
        elif modality_a == "audio":
            feats_a = encode_audio_to_features(batch_t["audio"], int(batch_t["sr"]))
            mask_a = None
        elif modality_a == "text":
            feats_a = encode_texts_to_features(batch_t["captions"])
            mask_a = None
        else:
            raise ValueError(modality_a)

        if modality_b == "vision":
            feats_b = encode_images_to_features(batch_t["images"])
            mask_b = None
        elif modality_b == "audio":
            feats_b = encode_audio_to_features(batch_t["audio"], int(batch_t["sr"]))
            mask_b = None
        elif modality_b == "text":
            feats_b = encode_texts_to_features(batch_t["captions"])
            mask_b = None
        else:
            raise ValueError(modality_b)

        z_a = encode_modality(model, feats_a, mask_a, modality_a, use_perceiver)
        z_b = encode_modality(model, feats_b, mask_b, modality_b, use_perceiver)

        all_a.append(z_a.cpu())
        all_b.append(z_b.cpu())

        if max_batches is not None and (b_idx + 1) >= max_batches:
            break

    emb_a = torch.cat(all_a, dim=0)
    emb_b = torch.cat(all_b, dim=0)
    print(f"Collected {emb_a.shape[0]} pairs for {modality_a}-{modality_b}")
    return emb_a, emb_b


def truncate_embeddings(z: torch.Tensor, dim: int) -> torch.Tensor:
    """Matryoshka truncation: keep only the first `dim` dimensions."""
    return z[..., :dim]


In [19]:
# ============================================================
# Evaluation loops (R@K, ranks, MRL curves)
# ============================================================

def eval_alignment_pair_mrl(
    model: MultimodalAlignmentModel,
    loader: DataLoader,
    modality_a: str,
    modality_b: str,
    use_perceiver: bool,
    mrl_dims: List[int],
    prefix: str,
) -> Dict[str, float]:
    """Evaluate retrieval for a modality pair over MRL dims."""
    z_a_full, z_b_full = collect_pair_embeddings(
        model=model,
        loader=loader,
        modality_a=modality_a,
        modality_b=modality_b,
        use_perceiver=use_perceiver,
        max_batches=None,
    )

    results: Dict[str, float] = {}

    for d in mrl_dims:
        z_a = truncate_embeddings(z_a_full, d)
        z_b = truncate_embeddings(z_b_full, d)

        # A -> B and B -> A
        metrics_a2b = compute_full_retrieval_metrics(z_a, z_b)
        metrics_b2a = compute_full_retrieval_metrics(z_b, z_a)

        for k in (1, 5, 10, 50):
            results[f"{prefix}/d{d}/{modality_a}_to_{modality_b}/R@{k}"] = metrics_a2b.get(f"R@{k}", 0.0)
            results[f"{prefix}/d{d}/{modality_b}_to_{modality_a}/R@{k}"] = metrics_b2a.get(f"R@{k}", 0.0)

        results[f"{prefix}/d{d}/{modality_a}_to_{modality_b}/mean_rank"] = metrics_a2b["mean_rank"]
        results[f"{prefix}/d{d}/{modality_a}_to_{modality_b}/median_rank"] = metrics_a2b["median_rank"]
        results[f"{prefix}/d{d}/{modality_a}_to_{modality_b}/MRR"] = metrics_a2b["MRR"]
        results[f"{prefix}/d{d}/{modality_a}_to_{modality_b}/NDCG@10"] = metrics_a2b["NDCG@10"]
        results[f"{prefix}/d{d}/{modality_a}_to_{modality_b}/NDCG@50"] = metrics_a2b["NDCG@50"]

        results[f"{prefix}/d{d}/{modality_b}_to_{modality_a}/mean_rank"] = metrics_b2a["mean_rank"]
        results[f"{prefix}/d{d}/{modality_b}_to_{modality_a}/median_rank"] = metrics_b2a["median_rank"]
        results[f"{prefix}/d{d}/{modality_b}_to_{modality_a}/MRR"] = metrics_b2a["MRR"]
        results[f"{prefix}/d{d}/{modality_b}_to_{modality_a}/NDCG@10"] = metrics_b2a["NDCG@10"]
        results[f"{prefix}/d{d}/{modality_b}_to_{modality_a}/NDCG@50"] = metrics_b2a["NDCG@50"]

    return results


def run_eval_for_model(
    model: MultimodalAlignmentModel,
    tag: str,
    mrl_dims: Optional[List[int]] = None,
) -> Dict[str, float]:
    """Run full eval (vision-text + audio-text) for a given model variant."""
    if mrl_dims is None:
        # fall back to model's config if available
        mm_cfg_model = get_model_module(model).config if hasattr(get_model_module(model), "config") else mm_cfg
        mrl_dims = list(mm_cfg_model.mrl_dims)

    results_all: Dict[str, float] = {}

    # Vision‚ÄìText (PixMo)
    if image_val_loader is not None:
        print(f"\n[{tag}] Evaluating Vision‚ÄìText alignment...")
        vt = eval_alignment_pair_mrl(
            model=model,
            loader=image_val_loader,
            modality_a="vision",
            modality_b="text",
            use_perceiver=(tag == "perceiver"),  # convention
            mrl_dims=mrl_dims,
            prefix=f"{tag}/vision_text",
        )
        results_all.update(vt)
    else:
        print("No image_val_loader, skipping vision-text eval.")

    # Audio‚ÄìText (Clotho)
    if audio_val_loader is not None:
        print(f"\n[{tag}] Evaluating Audio‚ÄìText alignment...")
        at = eval_alignment_pair_mrl(
            model=model,
            loader=audio_val_loader,
            modality_a="audio",
            modality_b="text",
            use_perceiver=(tag == "perceiver"),  # convention
            mrl_dims=mrl_dims,
            prefix=f"{tag}/audio_text",
        )
        results_all.update(at)
    else:
        print("No audio_val_loader, skipping audio-text eval.")

    return results_all


In [20]:
# ============================================================
# Main: run eval for available checkpoints + (optional) W&B
# ============================================================

use_wandb = cfg.use_wandb and WANDB_AVAILABLE

run = None
if use_wandb:
    run = wandb.init(
        project=cfg.wandb_project,
        entity=cfg.wandb_entity,
        mode=cfg.wandb_mode,
        name=cfg.wandb_run_name,
        config={
            "eval": True,
            "ckpt_root": str(CKPT_ROOT),
            "mlp_dir": str(MLP_DIR),
            "perceiver_dir": str(PERCEIVER_DIR),
        },
    )

all_results: Dict[str, float] = {}

if model_mlp is not None:
    res_mlp = run_eval_for_model(model_mlp, tag="mlp")
    all_results.update(res_mlp)

if model_perceiver is not None:
    res_perc = run_eval_for_model(model_perceiver, tag="perceiver")
    all_results.update(res_perc)

print("\n===== SUMMARY METRICS =====")
for k in sorted(all_results.keys()):
    print(f"{k}: {all_results[k]:.4f}")

if run is not None:
    wandb.log(all_results)
    run.finish()



[mlp] Evaluating Vision‚ÄìText alignment...


Embeddings vision-text: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:01<00:00,  1.94s/it]


Collected 5 pairs for vision-text

[mlp] Evaluating Audio‚ÄìText alignment...


Embeddings audio-text: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 16/16 [00:15<00:00,  1.04it/s]


Collected 1000 pairs for audio-text

[perceiver] Evaluating Vision‚ÄìText alignment...


Embeddings vision-text: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:02<00:00,  2.32s/it]


Collected 5 pairs for vision-text

[perceiver] Evaluating Audio‚ÄìText alignment...


Embeddings audio-text: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 16/16 [00:14<00:00,  1.13it/s]


Collected 1000 pairs for audio-text

===== SUMMARY METRICS =====
mlp/audio_text/d128/audio_to_text/MRR: 0.1375
mlp/audio_text/d128/audio_to_text/NDCG@10: 0.1628
mlp/audio_text/d128/audio_to_text/NDCG@50: 0.2335
mlp/audio_text/d128/audio_to_text/R@1: 0.0570
mlp/audio_text/d128/audio_to_text/R@10: 0.3060
mlp/audio_text/d128/audio_to_text/R@5: 0.1970
mlp/audio_text/d128/audio_to_text/R@50: 0.0000
mlp/audio_text/d128/audio_to_text/mean_rank: 70.6780
mlp/audio_text/d128/audio_to_text/median_rank: 29.0000
mlp/audio_text/d128/text_to_audio/MRR: 0.1609
mlp/audio_text/d128/text_to_audio/NDCG@10: 0.1875
mlp/audio_text/d128/text_to_audio/NDCG@50: 0.2662
mlp/audio_text/d128/text_to_audio/R@1: 0.0700
mlp/audio_text/d128/text_to_audio/R@10: 0.3380
mlp/audio_text/d128/text_to_audio/R@5: 0.2390
mlp/audio_text/d128/text_to_audio/R@50: 0.0000
mlp/audio_text/d128/text_to_audio/mean_rank: 58.5580
mlp/audio_text/d128/text_to_audio/median_rank: 22.0000
mlp/audio_text/d256/audio_to_text/MRR: 0.1417
mlp/audio

0,1
mlp/audio_text/d128/audio_to_text/MRR,‚ñÅ
mlp/audio_text/d128/audio_to_text/NDCG@10,‚ñÅ
mlp/audio_text/d128/audio_to_text/NDCG@50,‚ñÅ
mlp/audio_text/d128/audio_to_text/R@1,‚ñÅ
mlp/audio_text/d128/audio_to_text/R@10,‚ñÅ
mlp/audio_text/d128/audio_to_text/R@5,‚ñÅ
mlp/audio_text/d128/audio_to_text/R@50,‚ñÅ
mlp/audio_text/d128/audio_to_text/mean_rank,‚ñÅ
mlp/audio_text/d128/audio_to_text/median_rank,‚ñÅ
mlp/audio_text/d128/text_to_audio/MRR,‚ñÅ

0,1
mlp/audio_text/d128/audio_to_text/MRR,0.13751
mlp/audio_text/d128/audio_to_text/NDCG@10,0.16276
mlp/audio_text/d128/audio_to_text/NDCG@50,0.23348
mlp/audio_text/d128/audio_to_text/R@1,0.057
mlp/audio_text/d128/audio_to_text/R@10,0.306
mlp/audio_text/d128/audio_to_text/R@5,0.197
mlp/audio_text/d128/audio_to_text/R@50,0
mlp/audio_text/d128/audio_to_text/mean_rank,70.678
mlp/audio_text/d128/audio_to_text/median_rank,29
mlp/audio_text/d128/text_to_audio/MRR,0.16094


NameError: name 'Tensor' is not defined

## Diagnostics & Plots

To interpret alignment quality (similar to plots in **Freeze‚ÄëAlign** and related work), we provide:

- **Rank histograms** (already logged to W&B)
- **Rank CDF plots** ‚Äì how often the correct match is in the top‚ÄëK
- **Recall@K curves** ‚Äì R@K vs. K
- **Positive vs. Negative similarity distributions** ‚Äì to see separation of matched vs. unmatched pairs
- **Matryoshka curves** ‚Äì R@K vs. embedding dimension


In [22]:
# === Plotting Utilities ===
from torch import nn, Tensor


def plot_rank_histogram(ranks: Tensor, title: str, max_rank: int = 100):
    ranks_np = ranks.cpu().numpy()
    ranks_np = np.clip(ranks_np, 0, max_rank)

    plt.figure(figsize=(6, 4))
    plt.hist(ranks_np + 1, bins=min(max_rank, 100))
    plt.xlabel('Rank (1‚Äëindexed)')
    plt.ylabel('Frequency (log scale)')
    plt.yscale('log')
    plt.title(title)
    plt.tight_layout()
    plt.show()


def plot_rank_cdf(ranks: Tensor, title: str, max_k: int = 50):
    ranks_np = ranks.cpu().numpy() + 1  # 1‚Äëindexed
    ks = np.arange(1, max_k + 1)
    cdf = [(ranks_np <= k).mean() * 100.0 for k in ks]

    plt.figure(figsize=(6, 4))
    plt.plot(ks, cdf, marker='o')
    plt.xlabel('K')
    plt.ylabel('P(rank ‚â§ K) [%]')
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()


def plot_recall_curve(results: Dict[str, float], base_prefix: str, direction: str):
    ks = [1, 5, 10, 50]
    vals = [results.get(f'{base_prefix}/{direction}/R@{k}', np.nan) for k in ks]

    plt.figure(figsize=(6, 4))
    plt.plot(ks, vals, marker='o')
    plt.xlabel('K')
    plt.ylabel('Recall@K [%]')
    plt.title(f'Recall Curve: {direction}')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()


def plot_similarity_distributions(q: Tensor, k: Tensor, title: str, num_neg: int = 1024):
    q = l2_normalize(q)
    k = l2_normalize(k)

    with torch.no_grad():
        sims = q @ k.t()  # (N, N)
        N = sims.size(0)
        pos = sims.diag().cpu().numpy()

        # Sample a subset of negatives for visualization
        mask = torch.eye(N, device=sims.device).bool()
        neg = sims.masked_fill(mask, float('-inf'))
        neg_vals = neg[neg > -1e9].cpu().numpy()
        if len(neg_vals) > num_neg:
            idx = np.random.choice(len(neg_vals), size=num_neg, replace=False)
            neg_vals = neg_vals[idx]

    plt.figure(figsize=(6, 4))
    plt.hist(neg_vals, bins=50, alpha=0.6, label='Negatives')
    plt.hist(pos, bins=50, alpha=0.6, label='Positives')
    plt.xlabel('Cosine similarity')
    plt.ylabel('Count')
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.show()


def plot_mrl_curves(results: Dict[str, float], base_prefix: str, direction: str):
    """
    Plot Matryoshka curves (R@K vs embedding dim) for a given direction.

    Expects keys of the form:
      {base_prefix}/mrl_dim_{d}/{direction}/R@K

    Example:
      vision_text/image_text/mrl_dim_64/text_to_image/R@1
    """
    dims = []

    for key in results.keys():
        # We only care about keys that:
        #  - start with base_prefix
        #  - contain '/mrl_dim_'
        #  - end with '/{direction}/R@1'
        if not key.startswith(base_prefix):
            continue
        if '/mrl_dim_' not in key:
            continue
        if not key.endswith(f'{direction}/R@1'):
            continue

        parts = key.split('/')
        # Find the segment like "mrl_dim_64"
        dim_part = None
        for p in parts:
            if p.startswith('mrl_dim_'):
                dim_part = p
                break

        if dim_part is None:
            continue

        try:
            d = int(dim_part.split('_')[-1])
            dims.append(d)
        except ValueError:
            # In case something weird slips through
            continue

    dims = sorted(set(dims))
    if not dims:
        print('No Matryoshka results found for', direction)
        return

    r1_vals, r5_vals, r10_vals = [], [], []
    for d in dims:
        r1_vals.append(results.get(f'{base_prefix}/mrl_dim_{d}/{direction}/R@1', np.nan))
        r5_vals.append(results.get(f'{base_prefix}/mrl_dim_{d}/{direction}/R@5', np.nan))
        r10_vals.append(results.get(f'{base_prefix}/mrl_dim_{d}/{direction}/R@10', np.nan))

    plt.figure(figsize=(6, 4))
    plt.plot(dims, r1_vals, marker='o', label='R@1')
    plt.plot(dims, r5_vals, marker='o', label='R@5')
    plt.plot(dims, r10_vals, marker='o', label='R@10')
    plt.xlabel('Matryoshka dimension (d)')
    plt.ylabel('Recall [%]')
    plt.title(f'Matryoshka Retrieval vs Dim ({direction})')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()


In [None]:
# Run eval (Phase-1 vision‚Äìtext only)
image_text_results, (img_emb, txt_emb, ranks_t2i, ranks_i2t) = eval_image_text_retrieval(
    aligned_model,
    aligned_cfg,
    pixmo_eval_loader,
    prefix=f"{model_kind}/image_text",
)


In [None]:
# === Visualize Image‚ÄìText Alignment Diagnostics ===

base_prefix = f'{"multimodal"}/image_text'

# Rank histograms & CDFs
plot_rank_histogram(ranks_t2i, title='Text‚ÜíImage Rank Histogram')
plot_rank_histogram(ranks_i2t, title='Image‚ÜíText Rank Histogram')
plot_rank_cdf(ranks_t2i, title='Text‚ÜíImage Rank CDF')
plot_rank_cdf(ranks_i2t, title='Image‚ÜíText Rank CDF')

# Recall curves
plot_recall_curve(image_text_results, base_prefix, direction='text_to_image')
plot_recall_curve(image_text_results, base_prefix, direction='image_to_text')

# Similarity distributions
plot_similarity_distributions(txt_emb, img_emb, title='Text‚ÄìImage Similarity Distributions (Aligned Space)')

# Matryoshka curves (if available)
base_prefix = f'{model_kind}/image_text'
plot_mrl_curves(image_text_results, base_prefix, direction='text_to_image')
plot_mrl_curves(image_text_results, base_prefix, direction='image_to_text')

NameError: name 'ranks_t2i' is not defined

: 