In [8]:
# Tri-Modal Alignment Demo (Phases 0–3)
# Author: ChatGPT (for Vedaang)
#
# This notebook implements Phases 0–3 of your experiment plan:
#  - Phase 0: Setup, datasets, frozen encoders, sanity checks
#  - Phase 1: Pairwise-only training (i<->t and a<->t separately) + full tri-direction eval
#  - Phase 2: Image-Hub training (i<->t + i<->a) with pseudo (i,a,t) triplets via CLIP retrieval
#  - Phase 3: Tri-modal adapters with cycle consistency (shared space, frozen encoders)
#
# Outputs: retrieval metrics (R@1/5/10, mAP), cosine histograms, comparison tables.
#
# Notes:
# - Uses small, controllable subsets to be runnable on a single GPU.
# - Uses public HF datasets (COCO captions, AudioCaps) and HF encoders (CLIP, Whisper, MiniLM).
# - Encoders are *frozen*; we train light adapters/projectors.
# - Pseudo (i,a,t) triplets are created by retrieving a COCO image for each AudioCaps caption via CLIP.
#
# Before running, ensure internet access in the kernel (to fetch models/datasets) and enough disk cache.
# If offline, pre-download datasets/models in ~/.cache/huggingface.

# %% [markdown]
# ## 0. Setup
# * Installs (if needed)
# * Imports & Config
# * Reproducibility and device

# %%

In [9]:

import sys, os, math, random, time, json, itertools, functools, gc
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
except Exception as e:
    raise

# Light installs: uncomment if you need to install in the environment
# !pip -q install transformers datasets torchaudio tqdm scikit-learn matplotlib pillow

import torchaudio
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModel, AutoProcessor,
    CLIPProcessor, CLIPModel,
    WhisperFeatureExtractor, WhisperModel,
)
import numpy as np
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import average_precision_score

SEED = 1337
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


In [10]:
# %% [markdown]
# ## 0.1 Config
# You can tweak subset sizes and dims here for quick iterations.

# %%
@dataclass
class Config:
    # Data sizes (keep small for quick runs; scale later)
    coco_images_train: int = 4000   # number of COCO images for (i,t) train
    coco_images_val: int = 1000     # number for (i,t) eval
    audiocaps_train: int = 4000     # number of AudioCaps audio clips for (a,t) train
    audiocaps_val: int = 1000       # number for (a,t) eval

    # Subset used as candidate image pool for pseudo (i,a,t) triplets
    image_pool_for_triplets: int = 5000

    # Common embedding dim for adapters
    embed_dim: int = 512

    # Matryoshka (optional, Phase 5) – placeholders
    matryoshka_widths: List[int] = (64, 128, 256, 512)

    # Audio settings
    target_sr: int = 16000
    max_audio_sec: float = 12.0

    # Image settings
    image_size: int = 224

    # Train settings
    batch_size: int = 64
    num_workers: int = 4
    epochs_pairwise: int = 3
    epochs_hub: int = 3
    epochs_trimodal: int = 3
    lr: float = 1e-3
    weight_decay: float = 0.01
    fp16: bool = True

    # Eval
    eval_batch: int = 128

CFG = Config()
print(CFG)


In [11]:

# %% [markdown]
# ## 0.2 Encoders (Frozen)
# - Vision: CLIP ViT-B/32
# - Text: MiniLM (sentence-transformers/all-MiniLM-L6-v2)
# - Audio: Whisper encoder (mean-pooled hidden states)

# %%
# Vision (CLIP)
clip_model_id = "openai/clip-vit-base-patch32"
clip_model = CLIPModel.from_pretrained(clip_model_id).eval().to(device)
clip_processor = CLIPProcessor.from_pretrained(clip_model_id)
for p in clip_model.parameters():
    p.requires_grad = False

# Text (MiniLM)
text_model_id = "sentence-transformers/all-MiniLM-L6-v2"
text_model = AutoModel.from_pretrained(text_model_id).eval().to(device)
text_tokenizer = AutoTokenizer.from_pretrained(text_model_id)
for p in text_model.parameters():
    p.requires_grad = False

# Audio (Whisper encoder)
whisper_id = "openai/whisper-small"
whisper_enc = WhisperModel.from_pretrained(whisper_id).encoder.eval().to(device)
whisper_feat = WhisperFeatureExtractor.from_pretrained(whisper_id)
for p in whisper_enc.parameters():
    p.requires_grad = False


