# 02 ‚Äî Vision-Text Alignment (Vision Only) + Alignment Plots (W&B)


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Basic imports & environment setup

import os
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Optional, Dict, Any, List, Tuple

import math
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset

from datasets import load_dataset

# Local modules
from imports.in_memory_datasets import *
from imports.multimodal_alignment_perceiver import (
    MultimodalAlignmentConfig,
    MultimodalAlignmentModel,
    contrastive_loss,
    matryoshka_loss,
)
from imports.core import set_seed

# Transformers encoders
from transformers import (
    CLIPVisionModel,
    CLIPImageProcessor,
    AutoTokenizer,
    AutoModel,
    WhisperModel,
    WhisperProcessor,
)



In [3]:
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("Num GPUs:", torch.cuda.device_count())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Torch version: 2.9.0+cu128
CUDA available: True
Num GPUs: 1
Using device: cuda


In [None]:
# Experiment configuration

@dataclass
class ExperimentConfig:
    # General
    seed: int = 42
    num_epochs_mlp: int = 3          # was 3
    num_epochs_perceiver: int = 5   # was 5
    log_every: int = 25

    # Data
    pixmo_parquet_dir: Path = Path.cwd() / "data" / "final_dataset" / "pixmo"
    clotho_parquet_dir: Path = Path.cwd() / "data" / "final_dataset" / "clotho"
    pixmo_parquet_glob: str = "pixmo_train.parquet"
    clotho_parquet_glob: str = "clotho_train.parquet"

    image_hf_dataset: str = "allenai/pixmo-cap"
    audio_hf_dataset: str = "clotho_v2"

    max_image_samples: Optional[int] = None   # keep None for now; if debugging, use e.g. 20_000
    max_audio_samples: Optional[int] = None
    image_size: Tuple[int, int] = (224, 224)
    sample_val_fraction: float = 0.10   # was 0.05 ‚Üí more stable validation

    # Training
    base_batch_size: int = 64           # was 32; scale by num_gpus anyway
    learning_rate: float = 2e-4         # was 1e-4
    weight_decay: float = 0.01
    max_grad_norm: float = 1.0
    use_amp: bool = True

    # WandB
    use_wandb: bool = True
    wandb_project: str = "edgeglass_final_align"
    wandb_entity: Optional[str] = None
    wandb_mode: str = "online"
    wandb_run_name: Optional[str] = "final_align_perciever_1"

    # Checkpoints (Phase 1)
    root_dir: Path = Path(".").resolve()
    ckpt_root: str = "checkpoints/phase1_multimodal_vision"

    # Alignment / model
    mrl_weight: float = 1.0
    clip_weight: float = 0.25          # was 0.5

cfg = ExperimentConfig()
set_seed(cfg.seed)

ROOT_DIR = cfg.root_dir
CKPT_ROOT = ROOT_DIR / cfg.ckpt_root
MLP_DIR = CKPT_ROOT / "mlp_mrl"
PERCEIVER_DIR = CKPT_ROOT / "perceiver_mrl"

for d in [CKPT_ROOT, MLP_DIR, PERCEIVER_DIR, cfg.pixmo_parquet_dir]:
    d.mkdir(parents=True, exist_ok=True)

print("ROOT_DIR:", ROOT_DIR)
print("Checkpoint root:", CKPT_ROOT)
print("MLP run dir:", MLP_DIR)
print("Perceiver run dir:", PERCEIVER_DIR)

ROOT_DIR: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base
Checkpoint root: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/checkpoints/phase1_multimodal_vision
MLP run dir: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/checkpoints/phase1_multimodal_vision/mlp_mrl
Perceiver run dir: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/checkpoints/phase1_multimodal_vision/perceiver_mrl


In [None]:
# Weights & Biases helpers

def init_wandb(run_name: str, variant: str, extra_config: Dict[str, Any] = None):
    if not cfg.use_wandb:
        return None

    import wandb

    base_config = asdict(cfg)
    base_config["variant"] = variant
    if extra_config:
        base_config.update(extra_config)

    run = wandb.init(
        project=cfg.wandb_project,
        entity=cfg.wandb_entity,
        name=run_name,
        mode=cfg.wandb_mode,
        config=base_config,
    )
    return run


