<p>
# ============================================================
# 05_eval_embeddings.ipynb
# Edge Assistant — Embedding Evaluation Notebook
#
# This notebook assumes:
# - You have an "aligned" model (encoders + perceiver + projectors)
# - You have a saved checkpoint: ckpt_path (see EvalConfig below)
#
# It evaluates:
#   1. Text↔Image retrieval (R@1/5/10)
#   2. Text↔Audio retrieval (R@1/5/10)
#   3. Audio→Image zero-shot classification (ESC-50-style)
#   4. Matryoshka compression curve (accuracy vs K dims/tokens)
#   5. Gramian Volume (latent geometry)
#
# You plug in:
#   - Your model in `load_aligned_model`
#   - Your DataLoaders for COCO/Flickr/AudioCaps/Clotho/ESC-50/etc.
# ============================================================
</p>

In [None]:


import os
import math
import random
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme()
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


In [None]:
# ============================================================
# Config & Reproducibility
# ============================================================

@dataclass
class EvalConfig:
    # --- model ---
    ckpt_path: str = "./checkpoints/aligned_model.pt"  # <--- EDIT THIS
    use_bfloat16: bool = False

    # --- evaluation ---
    batch_size: int = 32
    num_workers: int = 4

    # matryoshka / dimensionality
    emb_dim_image: int = 1024   # set to your projector output dim
    emb_dim_audio: int = 1024
    emb_dim_text: int = 1024

    # gramian volume
    gram_n_samples: int = 256

cfg = EvalConfig()

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)


In [None]:
# ============================================================
# Model Loading & Wrapper
# ============================================================

class OmniEmbedWrapper(nn.Module):
    """
    Thin wrapper that standardizes the interface:
        - embed_image(pixel_values)  -> (B, D)
        - embed_audio(audio_values)  -> (B, D)
        - embed_text(input_ids, attention_mask) -> (B, D)
    You should adapt the internals to your model's API.
    """
    def __init__(self, backbone: nn.Module):
        super().__init__()
        self.backbone = backbone

    @torch.no_grad()
    def embed_image(self, pixel_values: torch.Tensor) -> torch.Tensor:
        # EXPECTED: pixel_values: (B, C, H, W)
        # Replace this with your own call:
        if hasattr(self.backbone, "encode_image"):
            z = self.backbone.encode_image(pixel_values)
        elif hasattr(self.backbone, "forward_image"):
            z = self.backbone.forward_image(pixel_values)
        else:
            raise NotImplementedError("Backbone has no image encoder method.")
        return z

    @torch.no_grad()
    def embed_audio(self, audio_values: torch.Tensor) -> torch.Tensor:
        # EXPECTED: audio_values: (B, T) or (B, C, T) depending on your pipeline
        if hasattr(self.backbone, "encode_audio"):
            z = self.backbone.encode_audio(audio_values)
        elif hasattr(self.backbone, "forward_audio"):
            z = self.backbone.forward_audio(audio_values)
        else:
            raise NotImplementedError("Backbone has no audio encoder method.")
        return z

    @torch.no_grad()
    def embed_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        if hasattr(self.backbone, "encode_text"):
            z = self.backbone.encode_text(input_ids, attention_mask)
        elif hasattr(self.backbone, "forward_text"):
            z = self.backbone.forward_text(input_ids, attention_mask)
        else:
            raise NotImplementedError("Backbone has no text encoder method.")
        return z


def load_aligned_model(cfg: EvalConfig) -> OmniEmbedWrapper:
    """
    EDIT THIS FUNCTION for your project.
    It should:
      1. Instantiate your aligned model class (Perceiver + projectors)
      2. Load the checkpoint
      3. Wrap it with OmniEmbedWrapper
    """

    # --------------------------------------------------------
    # Example skeleton (you must replace with your own import)
    # --------------------------------------------------------
    #
    # from edge_glass.models.aligned_model import AlignedModel
    # model = AlignedModel(...)
    #
    # state = torch.load(cfg.ckpt_path, map_location="cpu")
    # if "model" in state:
    #     model.load_state_dict(state["model"])
    # else:
    #     model.load_state_dict(state)
    #
    # model.to(device)
    # model.eval()
    # return OmniEmbedWrapper(model)
    # --------------------------------------------------------

    raise NotImplementedError(
        "Implement load_aligned_model(cfg) to return OmniEmbedWrapper(backbone)."
    )

# Try loading (will error until you implement)
# omni = load_aligned_model(cfg)


In [None]:
omni = load_aligned_model(cfg)
dtype = torch.bfloat16 if cfg.use_bfloat16 and torch.cuda.is_bf16_supported() else torch.float32
omni = omni.to(device=device, dtype=dtype)
omni.eval()