In [12]:
# %% [markdown]
# ## 0.3 Helper: Frozen encoder wrappers -> fixed-size vectors

# %%
@torch.inference_mode()
def encode_text(texts: List[str]) -> torch.Tensor:
    # Mean Pool over last hidden states
    toks = text_tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(device)
    out = text_model(**toks)
    last = out.last_hidden_state  # [B, L, H]
    mask = toks.attention_mask.unsqueeze(-1)  # [B, L, 1]
    summed = (last * mask).sum(dim=1)
    denom = mask.sum(dim=1).clamp(min=1)
    emb = summed / denom
    emb = F.normalize(emb, dim=-1)
    return emb

@torch.inference_mode()
def encode_image_pil(pil_list: List["PIL.Image.Image"]) -> torch.Tensor:
    inputs = clip_processor(images=pil_list, return_tensors="pt").to(device)
    out = clip_model.get_image_features(**inputs)  # [B, 512]
    emb = F.normalize(out, dim=-1)
    return emb

@torch.inference_mode()
def encode_text_with_clip(texts: List[str]) -> torch.Tensor:
    # For CLIP retrieval step only (text->image retrieval)
    inputs = clip_processor(text=texts, return_tensors="pt", padding=True).to(device)
    out = clip_model.get_text_features(**inputs)
    return F.normalize(out, dim=-1)

@torch.inference_mode()
def encode_audio_wave(wave: torch.Tensor, sr: int) -> torch.Tensor:
    # wave: [1, T] mono
    if sr != CFG.target_sr:
        wave = torchaudio.functional.resample(wave, sr, CFG.target_sr)
    # trim/pad to max_audio_sec
    max_len = int(CFG.target_sr * CFG.max_audio_sec)
    if wave.size(1) > max_len:
        wave = wave[:, :max_len]
    else:
        pad = max_len - wave.size(1)
        if pad > 0:
            wave = F.pad(wave, (0, pad))
    inputs = whisper_feat(wave.squeeze(0).cpu().numpy(), sampling_rate=CFG.target_sr, return_tensors="pt")
    # Whisper expects features shaped for decoder usually; for encoder, we use input_features.
    feats = inputs["input_features"].to(device)
    out = whisper_enc(feats)  # Base encoder forward
    last = out.last_hidden_state  # [B, T, H]
    emb = last.mean(dim=1)  # temporal mean-pool
    emb = F.normalize(emb, dim=-1)
    return emb


In [26]:
from datasets import load_dataset

# --- helpers --------------------------------------------------------------

def _load_head(dataset_name: str, split: str, n: int, streaming: bool = True):
    """
    Return the first n samples of `split` from `dataset_name`.
    - streaming=True: returns an IterableDataset via .take(n) (no full download)
    - streaming=False: materializes only the first n rows to disk/ram
    """
    if streaming:
        return load_dataset(dataset_name, split=split, streaming=True).take(n)
    else:
        return load_dataset(dataset_name, split=f"{split}[:{n}]")

# ---------- helpers ----------
def _first_caption_field(example):
    # For COCO: 'captions' is a list of dicts [{'text': ...}, ...]
    caps = example.get("captions") or []
    if caps:
        return caps[0]["text"] if isinstance(caps[0], dict) and "text" in caps[0] else str(caps[0])
    return example.get("caption", "")

# ---------- COCO (images + text) ----------
def load_coco_subsets(train_n: int, val_n: int):
    """
    sentence-transformers/coco-captions has only 'train'.
    We create disjoint train/val by skipping into the same stream.
    Everything is streaming: no full download, no Arrow files written.
    """
    def pick_first_caption(example):
        return {"image": example["image"], "caption": _first_caption_field(example)}

    # Fresh stream for train
    coco_train_stream = load_dataset("sentence-transformers/coco-captions",
                                     split="train", streaming=True)
    COCO_TRAIN = coco_train_stream.map(pick_first_caption).take(train_n)

    # Fresh stream for val (skip the first `train_n` to avoid overlap)
    coco_val_stream = load_dataset("sentence-transformers/coco-captions",
                                   split="train", streaming=True)
    COCO_VAL = coco_val_stream.map(pick_first_caption).skip(train_n).take(val_n)
    return COCO_TRAIN, COCO_VAL

