# 03a ‚Äì Alignment Retrieval Evaluation

This notebook evaluates **Phase‚Äë1 alignment checkpoints** for retrieval:

- Vision ‚Üî Text retrieval (image‚Üítext, text‚Üíimage)
- (Optional) Audio ‚Üî Text and Vision ‚Üî Audio retrieval for multimodal Perceiver models
- **Matryoshka** multi‚Äëscale evaluation (performance vs. truncation dimensionality)
- Comparison between **pre‚Äëalignment** (frozen encoders) and **post‚Äëalignment** (projectors / Perceiver)

All metrics and plots are logged to **Weights & Biases** for easy comparison across:

- Different encoder pairs (e.g., DINOv2 + MiniLM vs CLIP + MiniLM)
- Different alignment heads (MLP‚Äëonly vs Perceiver+MLP+MRL)
- Different Matryoshka scales.


## Evaluation Plan & Metrics

We follow common evaluation practices from **Freeze‚ÄëAlign, ImageBind, Matryoshka Multimodal Models, and Unified‚ÄëIO 2**:

### 1. Retrieval Metrics
For each modality pair (e.g., image‚Üîtext, audio‚Üîtext):

- **Recall@K** for K ‚àà {1, 5, 10, 50}
  - Image‚ÜíText: given an image embedding, rank all texts.
  - Text‚ÜíImage: given a caption embedding, rank all images.
- **Mean Rank (MR)** and **Median Rank (MedR)**
- **Mean Average Precision @K (mAP@K)** with K ‚àà {10, 50}
- **NDCG@K** for ranking quality (K ‚àà {10, 50})

These mirror the metrics used in CLIP/Freeze‚ÄëAlign style image‚Äëtext retrieval benchmarks.

### 2. Matryoshka Multi‚ÄëScale Evaluation
Given Matryoshka dimensions `mrl_dims = [d1, d2, ..., dN]` used during training:

- Compute the full set of retrieval metrics at each scale by truncating embeddings to the first `d_i` dims.
- This lets us draw **performance vs. dimension** curves (e.g., R@1 vs. dim), similar to Matryoshka.

### 3. Baseline vs. Aligned Comparison

If we can access frozen encoder outputs (without alignment projectors):
- Evaluate retrieval in the **original encoder spaces** (e.g., DINOv2 CLS vs MiniLM sentence embedding).
- Evaluate retrieval in the **aligned space** (after projectors / Perceiver).
- Log relative gains (ŒîR@K, ŒîmAP, ŒîNDCG) for quick comparison.

### 4. Diagnostics & Visualizations

- Histograms of **positive ranks** (position of the true match in the ranked list).
- Distributions of **intra‚Äëmodal** vs **cross‚Äëmodal** cosine similarities.
- Optional: t‚ÄëSNE / UMAP scatter plot of aligned embeddings colored by class / concept (if labels exist).
- Aggregated tables for quick copy‚Äëpaste into the report.


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# === Imports & Environment Setup ===
import os
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import math
import json

import numpy as np
import pandas as pd

import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
from tqdm import tqdm

import wandb

from imports.core import AlignmentConfig, VisionTextAligner, l2_normalize
from imports.multimodal_alignment_perceiver import MultimodalAlignmentModel, MultimodalAlignmentConfig

from datasets import load_dataset
from torch.utils.data import DataLoader

from imports.in_memory_datasets import (
    InMemoryImageTextDataset,
    collate_in_memory_images,
)

from imports.train import load_checkpoint as load_vt_checkpoint  # Phase-1 vision-text loader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)




Using device: cuda


In [3]:
# === Config: Paths, Checkpoints, and WandB ===
from dataclasses import asdict

# Root of your EdgeGlass / alignment project
ROOT_DIR = Path(os.environ.get('EDGE_GLASS_ROOT', '.')).resolve()