In [None]:
# ============================================================
# Retrieval Metrics (R@K)
# ============================================================

def recall_at_k(sim_matrix: torch.Tensor, k: int = 1) -> float:
    """
    sim_matrix: (N_query, N_target)
    Ground truth is assumed to be 'diagonal' (i <-> i).
    """
    # indices of top-k targets for each query
    topk = sim_matrix.topk(k, dim=-1).indices
    correct = torch.arange(sim_matrix.size(0), device=sim_matrix.device)
    hits = (topk == correct.unsqueeze(-1)).any(dim=-1).float().mean().item()
    return hits


def compute_retrieval_scores(query_emb: torch.Tensor,
                             target_emb: torch.Tensor,
                             prefix: str = "") -> dict:
    """
    Cosine similarities + R@1/5/10 + (optional) mean/median rank.
    """
    q = F.normalize(query_emb, dim=-1)
    t = F.normalize(target_emb, dim=-1)
    sim = q @ t.T  # (N, N)

    r1 = recall_at_k(sim, 1)
    r5 = recall_at_k(sim, 5)
    r10 = recall_at_k(sim, 10)

    # ranks for each query
    sorted_indices = sim.argsort(dim=-1, descending=True)
    targets = torch.arange(sim.size(0), device=sim.device)
    ranks = (sorted_indices == targets.unsqueeze(-1)).nonzero(as_tuple=False)[:, 1]
    mean_rank = ranks.float().mean().item()
    median_rank = ranks.median().item()

    return {
        f"{prefix}R@1": r1,
        f"{prefix}R@5": r5,
        f"{prefix}R@10": r10,
        f"{prefix}MeanRank": mean_rank,
        f"{prefix}MedianRank": median_rank,
        "sim_matrix": sim.detach().cpu()
    }


In [None]:
# ============================================================
# Dataset Interfaces (YOU NEED TO PROVIDE DATALOADERS)
# ============================================================

"""
The notebook expects DataLoaders of the following form:

1) Image–Text retrieval (e.g., COCO/Flickr):
   batch = {
       "pixel_values": FloatTensor (B, C, H, W),
       "input_ids": LongTensor (B, L),
       "attention_mask": LongTensor (B, L)
   }

2) Text–Audio retrieval (e.g., AudioCaps/Clotho):
   batch = {
       "audio_values": FloatTensor (B, T or C, T),
       "input_ids": LongTensor (B, L),
       "attention_mask": LongTensor (B, L)
   }

3) Audio–Image classification (ESC-50 style):
   train batch: {
       "pixel_values": FloatTensor (B, C, H, W),
       "label": LongTensor (B,)
   }
   test batch: {
       "audio_values": FloatTensor (B, T),
       "label": LongTensor (B,)
   }

4) Tri-modal for GRAM test (e.g., VALOR):
   batch = {
       "pixel_values": FloatTensor (B, C, H, W),
       "audio_values": FloatTensor (B, T),
       "input_ids": LongTensor (B, L),
       "attention_mask": LongTensor (B, L),
   }

You can build these DataLoaders in a separate cell or module.
Below we only define EVALUATION functions that *consume* them.
"""


In [None]:
# ============================================================
# Text → Image Retrieval
# ============================================================