# ---------- AudioCaps (audio + text) ----------
def load_audiocaps_subsets(train_n: int, val_n: int):
    """
    AudioCaps has 'train', 'validation', 'test' on the Hub.
    We try 'validation' for val, fall back to 'test' if needed.
    Streaming the Audio feature avoids big writes to disk.
    """
    def norm_audio(example):
        # Keep HF Audio feature (lazy decode). Add normalized caption.
        return {"audio": example["audio"], "caption": _first_caption_field(example)}

    AUDIO_TRAIN = load_dataset("d0rj/audiocaps", split="train", streaming=True).map(norm_audio).take(train_n)

    try:
        AUDIO_VAL = load_dataset("d0rj/audiocaps", split="validation", streaming=True).map(norm_audio).take(val_n)
    except Exception:
        AUDIO_VAL = load_dataset("d0rj/audiocaps", split="test", streaming=True).map(norm_audio).take(val_n)

    return AUDIO_TRAIN, AUDIO_VAL


In [28]:
print("Loading datasets (streaming, no full download)...")
COCO_TRAIN, COCO_VAL = load_coco_subsets(CFG.coco_images_train, CFG.coco_images_val)
AUDIO_TRAIN, AUDIO_VAL = load_audiocaps_subsets(CFG.audiocaps_train, CFG.audiocaps_val)


In [30]:
# Peek one example from each (doesn't materialize the whole thing)
print(next(iter(COCO_TRAIN)))
print(next(iter(COCO_VAL)))
print(next(iter(AUDIO_TRAIN)))
print(next(iter(AUDIO_VAL)))


In [None]:
# %% [markdown]
# ## 0.5 DataLoaders

# %%
class CocoITDataset(Dataset):
    def __init__(self, hf_ds):
        self.ds = hf_ds
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        ex = self.ds[int(idx)]
        img = ex["image"]  # PIL
        cap = ex["caption"]
        return img, cap

class AudioTDataset(Dataset):
    def __init__(self, hf_ds):
        self.ds = hf_ds
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        ex = self.ds[int(idx)]
        wave = torch.tensor(ex["audio"]).float().unsqueeze(0)  # [1, T]
        sr = int(ex["sr"])
        cap = ex["caption"]
        return wave, sr, cap

coco_train_loader = DataLoader(CocoITDataset(COCO_TRAIN), batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers, collate_fn=lambda b: list(zip(*b)))
coco_val_loader = DataLoader(CocoITDataset(COCO_VAL), batch_size=CFG.eval_batch, shuffle=False, num_workers=CFG.num_workers, collate_fn=lambda b: list(zip(*b)))

audio_train_loader = DataLoader(AudioTDataset(AUDIO_TRAIN), batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers, collate_fn=lambda b: list(zip(*b)))
audio_val_loader = DataLoader(AudioTDataset(AUDIO_VAL), batch_size=CFG.eval_batch, shuffle=False, num_workers=CFG.num_workers, collate_fn=lambda b: list(zip(*b)))

print("Loaders ready.")


In [None]:

# %% [markdown]
# ## 0.6 Pseudo (i,a,t) Triplets via CLIP Retrieval (for Phases 2 & 3)
# For each AudioCaps caption, find the nearest COCO image (from a pool) using CLIP text<->image similarity.

# %%
# Build an image pool from COCO_TRAIN + COCO_VAL
pool_imgs = list(itertools.islice((ex["image"] for ex in COCO_TRAIN), CFG.image_pool_for_triplets))
if len(pool_imgs) < CFG.image_pool_for_triplets:
    # top-up with validation if needed
    remaining = CFG.image_pool_for_triplets - len(pool_imgs)
    pool_imgs += list(itertools.islice((ex["image"] for ex in COCO_VAL), remaining))
print("Image pool size:", len(pool_imgs))

# Precompute CLIP image embeddings for the pool
img_pool_embs = []
BS = 64
for i in tqdm(range(0, len(pool_imgs), BS), desc="Encoding image pool"):
    batch = pool_imgs[i:i+BS]
    with torch.inference_mode():
        e = encode_image_pil(batch)
    img_pool_embs.append(e.cpu())
img_pool_embs = torch.cat(img_pool_embs, dim=0)  # [P, 512]
img_pool_embs = F.normalize(img_pool_embs, dim=-1)

# Retrieve top-1 image for each AudioCaps caption in the *train* split to create pseudo triplets
pseudo_triplets = []  # list of (PIL.Image, np.array audio, sr, caption)
text_caps = []
for i in tqdm(range(len(AUDIO_TRAIN)), desc="Retrieving images for AudioCaps captions"):
    cap = AUDIO_TRAIN[i]["caption"]
    text_caps.append(cap)

