In [4]:
! uv pip install faiss

In [1]:
import os, json, math, random, time, io
from dataclasses import dataclass
from typing import List, Dict, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm

from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModel, AutoProcessor, CLIPProcessor, CLIPModel, WhisperModel, WhisperProcessor
from sentence_transformers import SentenceTransformer
from PIL import Image
import torchaudio
import faiss

from sklearn.decomposition import PCA
import umap
import matplotlib.pyplot as plt

In [None]:


# ---------- Repro ----------
def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"

# ---------- Metrics ----------
def recall_at_k(scores: np.ndarray, gt: List[int], ks=(1,5,10)):
    # scores: [num_queries, num_items] (higher is better)
    ranks = np.argsort(-scores, axis=1)
    out = {}
    for k in ks:
        hits = 0
        for i, g in enumerate(gt):
            if g in ranks[i, :k]:
                hits += 1
        out[f'R@{k}'] = hits / len(gt)
    return out

def mean_average_precision(scores: np.ndarray, gt: List[int]):
    # mAP@all: binary relevance only for the single matching index g
    APs = []
    ranks = np.argsort(-scores, axis=1)
    for i, g in enumerate(gt):
        rank_list = ranks[i]
        # precision at the rank where g appears
        idx = np.where(rank_list == g)[0]
        if len(idx) == 0:
            APs.append(0.0)
        else:
            r = idx[0] + 1
            APs.append(1.0 / r)
    return float(np.mean(APs))

def cosine_sim(a: np.ndarray, b: np.ndarray):
    a_n = a / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-8)
    b_n = b / (np.linalg.norm(b, axis=1, keepdims=True) + 1e-8)
    return a_n @ b_n.T

def plot_cos_hist(pos_sims, neg_sims, title):
    plt.figure(figsize=(5,3.2))
    plt.hist(neg_sims, bins=50, alpha=0.6, label="negatives")
    plt.hist(pos_sims, bins=50, alpha=0.6, label="positives")
    plt.title(title); plt.xlabel("cosine similarity"); plt.ylabel("count")
    plt.legend(); plt.show()


In [None]:
@dataclass
class Config:
    # sample sizes (tune to your GPU / time budget)
    N_COCO: int = 5000
    N_AUDIOCAPS: int = 5000

    # image size / audio sampling
    IMG_SIZE: int = 224
    AUDIO_SR: int = 16000

    # batch sizes
    BATCH_CLIP: int = 64
    BATCH_AUDIO: int = 32
    BATCH_TEXT: int = 128

cfg = Config()

# If you already have local copies, set these (folders of images / wav + json with {image_path, caption} etc.)
LOCAL_COCO = None          # e.g., "/data/coco_2017_val_subset/"
LOCAL_AUDIOCAPS = None     # e.g., "/data/audiocaps_subset/"


In [None]:
def load_coco_pairs(n=5000, local_root=None):
    data = []
    try:
        ds = load_dataset("coco_captions", "2017", split="validation")
        # each item: {image, captions:[{id, caption}]}
        for ex in ds.select(range(min(n, len(ds)))):
            img = ex["image"]
            caps = [c["caption"] for c in ex["captions"]]
            if len(caps)==0: 
                continue
            data.append((img, caps[0]))
    except Exception as e:
        if not local_root:
            raise RuntimeError(f"HF coco_captions failed and no local_root provided: {e}")
        # Expect local_root to contain 'images/' and a jsonl with {"image_path":"...", "caption":"..."}
        import json, glob
        jfiles = glob.glob(os.path.join(local_root, "*.jsonl"))
        assert jfiles, "Provide a jsonl with image_path, caption"
        with open(jfiles[0], "r") as f:
            for i, line in enumerate(f):
                obj = json.loads(line)
                p = obj["image_path"]; cap = obj["caption"]
                img = Image.open(os.path.join(local_root, p)).convert("RGB")
                data.append((img, cap))
                if len(data)>=n: break
    return data[:n]

coco_pairs = load_coco_pairs(cfg.N_COCO, LOCAL_COCO)
len(coco_pairs), coco_pairs[0][1][:100]


