In [None]:
# ! uv pip install tf-keras

In [2]:
import os, math, random, time
from dataclasses import dataclass
from typing import Tuple, List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import pandas as pd

from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
from transformers import WhisperProcessor, WhisperModel
from sentence_transformers import SentenceTransformer

2025-10-13 11:22:59.745395: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-13 11:23:06.092460: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-10-13 11:23:27.405891: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [3]:
# -----------------------
# Repro & device
# -----------------------
SEED = 123
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cuda


In [4]:
# -----------------------
# Config
# -----------------------
@dataclass
class CFG:
    out_dir: str = "./phase1_artifacts"
    # Data slice to keep runtime manageable. Increase for stronger curves.
    max_train: int = 6000
    max_val: int = 2000
    batch_size: int = 32
    num_workers: int = 4
    epochs: int = 5
    lr: float = 1e-3
    temp: float = 0.07
    # Target text embedding dim
    dim_t: int = 384  # MiniLM all-MiniLM-L6-v2 returns 384-d by default
    # Encoders
    clip_name: str = "openai/clip-vit-base-patch16"
    whisper_name: str = "openai/whisper-small"
    sent_name: str = "sentence-transformers/all-MiniLM-L6-v2"
    # SpokenCOCO split names
    dataset_name: str = "mteb/SpeechCoco"  # SpokenCOCO on HF
    # mixed precision for faster embedding extraction
    amp: bool = True

os.makedirs(CFG.out_dir, exist_ok=True)