# Batch-encode text with CLIP for speed
cap_embs = []
for i in tqdm(range(0, len(text_caps), BS), desc="Encoding caps (CLIP)"):
    e = encode_text_with_clip(text_caps[i:i+BS]).detach().cpu()
    cap_embs.append(e)
cap_embs = torch.cat(cap_embs, dim=0)  # [N, 512]

# cosine sim -> nearest image index
cap_embs = F.normalize(cap_embs, dim=-1)
img_pool_embs_t = img_pool_embs.t()  # [512, P]
sims = cap_embs @ img_pool_embs_t  # [N, P]
nearest = sims.argmax(dim=1).tolist()

for idx, img_idx in enumerate(nearest):
    img = pool_imgs[img_idx]
    wave = torch.tensor(AUDIO_TRAIN[idx]["audio"]).float().numpy()
    sr = int(AUDIO_TRAIN[idx]["sr"])
    cap = AUDIO_TRAIN[idx]["caption"]
    pseudo_triplets.append({"image": img, "audio": wave, "sr": sr, "caption": cap})

print("Pseudo triplets created:", len(pseudo_triplets))

# %% [markdown]
# ## 0.7 Adapters & Losses

# %%
class LinearAdapter(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.proj = nn.Linear(in_dim, out_dim)
        self.ln = nn.LayerNorm(out_dim)
    def forward(self, x):
        z = self.proj(x)
        z = self.ln(z)
        z = F.normalize(z, dim=-1)
        return z

class MLPAdapter(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, hidden: int = 1024):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.GELU(),
            nn.Linear(hidden, out_dim),
        )
        self.ln = nn.LayerNorm(out_dim)
    def forward(self, x):
        z = self.net(x)
        z = self.ln(z)
        z = F.normalize(z, dim=-1)
        return z

# InfoNCE loss (symmetric)

def info_nce(z1: torch.Tensor, z2: torch.Tensor, temperature: float = 0.07):
    # z1, z2: [B, D], normalized
    logits = (z1 @ z2.t()) / temperature
    targets = torch.arange(z1.size(0), device=z1.device)
    loss = (F.cross_entropy(logits, targets) + F.cross_entropy(logits.t(), targets)) / 2
    return loss

# Cycle consistency: pull z_v and z_a together if they share same caption

def cycle_consistency(z_a: torch.Tensor, z_v: torch.Tensor, margin: float = 0.0):
    # Positive pairs are aligned by index in the batch
    # L2 or cosine: we use cosine distance (1 - cos)
    sim = (z_a * z_v).sum(dim=-1)
    loss = (1 - sim).mean()
    return loss

# %% [markdown]
# ## 0.8 Utilities: batching encode, metrics, plots

# %%
@torch.inference_mode()
def batch_encode_images(pils: List[Image.Image]) -> torch.Tensor:
    embs = []
    for i in range(0, len(pils), CFG.eval_batch):
        embs.append(encode_image_pil(pils[i:i+CFG.eval_batch]).cpu())
    return F.normalize(torch.cat(embs, dim=0), dim=-1)

@torch.inference_mode()
def batch_encode_texts(texts: List[str]) -> torch.Tensor:
    embs = []
    for i in range(0, len(texts), CFG.eval_batch):
        embs.append(encode_text(texts[i:i+CFG.eval_batch]).cpu())
    return F.normalize(torch.cat(embs, dim=0), dim=-1)

@torch.inference_mode()
def batch_encode_audio(waves: List[np.ndarray], srs: List[int]) -> torch.Tensor:
    embs = []
    for i in range(0, len(waves), CFG.eval_batch):
        chunk = waves[i:i+CFG.eval_batch]
        srs_chunk = srs[i:i+CFG.eval_batch]
        sub = []
        for w, sr in zip(chunk, srs_chunk):
            t = torch.tensor(w).float().unsqueeze(0).to(device)
            sub.append(encode_audio_wave(t, sr))
        embs.append(torch.cat(sub, dim=0).detach().cpu())
    return F.normalize(torch.cat(embs, dim=0), dim=-1)

@torch.inference_mode()
def recall_at_k(sim: torch.Tensor, k: int = 1) -> float:
    # sim: [N, N] (query x gallery) cosine similarities; diagonal are positives
    topk = sim.topk(k, dim=1).indices
    correct = torch.arange(sim.size(0)).view(-1, 1)
    hits = (topk == correct).any(dim=1).float().mean().item()
    return hits