In [None]:
def load_audiocaps_pairs(n=5000, sr=16000, local_root=None):
    out = []
    try:
        ds = load_dataset("audiocaps", "balanced_train", split="validation")
        # Some configs differ; fall back if needed
    except:
        ds = load_dataset("audiocaps", split="validation")

    try:
        for ex in ds.select(range(min(n, len(ds)))):
            # ex has "audio": {'array', 'sampling_rate'}, 'caption'
            au = ex["audio"]
            wav = torch.tensor(au["array"], dtype=torch.float32)
            wav_sr = au["sampling_rate"]
            if wav_sr != sr:
                wav = torchaudio.functional.resample(wav, wav_sr, sr)
            out.append((wav.numpy(), sr, ex["caption"]))
    except Exception as e:
        if not local_root:
            raise RuntimeError(f"HF audiocaps failed and no local_root provided: {e}")
        # Expect local_root with wav files + jsonl {"wav_path":"...", "caption":"..."}
        import json, glob, soundfile as sf
        jfiles = glob.glob(os.path.join(local_root, "*.jsonl"))
        assert jfiles, "Provide a jsonl with wav_path, caption"
        with open(jfiles[0], "r") as f:
            for i, line in enumerate(f):
                obj = json.loads(line)
                p = obj["wav_path"]; cap = obj["caption"]
                wav, wav_sr = torchaudio.load(os.path.join(local_root, p))
                wav = torchaudio.functional.resample(wav.squeeze(0), wav_sr, sr)
                out.append((wav.numpy(), sr, cap))
                if len(out)>=n: break
    return out[:n]

audio_pairs = load_audiocaps_pairs(cfg.N_AUDIOCAPS, cfg.AUDIO_SR, LOCAL_AUDIOCAPS)
len(audio_pairs), audio_pairs[0][2][:100]


In [None]:
clip_name = "openai/clip-vit-base-patch16"
clip_model = CLIPModel.from_pretrained(clip_name).to(device).eval()
clip_proc  = CLIPProcessor.from_pretrained(clip_name)

# --- build arrays ---
images = [im for im,_ in coco_pairs]
texts  = [cap for _,cap in coco_pairs]
N = len(images)

# --- embed images ---
img_emb = []
with torch.no_grad():
    for i in tqdm(range(0, N, cfg.BATCH_CLIP), desc="CLIP image embed"):
        batch = images[i:i+cfg.BATCH_CLIP]
        inputs = clip_proc(images=batch, return_tensors="pt").to(device)
        z = clip_model.get_image_features(**inputs)
        z = F.normalize(z, dim=-1)
        img_emb.append(z.cpu())
img_emb = torch.cat(img_emb, dim=0).numpy()

# --- embed texts ---
txt_emb = []
with torch.no_grad():
    for i in tqdm(range(0, N, cfg.BATCH_CLIP), desc="CLIP text embed"):
        batch = texts[i:i+cfg.BATCH_CLIP]
        inputs = clip_proc(text=batch, return_tensors="pt", padding=True, truncation=True).to(device)
        z = clip_model.get_text_features(**inputs)
        z = F.normalize(z, dim=-1)
        txt_emb.append(z.cpu())
txt_emb = torch.cat(txt_emb, dim=0).numpy()

# --- retrieval image index w/ FAISS ---
d = img_emb.shape[1]
index = faiss.IndexFlatIP(d)
index.add(img_emb)  # already normalized = use inner product == cosine
sims_t2i = []
bs = 256
for i in tqdm(range(0, N, bs), desc="t->i sims"):
    q = txt_emb[i:i+bs]
    D, I = index.search(q, k=N)  # we only need topK later, but full is ok for metrics
    # store only top 100 to save memory
    sims_t2i.append((D[:, :100], I[:, :100]))
# Flatten into a dense score matrix (for metrics); we’ll just recompute when needed:
score_t2i = cosine_sim(txt_emb, img_emb)
score_i2t = cosine_sim(img_emb, txt_emb)

gt = list(range(N))  # aligned pairs by index

m_t2i = recall_at_k(score_t2i, gt); m_t2i["mAP"] = mean_average_precision(score_t2i, gt)
m_i2t = recall_at_k(score_i2t, gt); m_i2t["mAP"] = mean_average_precision(score_i2t, gt)

print("E1.1 — CLIP t->i:", {k: round(v,4) for k,v in m_t2i.items()})
print("E1.1 — CLIP i->t:", {k: round(v,4) for k,v in m_i2t.items()})