def log_alignment_histograms(
    run,
    z_a: torch.Tensor,
    z_b: torch.Tensor,
    prefix: str,
    max_points: int = 512,
):
    """Log positive vs negative cosine similarity histograms to W&B.

    This is a lightweight way to 'see' alignment progress:
    - Positive sims: cosine between matching pairs
    - Negative sims: cosine between random mismatched pairs
    """
    if run is None:
        return

    import wandb

    with torch.no_grad():
        z_a = F.normalize(z_a, dim=-1)
        z_b = F.normalize(z_b, dim=-1)

        # Sample subset to keep logging light
        n = min(z_a.size(0), max_points)
        idx = torch.randperm(z_a.size(0))[:n]
        za = z_a[idx]
        zb = z_b[idx]

        # Positive similarities
        pos_sims = (za * zb).sum(dim=-1).cpu().numpy()

        # Negative similarities (shuffle)
        shuffle_idx = torch.randperm(n)
        neg_sims = (za * zb[shuffle_idx]).sum(dim=-1).cpu().numpy()

        run.log({
            f"{prefix}/pos_sim": wandb.Histogram(pos_sims),
            f"{prefix}/neg_sim": wandb.Histogram(neg_sims),
        })

In [None]:
# Data: load PixMo-Cap and MusicCaps from local Parquet (or HF fallback)

import glob

def get_pixmo_dataset():
    """Load PixMo-Cap 'train' split from Parquet if available, otherwise from HF."""
    parquet_pattern = cfg.pixmo_parquet_dir / cfg.pixmo_parquet_glob
    matches = sorted(glob.glob(str(parquet_pattern)))
    if matches:
        print(f"Loading PixMo-Cap from Parquet: {matches[-1]}")
        ds_dict = load_dataset("parquet", data_files={"train": matches[-1]})
        ds = ds_dict["train"]
    else:
        print(f"No PixMo Parquet found at pattern {parquet_pattern}, loading from HF: {cfg.image_hf_dataset}")
        ds = load_dataset(cfg.image_hf_dataset, split="train")
    return ds

In [7]:
def build_image_datasets() -> Tuple[DataLoader, DataLoader]:
    """Create train/val DataLoaders for image-text data using PixMo-Cap."""
    print("\n=== Building Image-Text Datasets (PixMo-Cap) ===")
    hf_ds = get_pixmo_dataset()

    if cfg.max_image_samples is not None:
        hf_ds = hf_ds.select(range(min(cfg.max_image_samples, len(hf_ds))))

    # Expect PixMo-Cap columns: image_url + caption
    colnames = hf_ds.column_names
    if "image_url" not in colnames:
        raise ValueError(f"Expected 'image_url' column in PixMo dataset, found: {colnames}")
    if "caption" not in colnames:
        raise ValueError(f"Expected 'caption' column in PixMo dataset, found: {colnames}")

    ds = InMemoryImageTextDataset(
        hf_dataset=hf_ds,
        img_col="image_url",
        txt_col="caption",
        max_samples=None,  # already limited above
        image_size=cfg.image_size,
    )

    # Train/val split via indices
    n = len(ds)
    idx = np.random.permutation(n)
    split = int(n * (1.0 - cfg.sample_val_fraction))
    train_idx, val_idx = idx[:split], idx[split:]

    train_ds = Subset(ds, train_idx)
    val_ds = Subset(ds, val_idx)

    # Scale batch size with number of GPUs
    num_gpus = max(1, torch.cuda.device_count())
    batch_size = cfg.base_batch_size * num_gpus

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=min(8, os.cpu_count() or 4),
        collate_fn=collate_in_memory_images,
        pin_memory=torch.cuda.is_available(),
        drop_last=True,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=min(8, os.cpu_count() or 4),
        collate_fn=collate_in_memory_images,
        pin_memory=torch.cuda.is_available(),
        drop_last=False,
    )

    print(f"Image train size: {len(train_ds)} | val size: {len(val_ds)} | batch size: {batch_size}")
    return train_loader, val_loader

image_train_loader, image_val_loader = build_image_datasets()


=== Building Image-Text Datasets (PixMo-Cap) ===
Loading PixMo-Cap from Parquet: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/data/final_dataset/pixmo/pixmo_train.parquet