@torch.inference_mode()
def map_score(sim: torch.Tensor) -> float:
    # Diagonal is the only positive per query
    N = sim.size(0)
    y_true = torch.zeros((N, N), dtype=torch.float32)
    y_true[torch.arange(N), torch.arange(N)] = 1.0
    y_score = sim.cpu().numpy()
    ap = [average_precision_score(y_true[i].numpy(), y_score[i]) for i in range(N)]
    return float(np.mean(ap))

@torch.inference_mode()
def eval_retrieval(z_q: torch.Tensor, z_g: torch.Tensor, direction: str):
    # cosine similarity matrix
    z_q = F.normalize(z_q, dim=-1)
    z_g = F.normalize(z_g, dim=-1)
    sim = z_q @ z_g.t()
    return {
        f"{direction}_R@1": recall_at_k(sim, 1),
        f"{direction}_R@5": recall_at_k(sim, 5),
        f"{direction}_R@10": recall_at_k(sim, 10),
        f"{direction}_mAP": map_score(sim)
    }


def plot_pos_neg_hist(z1: torch.Tensor, z2: torch.Tensor, title: str):
    # Build positives (diag) and negatives (off-diag) cosines
    z1 = F.normalize(z1, dim=-1)
    z2 = F.normalize(z2, dim=-1)
    sim = (z1 @ z2.t()).cpu().numpy()
    pos = sim.diagonal()
    neg = sim[~np.eye(sim.shape[0], dtype=bool)]
    plt.figure()
    plt.hist(pos, bins=40, alpha=0.6, label="positives")
    plt.hist(neg, bins=40, alpha=0.6, label="negatives")
    plt.title(title)
    plt.legend()
    plt.xlabel("cosine similarity")
    plt.ylabel("count")
    plt.show()

# %% [markdown]
# ## 1. Phase 0 Sanity: Encode & Baseline (no training)

# %%
# Build eval tensors for (i,t) and (a,t)
val_imgs = [ex["image"] for ex in COCO_VAL]
val_caps_it = [ex["caption"] for ex in COCO_VAL]

val_audio = [ex["audio"] for ex in AUDIO_VAL]
val_srs = [int(ex["sr"]) for ex in AUDIO_VAL]
val_caps_at = [ex["caption"] for ex in AUDIO_VAL]

print("Encoding eval sets...")
Z_i_val = batch_encode_images(val_imgs)
Z_t_it_val = batch_encode_texts(val_caps_it)
Z_a_val = batch_encode_audio(val_audio, val_srs)
Z_t_at_val = batch_encode_texts(val_caps_at)

print("Eval (no adapters, just frozen encoders, cosine retrieval):")
res_baseline = {}
res_baseline.update(eval_retrieval(Z_t_it_val, Z_i_val, "t->i"))
res_baseline.update(eval_retrieval(Z_i_val, Z_t_it_val, "i->t"))
res_baseline.update(eval_retrieval(Z_t_at_val, Z_a_val, "t->a"))
res_baseline.update(eval_retrieval(Z_a_val, Z_t_at_val, "a->t"))
print(json.dumps(res_baseline, indent=2))

# %% [markdown]
# ## 2. Phase 1 — Pairwise-only training
# Train independent adapters for (i<->t) and (a<->t) with InfoNCE. Then *evaluate all three* directions, including i<->a (which should be weak).

# %%
# Infer encoder output dims
with torch.inference_mode():
    d_img = Z_i_val.size(1)
    d_txt = Z_t_it_val.size(1)
    d_aud = Z_a_val.size(1)
print("Encoder dims:", d_img, d_txt, d_aud)

# Projectors for Pairwise (separate text heads)
proj_v = LinearAdapter(d_img, CFG.embed_dim).to(device)
proj_t_v = LinearAdapter(d_txt, CFG.embed_dim).to(device)
proj_a = LinearAdapter(d_aud, CFG.embed_dim).to(device)
proj_t_a = LinearAdapter(d_txt, CFG.embed_dim).to(device)

opt = torch.optim.AdamW(list(proj_v.parameters()) + list(proj_t_v.parameters()), lr=CFG.lr, weight_decay=CFG.weight_decay)

scaler = torch.cuda.amp.GradScaler(enabled=CFG.fp16 and device.type=="cuda")