In [5]:
# -----------------------
# Helpers
# -----------------------
def l2(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    return x / (x.norm(dim=-1, keepdim=True) + eps)

# -----------------------
# Load dataset (SpokenCOCO)
# Fields: id, image_id, audio, image, text, ...
# Splits: train / validation
# -----------------------
print("Loading SpokenCOCO subset (10%)...")
ds = load_dataset("mteb/SpeechCoco",streaming=True)

Loading SpokenCOCO subset (10%)...


Resolving data files:   0%|          | 0/230 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/111 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/230 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/111 [00:00<?, ?it/s]

In [None]:
from itertools import islice
from copy import deepcopy

class StreamHead:
    """
    Provides __iter__ and a constant __len__=n without materializing to disk.
    `make_stream` must be a callable that returns a *fresh* HF stream each time.
    """
    def __init__(self, make_stream, n):
        self.make_stream = make_stream
        self.n = n

    def __iter__(self):
        # new iterator each time so you can re-iterate in multiple epochs
        stream = self.make_stream()
        return islice(stream, self.n)

    def __len__(self):
        return self.n  # pretend length for code paths that call len()

# Usage:
def make_train_stream():
    return load_dataset("mteb/SpeechCoco", streaming=True, split="train").shuffle(seed=42, buffer_size=5000)

def make_val_stream():
    return load_dataset("mteb/SpeechCoco", streaming=True, split="validation")

train_raw = StreamHead(make_train_stream, n=500)
val_raw   = StreamHead(make_val_stream,   n=200)
print("Train size:", len(train_raw), "Val size:", len(val_raw))

Train size: 500 Val size: 200


In [7]:
train_raw[0]

TypeError: 'StreamHead' object is not subscriptable

In [None]:
# # def take_first_n(split_ds, n):
# #     return split_ds.select(range(min(n, len(split_ds))))

# # train_raw = take_first_n(ds["train"], CFG.max_train)
# # val_raw   = take_first_n(ds["validation"], CFG.max_val)

# # print("Train size:", len(train_raw), "Val size:", len(val_raw))

# # Take only small heads of each split (doesn't store to disk)
# train_stream = ds["train"].take(500)         # first 500 samples
# val_stream   = ds["validation"].take(200)    # first 200 samples


In [7]:
# -----------------------
# Build frozen encoders
# -----------------------
print("Loading frozen encoders ...")
clip_model = CLIPModel.from_pretrained(CFG.clip_name).eval().to(device)
clip_proc  = CLIPProcessor.from_pretrained(CFG.clip_name)

whisper_model = WhisperModel.from_pretrained(CFG.whisper_name).eval().to(device)
whisper_proc  = WhisperProcessor.from_pretrained(CFG.whisper_name)

sent_model = SentenceTransformer(CFG.sent_name, device=str(device))

for p in clip_model.parameters():    p.requires_grad = False
for p in whisper_model.parameters(): p.requires_grad = False
# SentenceTransformer is already in eval/frozen mode for encode()


Loading frozen encoders ...


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [8]:
# -----------------------
# Embedding functions
# -----------------------
@torch.no_grad()
def embed_image(pil_images: List):
    # Returns [N, D_v]; CLIP returns 512-d by default for ViT-B/16
    inputs = clip_proc(images=pil_images, return_tensors="pt")
    pixel_values = inputs["pixel_values"].to(device)
    with torch.autocast(device_type=device.type, dtype=torch.float16) if (CFG.amp and device.type=="cuda") else torch.no_grad():
        img_feats = clip_model.get_image_features(pixel_values=pixel_values)
    return l2(img_feats.float()).cpu()

@torch.no_grad()
def embed_audio(list_of_np_wavs: List, sampling_rate: int):
    # Whisper expects 16k mono; datasets usually provide .array & .sampling_rate
    # We'll resample via processor if needed.
    # Strategy: run encoder, mean-pool last_hidden_state.
    batch_inputs = whisper_proc(
        audio=list_of_np_wavs, sampling_rate=sampling_rate, return_tensors="pt"
    )
    input_features = batch_inputs.input_features.to(device)
    with torch.autocast(device_type=device.type, dtype=torch.float16) if (CFG.amp and device.type=="cuda") else torch.no_grad():
        enc_out = whisper_model.encoder(input_features)
    hidden = enc_out.last_hidden_state  # [B, T, D_a]
    a_emb = hidden.mean(dim=1)          # [B, D_a]
    return l2(a_emb.float()).cpu()

@torch.no_grad()
def embed_text(list_of_strs: List):
    # SentenceTransformer handles batching internally; ensure normalize_embeddings=True
    t = sent_model.encode(list_of_strs, batch_size=64, convert_to_tensor=True, normalize_embeddings=True)
    return t.cpu().float()  # [N, dim_t]


In [9]:
# -----------------------
# Pre-encode a slice of the dataset for speed
# -----------------------
def preencode_split(split, max_batch_images=32, max_batch_audio=16, max_batch_text=128):
    Vs, As, Ts = [], [], []
    # image/audio/text are aligned per row; keep the order for retrieval GT.
    # 1) Images
    pil_buf = []
    idxs_img = []
    for i, row in enumerate(split):
        pil_buf.append(row["image"])
        idxs_img.append(i)
        if len(pil_buf) == max_batch_images or i == len(split)-1:
            Vs.append(embed_image(pil_buf))
            pil_buf.clear()
    V = torch.cat(Vs, dim=0)

    # 2) Audio (datasets audio object provides "array" and "sampling_rate")
    wav_buf, sr_buf = [], []
    for i, row in enumerate(split):
        aud = row["audio"]
        wav_buf.append(aud["array"])
        sr_buf.append(aud["sampling_rate"])
        if len(wav_buf) == max_batch_audio or i == len(split)-1:
            # resample to first SR in batch if differs? WhisperProcessor handles list with distinct sampling_rate as separate calls
            # For simplicity, call per-item if SRs differ in the batch
            if len(set(sr_buf)) == 1:
                As.append(embed_audio(wav_buf, sr_buf[0]))
            else:
                for w, s in zip(wav_buf, sr_buf):
                    As.append(embed_audio([w], s))
                As = [torch.cat(As, dim=0)]
            wav_buf, sr_buf = [], []
    A = torch.cat(As, dim=0)

    # 3) Text
    texts = [row["text"] for row in split]
    # chunk long lists to avoid OOM inside sentence-transformers
    Tparts = []
    for i in range(0, len(texts), max_batch_text):
        Tparts.append(embed_text(texts[i:i+max_batch_text]))
    T = torch.cat(Tparts, dim=0)

    assert V.shape[0] == A.shape[0] == T.shape[0] == len(split)
    return V, A, T

print("Pre-encoding TRAIN ... (this may take a few minutes the first time)")
V_tr, A_tr, T_tr = preencode_split(train_raw)
print("Pre-encoding VAL ...")
V_va, A_va, T_va = preencode_split(val_raw)

dim_v = V_tr.shape[1]
dim_a = A_tr.shape[1]
print(f"Emb dims -> vision: {dim_v}, audio: {dim_a}, text: {T_tr.shape[1]} (target={CFG.dim_t})")


Pre-encoding TRAIN ... (this may take a few minutes the first time)


Resolving data files:   0%|          | 0/230 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/111 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/230 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/111 [00:00<?, ?it/s]

TypeError: only a single or a list of entries is supported but got type=<class 'dict'>

In [None]:
# -----------------------
# Torch datasets over the embeddings
# -----------------------
class TripletEmbedDataset(Dataset):
    def __init__(self, V, A, T):
        self.V = V; self.A = A; self.T = T
        assert len(V) == len(A) == len(T)
    def __len__(self): return len(self.V)
    def __getitem__(self, i):
        return self.V[i], self.A[i], self.T[i], i  # use index as "class" for retrieval GT

train_ds = TripletEmbedDataset(V_tr, A_tr, T_tr)
val_ds   = TripletEmbedDataset(V_va, A_va, T_va)

train_loader = DataLoader(train_ds, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers, drop_last=True)
val_loader   = DataLoader(val_ds, batch_size=CFG.batch_size, shuffle=False,  num_workers=CFG.num_workers)


In [None]:
# -----------------------
# Linear adapters + InfoNCE
# -----------------------
class LinearAdapter(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(din),
            nn.Linear(din, dout, bias=False),
            nn.LayerNorm(dout)
        )
    def forward(self, x): return l2(self.net(x))

class PairwiseAligner(nn.Module):
    def __init__(self, dim_v, dim_a, dim_t):
        super().__init__()
        self.v = LinearAdapter(dim_v, dim_t)
        self.a = LinearAdapter(dim_a, dim_t)
    def forward(self, v, a):
        return self.v(v), self.a(a)

def info_nce_symmetric(z1: torch.Tensor, z2: torch.Tensor, temp: float) -> torch.Tensor:
    sim = (z1 @ z2.t()) / temp
    y = torch.arange(z1.size(0), device=z1.device)
    return 0.5*(F.cross_entropy(sim, y) + F.cross_entropy(sim.t(), y))

model = PairwiseAligner(dim_v, dim_a, CFG.dim_t).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=CFG.lr)