üì• Pre-loading 14000 images into memory...
   Image size: (224, 224)
   Using 32 parallel workers


Loading images: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14000/14000 [03:44<00:00, 62.41it/s] 


‚úÖ Loaded 14000 images into memory
   ‚ö†Ô∏è  99 images failed to load (using fallback)
Image train size: 12600 | val size: 1400 | batch size: 64


In [8]:
from typing import List, Dict, Any
import torch
import torchaudio

# cache resamplers per original sr
_resamplers: Dict[int, torchaudio.transforms.Resample] = {}



## Loading Model Architectures for Alignment

In [9]:
# Frozen encoders: CLIP (vision), Whisper (audio), MiniLM (text)

# Multimodal alignment config (from multimodal_alignment_perceiver.py)

mm_cfg = MultimodalAlignmentConfig()
mm_cfg.device = str(device)
mm_cfg.mrl_weight = cfg.mrl_weight
mm_cfg.clip_weight = cfg.clip_weight

print(mm_cfg)

MultimodalAlignmentConfig(vision_model_name='openai/clip-vit-base-patch32', audio_model_name='openai/whisper-base', text_model_name='sentence-transformers/all-MiniLM-L6-v2', llm_model_name='Qwen/Qwen2.5-1.5B-Instruct', d_vision=768, d_audio=512, d_text=384, perceiver_dim=512, num_latents=64, num_perceiver_layers=4, num_attn_heads=8, perceiver_mlp_ratio=4.0, perceiver_dropout=0.1, d_align=512, mrl_dims=(64, 128, 256, 512), llm_hidden_size=1536, num_prefix_tokens=64, batch_size=32, learning_rate=0.0001, weight_decay=0.01, num_epochs=10, warmup_ratio=0.1, max_grad_norm=1.0, temperature=0.07, mrl_weight=1.0, clip_weight=0.25, seed=42, device='cuda', dtype='float32')


In [10]:
# Vision encoder (CLIP)
vision_model_name = mm_cfg.vision_model_name
print("\nLoading CLIP vision encoder:", vision_model_name)
clip_processor = CLIPImageProcessor.from_pretrained(vision_model_name)
clip_vision = CLIPVisionModel.from_pretrained(vision_model_name)
clip_vision.to(device)
clip_vision.eval()
for p in clip_vision.parameters():
    p.requires_grad = False


Loading CLIP vision encoder: openai/clip-vit-base-patch32


In [11]:
# Text encoder (MiniLM)
text_model_name = mm_cfg.text_model_name
print("Loading text encoder:", text_model_name)
text_tokenizer = AutoTokenizer.from_pretrained(text_model_name)
text_encoder = AutoModel.from_pretrained(text_model_name)
text_encoder.to(device)
text_encoder.eval()
for p in text_encoder.parameters():
    p.requires_grad = False

# Audio encoder (Whisper)
audio_model_name = mm_cfg.audio_model_name
print("Loading audio encoder:", audio_model_name)
whisper_processor = WhisperProcessor.from_pretrained(audio_model_name)
whisper_encoder = WhisperModel.from_pretrained(audio_model_name).get_encoder()
whisper_encoder.to(device)
whisper_encoder.eval()
for p in whisper_encoder.parameters():
    p.requires_grad = False



Loading text encoder: sentence-transformers/all-MiniLM-L6-v2
Loading audio encoder: openai/whisper-base


In [12]:

def encode_images_to_features(images: List) -> torch.Tensor:
    """Encode a batch of PIL images to CLIP patch features (B, T, 768)."""
    with torch.no_grad():
        inputs = clip_processor(images=images, return_tensors="pt").to(device)
        outputs = clip_vision(**inputs)
        feats = outputs.last_hidden_state  # (B, T, 768)
    return feats


def encode_texts_to_features(texts: List[str]) -> torch.Tensor:
    """Encode a batch of texts to token features (B, L, 384)."""
    with torch.no_grad():
        tokens = text_tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=64,
            return_tensors="pt",
        ).to(device)
        outputs = text_encoder(**tokens)
        feats = outputs.last_hidden_state  # (B, L, 384)
    return feats

import numpy as np