print("Training (i<->t) pairwise...")
for epoch in range(CFG.epochs_pairwise):
    proj_v.train(); proj_t_v.train()
    pbar = tqdm(coco_train_loader, desc=f"[Pairwise i<->t] epoch {epoch+1}")
    for imgs, caps in pbar:
        with torch.cuda.amp.autocast(enabled=CFG.fp16 and device.type=="cuda"):
            # encode
            z_i = encode_image_pil(list(imgs))
            z_t = encode_text(list(caps))
            # project
            z_i = proj_v(z_i)
            z_t = proj_t_v(z_t)
            loss = info_nce(z_i, z_t)
        opt.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        pbar.set_postfix({"loss": float(loss.item())})

opt2 = torch.optim.AdamW(list(proj_a.parameters()) + list(proj_t_a.parameters()), lr=CFG.lr, weight_decay=CFG.weight_decay)
print("Training (a<->t) pairwise...")
for epoch in range(CFG.epochs_pairwise):
    proj_a.train(); proj_t_a.train()
    pbar = tqdm(audio_train_loader, desc=f"[Pairwise a<->t] epoch {epoch+1}")
    for waves, srs, caps in pbar:
        # batch encode audio
        z_a_list = []
        for w, sr in zip(waves, srs):
            w = w.to(device)
            z_a_list.append(encode_audio_wave(w, int(sr)))
        z_a = torch.cat(z_a_list, dim=0)
        z_t = encode_text(list(caps))
        with torch.cuda.amp.autocast(enabled=CFG.fp16 and device.type=="cuda"):
            z_a = proj_a(z_a)
            z_t = proj_t_a(z_t)
            loss = info_nce(z_a, z_t)
        opt2.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.step(opt2)
        scaler.update()
        pbar.set_postfix({"loss": float(loss.item())})

# Eval Phase 1
@torch.inference_mode()
def eval_phase1():
    proj_v.eval(); proj_t_v.eval(); proj_a.eval(); proj_t_a.eval()
    Zi = proj_v(Z_i_val.to(device)).cpu()
    Zt_it = proj_t_v(Z_t_it_val.to(device)).cpu()
    Za = proj_a(Z_a_val.to(device)).cpu()
    Zt_at = proj_t_a(Z_t_at_val.to(device)).cpu()

    # t<->i (using t head from (i,t) model)
    res = {}
    res.update(eval_retrieval(Zt_it, Zi, "t->i"))
    res.update(eval_retrieval(Zi, Zt_it, "i->t"))

    # t<->a (using t head from (a,t) model)
    res.update(eval_retrieval(Zt_at, Za, "t->a"))
    res.update(eval_retrieval(Za, Zt_at, "a->t"))

    # i<->a (mismatch: image head trained with t_v, audio head trained with t_a) → expected to be poor
    res.update(eval_retrieval(Zi, Za, "i->a"))
    res.update(eval_retrieval(Za, Zi, "a->i"))
    return res, Zi, Zt_it, Za, Zt_at

res_p1, Zi_p1, Zt_it_p1, Za_p1, Zt_at_p1 = eval_phase1()
print("Phase 1 results:")
print(json.dumps(res_p1, indent=2))

# Histograms
plot_pos_neg_hist(Zt_it_p1, Zi_p1, title="Phase 1: t<->i cosine")
plot_pos_neg_hist(Zt_at_p1, Za_p1, title="Phase 1: t<->a cosine")
plot_pos_neg_hist(Zi_p1, Za_p1, title="Phase 1: i<->a cosine (expected weak)")

# %% [markdown]
# ## 3. Phase 2 — Image Hub (i<->t + i<->a)
# Use pseudo triplets to supervise i<->a; keep encoders frozen; train small adapters for all three.

# %%
# Fresh adapters for hub training
hub_v = LinearAdapter(d_img, CFG.embed_dim).to(device)
hub_t = LinearAdapter(d_txt, CFG.embed_dim).to(device)
hub_a = LinearAdapter(d_aud, CFG.embed_dim).to(device)

opt_hub = torch.optim.AdamW(list(hub_v.parameters()) + list(hub_t.parameters()) + list(hub_a.parameters()), lr=CFG.lr, weight_decay=CFG.weight_decay)

# Build simple loaders for pseudo (i,a,t) triplets
class PseudoTripletDS(Dataset):
    def __init__(self, triples):
        self.triples = triples
    def __len__(self):
        return len(self.triples)
    def __getitem__(self, idx):
        ex = self.triples[idx]
        return ex["image"], ex["audio"], int(ex["sr"]), ex["caption"]