# Directory where Phase‚Äë1 alignment checkpoints are saved
# Example pattern (adjust as needed):
#   ROOT_DIR / 'checkpoints' / 'phase1' / 'vision_text' / 'minilm_siglip' / 'best.pt'
CHECKPOINT_PATH = ROOT_DIR / "checkpoints" / "phase1" / "vision_text" / "best.pt"

# If you also want to evaluate the multimodal Perceiver alignment model,
# set this path as well (optional).
CHECKPOINT_MLP_MULTI = ROOT_DIR / "checkpoints" / "phase1_multimodal" / "mlp_mrl" / "best.pt"
CHECKPOINT_PERCEIVER_MULTI = ROOT_DIR / "checkpoints" / "phase1_multimodal" / "perceiver_mrl" / "best.pt"


# Dataset settings (evaluation split)
# We reuse the same loader utilities as training; you can also swap in
# your Parquet‚Äëbased feature datasets if desired.
DATASET_NAME = 'pixmo_cap'  # or whatever you used in training
EVAL_SPLIT = 'val'          # or 'test' if you created a test split
MAX_EVAL_SAMPLES = 500    # cap for quick eval; set None for full

# Dataloader settings
BATCH_SIZE = 64
NUM_WORKERS = 2

# WandB config
WANDB_PROJECT = 'edgeglass_phase1_alignment'
WANDB_ENTITY = None  # set your entity if needed
RUN_NAME = 'phase1_alignment_retrieval_eval'

wandb_config = {
    'checkpoint_path': str(CHECKPOINT_PATH),
    'perceiver_checkpoint_path': str(CHECKPOINT_PERCEIVER_MULTI),
    'mlp_checkpoint_path': str(CHECKPOINT_MLP_MULTI),
    'dataset_name': DATASET_NAME,
    'eval_split': EVAL_SPLIT,
    'max_eval_samples': MAX_EVAL_SAMPLES,
    'batch_size': BATCH_SIZE,
}

wandb_run = wandb.init(
    project=WANDB_PROJECT,
    entity=WANDB_ENTITY,
    name=RUN_NAME,
    job_type='alignment_eval',
    config=wandb_config,
)