def encode_audio_to_features(audio_batch: torch.Tensor, sr: int) -> torch.Tensor:
    """
    Encode a batch of padded 16 kHz audio waveforms using Whisper.

    Args:
        audio_batch: Tensor of shape (B, T_max_16k)
        sr: sampling rate (should be 16000 after collate)

    Returns:
        feats: Tensor of shape (B, T_feat_max, 512)
    """
    assert isinstance(audio_batch, torch.Tensor), "audio_batch must be Tensor[B, T]"
    if sr != 16000:
        # In case something upstream goes wrong, fail loudly
        raise ValueError(f"Expected 16 kHz audio for Whisper, got sr={sr}")

    B, T_max = audio_batch.shape
    features = []

    with torch.no_grad():
        for i in range(B):
            # 1) Take single waveform [T]
            wav = audio_batch[i].detach().cpu().float().numpy()  # (T,)

            # 2) Run WhisperProcessor on this single example at 16k
            inputs = whisper_processor(
                wav,
                sampling_rate=16000,
                return_tensors="pt",
            ).to(device)

            # 3) Encode with Whisper encoder
            out = whisper_encoder(inputs.input_features)          # (1, T_feat_i, 512)
            feat_i = out.last_hidden_state.squeeze(0)            # (T_feat_i, 512)
            features.append(feat_i)

    # 4) Pad along time dimension to get (B, T_feat_max, 512)
    max_T_feat = max(f.shape[0] for f in features)
    hidden_dim = features[0].shape[1]

    feats_batch = torch.zeros(B, max_T_feat, hidden_dim, device=device)
    for i, f in enumerate(features):
        feats_batch[i, : f.shape[0]] = f

    return feats_batch


In [13]:
# Multimodal alignment model (Perceiver backbone)

model = MultimodalAlignmentModel(mm_cfg).to(device)
print("MultimodalAlignmentModel created.")


MultimodalAlignmentModel created.


In [20]:
# Multi-GPU support
num_gpus = torch.cuda.device_count()
if num_gpus > 1:
    print(f"Wrapping model in DataParallel for {num_gpus} GPUs")
    model = nn.DataParallel(model)


def get_model_module(m: nn.Module) -> nn.Module:
    """Return underlying module if wrapped in DataParallel."""
    return m.module if isinstance(m, nn.DataParallel) else m


def count_trainable_params(m: nn.Module) -> int:
    return sum(p.numel() for p in m.parameters() if p.requires_grad)


print("Trainable parameters:", f"{count_trainable_params(model):,}")

Trainable parameters: 21,621,760


In [21]:
# Checkpoint utilities

def save_checkpoint(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    epoch: int,
    best_metric: float,
    out_dir: Path,
    tag: str,
):
    out_dir.mkdir(parents=True, exist_ok=True)
    state = {
        "epoch": epoch,
        "best_metric": best_metric,
        "model_state": get_model_module(model).state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "mm_config": mm_cfg.__dict__,
        "exp_config": asdict(cfg),
    }
    ckpt_path = out_dir / f"{tag}.pt"
    torch.save(state, ckpt_path)
    print(f"Saved checkpoint to: {ckpt_path}")


def load_checkpoint(
    model: nn.Module,
    optimizer: Optional[torch.optim.Optimizer],
    ckpt_path: Path,
    strict: bool = True,
):
    state = torch.load(ckpt_path, map_location=device, weights_only=False)
    get_model_module(model).load_state_dict(state["model_state"], strict=strict)
    if optimizer is not None and "optimizer_state" in state:
        optimizer.load_state_dict(state["optimizer_state"])
    print(f"Loaded checkpoint from epoch {state.get('epoch', 'N/A')} at {ckpt_path}")
    return state

In [22]:
# Training helpers ‚Äì MLP-only and Perceiver variants

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler(enabled=cfg.use_amp)


def encode_modality(
    model: MultimodalAlignmentModel,
    feats: torch.Tensor,
    mask: Optional[torch.Tensor],
    modality: str,
    use_perceiver: bool,
) -> torch.Tensor:
    """Encode features into aligned space.

    If use_perceiver=True, uses full Perceiver pipeline.
    If False, bypasses Perceiver and uses adapters + alignment projector directly.
    """
    m = get_model_module(model)

    if modality == "vision":
        adapter = m.vision_adapter
    elif modality == "audio":
        adapter = m.audio_adapter
    elif modality == "text":
        adapter = m.text_adapter
    else:
        raise ValueError(f"Unknown modality: {modality}")

    tokens = adapter(feats)  # (B, T, perceiver_dim)

    if use_perceiver:
        latents = m.perceiver(tokens, mask)  # (B, K, perceiver_dim)
    else:
        # MLP-only: treat tokens as latents and pool across sequence
        latents = tokens  # (B, T, perceiver_dim)

    z = m.alignment_projector(latents)  # (B, d_align)
    return z