# Cosine hist: positives vs negatives (sampled)
pos_sims = np.array([score_t2i[i, i] for i in range(N)])
neg_idx = np.random.randint(0, N, size=(N,))
neg_sims = score_t2i[np.arange(N), neg_idx]
plot_cos_hist(pos_sims, neg_sims, "CLIP t->i cosine distribution")


In [None]:
# --- Whisper encoder (no decoder) ---
whisper_name = "openai/whisper-base"
whisper = WhisperModel.from_pretrained(whisper_name).to(device).eval()
whisper_proc = WhisperProcessor.from_pretrained(whisper_name)
target_len_s = 10.0  # trim/pad to ~10s for consistency

def whisper_embed_batch(wavs: List[np.ndarray], sr: int):
    # pad/trim and log-mel inside processor
    inputs = whisper_proc(wavs, sampling_rate=sr, return_tensors="pt")
    input_feats = inputs.input_features.to(device)  # [B, 80, T]
    with torch.no_grad():
        z = whisper.encoder(input_feats).last_hidden_state  # [B, T', C]
        z = z.mean(dim=1)  # simple mean-pool
        z = F.normalize(z, dim=-1)
    return z.cpu().numpy()

# --- MiniLM text encoder ---
minilm = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=device)

audios = [wav for (wav, sr, cap) in audio_pairs]
a_texts = [cap for (wav, sr, cap) in audio_pairs]
M = len(audios)

# --- audio embeds ---
a_emb = []
for i in tqdm(range(0, M, cfg.BATCH_AUDIO), desc="Whisper audio embed"):
    batch = audios[i:i+cfg.BATCH_AUDIO]
    a_emb.append(whisper_embed_batch(batch, cfg.AUDIO_SR))
a_emb = np.concatenate(a_emb, axis=0)

# --- text embeds ---
t_emb = []
for i in tqdm(range(0, M, cfg.BATCH_TEXT), desc="MiniLM text embed"):
    batch = a_texts[i:i+cfg.BATCH_TEXT]
    z = minilm.encode(batch, batch_size=cfg.BATCH_TEXT, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False)
    t_emb.append(z)
t_emb = np.concatenate(t_emb, axis=0)

# --- retrieval ---
score_a2t = cosine_sim(a_emb, t_emb)
score_t2a = cosine_sim(t_emb, a_emb)
gt_audio = list(range(M))

m_a2t = recall_at_k(score_a2t, gt_audio); m_a2t["mAP"] = mean_average_precision(score_a2t, gt_audio)
m_t2a = recall_at_k(score_t2a, gt_audio); m_t2a["mAP"] = mean_average_precision(score_t2a, gt_audio)
print("E1.2 — audio->text:", {k: round(v,4) for k,v in m_a2t.items()})
print("E1.2 — text->audio:", {k: round(v,4) for k,v in m_t2a.items()})

# Cosine hist example (a->t)
pos_sims = np.array([score_a2t[i, i] for i in range(M)])
neg_idx = np.random.randint(0, M, size=(M,))
neg_sims = score_a2t[np.arange(M), neg_idx]
plot_cos_hist(pos_sims, neg_sims, "Whisper+MiniLM a->t cosine distribution")


In [None]:
K = min(700, N, M)  # per-modality count for the plot
# Use:
# - CLIP image embeddings (img_emb)
# - CLIP text embeddings for COCO (txt_emb)
# - Whisper audio embeddings (a_emb)
X = np.concatenate([img_emb[:K], txt_emb[:K], a_emb[:K]], axis=0)
labels = (["image"]*K) + (["text"]*K) + (["audio"]*K)

# PCA -> UMAP (faster & cleaner)
pca = PCA(n_components=50, random_state=42).fit_transform(X)
um = umap.UMAP(n_neighbors=30, min_dist=0.1, random_state=42).fit_transform(pca)

# Plot
plt.figure(figsize=(6,5))
for mod, c in [("image", 0), ("text", 1), ("audio", 2)]:
    idx = [i for i,l in enumerate(labels) if l==mod]
    plt.scatter(um[idx,0], um[idx,1], s=8, label=mod, alpha=0.7)
plt.title("E1.4 — 2D projection (UMAP on PCA)")
plt.legend()
plt.show()
