[31mERROR: Could not find a version that satisfies the requirement open_clip (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for open_clip[0m[31m
[0m

In [9]:
# ---- Quick config ----
USE_REAL_MODELS = True     # flip to True to use CLIP / Whisper / MiniLM encoders
DEVICE = "cuda:0"
SEED = 42

# Data sizes for quick runs (adjust later)
N_IT_PAIRS = 3000           # image-text pairs
N_AT_PAIRS = 3000           # audio-text pairs
VAL_FRAC = 0.1

# Embedding sizes
EMB_VIS  = 512
EMB_AUD  = 512
EMB_TXT  = 384

# Adapter output dimension (base/budget sweep will vary)
ADAPT_OUT_DIMS = [128, 256, 512]

# InfoNCE / Training
BATCH_SIZE = 128
EPOCHS = 2
LR = 3e-4
TAU = 0.07   # temperature

# Repro
import random, numpy as np, torch
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

print("Config loaded.")


Config loaded.


In [10]:
import math, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

def cosine_sim(a, b):
    a = F.normalize(a, dim=-1)
    b = F.normalize(b, dim=-1)
    return a @ b.T  # [N, M]

def recall_at_k(sim, k=1):
    # sim: [N, M], rows are queries, cols are candidates; assume gold index matches (i->i)
    ranks = sim.argsort(dim=1, descending=True)
    correct = torch.arange(sim.size(0), device=sim.device).unsqueeze(1)
    hit = (ranks[:, :k] == correct).any(dim=1).float()
    return hit.mean().item()

def mAP(sim):
    # Simple mAP assuming 1 positive per row at same index
    ranks = sim.argsort(dim=1, descending=True)
    idx = torch.arange(sim.size(0), device=sim.device)
    pos_rank = (ranks == idx.unsqueeze(1)).nonzero()[:,1] + 1  # 1-based
    return (1.0 / pos_rank.float()).mean().item()

def info_nce_loss(z_q, z_k, tau=0.07):
    # z_q, z_k: [B, D]
    z_q = F.normalize(z_q, dim=-1)
    z_k = F.normalize(z_k, dim=-1)
    logits = (z_q @ z_k.T) / tau
    labels = torch.arange(z_q.size(0), device=z_q.device)
    loss_qk = F.cross_entropy(logits, labels)
    loss_kq = F.cross_entropy(logits.T, labels)
    return (loss_qk + loss_kq) / 2

def tsne_plot(emb_list, colors, title="t-SNE"):
    x = torch.cat(emb_list, dim=0).detach().cpu().numpy()
    tsne = TSNE(n_components=2, init='pca', perplexity=30, learning_rate='auto')
    xy = tsne.fit_transform(x)
    n = [e.shape[0] for e in emb_list]
    idxs = np.cumsum([0]+n)
    plt.figure(figsize=(6,5))
    for i,(c,label) in enumerate(colors):
        plt.scatter(xy[idxs[i]:idxs[i+1],0], xy[idxs[i]:idxs[i+1],1], s=6, alpha=0.7, label=label)
    plt.title(title); plt.legend(); plt.show()

def plot_pos_neg_hist(pos_sims, neg_sims, title="Cosine hist"):
    plt.figure(figsize=(6,4))
    plt.hist(pos_sims, bins=50, alpha=0.7, label="positives")
    plt.hist(neg_sims, bins=50, alpha=0.7, label="negatives")
    plt.legend(); plt.title(title); plt.xlabel("cosine"); plt.ylabel("count")
    plt.show()

def split_train_val(N, val_frac=0.1):
    N_val = int(N*val_frac)
    idx = torch.randperm(N)
    return idx[N_val:], idx[:N_val]


In [11]:
# Synthetic encoders: fast, deterministic, good for pipeline & plots today
class FakeEncoder(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        torch.manual_seed(SEED)
        self.proj = nn.Linear(in_dim, out_dim, bias=False)
        # initialize to be near-orthogonal-ish
        nn.init.orthogonal_(self.proj.weight)
    def forward(self, x):
        # x is ignored; just return learned anchors (we’ll feed IDs as one-hot)
        return F.normalize(self.proj(x), dim=-1)

def make_one_hot_ids(n, dim):
    eye = torch.eye(dim)[:n] if dim >= n else F.pad(torch.eye(dim), (0,0,0,n-dim))
    return eye[:n].to(DEVICE)

# Real encoders (optional): CLIP, MiniLM, Whisper
def build_real_encoders():
    """
    Returns dict with encode_image, encode_text, encode_audio callables.
    Expect user to have:
      - open_clip_torch
      - sentence_transformers
      - transformers (for Whisper encoder-only)
    """
    import open_clip, torch
    from PIL import Image
    from sentence_transformers import SentenceTransformer
    from transformers import WhisperModel, WhisperFeatureExtractor
    import torchaudio

    clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
        "ViT-B-16", pretrained="laion2b_s34b_b88k", device=DEVICE
    )
    clip_tokenizer = open_clip.get_tokenizer("ViT-B-16")
    txt_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=DEVICE)
    whisper = WhisperModel.from_pretrained("openai/whisper-small").to(DEVICE).eval()
    feat_ext = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")

    @torch.no_grad()
    def encode_image(pil_list):
        imgs = torch.stack([clip_preprocess(im).to(DEVICE) for im in pil_list])
        return F.normalize(clip_model.encode_image(imgs), dim=-1)

    @torch.no_grad()
    def encode_text(str_list):
        return F.normalize(torch.tensor(txt_model.encode(str_list, convert_to_numpy=True, normalize_embeddings=True)).to(DEVICE), dim=-1)

    @torch.no_grad()
    def encode_audio(wave_list, sr_list):
        # Whisper: use encoder states mean pooled as an embedding proxy
        embs = []
        for wav, sr in zip(wave_list, sr_list):
            if sr != 16000:
                wav = torchaudio.functional.resample(torch.tensor(wav), sr, 16000).numpy()
            inputs = feat_ext(wav, sampling_rate=16000, return_tensors="pt")
            enc = whisper.encoder(inputs.input_features.to(DEVICE)).last_hidden_state.mean(dim=1)
            embs.append(F.normalize(enc, dim=-1))
        return torch.cat(embs, dim=0)

    return dict(encode_image=encode_image, encode_text=encode_text, encode_audio=encode_audio)

# Build encoder handles
if USE_REAL_MODELS:
    enc = build_real_encoders()
    Dv = 512; Dt = 384; Da = 768
else:
    Dv, Dt, Da = EMB_VIS, EMB_TXT, EMB_AUD
    # one-hot ID spaces to simulate “semantics” via shared caption IDs
    # We’ll generate data below.
    enc = dict(encode_image=None, encode_text=None, encode_audio=None)

print("Encoders ready. Real:", USE_REAL_MODELS)


ModuleNotFoundError: No module named 'open_clip'

In [8]:
from typing import Dict

def make_synthetic_data(n_it=3000, n_at=3000, vocab=4000):
    """
    Generate:
      - shared caption IDs for i<->t and a<->t
      - embeddings from one-hot IDs passed through FakeEncoder as "frozen encoders"
    """
    # Shared caption IDs simulate cross-modality semantics
    cap_ids_it = torch.randint(0, vocab, (n_it,))
    cap_ids_at = torch.randint(0, vocab, (n_at,))
    # “Inputs” for fake encoders are one-hots in a common index space
    oh_dim = max(vocab, max(n_it, n_at))
    img_ids = make_one_hot_ids(n_it, oh_dim)
    txt_ids_it = torch.zeros_like(img_ids); txt_ids_it.scatter_(1, cap_ids_it.view(-1,1).cpu(), 1.0); txt_ids_it = txt_ids_it.to(DEVICE)
    aud_ids = make_one_hot_ids(n_at, oh_dim)
    txt_ids_at = torch.zeros_like(aud_ids); txt_ids_at.scatter_(1, cap_ids_at.view(-1,1).cpu(), 1.0); txt_ids_at = txt_ids_at.to(DEVICE)
    return dict(
        img_ids=img_ids, txt_ids_it=txt_ids_it, cap_ids_it=cap_ids_it.to(DEVICE),
        aud_ids=aud_ids, txt_ids_at=txt_ids_at, cap_ids_at=cap_ids_at.to(DEVICE),
        oh_dim=oh_dim
    )

def encode_synthetic(fake: Dict, Dv, Dt, Da):
    enc_v = FakeEncoder(fake['oh_dim'], Dv).to(DEVICE).eval()
    enc_t = FakeEncoder(fake['oh_dim'], Dt).to(DEVICE).eval()
    enc_a = FakeEncoder(fake['oh_dim'], Da).to(DEVICE).eval()
    with torch.no_grad():
        v = enc_v(fake['img_ids'])
        t_it = enc_t(fake['txt_ids_it'])
        a = enc_a(fake['aud_ids'])
        t_at = enc_t(fake['txt_ids_at'])
    return v, t_it, a, t_at

# Optional: real loaders (fill in later if needed)
def load_real_it_pairs(limit):
    """
    Return list_of_PIL_images, list_of_captions (same length)
    Implement with HF `datasets` coco_captions or your local COCO subset.
    """
    raise NotImplementedError("Plug your real (image, text) loader here.")

def load_real_at_pairs(limit):
    """
    Return list_of_waveforms, list_of_sample_rates, list_of_captions
    Implement with AudioCaps/Clotho loaders in your environment.
    """
    raise NotImplementedError("Plug your real (audio, text) loader here.")

# Build working arrays
if USE_REAL_MODELS:
    # Example scaffold (uncomment when you fill loaders)
    # pil_imgs, caps_it = load_real_it_pairs(N_IT_PAIRS)
    # wavs, srs, caps_at = load_real_at_pairs(N_AT_PAIRS)
    # v = enc['encode_image'](pil_imgs)
    # t_it = enc['encode_text'](caps_it)
    # a = enc['encode_audio'](wavs, srs)
    # t_at = enc['encode_text'](caps_at)
    raise RuntimeError("Real-mode loaders not yet implemented in this snippet. Use synthetic for now or plug your loaders.")
else:
    synth = make_synthetic_data(N_IT_PAIRS, N_AT_PAIRS, vocab=4000)
    V, T_it, A, T_at = encode_synthetic(synth, Dv, Dt, Da)

# Train/val splits by index alignment
tr_it, va_it = split_train_val(V.shape[0], VAL_FRAC)
tr_at, va_at = split_train_val(A.shape[0], VAL_FRAC)

print("Data ready. Synthetic:", not USE_REAL_MODELS)


RuntimeError: Expected all tensors to be on the same device, but got index is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_scatter__value)

In [None]:
class Adapter(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.LayerNorm(out_dim),
        )
    def forward(self, x): return self.net(x)

def train_pairwise(V, T_it, A, T_at, out_dim=256, epochs=2, tau=0.07):
    va = Adapter(V.shape[1], out_dim).to(DEVICE)
    ta = Adapter(T_it.shape[1], out_dim).to(DEVICE)
    aa = Adapter(A.shape[1], out_dim).to(DEVICE)
    # We share text adapter for both it/at; easy to change if you want per-pair text adapters
    opt = torch.optim.AdamW(list(va.parameters()) + list(ta.parameters()) + list(aa.parameters()), lr=LR)

    # Prepare quick index sets
    it_idx = tr_it
    at_idx = tr_at

    for ep in range(1, epochs+1):
        va.train(); ta.train(); aa.train()
        # minibatches over the *larger* split
        max_len = max(len(it_idx), len(at_idx))
        perm_it = it_idx[torch.randperm(len(it_idx))]
        perm_at = at_idx[torch.randperm(len(at_idx))]
        for s in range(0, max_len, BATCH_SIZE):
            # gather it batch
            it_slice = perm_it[s:s+BATCH_SIZE]
            at_slice = perm_at[s:s+BATCH_SIZE]
            # handle length mismatch by wrapping indices
            if len(it_slice)==0 and len(at_slice)==0: break
            loss = 0.0

            if len(it_slice)>0:
                vq = va(V[it_slice])
                tk = ta(T_it[it_slice])
                loss += info_nce_loss(vq, tk, tau)

            if len(at_slice)>0:
                aq = aa(A[at_slice])
                tk2 = ta(T_at[at_slice])  # same text adapter
                loss += info_nce_loss(aq, tk2, tau)

            opt.zero_grad()
            loss.backward()
            opt.step()

        # quick val metrics
        va.eval(); ta.eval(); aa.eval()
        with torch.no_grad():
            vq = F.normalize(va(V[va_it]), dim=-1); tk = F.normalize(ta(T_it[va_it]), dim=-1)
            aq = F.normalize(aa(A[va_at]), dim=-1); tk2 = F.normalize(ta(T_at[va_at]), dim=-1)

            sim_it = vq @ tk.T
            sim_at = aq @ tk2.T

            r1_it = recall_at_k(sim_it, 1); r5_it = recall_at_k(sim_it, 5)
            r1_at = recall_at_k(sim_at, 1); r5_at = recall_at_k(sim_at, 5)

        print(f"[ep {ep}] it R@1={r1_it:.3f} R@5={r5_it:.3f} | at R@1={r1_at:.3f} R@5={r5_at:.3f}")

    return va, ta, aa

# Train quick baseline
adV, adT, adA = train_pairwise(V, T_it, A, T_at, out_dim=ADAPT_OUT_DIMS[-1], epochs=EPOCHS, tau=TAU)
print("Adapters trained (pairwise baseline).")


In [None]:
@torch.no_grad()
def eval_all(adV, adT, adA, V, T_it, A, T_at, split=('val','val','val','val')):
    adV.eval(); adT.eval(); adA.eval()
    # choose split indices
    I = va_it if split[0]=='val' else tr_it
    IT = va_it if split[1]=='val' else tr_it
    J = va_at if split[2]=='val' else tr_at
    AT = va_at if split[3]=='val' else tr_at

    v = F.normalize(adV(V[I]), dim=-1)
    t1 = F.normalize(adT(T_it[IT]), dim=-1)
    a = F.normalize(adA(A[J]), dim=-1)
    t2 = F.normalize(adT(T_at[AT]), dim=-1)

    # i<->t
    sim_it = v @ t1.T
    # a<->t
    sim_at = a @ t2.T
    # i<->a (emergent via trained adapters, no direct loss yet)
    # align by their own indices sizes; use min to form matched pairs
    N = min(v.shape[0], a.shape[0])
    sim_ia = v[:N] @ a[:N].T

    metrics = dict(
        it_R1 = recall_at_k(sim_it, 1),
        it_R5 = recall_at_k(sim_it, 5),
        at_R1 = recall_at_k(sim_at, 1),
        at_R5 = recall_at_k(sim_at, 5),
        ia_R1 = recall_at_k(sim_ia, 1),
        ia_R5 = recall_at_k(sim_ia, 5),
        it_mAP = mAP(sim_it),
        at_mAP = mAP(sim_at),
        ia_mAP = mAP(sim_ia)
    )
    return metrics, sim_it, sim_at, sim_ia

metrics, sim_it, sim_at, sim_ia = eval_all(adV, adT, adA, V, T_it, A, T_at, split=('val','val','val','val'))
print(metrics)

# Cosine histograms (positives are diagonal; negatives are off-diagonal)
def diag_offdiag(sim):
    pos = sim.diag().detach().cpu().numpy()
    neg = sim.detach().cpu().numpy()[~np.eye(sim.shape[0], dtype=bool)]
    return pos, neg

p_it, n_it = diag_offdiag(sim_it)
p_at, n_at = diag_offdiag(sim_at)
p_ia, n_ia = diag_offdiag(sim_ia)

plot_pos_neg_hist(p_it, n_it, title="i↔t cosine")
plot_pos_neg_hist(p_at, n_at, title="a↔t cosine")
plot_pos_neg_hist(p_ia, n_ia, title="i↔a cosine (emergent)")

# t-SNE (mixed modalities)
# Use a small sample for speed
K = 600
v_s = F.normalize(adV(V[va_it][:K]), dim=-1)
t_s = F.normalize(adT(T_it[va_it][:K]), dim=-1)
a_s = F.normalize(adA(A[va_at][:K]), dim=-1)
tsne_plot([v_s, t_s, a_s], colors=[('C0','image'),('C1','text'),('C2','audio')], title="t-SNE of aligned embeddings")


In [None]:
def train_with_consistency(V, T_it, A, T_at, out_dim=256, epochs=2, tau=0.07, w_cons=0.2):
    va = Adapter(V.shape[1], out_dim).to(DEVICE)
    ta = Adapter(T_it.shape[1], out_dim).to(DEVICE)
    aa = Adapter(A.shape[1], out_dim).to(DEVICE)
    opt = torch.optim.AdamW(list(va.parameters()) + list(ta.parameters()) + list(aa.parameters()), lr=LR)

    it_idx = tr_it
    at_idx = tr_at

    # create simple caption-matched synthetic i-a positives via nearest in text space
    # (for real data: group by same caption string or high-sim threshold)
    # Here we use *indexes* as proxy (controlled synthetic setting).
    for ep in range(1, epochs+1):
        va.train(); ta.train(); aa.train()
        max_len = max(len(it_idx), len(at_idx))
        perm_it = it_idx[torch.randperm(len(it_idx))]
        perm_at = at_idx[torch.randperm(len(at_idx))]
        for s in range(0, max_len, BATCH_SIZE):
            it_slice = perm_it[s:s+BATCH_SIZE]
            at_slice = perm_at[s:s+BATCH_SIZE]
            if len(it_slice)==0 and len(at_slice)==0: break
            loss = 0.0

            if len(it_slice)>0:
                vq = va(V[it_slice]); tk = ta(T_it[it_slice])
                loss += info_nce_loss(vq, tk, tau)

            if len(at_slice)>0:
                aq = aa(A[at_slice]); tk2 = ta(T_at[at_slice])
                loss += info_nce_loss(aq, tk2, tau)

            # consistency: push image/audio of *same batch index* closer
            if len(it_slice)>0 and len(at_slice)>0:
                K = min(len(it_slice), len(at_slice))
                vq_c = F.normalize(va(V[it_slice[:K]]), dim=-1)
                aq_c = F.normalize(aa(A[at_slice[:K]]), dim=-1)
                # InfoNCE between v and a as if paired
                loss += w_cons * info_nce_loss(vq_c, aq_c, tau)

            opt.zero_grad(); loss.backward(); opt.step()

        # quick log
        va.eval(); ta.eval(); aa.eval()
        with torch.no_grad():
            vq = F.normalize(va(V[va_it]), dim=-1); tk = F.normalize(ta(T_it[va_it]), dim=-1)
            aq = F.normalize(aa(A[va_at]), dim=-1); tk2 = F.normalize(ta(T_at[va_at]), dim=-1)
            sim_it = vq @ tk.T; sim_at = aq @ tk2.T
        print(f"[ep {ep}] it R@1={recall_at_k(sim_it,1):.3f} | at R@1={recall_at_k(sim_at,1):.3f}")

    return va, ta, aa

# Train a quick joint model with consistency
adV_joint, adT_joint, adA_joint = train_with_consistency(V, T_it, A, T_at, out_dim=ADAPT_OUT_DIMS[-1], epochs=EPOCHS, tau=TAU, w_cons=0.2)

metrics_joint, _, _, sim_ia_joint = eval_all(adV_joint, adT_joint, adA_joint, V, T_it, A, T_at)
print("Joint metrics:", metrics_joint)

p_ia_j, n_ia_j = diag_offdiag(sim_ia_joint)
plot_pos_neg_hist(p_ia_j, n_ia_j, title="i↔a cosine (with consistency)")


In [None]:
@torch.no_grad()
def budget_curve(adV, adT, adA, dims=[128,256,512]):
    print("Budget curve (dim vs R@1 for i↔t / a↔t / i↔a):")
    baseV, baseT, baseA = adV, adT, adA
    for d in dims:
        redV = Adapter(baseV.net[0].in_features, d).to(DEVICE); redV.load_state_dict(baseV.state_dict(), strict=False)
        redT = Adapter(baseT.net[0].in_features, d).to(DEVICE); redT.load_state_dict(baseT.state_dict(), strict=False)
        redA = Adapter(baseA.net[0].in_features, d).to(DEVICE); redA.load_state_dict(baseA.state_dict(), strict=False)
        m, sim_it, sim_at, sim_ia = eval_all(redV, redT, redA, V, T_it, A, T_at)
        print(f"  d={d:4d} | i↔t R@1={m['it_R1']:.3f}  a↔t R@1={m['at_R1']:.3f}  i↔a R@1={m['ia_R1']:.3f}")

print("Pairwise baseline curve:")
budget_curve(adV, adT, adA, ADAPT_OUT_DIMS)

print("Joint (consistency) curve:")
budget_curve(adV_joint, adT_joint, adA_joint, ADAPT_OUT_DIMS)


In [None]:
def negative_control_shuffle_text(T_tensor, idx):
    # Shuffle the alignment between embeddings & indices
    shuf = idx[torch.randperm(len(idx))]
    return T_tensor[shuf]

@torch.no_grad()
def eval_with_text_shuffle(adV, adT, adA):
    v = F.normalize(adV(V[va_it]), dim=-1)
    t = F.normalize(adT(negative_control_shuffle_text(T_it, va_it)), dim=-1)
    a = F.normalize(adA(A[va_at]), dim=-1)
    t2 = F.normalize(adT(negative_control_shuffle_text(T_at, va_at)), dim=-1)
    sim_it = v @ t.T; sim_at = a @ t2.T
    print(f"NegCtrl — i↔t R@1={recall_at_k(sim_it,1):.3f}  a↔t R@1={recall_at_k(sim_at,1):.3f}")

print("Negative control (pairwise):")
eval_with_text_shuffle(adV, adT, adA)

print("Negative control (joint):")
eval_with_text_shuffle(adV_joint, adT_joint, adA_joint)


In [None]:
@torch.no_grad()
def oracle_text_bridge_R1(adV, adT, adA):
    # Rank i->t and t->a, then compose ranks for an oracle upper bound
    v = F.normalize(adV(V[va_it]), dim=-1)
    t1 = F.normalize(adT(T_it[va_it]), dim=-1)
    a = F.normalize(adA(A[va_at]), dim=-1)
    t2 = F.normalize(adT(T_at[va_at]), dim=-1)

    sim_i_t = v @ t1.T
    sim_t_a = t2 @ a.T

    # For each image i, pick best text, then that text's best audio
    top_t = sim_i_t.argmax(dim=1)          # [Ni]
    sim_i_a_via_t = sim_t_a[top_t]         # [Ni, Na]
    r1 = recall_at_k(sim_i_a_via_t, 1)
    return r1

print("Oracle text bridge upper bound R@1 (pairwise adapters):", oracle_text_bridge_R1(adV, adT, adA))
print("Oracle text bridge upper bound R@1 (joint adapters):   ", oracle_text_bridge_R1(adV_joint, adT_joint, adA_joint))