def eval_text_to_image(omni: OmniEmbedWrapper,
                       dataloader: DataLoader,
                       max_batches: int | None = None) -> dict:
    img_embs = []
    txt_embs = []

    for i, batch in enumerate(tqdm(dataloader, desc="Text→Image")):
        if max_batches is not None and i >= max_batches:
            break

        pixels = batch["pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)
        mask = batch["attention_mask"].to(device)

        with torch.no_grad():
            i_emb = omni.embed_image(pixels)
            t_emb = omni.embed_text(input_ids, mask)

        img_embs.append(i_emb)
        txt_embs.append(t_emb)

    img_embs = torch.cat(img_embs, dim=0)
    txt_embs = torch.cat(txt_embs, dim=0)

    metrics = compute_retrieval_scores(txt_embs, img_embs, prefix="T2I_")
    print("Text→Image Retrieval:")
    for k, v in metrics.items():
        if k.endswith("sim_matrix"):
            continue
        print(f"  {k}: {v:.4f}" if isinstance(v, float) else f"  {k}: {v}")
    return metrics


In [None]:
# ============================================================
# Text → Audio Retrieval
# ============================================================

def eval_text_to_audio(omni: OmniEmbedWrapper,
                       dataloader: DataLoader,
                       max_batches: int | None = None) -> dict:
    aud_embs = []
    txt_embs = []

    for i, batch in enumerate(tqdm(dataloader, desc="Text→Audio")):
        if max_batches is not None and i >= max_batches:
            break

        audio = batch["audio_values"].to(device)
        input_ids = batch["input_ids"].to(device)
        mask = batch["attention_mask"].to(device)

        with torch.no_grad():
            a_emb = omni.embed_audio(audio)
            t_emb = omni.embed_text(input_ids, mask)

        aud_embs.append(a_emb)
        txt_embs.append(t_emb)

    aud_embs = torch.cat(aud_embs, dim=0)
    txt_embs = torch.cat(txt_embs, dim=0)

    metrics = compute_retrieval_scores(txt_embs, aud_embs, prefix="T2A_")
    print("Text→Audio Retrieval:")
    for k, v in metrics.items():
        if k.endswith("sim_matrix"):
            continue
        print(f"  {k}: {v:.4f}" if isinstance(v, float) else f"  {k}: {v}")
    return metrics


In [None]:
# ============================================================
# Audio → Image Zero-shot Classification (ESC-50 style)
# ============================================================

from sklearn.metrics import accuracy_score

def build_image_class_prototypes(omni: OmniEmbedWrapper,
                                 train_loader: DataLoader,
                                 n_classes: int) -> torch.Tensor:
    """
    Build class prototypes by averaging image embeddings per class.
    Returns Tensor of shape (n_classes, D)
    """
    protos = [ [] for _ in range(n_classes) ]

    for batch in tqdm(train_loader, desc="Building image prototypes"):
        pixels = batch["pixel_values"].to(device)
        labels = batch["label"]  # (B,)

        with torch.no_grad():
            img_emb = omni.embed_image(pixels)  # (B, D)
            img_emb = img_emb.cpu()

        for emb, lab in zip(img_emb, labels):
            protos[lab.item()].append(emb)

    # average
    proto_tensors = []
    for cls_idx, vecs in enumerate(protos):
        if len(vecs) == 0:
            raise ValueError(f"No samples found for class {cls_idx}.")
        proto_tensors.append(torch.stack(vecs, dim=0).mean(dim=0))

    proto_mat = torch.stack(proto_tensors, dim=0)  # (C, D)
    return proto_mat.to(device)


def eval_audio_zero_shot_classification(omni: OmniEmbedWrapper,
                                        proto_mat: torch.Tensor,
                                        test_loader: DataLoader) -> float:
    preds = []
    gts = []

    proto_norm = F.normalize(proto_mat, dim=-1)

    for batch in tqdm(test_loader, desc="Audio→Image Zero-shot"):
        audio = batch["audio_values"].to(device)
        labels = batch["label"].tolist()

        with torch.no_grad():
            a_emb = omni.embed_audio(audio)  # (B, D)
            a_emb = F.normalize(a_emb, dim=-1)

            sim = a_emb @ proto_norm.T  # (B, C)
            pred = sim.argmax(dim=-1).cpu().tolist()

        preds.extend(pred)
        gts.extend(labels)

    acc = accuracy_score(gts, preds)
    print(f"Audio→Image zero-shot Top-1 accuracy: {acc:.4f}")
    return acc


In [None]:
# ============================================================
# Matryoshka Compression Curve
# ============================================================

def matryoshka_curve_image(omni: OmniEmbedWrapper,
                           dataloader: DataLoader,
                           ks: list[int],
                           full_dim: int) -> dict:
    """
    Evaluate Text→Image R@1 while truncating image embeddings
    to first K dimensions (Matryoshka-style).
    """
    # First, cache FULL embeddings once
    img_embs = []
    txt_embs = []

    for batch in tqdm(dataloader, desc="Caching full embeddings"):
        pixels = batch["pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)
        mask = batch["attention_mask"].to(device)

        with torch.no_grad():
            i_emb = omni.embed_image(pixels)   # (B, D)
            t_emb = omni.embed_text(input_ids, mask)

        img_embs.append(i_emb.cpu())
        txt_embs.append(t_emb.cpu())

    img_embs = torch.cat(img_embs, dim=0)
    txt_embs = torch.cat(txt_embs, dim=0)

    results = {}

    for k in ks:
        assert k <= full_dim, f"K={k} > full_dim={full_dim}"
        print(f"\nEvaluating Matryoshka at K={k} dims")

        i_k = img_embs[:, :k].to(device)
        t_k = txt_embs[:, :k].to(device)

        metrics = compute_retrieval_scores(t_k, i_k, prefix=f"K{k}_")
        r1 = metrics[f"K{k}_R@1"]
        results[k] = r1
        print(f"  R@1 (K={k}): {r1:.4f}")

    return results


In [None]:
# ============================================================
# Gramian Volume (GRAM) — Latent Geometry
# ============================================================

def gramian_volume_triplet(a: torch.Tensor,
                           b: torch.Tensor,
                           c: torch.Tensor) -> float:
    """
    a, b, c: (D,) vectors.
    Compute det(G) where G_ij = <v_i, v_j> for normalized vectors.
    """
    A = F.normalize(a, dim=-1)
    B = F.normalize(b, dim=-1)
    C = F.normalize(c, dim=-1)

    G = torch.stack([
        torch.stack([A @ A, A @ B, A @ C]),
        torch.stack([B @ A, B @ B, B @ C]),
        torch.stack([C @ A, C @ B, C @ C]),
    ])
    return torch.det(G).item()


def eval_gramian_volume(omni: OmniEmbedWrapper,
                        dataloader: DataLoader,
                        n_samples: int) -> list[float]:
    """
    Expects tri-modal batches with keys:
      - "pixel_values"
      - "audio_values"
      - "input_ids"
      - "attention_mask"
    """
    vols = []
    for batch in tqdm(dataloader, desc="GRAM volume"):
        pixels = batch["pixel_values"].to(device)
        audio = batch["audio_values"].to(device)
        input_ids = batch["input_ids"].to(device)
        mask = batch["attention_mask"].to(device)

        with torch.no_grad():
            i_emb = omni.embed_image(pixels)       # (B, D)
            a_emb = omni.embed_audio(audio)        # (B, D)
            t_emb = omni.embed_text(input_ids, mask)

        for j in range(i_emb.size(0)):
            v = gramian_volume_triplet(
                i_emb[j].cpu(), a_emb[j].cpu(), t_emb[j].cpu()
            )
            vols.append(v)
            if len(vols) >= n_samples:
                break

        if len(vols) >= n_samples:
            break

    vols = vols[:n_samples]
    print(f"Collected {len(vols)} Gramian volumes.")
    print(f"Mean: {np.mean(vols):.6f}, Median: {np.median(vols):.6f}")
    return vols


In [None]:
# ============================================================
# MAIN DRIVER — Plug in your loaders here
# ============================================================

# These should be real DataLoaders from your code.
# For now we keep them as placeholders.
coco_val_loader = None          # -> image-text (COCO/Flickr)
audiocaps_val_loader = None     # -> text-audio (AudioCaps/Clotho)
esc50_train_loader = None       # -> image-labels for prototypes
esc50_test_loader = None        # -> audio-labels for classification
tri_modal_loader = None         # -> VALOR/VGGSound-style (I+A+T)


def run_all_evals(omni: OmniEmbedWrapper, cfg: EvalConfig):
    results = {}

    # 1) Text→Image retrieval (COCO/Flickr)
    if coco_val_loader is not None:
        t2i_metrics = eval_text_to_image(omni, coco_val_loader)
        results.update(t2i_metrics)

        # Matryoshka curve on same loader (optional)
        ks = [8, 16, 32, 64, 128, 256, cfg.emb_dim_image]
        ks = [k for k in ks if k <= cfg.emb_dim_image]
        mat_res = matryoshka_curve_image(
            omni, coco_val_loader, ks=ks, full_dim=cfg.emb_dim_image
        )
        results["matryoshka"] = mat_res
        plot_matryoshka_curve(mat_res)

        # Heatmap for a small subset
        sim_matrix = t2i_metrics["sim_matrix"]
        plot_similarity_heatmap(sim_matrix[:32, :32], title="T→I Similarity (First 32 examples)")

    # 2) Text→Audio retrieval (AudioCaps/Clotho)
    if audiocaps_val_loader is not None:
        t2a_metrics = eval_text_to_audio(omni, audiocaps_val_loader)
        results.update(t2a_metrics)

    # 3) Audio→Image zero-shot classification (ESC-50)
    if esc50_train_loader is not None and esc50_test_loader is not None:
        # You must pass correct n_classes (e.g., 50 for ESC-50)
        n_classes = 50
        proto_mat = build_image_class_prototypes(omni, esc50_train_loader, n_classes)
        esc_acc = eval_audio_zero_shot_classification(omni, proto_mat, esc50_test_loader)
        results["ESC50_zeroshot_acc"] = esc_acc

    # 4) Gramian Volume (tri-modal binding)
    if tri_modal_loader is not None:
        vols = eval_gramian_volume(omni, tri_modal_loader, n_samples=cfg.gram_n_samples)
        results["gram_volumes"] = vols
        plt.figure(figsize=(6,4))
        plt.hist(vols, bins=30)
        plt.title("Distribution of Gramian Volumes (Tri-modal)")
        plt.xlabel("Volume")
        plt.ylabel("Count")
        plt.tight_layout()
        plt.show()

    return results


# Example usage (after you've constructed all loaders):

# omni = load_aligned_model(cfg).to(device)
# omni.eval()
# all_results = run_all_evals(omni, cfg)
# all_results