def compute_alignment_losses(
    z_a: torch.Tensor,
    z_b: torch.Tensor,
) -> Dict[str, torch.Tensor]:
    """Compute MRL + CLIP contrastive losses between two modalities."""
    loss_mrl = matryoshka_loss(
        z_a,
        z_b,
        dims=mm_cfg.mrl_dims,
        temperature=mm_cfg.temperature,
    )
    loss_clip = contrastive_loss(
        z_a,
        z_b,
        temperature=mm_cfg.temperature,
    )
    loss_total = cfg.mrl_weight * loss_mrl + cfg.clip_weight * loss_clip
    return {
        "loss_total": loss_total,
        "loss_mrl": loss_mrl,
        "loss_clip": loss_clip,
    }

  scaler = GradScaler(enabled=cfg.use_amp)


In [23]:
@torch.no_grad()
def validate_alignment(
    model: nn.Module,
    image_val_loader: DataLoader,
    use_perceiver: bool,
    max_batches: int = 20,
    run=None,
    prefix: str = "val",
) -> Dict[str, float]:
    """Validate *vision‚Äìtext only* alignment.

    Computes VT embeddings, logs histograms to W&B, and returns summary stats.
    """
    model.eval()

    all_z_v = []
    all_z_t_img = []

    for b_idx, batch in enumerate(image_val_loader):
        if b_idx >= max_batches:
            break
        images = batch["images"]
        texts = batch["captions"]

        vision_feats = encode_images_to_features(images)
        text_feats = encode_texts_to_features(texts)

        z_v = encode_modality(model, vision_feats, None, "vision", use_perceiver)
        z_t = encode_modality(model, text_feats, None, "text", use_perceiver)

        all_z_v.append(z_v.cpu())
        all_z_t_img.append(z_t.cpu())

    if not all_z_v:
        return {}

    z_v_all = torch.cat(all_z_v, dim=0)
    z_t_all = torch.cat(all_z_t_img, dim=0)

    # Log histograms for VT
    log_alignment_histograms(run, z_v_all.to(device), z_t_all.to(device), f"{prefix}/vision_text")

    metrics = {
        f"{prefix}/num_samples_vt": float(z_v_all.size(0)),
    }

    if run is not None:
        run.log(metrics)

    print(f"Validation ({prefix}) | VT samples: {metrics[f'{prefix}/num_samples_vt']}")

    return metrics


In [24]:
# === Variant B: Perceiver + MLP + MRL (full model) ===

# Re-instantiate model to avoid interference from MLP-only run
model_perceiver = MultimodalAlignmentModel(mm_cfg).to(device)
if torch.cuda.device_count() > 1:
    model_perceiver = nn.DataParallel(model_perceiver)

print("Perceiver model trainable params:", count_trainable_params(model_perceiver))

optimizer_perceiver = torch.optim.AdamW(
    get_model_module(model_perceiver).parameters(),
    lr=cfg.learning_rate,
    weight_decay=cfg.weight_decay,
)

best_metric_perceiver = -float("inf")

run_perceiver = init_wandb(run_name="02_multimodal_alignment_perceiver_mrl", variant="perceiver_mrl")


Perceiver model trainable params: 21621760


In [25]:

for epoch in range(1, cfg.num_epochs_perceiver + 1):
    stats = train_one_epoch(
        epoch=epoch,
        model=model_perceiver,
        optimizer=optimizer_perceiver,
        image_loader=image_train_loader,
        use_perceiver=True,
        run=run_perceiver,
    )

    val_metrics = validate_alignment(
        model=model_perceiver,
        image_val_loader=image_val_loader,
        use_perceiver=True,
        max_batches=20,
        run=run_perceiver,
        prefix="val_perceiver",
    )

    current_metric = val_metrics.get("val_perceiver/num_samples_vt", 0.0)
    if current_metric >= best_metric_perceiver:
        best_metric_perceiver = current_metric
        save_checkpoint(
            model=model_perceiver,
            optimizer=optimizer_perceiver,
            epoch=epoch,
            best_metric=best_metric_perceiver,
            out_dir=PERCEIVER_DIR,
            tag="best",
        )