In [None]:

# -----------------------
# Train (v↔t and a↔t only)
# -----------------------
for epoch in range(1, CFG.epochs+1):
    model.train(); losses=[]
    for Vb, Ab, Tb, _ in train_loader:
        Vb, Ab, Tb = Vb.to(device), Ab.to(device), Tb.to(device)
        Vt, At = model(Vb, Ab)
        loss = info_nce_symmetric(Vt, Tb, CFG.temp) + info_nce_symmetric(At, Tb, CFG.temp)
        opt.zero_grad(); loss.backward(); opt.step()
        losses.append(loss.item())
    print(f"Epoch {epoch:02d} | train loss: {np.mean(losses):.4f}")


In [None]:
# -----------------------
# Eval utils
# -----------------------
@torch.no_grad()
def collect_embeds(loader, model):
    model.eval()
    Vt_all, At_all, T_all = [], [], []
    for Vb, Ab, Tb, _ in loader:
        Vb, Ab, Tb = Vb.to(device), Ab.to(device), Tb.to(device)
        Vt, At = model(Vb, Ab)
        Vt_all.append(Vt.cpu()); At_all.append(At.cpu()); T_all.append(Tb.cpu())
    return torch.cat(Vt_all), torch.cat(At_all), torch.cat(T_all)

def cosine_matrix(X, Y): return X @ Y.t()

def recall_at_k(sim: torch.Tensor, k: int) -> float:
    topk = torch.topk(sim, k, dim=1).indices
    gt = torch.arange(sim.size(0)).unsqueeze(1)
    return float((topk == gt).any(dim=1).float().mean() * 100.0)

def pos_neg(sim: torch.Tensor):
    pos = sim.diag().numpy()
    neg = sim[~torch.eye(sim.size(0), dtype=torch.bool)].numpy()
    return pos, neg

In [None]:
# -----------------------
# Evaluate on VAL
# -----------------------
Vt, At, Tt = collect_embeds(val_loader, model)

sim_ti = cosine_matrix(Tt, Vt)  # text->image
sim_ta = cosine_matrix(Tt, At)  # text->audio
sim_ia = cosine_matrix(Vt, At)  # image->audio (NO direct loss)  <-- should be weak

metrics = []
def add(name, sim):
    metrics.append({
        "Pair": name,
        "R@1":  round(recall_at_k(sim, 1), 2),
        "R@5":  round(recall_at_k(sim, 5), 2),
        "R@10": round(recall_at_k(sim, 10), 2),
        "MeanCos(PosDiag)": round(sim.diag().mean().item(), 3)
    })
add("t→i", sim_ti); add("i→t", sim_ti.t())
add("t→a", sim_ta); add("a→t", sim_ta.t())
add("i→a", sim_ia); add("a→i", sim_ia.t())

df = pd.DataFrame(metrics)
display(df)

csv_path = os.path.join(CFG.out_dir, "B0_pairwise_baseline_metrics.csv")
df.to_csv(csv_path, index=False)
print("Saved:", csv_path)

# -----------------------
# Histograms
# -----------------------
def plot_hist(pos, neg, title, fname):
    plt.figure(figsize=(6,4))
    plt.hist(pos, bins=50, alpha=0.6, label="Positives")
    plt.hist(neg, bins=50, alpha=0.6, label="Negatives")
    plt.xlabel("Cosine similarity"); plt.ylabel("Count")
    plt.title(title); plt.legend(); plt.tight_layout()
    out = os.path.join(CFG.out_dir, fname)
    plt.savefig(out, dpi=150); plt.show()
    print("Saved:", out)

p_ti, n_ti = pos_neg(sim_ti)
p_ta, n_ta = pos_neg(sim_ta)
p_ia, n_ia = pos_neg(sim_ia)

plot_hist(p_ti, n_ti, "t↔i Cosine — Pos vs Neg", "hist_t_i.png")
plot_hist(p_ta, n_ta, "t↔a Cosine — Pos vs Neg", "hist_t_a.png")
plot_hist(p_ia, n_ia, "i↔a Cosine — Pos vs Neg (no direct loss)", "hist_i_a.png")

print("Done.")