[34m[1mwandb[0m: Currently logged in as: [33mvedaangchopra[0m ([33mvedaangchopra_gatech[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
from datasets import Dataset as HFDataset  # at top with other imports

# Path to the Parquet file that contains the image data + captions
PIXMO_PARQUET_PATH = ROOT_DIR / "data" / "alignment_offline" / "pixmocap_offline_20000.parquet"

# Column names inside the Parquet
PIXMO_IMAGE_COL = "image_bytes"   # or "image_path" or whatever you used
PIXMO_TEXT_COL = "caption"


In [5]:
def build_eval_dataloader() -> DataLoader:
    """
    Build evaluation dataloader using the local PixMo-Cap parquet file,
    matching the in-memory setup used during training.
    """
    assert PIXMO_PARQUET_PATH.exists(), f"PixMo parquet not found at {PIXMO_PARQUET_PATH}"

    print(f"[Eval] Loading PixMo-Cap from parquet: {PIXMO_PARQUET_PATH}")
    pixmo_local = load_dataset(
        "parquet",
        data_files={"train": str(PIXMO_PARQUET_PATH)},
    )["train"]

    print("[Eval] Columns:", pixmo_local.column_names)

    if "split" in pixmo_local.column_names:
        before = len(pixmo_local)
        pixmo_local = pixmo_local.filter(lambda ex: ex["split"] == EVAL_SPLIT)
        print(f"[Eval] Filtered by split='{EVAL_SPLIT}': {before} -> {len(pixmo_local)} samples")

    if MAX_EVAL_SAMPLES is not None:
        n = min(MAX_EVAL_SAMPLES, len(pixmo_local))
        pixmo_local = pixmo_local.select(range(n))
        print(f"[Eval] Capped eval samples to {len(pixmo_local)}")

    print("\n[Eval] Example row preview:")
    ex0 = pixmo_local[0]
    print({k: str(v)[:80] for k, v in ex0.items()})

    dataset = InMemoryImageTextDataset(
        hf_dataset=pixmo_local,
        img_col=PIXMO_IMAGE_COL,
        txt_col=PIXMO_TEXT_COL,
        max_samples=None,
        image_size=(224, 224),
    )

    dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        collate_fn=collate_in_memory_images,
        pin_memory=True,
    )

    print(f"[Eval] Batches per epoch: {len(dataloader)}")
    return dataloader


eval_loader = build_eval_dataloader()
print("Eval batches:", len(eval_loader))


[Eval] Loading PixMo-Cap from parquet: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/data/alignment_offline/pixmocap_offline_20000.parquet
[Eval] Columns: ['image_bytes', 'caption', 'image_url', 'sample_id']
[Eval] Capped eval samples to 500

[Eval] Example row preview:
{'image_bytes': "b'\\xff\\xd8\\xff\\xe0\\x00\\x10JFIF\\x00\\x01\\x01\\x00\\x00\\x01\\x00\\x01\\x00\\x00\\xff\\xe2\\x", 'caption': 'This aerial photograph showcases a meticulously organized array of travel essent', 'image_url': 'https://i.redd.it/wbibz0yne60c1.jpg', 'sample_id': 'pixmo_0004864'}

üì• Pre-loading 500 images into memory...
   Image size: (224, 224)
   Using 32 parallel workers


Loading images: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 500/500 [00:00<00:00, 1957.36it/s]


‚úÖ Loaded 500 images into memory
   ‚ö†Ô∏è  500 images failed to load (using fallback)
[Eval] Batches per epoch: 8
Eval batches: 8


In [6]:
def load_aligned_vision_text_model(checkpoint_path: Path) -> Tuple[VisionTextAligner, AlignmentConfig]:
    """
    Load Phase-1 VisionTextAligner using train.py's checkpoint format.
    """
    # IMPORTANT: mirror the config you used during training
    cfg = AlignmentConfig(
        vision_model_name="openai/clip-vit-base-patch32",
        text_model_name="sentence-transformers/all-MiniLM-L6-v2",
        d_align=512,
        mrl_dims=[64, 128, 256, 512],
        device=str(device),
    )
    model = VisionTextAligner(cfg).to(device)

    # This populates vision_adapter and text_adapter from checkpoint
    load_vt_checkpoint(model, str(checkpoint_path))

    model.eval()
    return model, cfg


vt_model, vt_cfg = load_aligned_vision_text_model(CHECKPOINT_PATH)
print("Loaded VisionTextAligner with d_align =", vt_cfg.d_align)
print("Matryoshka dims:", vt_cfg.mrl_dims)


[VisionEncoder] Loaded openai/clip-vit-base-patch32, hidden_size=768
[TextEncoder] Loaded sentence-transformers/all-MiniLM-L6-v2, hidden_size=384
[VisionTextAligner] d_vision=768, d_text=384, d_align=512


FileNotFoundError: [Errno 2] No such file or directory: '/storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/checkpoints/phase1/vision_text/best.pt'

In [None]:
# === Helper: Exhaustive Retrieval Metrics ===

def compute_full_retrieval_metrics(
    q: Tensor,
    k: Tensor,
    ks: Tuple[int, ...] = (1, 5, 10, 50),
    map_ks: Tuple[int, ...] = (10, 50),
    ndcg_ks: Tuple[int, ...] = (10, 50),
) -> Dict[str, float]:
    """Compute rich retrieval metrics for one‚Äëto‚Äëone matching.

    Assumes q[i] should match k[i]. This is the typical setting for
    image‚Äìcaption retrieval on captioned datasets.
    """
    q = l2_normalize(q)
    k = l2_normalize(k)

    sims = q @ k.t()  # (N, N)
    N = sims.size(0)
    targets = torch.arange(N, device=sims.device)

    # Sort in descending similarity
    _, indices = sims.sort(dim=1, descending=True)

    # Position of the correct item for each query
    # ranks[i] is 0‚Äëindexed rank of the correct match
    ranks = (indices == targets.unsqueeze(1)).nonzero(as_tuple=False)[:, 1]

    metrics: Dict[str, float] = {}

    # Recall@K
    for k_val in ks:
        hit_rate = (ranks < k_val).float().mean().item()
        metrics[f'R@{k_val}'] = hit_rate * 100.0

    # Rank statistics (1‚Äëindexed for readability)
    metrics['mean_rank'] = (ranks.float() + 1).mean().item()
    metrics['median_rank'] = (ranks.float() + 1).median().item()

    # mAP@K (only one relevant item per query => AP = 1/rank if rank < K else 0)
    for k_val in map_ks:
        ap = torch.where(ranks < k_val, 1.0 / (ranks.float() + 1.0), torch.zeros_like(ranks, dtype=torch.float))
        metrics[f'mAP@{k_val}'] = ap.mean().item()

    # NDCG@K (single relevant item per query)
    # DCG = 1 / log2(rank + 2) if rank < K else 0; IDCG = 1
    for k_val in ndcg_ks:
        gains = torch.where(
            ranks < k_val,
            1.0 / torch.log2(ranks.float() + 2.0),
            torch.zeros_like(ranks, dtype=torch.float),
        )
        metrics[f'NDCG@{k_val}'] = gains.mean().item()

    # Also return raw ranks for diagnostic plots
    metrics['ranks_tensor'] = ranks
    return metrics


In [None]:
# === Vision‚ÄìText Retrieval Evaluation ===

@torch.no_grad()
def collect_vision_text_embeddings(
    model: VisionTextAligner,
    loader: DataLoader,
    use_text_as_query: bool = True,
    max_batches: Optional[int] = None,
) -> Tuple[Tensor, Tensor]:
    """
    Collect aligned embeddings for image‚Äìtext pairs in the aligned space.
    """
    model = model.to(device)
    model.eval()

    all_img = []
    all_txt = []

    for batch_idx, batch in enumerate(tqdm(loader, desc="Collecting embeddings")):
        images = batch["images"]
        texts = batch["captions"]

        z_img = model.encode_vision(images)  # (B, d_align)
        z_txt = model.encode_text(texts)     # (B, d_align)

        all_img.append(z_img.cpu())
        all_txt.append(z_txt.cpu())

        if max_batches is not None and (batch_idx + 1) >= max_batches:
            break

    emb_img = torch.cat(all_img, dim=0)
    emb_txt = torch.cat(all_txt, dim=0)
    print("Collected", emb_img.shape[0], "pairs")

    if use_text_as_query:
        q, k = emb_txt, emb_img
    else:
        q, k = emb_img, emb_txt

    return q, k


In [None]:


@torch.no_grad()
def eval_vision_text_retrieval(model: VisionTextAligner, cfg: AlignmentConfig, loader: DataLoader) -> Dict[str, float]:
    """Evaluate image‚Üîtext retrieval (both directions) including Matryoshka scales.
    Logs everything to WandB.
    """
    # 1. Collect aligned embeddings
    txt_queries, img_keys = collect_vision_text_embeddings(
        model, loader, use_text_as_query=True,
    )
    img_queries, txt_keys = img_keys, txt_queries  # reuse for opposite direction

    results: Dict[str, float] = {}

    # 2. Full‚Äëdim metrics
    txt_img_metrics = compute_full_retrieval_metrics(txt_queries, img_keys)
    img_txt_metrics = compute_full_retrieval_metrics(img_queries, txt_keys)

    # Log basic metrics (R@K etc.)
    for k_val in (1, 5, 10, 50):
        results[f'text_to_image/R@{k_val}'] = txt_img_metrics[f'R@{k_val}']
        results[f'image_to_text/R@{k_val}'] = img_txt_metrics[f'R@{k_val}']

    results['text_to_image/mean_rank'] = txt_img_metrics['mean_rank']
    results['text_to_image/median_rank'] = txt_img_metrics['median_rank']
    results['image_to_text/mean_rank'] = img_txt_metrics['mean_rank']
    results['image_to_text/median_rank'] = img_txt_metrics['median_rank']

    for k_val in (10, 50):
        results[f'text_to_image/mAP@{k_val}'] = txt_img_metrics[f'mAP@{k_val}']
        results[f'image_to_text/mAP@{k_val}'] = img_txt_metrics[f'mAP@{k_val}']
        results[f'text_to_image/NDCG@{k_val}'] = txt_img_metrics[f'NDCG@{k_val}']
        results[f'image_to_text/NDCG@{k_val}'] = img_txt_metrics[f'NDCG@{k_val}']

    # Log histograms of ranks
    wandb.log({
        'text_to_image/rank_hist': wandb.Histogram(txt_img_metrics['ranks_tensor'].cpu().numpy()),
        'image_to_text/rank_hist': wandb.Histogram(img_txt_metrics['ranks_tensor'].cpu().numpy()),
    })

    # 3. Matryoshka evaluation (if dims are defined)
    if cfg.mrl_dims and len(cfg.mrl_dims) > 0:
        dims_sorted = sorted(cfg.mrl_dims)
        for d in dims_sorted:
            qt = txt_queries[:, :d]
            ki = img_keys[:, :d]
            qi = img_queries[:, :d]
            kt = txt_keys[:, :d]

            t2i = compute_full_retrieval_metrics(qt, ki)
            i2t = compute_full_retrieval_metrics(qi, kt)

            for k_val in (1, 5, 10, 50):
                results[f'mrl_dim_{d}/text_to_image/R@{k_val}'] = t2i[f'R@{k_val}']
                results[f'mrl_dim_{d}/image_to_text/R@{k_val}'] = i2t[f'R@{k_val}']

            for k_val in (10, 50):
                results[f'mrl_dim_{d}/text_to_image/mAP@{k_val}'] = t2i[f'mAP@{k_val}']
                results[f'mrl_dim_{d}/image_to_text/mAP@{k_val}'] = i2t[f'mAP@{k_val}']

    # 4. Log all scalar metrics to WandB
    wandb.log(results)
    return results


vt_retrieval_results = eval_vision_text_retrieval(vt_model, vt_cfg, eval_loader)
vt_retrieval_results

In [None]:
# === Optional: Multimodal Perceiver Retrieval (Vision, Audio, Text) ===

@torch.no_grad()
def load_multimodal_alignment_model(checkpoint_path: Path) -> Optional[MultimodalAlignmentModel]:
    if not checkpoint_path.exists():
        print(f"[Eval] Multimodal checkpoint not found at {checkpoint_path}, skipping.")
        return None

    state = torch.load(checkpoint_path, map_location="cpu")
    mm_cfg_dict = state["mm_config"]
    mm_cfg = MultimodalAlignmentConfig(**mm_cfg_dict)
    mm_cfg.device = str(device)

    model = MultimodalAlignmentModel(mm_cfg).to(device)
    model.load_state_dict(state["model_state"])
    model.eval()
    print(f"[Eval] Loaded multimodal alignment model from {checkpoint_path}")
    return model


@torch.no_grad()
def eval_multimodal_pairs(model: MultimodalAlignmentModel, df: pd.DataFrame, max_samples: Optional[int] = None) -> Dict[str, float]:
    """Evaluate retrieval for (vision, audio, text) feature triples.

    Assumes `df` has columns like:
        - 'vision_feats'  : (T_v, D_v) flattened or serialized
        - 'audio_feats'   : (T_a, D_a)
        - 'text_feats'    : (T_t, D_t)

    You can adapt this to your actual feature storage format.
    """
    # This is a template: you'll need to plug in your own loading logic
    # for Perceiver feature datasets. For now we just show the metric
    # computation pattern once you have aligned embeddings.
    
    # Placeholders for demonstration
    # vision_aligned, audio_aligned, text_aligned = ...
    # For now, we skip implementation to avoid breaking the notebook.
    print('‚ö†Ô∏è Perceiver multimodal eval is a template; fill in feature loading for your dataset.')
    return {}


perceiver_model = load_multimodal_alignment_model(CHECKPOINT_PERCEIVER_MULTI)
if perceiver_model is not None:
    # TODO: replace this with real multimodal feature dataframe
    dummy_df = pd.DataFrame()
    perceiver_results = eval_multimodal_pairs(perceiver_model, dummy_df)
else:
    perceiver_results = {}

perceiver_results

In [None]:
# === Diagnostics: Plots for Ranks & Matryoshka Scales ===

def plot_rank_histogram(ranks: Tensor, title: str, max_rank: int = 100):
    ranks_np = ranks.cpu().numpy()
    ranks_np = np.clip(ranks_np, 0, max_rank)

    plt.figure(figsize=(6, 4))
    plt.hist(ranks_np + 1, bins=min(max_rank, 100))
    plt.xlabel('Rank (1‚Äëindexed)')
    plt.ylabel('Frequency')
    plt.title(title)
    plt.yscale('log')
    plt.tight_layout()
    plt.show()


def plot_mrl_curves(results: Dict[str, float], direction: str = 'text_to_image'):
    dims = []
    r1_vals = []
    r5_vals = []
    r10_vals = []

    for key, val in results.items():
        if key.startswith('mrl_dim_') and f'{direction}/R@1' in key:
            dim = int(key.split('/')[0].split('_')[-1])
            dims.append(dim)
    
    dims = sorted(set(dims))
    if not dims:
        print('No Matryoshka results found in metrics dict.')
        return

    for d in dims:
        r1_vals.append(results[f'mrl_dim_{d}/{direction}/R@1'])
        r5_vals.append(results[f'mrl_dim_{d}/{direction}/R@5'])
        r10_vals.append(results[f'mrl_dim_{d}/{direction}/R@10'])

    plt.figure(figsize=(6, 4))
    plt.plot(dims, r1_vals, marker='o', label='R@1')
    plt.plot(dims, r5_vals, marker='o', label='R@5')
    plt.plot(dims, r10_vals, marker='o', label='R@10')
    plt.xlabel('Matryoshka dimension (d)')
    plt.ylabel('Recall (%)')
    plt.title(f'Matryoshka Retrieval vs Dim ({direction})')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()


# Example usage once `vt_retrieval_results` is computed:
if 'text_to_image/R@1' in vt_retrieval_results:
    print('Text‚ÜíImage R@1:', vt_retrieval_results['text_to_image/R@1'])
    plot_mrl_curves(vt_retrieval_results, direction='text_to_image')
    plot_mrl_curves(vt_retrieval_results, direction='image_to_text')


## Summary & Next Steps

This notebook gives a **rich retrieval evaluation suite** for Phase‚Äë1 alignment models:

- Standard image‚Üîtext retrieval metrics (R@K, MR, MedR, mAP, NDCG)
- Matryoshka multi‚Äëscale analysis vs embedding dimension
- Hooks to extend to multimodal Perceiver models (vision, audio, text)
- All results logged to **Weights & Biases** for cross‚Äëexperiment comparison.

**Next:**
- Plug in your Parquet‚Äëbased PixMo‚ÄëCap / audio feature loaders into `build_eval_dataloader` and the Perceiver template.
- Add baselines (e.g., CLIP, raw DINOv2 + MiniLM retrieval) and log them with a different `job_type` / `group` in WandB.
- Use the logged tables & plots directly in your report, comparing against Freeze‚ÄëAlign, ImageBind, Matryoshka, and Unified‚ÄëIO 2.