triplet_loader = DataLoader(PseudoTripletDS(pseudo_triplets), batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers, collate_fn=lambda b: list(zip(*b)))

print("Training Hub (i<->t + i<->a)...")
for epoch in range(CFG.epochs_hub):
    hub_v.train(); hub_t.train(); hub_a.train()
    pbar = tqdm(triplet_loader, desc=f"[Hub i<->t,i<->a] epoch {epoch+1}")
    for imgs, auds, srs, caps in pbar:
        # Encode
        z_i = encode_image_pil(list(imgs))
        z_t = encode_text(list(caps))
        z_a_list = []
        for a, sr in zip(auds, srs):
            w = torch.tensor(a).float().unsqueeze(0).to(device)
            z_a_list.append(encode_audio_wave(w, int(sr)))
        z_a = torch.cat(z_a_list, dim=0)

        with torch.cuda.amp.autocast(enabled=CFG.fp16 and device.type=="cuda"):
            zi = hub_v(z_i)
            zt = hub_t(z_t)
            za = hub_a(z_a)
            # Hub losses: i<->t and i<->a only
            loss = info_nce(zi, zt) + info_nce(zi, za)
        opt_hub.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.step(opt_hub)
        scaler.update()
        pbar.set_postfix({"loss": float(loss.item())})

@torch.inference_mode()
def eval_phase2():
    hub_v.eval(); hub_t.eval(); hub_a.eval()
    Zi = hub_v(Z_i_val.to(device)).cpu()
    Zt_it = hub_t(Z_t_it_val.to(device)).cpu()
    Za = hub_a(Z_a_val.to(device)).cpu()
    Zt_at = hub_t(Z_t_at_val.to(device)).cpu()  # same text head

    res = {}
    # i<->t
    res.update(eval_retrieval(Zt_it, Zi, "t->i"))
    res.update(eval_retrieval(Zi, Zt_it, "i->t"))
    # i<->a
    res.update(eval_retrieval(Zi, Za, "i->a"))
    res.update(eval_retrieval(Za, Zi, "a->i"))
    # a<->t (note: no direct a<->t loss; test generalization)
    res.update(eval_retrieval(Zt_at, Za, "t->a"))
    res.update(eval_retrieval(Za, Zt_at, "a->t"))
    return res, Zi, Zt_it, Za, Zt_at

res_p2, Zi_p2, Zt_it_p2, Za_p2, Zt_at_p2 = eval_phase2()
print("Phase 2 (Hub) results:")
print(json.dumps(res_p2, indent=2))

plot_pos_neg_hist(Zi_p2, Za_p2, title="Phase 2: i<->a cosine (hub)")

# %% [markdown]
# ## 4. Phase 3 — Tri-modal Adapters + Cycle Consistency (Ours)
# Jointly train v/a/t adapters in a *shared* space with InfoNCE on (i<->t) and (a<->t), plus cycle loss tying (i,a) that share the same caption.

# %%
tri_v = LinearAdapter(d_img, CFG.embed_dim).to(device)
tri_t = LinearAdapter(d_txt, CFG.embed_dim).to(device)
tri_a = LinearAdapter(d_aud, CFG.embed_dim).to(device)

opt_tri = torch.optim.AdamW(list(tri_v.parameters()) + list(tri_t.parameters()) + list(tri_a.parameters()), lr=CFG.lr, weight_decay=CFG.weight_decay)

print("Training Tri-modal (i<->t, a<->t) + Cycle(i~a|t)...")
for epoch in range(CFG.epochs_trimodal):
    tri_v.train(); tri_t.train(); tri_a.train()
    pbar = tqdm(triplet_loader, desc=f"[Tri-Modal + Cycle] epoch {epoch+1}")
    for imgs, auds, srs, caps in pbar:
        # Encode
        z_i = encode_image_pil(list(imgs))
        z_t = encode_text(list(caps))
        z_a_list = []
        for a, sr in zip(auds, srs):
            w = torch.tensor(a).float().unsqueeze(0).to(device)
            z_a_list.append(encode_audio_wave(w, int(sr)))
        z_a = torch.cat(z_a_list, dim=0)

        with torch.cuda.amp.autocast(enabled=CFG.fp16 and device.type=="cuda"):
            zi = tri_v(z_i)
            zt = tri_t(z_t)
            za = tri_a(z_a)
            loss = info_nce(zi, zt) + info_nce(za, zt) + 0.2 * cycle_consistency(za, zi)
        opt_tri.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.step(opt_tri)
        scaler.update()
        pbar.set_postfix({"loss": float(loss.item())})