# Save final checkpoint
save_checkpoint(
    model=model_perceiver,
    optimizer=optimizer_perceiver,
    epoch=cfg.num_epochs_perceiver,
    best_metric=best_metric_perceiver,
    out_dir=PERCEIVER_DIR,
    tag="final",
)

if run_perceiver is not None:
    run_perceiver.finish()


NameError: name 'train_one_epoch' is not defined

In [None]:
# === Detailed alignment analysis for best checkpoints ===

import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import wandb


@torch.no_grad()
def collect_val_embeddings(
    model: nn.Module,
    image_val_loader: DataLoader,
    use_perceiver: bool,
    max_batches: int = 10,
):
    """Collect aligned embeddings for a subset of the image-text val set."""
    model.eval()

    all_z_v = []
    all_z_t = []

    for b_idx, batch in enumerate(image_val_loader):
        if b_idx >= max_batches:
            break

        images = batch["images"]
        texts = batch["captions"]

        vision_feats = encode_images_to_features(images)
        text_feats = encode_texts_to_features(texts)

        z_v = encode_modality(model, vision_feats, None, "vision", use_perceiver)
        z_t = encode_modality(model, text_feats, None, "text", use_perceiver)

        all_z_v.append(z_v.cpu())
        all_z_t.append(z_t.cpu())

    if not all_z_v:
        raise RuntimeError("No validation batches collected for embeddings.")

    z_v_all = torch.cat(all_z_v, dim=0)
    z_t_all = torch.cat(all_z_t, dim=0)

    return z_v_all, z_t_all


In [None]:


def make_similarity_heatmap(z_v: torch.Tensor, z_t: torch.Tensor, title: str = ""):
    """Create a cosine similarity heatmap between image and text embeddings."""
    z_v_norm = F.normalize(z_v, dim=-1)
    z_t_norm = F.normalize(z_t, dim=-1)

    sim_matrix = z_v_norm @ z_t_norm.T  # (N, N)

    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(sim_matrix.numpy(), aspect="auto", cmap="viridis")
    ax.set_title(title or "Cosine Similarity (vision ‚Üî text)")
    ax.set_xlabel("Text index")
    ax.set_ylabel("Image index")
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

    return fig, sim_matrix


In [None]:


def make_tsne_plot(z_v: torch.Tensor, z_t: torch.Tensor, title: str = ""):
    """Create a t-SNE 2D scatter of vision and text aligned embeddings."""
    z_v_np = z_v.numpy()
    z_t_np = z_t.numpy()

    n_v = z_v_np.shape[0]
    n_t = z_t_np.shape[0]

    all_embeds = np.concatenate([z_v_np, z_t_np], axis=0)
    perplexity = min(30, max(5, (n_v + n_t) // 5))
    tsne = TSNE(
        n_components=2,
        perplexity=perplexity,
        init="random",
        learning_rate="auto",
    )
    all_2d = tsne.fit_transform(all_embeds)

    v_2d = all_2d[:n_v]
    t_2d = all_2d[n_v:]

    fig, ax = plt.subplots(figsize=(6, 5))
    ax.scatter(v_2d[:, 0], v_2d[:, 1], label="vision", alpha=0.7, s=20)
    ax.scatter(t_2d[:, 0], t_2d[:, 1], label="text", alpha=0.7, s=20, marker="x")
    ax.set_title(title or "t-SNE of aligned embeddings (vision & text)")
    ax.set_xlabel("t-SNE dim 1")
    ax.set_ylabel("t-SNE dim 2")
    ax.legend()

    return fig



In [None]:

def analyze_best_checkpoint(
    model: nn.Module,
    ckpt_dir: Path,
    variant_name: str,
    use_perceiver: bool,
    max_batches: int = 10,
):
    """Load best checkpoint, compute alignment diagnostics, and log plots to W&B."""
    best_ckpt_path = ckpt_dir / "best.pt"
    if not best_ckpt_path.exists():
        print(f"‚ö†Ô∏è Best checkpoint not found at {best_ckpt_path}, skipping analysis.")
        return

    print(f"\n=== Alignment analysis for variant: {variant_name} ===")
    print(f"Loading best checkpoint from: {best_ckpt_path}")
    _ = load_checkpoint(model, optimizer=None, ckpt_path=best_ckpt_path, strict=True)

    # Collect embeddings
    z_v_all, z_t_all = collect_val_embeddings(
        model=model,
        image_val_loader=image_val_loader,
        use_perceiver=use_perceiver,
        max_batches=max_batches,
    )

    analysis_run = init_wandb(
        run_name=f"02_alignment_analysis_{variant_name}",
        variant=f"{variant_name}_analysis",
        extra_config={"checkpoint_path": str(best_ckpt_path)},
    )

    # Histograms
    if analysis_run is not None:
        log_alignment_histograms(
            analysis_run,
            z_a=z_v_all.to(device),
            z_b=z_t_all.to(device),
            prefix=f"analysis/{variant_name}/vision_text",
            max_points=512,
        )

    # Heatmap
    fig_heatmap, sim_matrix = make_similarity_heatmap(
        z_v_all,
        z_t_all,
        title=f"{variant_name}: cosine similarity (vision ‚Üî text)",
    )

    sim_np = sim_matrix.numpy()
    diag = np.diag(sim_np)
    off_diag = sim_np[~np.eye(sim_np.shape[0], dtype=bool)]

    diag_mean = float(diag.mean())
    diag_std = float(diag.std())
    off_mean = float(off_diag.mean())
    off_std = float(off_diag.std())

    if analysis_run is not None:
        analysis_run.log(
            {
                f"analysis/{variant_name}/sim_heatmap": wandb.Image(fig_heatmap),
                f"analysis/{variant_name}/diag_mean": diag_mean,
                f"analysis/{variant_name}/diag_std": diag_std,
                f"analysis/{variant_name}/offdiag_mean": off_mean,
                f"analysis/{variant_name}/offdiag_std": off_std,
            }
        )
    plt.close(fig_heatmap)

    print(
        f"Diagonal similarity: mean={diag_mean:.3f}, std={diag_std:.3f} | "
        f"Off-diagonal: mean={off_mean:.3f}, std={off_std:.3f}"
    )

    # t-SNE
    fig_tsne = make_tsne_plot(
        z_v_all,
        z_t_all,
        title=f"{variant_name}: t-SNE of aligned embeddings",
    )
    if analysis_run is not None:
        analysis_run.log({f"analysis/{variant_name}/tsne": wandb.Image(fig_tsne)})
        analysis_run.log(
            {
                f"analysis/{variant_name}/num_samples": float(z_v_all.size(0)),
                f"analysis/{variant_name}/d_align": float(z_v_all.size(1)),
            }
        )
        analysis_run.finish()
    plt.close(fig_tsne)


analyze_best_checkpoint(
    model=model_perceiver,
    ckpt_dir=PERCEIVER_DIR,
    variant_name="perceiver_mrl",
    use_perceiver=True,
    max_batches=10,
)


## Outputs & Next Steps

This notebook produces **Phase 1 multimodal alignment checkpoints**:

- **MLP + MRL (no Perceiver)**: `checkpoints/phase1_multimodal/mlp_mrl/`
- **Perceiver + MLP + MRL**: `checkpoints/phase1_multimodal/perceiver_mrl/`

Each directory contains:

- `best.pt` ‚Äì best model according to a validation proxy metric (placeholder for now)
- `final.pt` ‚Äì final epoch checkpoint

Additionally, W&B runs include:

- Training curves (`train/loss_*`, `val_*/num_samples_*`).
- Cosine similarity **histograms** (pos vs neg pairs) for vision‚Äìtext and audio‚Äìtext.
- Cosine similarity **heatmaps** (vision ‚Üî text) for best checkpoints.
- **t-SNE visualizations** of aligned embeddings (vision & text) for best checkpoints.

These checkpoints and diagnostics are the foundation for Phase 2 experiments:

- LLM decoder alignment (normal decoder).
- TRM decoder alignment.
- MoE decoder alignment.
- Full retrieval-based evaluation and Matryoshka ablations.