@torch.inference_mode()
def eval_phase3():
    tri_v.eval(); tri_t.eval(); tri_a.eval()
    Zi = tri_v(Z_i_val.to(device)).cpu()
    Zt_it = tri_t(Z_t_it_val.to(device)).cpu()
    Za = tri_a(Z_a_val.to(device)).cpu()
    Zt_at = tri_t(Z_t_at_val.to(device)).cpu()

    res = {}
    # All six directions
    res.update(eval_retrieval(Zt_it, Zi, "t->i"))
    res.update(eval_retrieval(Zi, Zt_it, "i->t"))

    res.update(eval_retrieval(Zt_at, Za, "t->a"))
    res.update(eval_retrieval(Za, Zt_at, "a->t"))

    res.update(eval_retrieval(Zi, Za, "i->a"))
    res.update(eval_retrieval(Za, Zi, "a->i"))
    return res, Zi, Zt_it, Za, Zt_at

res_p3, Zi_p3, Zt_it_p3, Za_p3, Zt_at_p3 = eval_phase3()
print("Phase 3 (Ours) results:")
print(json.dumps(res_p3, indent=2))

plot_pos_neg_hist(Zi_p3, Za_p3, title="Phase 3: i<->a cosine (ours)")

# %% [markdown]
# ## 5. Comparison Table (Phases 1–3)

# %%
import pandas as pd

# unify keys
keys = [
    "t->i_R@1","t->i_R@5","t->i_R@10","t->i_mAP",
    "i->t_R@1","i->t_R@5","i->t_R@10","i->t_mAP",
    "t->a_R@1","t->a_R@5","t->a_R@10","t->a_mAP",
    "a->t_R@1","a->t_R@5","a->t_R@10","a->t_mAP",
    "i->a_R@1","i->a_R@5","i->a_R@10","i->a_mAP",
    "a->i_R@1","a->i_R@5","a->i_R@10","a->i_mAP",
]

def row_of(res: Dict[str, float]):
    return [res.get(k, float('nan')) for k in keys]

df = pd.DataFrame(
    [row_of(res_p1), row_of(res_p2), row_of(res_p3)],
    index=["Pairwise-only", "Hub (image)", "Ours (Tri+Cycle)"],
    columns=keys,
)
print(df.round(4))

# Matplotlib bar chart: focus on i<->a R@1 across the three settings
plt.figure()
vals = [res_p1.get("i->a_R@1", 0.0), res_p2.get("i->a_R@1", 0.0), res_p3.get("i->a_R@1", 0.0)]
plt.bar(["Pairwise", "Hub", "Ours"], vals)
plt.title("i->a R@1 across settings")
plt.ylabel("Recall@1")
plt.show()

# %% [markdown]
# ## 6. Save Artifacts

# %%
out_dir = Path("artifacts")
out_dir.mkdir(parents=True, exist_ok=True)

torch.save({"proj_v": proj_v.state_dict(), "proj_t_v": proj_t_v.state_dict()}, out_dir/"phase1_it.pt")
torch.save({"proj_a": proj_a.state_dict(), "proj_t_a": proj_t_a.state_dict()}, out_dir/"phase1_at.pt")
torch.save({"hub_v": hub_v.state_dict(), "hub_t": hub_t.state_dict(), "hub_a": hub_a.state_dict()}, out_dir/"phase2_hub.pt")
torch.save({"tri_v": tri_v.state_dict(), "tri_t": tri_t.state_dict(), "tri_a": tri_a.state_dict()}, out_dir/"phase3_ours.pt")

df.to_csv(out_dir/"comparison_phase1_2_3.csv", index=True)

print("Saved:", list(out_dir.iterdir()))

# %% [markdown]
# ---
# ### Notes & Next Steps
# - Phases 4–7 can build on these adapters:
#   * Phase 4: profile FLOPs/VRAM (use torch.cuda.max_memory_allocated, measure timestamps, count parameters in adapters).
#   * Phase 5: add Matryoshka heads (multiple width slices in adapters) + token resampling.
#   * Phase 6: apply audio noise / image blur and re-run eval.
#   * Phase 7: ablations (turn off cycle, swap LinearAdapter with MLPAdapter, vary pseudo triplet %).
# - If memory becomes a bottleneck, reduce subset sizes in Config.
# - To speed up triplet retrieval, cache CLIP image pool embeddings to disk.
