# üöÄ Multimodal Alignment Training with Perceiver Resampler

**Complete training pipeline for multimodal alignment**

This notebook implements:

## Training Phases

| Phase | Objective | Trainable | Loss |
|-------|-----------|-----------|------|
| **Level 1** | Align modalities | Adapters + Perceiver + Projector | MRL + CLIP Contrastive |
| **Level 2** | LLM Integration | LLM Projector (+ LoRA) | Language Modeling |

---

## Architecture Overview

```
Image/Audio/Text ‚Üí Frozen Encoder ‚Üí Adapter ‚Üí Perceiver ‚Üí Projector ‚Üí Aligned Embedding
                                                              ‚Üì
                                              (Phase 2) ‚Üí LLM Projector ‚Üí LLM ‚Üí Generated Text
```

---

## 0. Setup & Installation

In [None]:
# Install dependencies (uncomment if needed)
# !pip install torch transformers datasets accelerate sentencepiece
# !pip install librosa soundfile
# !pip install wandb  # Optional: for logging

In [None]:
import os
import math
import random
import json
import time
import warnings
from pathlib import Path
from dataclasses import dataclass, field, asdict
from typing import Dict, List, Optional, Tuple, Any, Union
from collections import defaultdict
from io import BytesIO

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torch.cuda.amp import autocast, GradScaler

import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

# Transformers
from transformers import (
    AutoTokenizer,
    AutoModel,
    AutoModelForCausalLM,
    CLIPVisionModel,
    CLIPImageProcessor,
    WhisperModel,
    WhisperProcessor,
    get_cosine_schedule_with_warmup,
)
from datasets import load_dataset
from PIL import Image
import requests

# Audio
try:
    import librosa
    HAS_LIBROSA = True
except ImportError:
    HAS_LIBROSA = False
    print("‚ö†Ô∏è librosa not installed")

# Logging (optional)
try:
    import wandb
    HAS_WANDB = True
except ImportError:
    HAS_WANDB = False

warnings.filterwarnings('ignore')

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üñ•Ô∏è  Device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 1. Configuration

In [None]:
@dataclass
class TrainingConfig:
    """
    Complete configuration for multimodal alignment training.
    
    Adjust these parameters based on your:
    - GPU memory (reduce batch_size if OOM)
    - Dataset size (increase epochs for small datasets)
    - Quality requirements (increase perceiver_layers for better quality)
    """
    
    # === Experiment ===
    experiment_name: str = "multimodal_align_v1"
    seed: int = 42
    
    # === Model Names ===
    vision_model_name: str = "openai/clip-vit-base-patch32"
    audio_model_name: str = "openai/whisper-base"
    text_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
    llm_model_name: str = "Qwen/Qwen2.5-0.5B-Instruct"  # Small for training
    
    # === Encoder Dimensions ===
    d_vision: int = 768      # CLIP ViT-B/32
    d_audio: int = 512       # Whisper-base
    d_text: int = 384        # MiniLM
    
    # === Perceiver Architecture ===
    perceiver_dim: int = 512
    num_latents: int = 64
    num_perceiver_layers: int = 4
    num_attn_heads: int = 8
    perceiver_mlp_ratio: float = 4.0
    dropout: float = 0.1
    
    # === Alignment ===
    d_align: int = 512
    mrl_dims: Tuple[int, ...] = (64, 128, 256, 512)
    
    # === LLM ===
    llm_hidden_size: int = 896  # Qwen2.5-0.5B
    
    # === Phase 1: Alignment Training ===
    phase1_epochs: int = 5
    phase1_batch_size: int = 32
    phase1_lr: float = 1e-4
    phase1_weight_decay: float = 0.01
    phase1_warmup_ratio: float = 0.1
    phase1_max_samples: int = 10000  # Set to None for full dataset
    
    # Loss weights
    mrl_weight: float = 1.0
    clip_weight: float = 0.5
    temperature: float = 0.07
    
    # === Phase 2: LLM Training ===
    phase2_epochs: int = 3
    phase2_batch_size: int = 8
    phase2_lr: float = 2e-5
    phase2_max_samples: int = 5000
    max_seq_length: int = 128
    
    # === Training ===
    gradient_accumulation_steps: int = 1
    max_grad_norm: float = 1.0
    use_amp: bool = True  # Mixed precision
    num_workers: int = 0
    
    # === Checkpointing ===
    checkpoint_dir: str = "./checkpoints"
    save_every_n_steps: int = 500
    eval_every_n_steps: int = 250
    
    # === Logging ===
    use_wandb: bool = False
    log_every_n_steps: int = 50


# Create config
cfg = TrainingConfig()

# Set seed
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(cfg.seed)

# Create directories
Path(cfg.checkpoint_dir).mkdir(parents=True, exist_ok=True)
Path(f"{cfg.checkpoint_dir}/phase1").mkdir(exist_ok=True)
Path(f"{cfg.checkpoint_dir}/phase2").mkdir(exist_ok=True)

print("üìã Configuration:")
print(f"   Experiment: {cfg.experiment_name}")
print(f"   Perceiver: {cfg.num_latents} latents, {cfg.num_perceiver_layers} layers")
print(f"   Phase 1: {cfg.phase1_epochs} epochs, batch={cfg.phase1_batch_size}")
print(f"   Phase 2: {cfg.phase2_epochs} epochs, batch={cfg.phase2_batch_size}")

## 2. Model Architecture

In [None]:
# ============================================================
# PERCEIVER COMPONENTS
# ============================================================

class FeedForward(nn.Module):
    """FFN with GELU activation."""
    def __init__(self, dim: int, mlp_ratio: float = 4.0, dropout: float = 0.1):
        super().__init__()
        hidden = int(dim * mlp_ratio)
        self.net = nn.Sequential(
            nn.Linear(dim, hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, dim),
            nn.Dropout(dropout),
        )
    
    def forward(self, x):
        return self.net(x)


class PerceiverAttention(nn.Module):
    """Multi-head attention with pre-LayerNorm."""
    def __init__(self, dim: int, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.ln_q = nn.LayerNorm(dim)
        self.ln_kv = nn.LayerNorm(dim)
        
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, q, kv, mask=None):
        B, N_q, D = q.shape
        N_kv = kv.shape[1]
        
        q = self.ln_q(q)
        kv = self.ln_kv(kv)
        
        Q = self.q_proj(q).view(B, N_q, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(kv).view(B, N_kv, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(kv).view(B, N_kv, self.num_heads, self.head_dim).transpose(1, 2)
        
        attn = (Q @ K.transpose(-2, -1)) * self.scale
        
        if mask is not None:
            attn = attn.masked_fill(~mask.bool().unsqueeze(1).unsqueeze(2), float('-inf'))
        
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        out = (attn @ V).transpose(1, 2).reshape(B, N_q, D)
        return self.out_proj(out)


class PerceiverLayer(nn.Module):
    """Single Perceiver layer: cross-attn + self-attn + FFN."""
    def __init__(self, dim: int, num_heads: int = 8, mlp_ratio: float = 4.0, dropout: float = 0.1):
        super().__init__()
        self.cross_attn = PerceiverAttention(dim, num_heads, dropout)
        self.self_attn = PerceiverAttention(dim, num_heads, dropout)
        self.ffn = FeedForward(dim, mlp_ratio, dropout)
        self.ln = nn.LayerNorm(dim)
    
    def forward(self, latents, tokens, mask=None):
        latents = latents + self.cross_attn(latents, tokens, mask)
        latents = latents + self.self_attn(latents, latents)
        latents = latents + self.ffn(self.ln(latents))
        return latents


class PerceiverResampler(nn.Module):
    """Perceiver Resampler: compress variable-length to fixed latents."""
    def __init__(self, dim: int, num_latents: int = 64, num_layers: int = 4, 
                 num_heads: int = 8, mlp_ratio: float = 4.0, dropout: float = 0.1):
        super().__init__()
        self.latents = nn.Parameter(torch.randn(num_latents, dim) * (dim ** -0.5))
        self.layers = nn.ModuleList([
            PerceiverLayer(dim, num_heads, mlp_ratio, dropout)
            for _ in range(num_layers)
        ])
        self.ln_out = nn.LayerNorm(dim)
    
    def forward(self, tokens, mask=None):
        B = tokens.shape[0]
        latents = self.latents.unsqueeze(0).expand(B, -1, -1)
        
        for layer in self.layers:
            latents = layer(latents, tokens, mask)
        
        return self.ln_out(latents)

In [None]:
# ============================================================
# ADAPTERS AND PROJECTORS
# ============================================================

class MLPAdapter(nn.Module):
    """MLP adapter for projecting encoder outputs."""
    def __init__(self, in_dim: int, out_dim: int, hidden_factor: float = 2.0, dropout: float = 0.1):
        super().__init__()
        hidden = int(in_dim * hidden_factor)
        self.net = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, out_dim),
        )
    
    def forward(self, x):
        return self.net(x)


class AlignmentProjector(nn.Module):
    """Project latents to alignment space with mean pooling."""
    def __init__(self, dim: int, out_dim: int):
        super().__init__()
        self.proj = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, out_dim),
        )
    
    def forward(self, latents):
        pooled = latents.mean(dim=1)  # (B, K, D) -> (B, D)
        return self.proj(pooled)


class LLMProjector(nn.Module):
    """Project latents to LLM embedding space."""
    def __init__(self, dim: int, llm_dim: int):
        super().__init__()
        self.proj = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, llm_dim),
        )
    
    def forward(self, latents):
        return self.proj(latents)

In [None]:
# ============================================================
# COMPLETE MODEL
# ============================================================

class MultimodalAlignmentModel(nn.Module):
    """
    Complete Multimodal Alignment Model with Perceiver Resampler.
    
    Phase 1 (Alignment): Uses adapters + perceiver + alignment_proj
    Phase 2 (LLM): Adds llm_proj for generation
    """
    
    def __init__(self, cfg: TrainingConfig):
        super().__init__()
        self.cfg = cfg
        
        # Modality adapters
        self.vision_adapter = MLPAdapter(cfg.d_vision, cfg.perceiver_dim, dropout=cfg.dropout)
        self.audio_adapter = MLPAdapter(cfg.d_audio, cfg.perceiver_dim, dropout=cfg.dropout)
        self.text_adapter = MLPAdapter(cfg.d_text, cfg.perceiver_dim, dropout=cfg.dropout)
        
        # Shared Perceiver
        self.perceiver = PerceiverResampler(
            dim=cfg.perceiver_dim,
            num_latents=cfg.num_latents,
            num_layers=cfg.num_perceiver_layers,
            num_heads=cfg.num_attn_heads,
            mlp_ratio=cfg.perceiver_mlp_ratio,
            dropout=cfg.dropout,
        )
        
        # Alignment projector (Phase 1)
        self.alignment_proj = AlignmentProjector(cfg.perceiver_dim, cfg.d_align)
        
        # LLM projector (Phase 2)
        self.llm_proj = LLMProjector(cfg.perceiver_dim, cfg.llm_hidden_size)
    
    def encode_modality(self, features, adapter, mask=None):
        """Encode features through adapter + perceiver."""
        tokens = adapter(features)
        latents = self.perceiver(tokens, mask)
        return latents
    
    def encode_vision(self, features, mask=None):
        latents = self.encode_modality(features, self.vision_adapter, mask)
        z = self.alignment_proj(latents)
        return z, latents
    
    def encode_audio(self, features, mask=None):
        latents = self.encode_modality(features, self.audio_adapter, mask)
        z = self.alignment_proj(latents)
        return z, latents
    
    def encode_text(self, features, mask=None):
        latents = self.encode_modality(features, self.text_adapter, mask)
        z = self.alignment_proj(latents)
        return z, latents
    
    def project_to_llm(self, latents):
        """Project latents to LLM embedding space."""
        return self.llm_proj(latents)
    
    def get_phase1_params(self):
        """Get parameters for Phase 1 training."""
        return list(self.vision_adapter.parameters()) + \
               list(self.audio_adapter.parameters()) + \
               list(self.text_adapter.parameters()) + \
               list(self.perceiver.parameters()) + \
               list(self.alignment_proj.parameters())
    
    def get_phase2_params(self):
        """Get parameters for Phase 2 training."""
        return list(self.llm_proj.parameters())


# Create model
model = MultimodalAlignmentModel(cfg).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nüìä Model Statistics:")
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable: {trainable_params:,}")

## 3. Frozen Encoders

In [None]:
# ============================================================
# LOAD FROZEN ENCODERS
# ============================================================

class FrozenEncoders:
    """Container for frozen encoders."""
    
    def __init__(self, cfg: TrainingConfig):
        self.cfg = cfg
        self.device = device
        
        print("\nüì¶ Loading frozen encoders...")
        
        # Vision (CLIP)
        self.vision_processor = CLIPImageProcessor.from_pretrained(cfg.vision_model_name)
        self.vision_encoder = CLIPVisionModel.from_pretrained(cfg.vision_model_name)
        self.vision_encoder.to(device).eval()
        for p in self.vision_encoder.parameters():
            p.requires_grad = False
        print(f"   ‚úì Vision: {cfg.vision_model_name}")
        
        # Audio (Whisper)
        self.audio_processor = WhisperProcessor.from_pretrained(cfg.audio_model_name)
        self.audio_encoder = WhisperModel.from_pretrained(cfg.audio_model_name).encoder
        self.audio_encoder.to(device).eval()
        for p in self.audio_encoder.parameters():
            p.requires_grad = False
        print(f"   ‚úì Audio: {cfg.audio_model_name}")
        
        # Text (Sentence-BERT)
        self.text_tokenizer = AutoTokenizer.from_pretrained(cfg.text_model_name)
        self.text_encoder = AutoModel.from_pretrained(cfg.text_model_name)
        self.text_encoder.to(device).eval()
        for p in self.text_encoder.parameters():
            p.requires_grad = False
        print(f"   ‚úì Text: {cfg.text_model_name}")
    
    @torch.no_grad()
    def encode_images(self, images: List[Image.Image]) -> torch.Tensor:
        """Encode images with CLIP."""
        inputs = self.vision_processor(images=images, return_tensors="pt")
        pixel_values = inputs["pixel_values"].to(self.device)
        outputs = self.vision_encoder(pixel_values=pixel_values)
        return outputs.last_hidden_state  # (B, 50, 768)
    
    @torch.no_grad()
    def encode_audio(self, waveforms: List[np.ndarray], sr: int = 16000) -> torch.Tensor:
        """Encode audio with Whisper."""
        inputs = self.audio_processor(
            waveforms, sampling_rate=sr, return_tensors="pt"
        )
        input_features = inputs["input_features"].to(self.device)
        outputs = self.audio_encoder(input_features)
        return outputs.last_hidden_state  # (B, T, 512)
    
    @torch.no_grad()
    def encode_texts(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode texts with Sentence-BERT."""
        tokens = self.text_tokenizer(
            texts, padding=True, truncation=True, max_length=128, return_tensors="pt"
        ).to(self.device)
        outputs = self.text_encoder(**tokens)
        return outputs.last_hidden_state, tokens["attention_mask"]  # (B, L, 384), (B, L)


# Load encoders
encoders = FrozenEncoders(cfg)

## 4. Loss Functions

In [None]:
# ============================================================
# LOSS FUNCTIONS
# ============================================================

def contrastive_loss(z_a: torch.Tensor, z_b: torch.Tensor, temperature: float = 0.07) -> torch.Tensor:
    """
    Symmetric InfoNCE contrastive loss.
    Brings matching pairs together, pushes non-matching apart.
    """
    z_a = F.normalize(z_a, dim=-1)
    z_b = F.normalize(z_b, dim=-1)
    
    logits = z_a @ z_b.T / temperature
    labels = torch.arange(z_a.size(0), device=z_a.device)
    
    loss_a2b = F.cross_entropy(logits, labels)
    loss_b2a = F.cross_entropy(logits.T, labels)
    
    return (loss_a2b + loss_b2a) / 2


def matryoshka_loss(
    z_a: torch.Tensor, 
    z_b: torch.Tensor, 
    dims: Tuple[int, ...] = (64, 128, 256, 512),
    temperature: float = 0.07,
) -> torch.Tensor:
    """
    Matryoshka Representation Learning loss.
    Trains at multiple embedding dimensions for flexible inference.
    """
    total_loss = 0.0
    count = 0
    
    for dim in dims:
        if dim > z_a.size(-1):
            continue
        z_a_trunc = z_a[:, :dim]
        z_b_trunc = z_b[:, :dim]
        total_loss += contrastive_loss(z_a_trunc, z_b_trunc, temperature)
        count += 1
    
    return total_loss / count if count > 0 else total_loss


def compute_alignment_metrics(z_a: torch.Tensor, z_b: torch.Tensor) -> Dict[str, float]:
    """Compute alignment quality metrics."""
    z_a = F.normalize(z_a, dim=-1)
    z_b = F.normalize(z_b, dim=-1)
    
    # Alignment: average distance between pairs (lower = better)
    alignment = (z_a - z_b).pow(2).sum(dim=-1).mean().item()
    
    # Similarity of matching pairs (higher = better)
    pos_sim = (z_a * z_b).sum(dim=-1).mean().item()
    
    # Recall@1
    sim_matrix = z_a @ z_b.T
    preds = sim_matrix.argmax(dim=-1)
    targets = torch.arange(z_a.size(0), device=z_a.device)
    recall_at_1 = (preds == targets).float().mean().item() * 100
    
    return {
        "alignment": alignment,
        "pos_sim": pos_sim,
        "R@1": recall_at_1,
    }

## 5. Data Loading

In [None]:
# ============================================================
# PHASE 1 DATASET: IMAGE-TEXT PAIRS
# ============================================================

class ImageTextDataset(Dataset):
    """Dataset for image-text pairs from COCO or similar."""
    
    def __init__(self, encoders: FrozenEncoders, max_samples: int = None):
        self.encoders = encoders
        
        print("\nüìö Loading image-text dataset...")
        
        # Try different datasets
        try:
            self.dataset = load_dataset("yerevann/coco-karpathy", split="train")
            self.img_col = "image"
            self.txt_col = "sentences"
            print("   Using: yerevann/coco-karpathy")
        except:
            try:
                self.dataset = load_dataset("nlphuji/flickr30k", split="test")
                self.img_col = "image"
                self.txt_col = "caption"
                print("   Using: nlphuji/flickr30k")
            except:
                # Fallback: create dummy data
                print("   ‚ö†Ô∏è Using dummy data (no dataset available)")
                self.dataset = None
        
        if self.dataset and max_samples:
            self.dataset = self.dataset.select(range(min(max_samples, len(self.dataset))))
        
        if self.dataset:
            print(f"   Samples: {len(self.dataset)}")
    
    def __len__(self):
        return len(self.dataset) if self.dataset else 1000
    
    def __getitem__(self, idx):
        if self.dataset is None:
            # Return dummy data
            return {
                "image": Image.new("RGB", (224, 224), color="white"),
                "caption": "A dummy caption for testing.",
            }
        
        item = self.dataset[idx]
        
        # Get image
        img = item[self.img_col]
        if not isinstance(img, Image.Image):
            img = Image.open(BytesIO(img["bytes"])).convert("RGB")
        else:
            img = img.convert("RGB")
        
        # Get caption
        caption = item[self.txt_col]
        if isinstance(caption, list):
            caption = caption[0] if caption else ""
        if isinstance(caption, dict):
            caption = caption.get("raw", str(caption))
        
        return {"image": img, "caption": str(caption)}


def collate_phase1(batch, encoders):
    """Collate function for Phase 1 training."""
    images = [item["image"] for item in batch]
    captions = [item["caption"] for item in batch]
    
    # Encode with frozen encoders
    vision_features = encoders.encode_images(images)
    text_features, text_mask = encoders.encode_texts(captions)
    
    return {
        "vision_features": vision_features,
        "text_features": text_features,
        "text_mask": text_mask,
    }

In [None]:
# ============================================================
# PHASE 2 DATASET: CAPTIONING
# ============================================================

class CaptioningDataset(Dataset):
    """Dataset for image captioning (Phase 2 LLM training)."""
    
    def __init__(self, encoders: FrozenEncoders, llm_tokenizer, max_samples: int = None):
        self.encoders = encoders
        self.tokenizer = llm_tokenizer
        
        print("\nüìö Loading captioning dataset...")
        
        try:
            self.dataset = load_dataset("yerevann/coco-karpathy", split="train")
            self.img_col = "image"
            self.txt_col = "sentences"
        except:
            self.dataset = None
            print("   ‚ö†Ô∏è Using dummy data")
        
        if self.dataset and max_samples:
            self.dataset = self.dataset.select(range(min(max_samples, len(self.dataset))))
        
        if self.dataset:
            print(f"   Samples: {len(self.dataset)}")
    
    def __len__(self):
        return len(self.dataset) if self.dataset else 500
    
    def __getitem__(self, idx):
        if self.dataset is None:
            return {
                "image": Image.new("RGB", (224, 224)),
                "caption": "A sample image.",
            }
        
        item = self.dataset[idx]
        
        img = item[self.img_col]
        if not isinstance(img, Image.Image):
            img = Image.open(BytesIO(img["bytes"])).convert("RGB")
        else:
            img = img.convert("RGB")
        
        caption = item[self.txt_col]
        if isinstance(caption, list):
            caption = caption[0] if caption else ""
        if isinstance(caption, dict):
            caption = caption.get("raw", str(caption))
        
        return {"image": img, "caption": str(caption)}

## 6. Training Utilities

In [None]:
# ============================================================
# TRAINING UTILITIES
# ============================================================

class TrainingLogger:
    """Simple logging for training metrics."""
    
    def __init__(self, use_wandb: bool = False, project_name: str = "multimodal_align"):
        self.use_wandb = use_wandb and HAS_WANDB
        self.metrics_history = defaultdict(list)
        
        if self.use_wandb:
            wandb.init(project=project_name)
    
    def log(self, metrics: Dict[str, float], step: int):
        for k, v in metrics.items():
            self.metrics_history[k].append((step, v))
        
        if self.use_wandb:
            wandb.log(metrics, step=step)
    
    def plot(self, metric_name: str):
        if metric_name not in self.metrics_history:
            return
        
        steps, values = zip(*self.metrics_history[metric_name])
        plt.figure(figsize=(10, 4))
        plt.plot(steps, values)
        plt.xlabel("Step")
        plt.ylabel(metric_name)
        plt.title(f"Training: {metric_name}")
        plt.grid(True, alpha=0.3)
        plt.show()


def save_checkpoint(model, optimizer, epoch, step, loss, path):
    """Save training checkpoint."""
    torch.save({
        "epoch": epoch,
        "step": step,
        "loss": loss,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }, path)
    print(f"   üíæ Saved checkpoint: {path}")


def load_checkpoint(model, optimizer, path):
    """Load training checkpoint."""
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    if optimizer:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    return checkpoint["epoch"], checkpoint["step"]

---

# üéØ PHASE 1: Alignment Training

**Objective**: Align vision, audio, and text in a shared embedding space.

**Trainable Components**:
- Vision Adapter
- Audio Adapter  
- Text Adapter
- Perceiver Resampler
- Alignment Projector

**Loss**: Matryoshka + CLIP Contrastive

---

In [None]:
# ============================================================
# PHASE 1: ALIGNMENT TRAINING
# ============================================================

def train_phase1(model, encoders, cfg, resume_from=None):
    """
    Phase 1: Train multimodal alignment.
    
    This trains the adapters, perceiver, and alignment projector
    to align vision and text embeddings in a shared space.
    """
    print("\n" + "="*70)
    print("üéØ PHASE 1: ALIGNMENT TRAINING")
    print("="*70)
    
    # Dataset
    dataset = ImageTextDataset(encoders, max_samples=cfg.phase1_max_samples)
    dataloader = DataLoader(
        dataset,
        batch_size=cfg.phase1_batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        collate_fn=lambda b: collate_phase1(b, encoders),
        drop_last=True,
    )
    
    # Optimizer (only Phase 1 params)
    optimizer = torch.optim.AdamW(
        model.get_phase1_params(),
        lr=cfg.phase1_lr,
        weight_decay=cfg.phase1_weight_decay,
    )
    
    # Scheduler
    total_steps = len(dataloader) * cfg.phase1_epochs
    warmup_steps = int(total_steps * cfg.phase1_warmup_ratio)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
    )
    
    # Mixed precision
    scaler = GradScaler() if cfg.use_amp else None
    
    # Logger
    logger = TrainingLogger(use_wandb=cfg.use_wandb)
    
    # Resume if specified
    start_epoch = 0
    global_step = 0
    if resume_from and Path(resume_from).exists():
        start_epoch, global_step = load_checkpoint(model, optimizer, resume_from)
        print(f"   Resumed from epoch {start_epoch}, step {global_step}")
    
    # Training loop
    model.train()
    best_loss = float('inf')
    
    for epoch in range(start_epoch, cfg.phase1_epochs):
        epoch_loss = 0.0
        epoch_metrics = defaultdict(float)
        
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{cfg.phase1_epochs}")
        
        for batch_idx, batch in enumerate(pbar):
            # Get features
            vision_feats = batch["vision_features"].to(device)
            text_feats = batch["text_features"].to(device)
            text_mask = batch["text_mask"].to(device)
            
            # Forward pass
            with autocast(enabled=cfg.use_amp):
                z_vision, _ = model.encode_vision(vision_feats)
                z_text, _ = model.encode_text(text_feats, text_mask)
                
                # Compute losses
                loss_mrl = matryoshka_loss(z_vision, z_text, cfg.mrl_dims, cfg.temperature)
                loss_clip = contrastive_loss(z_vision, z_text, cfg.temperature)
                
                loss = cfg.mrl_weight * loss_mrl + cfg.clip_weight * loss_clip
            
            # Backward pass
            optimizer.zero_grad()
            if scaler:
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
                optimizer.step()
            
            scheduler.step()
            global_step += 1
            
            # Metrics
            epoch_loss += loss.item()
            epoch_metrics["loss_mrl"] += loss_mrl.item()
            epoch_metrics["loss_clip"] += loss_clip.item()
            
            # Update progress bar
            pbar.set_postfix({
                "loss": f"{loss.item():.4f}",
                "lr": f"{scheduler.get_last_lr()[0]:.2e}",
            })
            
            # Logging
            if global_step % cfg.log_every_n_steps == 0:
                metrics = compute_alignment_metrics(z_vision.detach(), z_text.detach())
                logger.log({
                    "phase1/loss": loss.item(),
                    "phase1/loss_mrl": loss_mrl.item(),
                    "phase1/loss_clip": loss_clip.item(),
                    "phase1/R@1": metrics["R@1"],
                    "phase1/lr": scheduler.get_last_lr()[0],
                }, global_step)
            
            # Checkpointing
            if global_step % cfg.save_every_n_steps == 0:
                save_checkpoint(
                    model, optimizer, epoch, global_step, loss.item(),
                    f"{cfg.checkpoint_dir}/phase1/step_{global_step}.pt"
                )
        
        # End of epoch
        avg_loss = epoch_loss / len(dataloader)
        print(f"\n   Epoch {epoch+1} Summary:")
        print(f"   Loss: {avg_loss:.4f}")
        print(f"   MRL: {epoch_metrics['loss_mrl']/len(dataloader):.4f}")
        print(f"   CLIP: {epoch_metrics['loss_clip']/len(dataloader):.4f}")
        
        # Save best model
        if avg_loss < best_loss:
            best_loss = avg_loss
            save_checkpoint(
                model, optimizer, epoch, global_step, avg_loss,
                f"{cfg.checkpoint_dir}/phase1/best.pt"
            )
    
    # Final save
    save_checkpoint(
        model, optimizer, cfg.phase1_epochs, global_step, avg_loss,
        f"{cfg.checkpoint_dir}/phase1/final.pt"
    )
    
    print("\n‚úÖ Phase 1 Training Complete!")
    return logger

In [None]:
# Run Phase 1 Training
phase1_logger = train_phase1(model, encoders, cfg)

In [None]:
# Plot training curves
phase1_logger.plot("phase1/loss")
phase1_logger.plot("phase1/R@1")

---

# üß† PHASE 2: LLM Integration Training

**Objective**: Enable LLM to understand and generate from multimodal inputs.

**Trainable Components**:
- LLM Projector (maps perceiver latents to LLM embedding space)
- Optionally: LoRA adapters on LLM

**Frozen**:
- All Phase 1 components (adapters, perceiver, alignment projector)
- LLM base weights

**Loss**: Language Modeling (next-token prediction)

---

In [None]:
# ============================================================
# LOAD LLM FOR PHASE 2
# ============================================================

print("\nüì¶ Loading LLM for Phase 2...")

llm_tokenizer = AutoTokenizer.from_pretrained(cfg.llm_model_name)
if llm_tokenizer.pad_token is None:
    llm_tokenizer.pad_token = llm_tokenizer.eos_token

llm_model = AutoModelForCausalLM.from_pretrained(
    cfg.llm_model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Freeze LLM
for p in llm_model.parameters():
    p.requires_grad = False

print(f"   ‚úì LLM: {cfg.llm_model_name}")
print(f"   Parameters: {sum(p.numel() for p in llm_model.parameters()):,}")

In [None]:
# ============================================================
# PHASE 2: LLM INTEGRATION TRAINING
# ============================================================

def train_phase2(model, llm_model, llm_tokenizer, encoders, cfg):
    """
    Phase 2: Train LLM integration.
    
    This trains the LLM projector to map perceiver latents
    to the LLM embedding space for generation tasks.
    """
    print("\n" + "="*70)
    print("üß† PHASE 2: LLM INTEGRATION TRAINING")
    print("="*70)
    
    # Freeze Phase 1 components
    for p in model.vision_adapter.parameters():
        p.requires_grad = False
    for p in model.audio_adapter.parameters():
        p.requires_grad = False
    for p in model.text_adapter.parameters():
        p.requires_grad = False
    for p in model.perceiver.parameters():
        p.requires_grad = False
    for p in model.alignment_proj.parameters():
        p.requires_grad = False
    
    # Only train LLM projector
    for p in model.llm_proj.parameters():
        p.requires_grad = True
    
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"   Trainable parameters: {trainable:,}")
    
    # Dataset
    dataset = CaptioningDataset(encoders, llm_tokenizer, max_samples=cfg.phase2_max_samples)
    
    # Optimizer
    optimizer = torch.optim.AdamW(
        model.get_phase2_params(),
        lr=cfg.phase2_lr,
        weight_decay=0.01,
    )
    
    # Logger
    logger = TrainingLogger(use_wandb=cfg.use_wandb)
    
    # Mixed precision
    scaler = GradScaler() if cfg.use_amp else None
    
    # Get LLM device and dtype
    llm_device = llm_model.model.embed_tokens.weight.device
    llm_dtype = llm_model.model.embed_tokens.weight.dtype
    
    # Training loop
    model.train()
    global_step = 0
    best_loss = float('inf')
    
    for epoch in range(cfg.phase2_epochs):
        epoch_loss = 0.0
        
        # Manual batching for captioning
        indices = list(range(len(dataset)))
        random.shuffle(indices)
        
        pbar = tqdm(
            range(0, len(indices), cfg.phase2_batch_size),
            desc=f"Epoch {epoch+1}/{cfg.phase2_epochs}"
        )
        
        for batch_start in pbar:
            batch_indices = indices[batch_start:batch_start + cfg.phase2_batch_size]
            
            # Collect batch
            images = []
            captions = []
            for idx in batch_indices:
                item = dataset[idx]
                images.append(item["image"])
                captions.append(item["caption"])
            
            try:
                # Encode images
                with torch.no_grad():
                    vision_feats = encoders.encode_images(images)
                
                # Get perceiver latents and project to LLM space
                with autocast(enabled=cfg.use_amp):
                    _, latents = model.encode_vision(vision_feats)
                    prefix_embeds = model.project_to_llm(latents)  # (B, K, D_llm)
                
                # Move to LLM device/dtype
                prefix_embeds = prefix_embeds.to(device=llm_device, dtype=llm_dtype)
                
                # Tokenize captions
                text_inputs = llm_tokenizer(
                    captions,
                    padding=True,
                    truncation=True,
                    max_length=cfg.max_seq_length,
                    return_tensors="pt",
                ).to(llm_device)
                
                # Get text embeddings
                text_embeds = llm_model.get_input_embeddings()(text_inputs.input_ids)
                
                # Concatenate: [prefix_embeds, text_embeds]
                combined_embeds = torch.cat([prefix_embeds, text_embeds], dim=1)
                
                # Create labels: -100 for prefix (don't compute loss), then text tokens
                prefix_len = prefix_embeds.size(1)
                labels = torch.full((text_inputs.input_ids.size(0), prefix_len), -100, device=llm_device)
                labels = torch.cat([labels, text_inputs.input_ids], dim=1)
                
                # Create attention mask
                prefix_mask = torch.ones(prefix_embeds.size(0), prefix_len, device=llm_device)
                combined_mask = torch.cat([prefix_mask, text_inputs.attention_mask], dim=1)
                
                # Forward through LLM
                outputs = llm_model(
                    inputs_embeds=combined_embeds,
                    attention_mask=combined_mask,
                    labels=labels,
                )
                loss = outputs.loss
                
                # Backward
                optimizer.zero_grad()
                if scaler:
                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.llm_proj.parameters(), cfg.max_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.llm_proj.parameters(), cfg.max_grad_norm)
                    optimizer.step()
                
                global_step += 1
                epoch_loss += loss.item()
                
                pbar.set_postfix({"loss": f"{loss.item():.4f}"})
                
                # Logging
                if global_step % cfg.log_every_n_steps == 0:
                    logger.log({"phase2/loss": loss.item()}, global_step)
                
            except Exception as e:
                print(f"   ‚ö†Ô∏è Batch error: {e}")
                continue
        
        # End of epoch
        num_batches = len(indices) // cfg.phase2_batch_size
        avg_loss = epoch_loss / max(num_batches, 1)
        print(f"\n   Epoch {epoch+1} Summary:")
        print(f"   Loss: {avg_loss:.4f}")
        
        # Save best
        if avg_loss < best_loss:
            best_loss = avg_loss
            save_checkpoint(
                model, optimizer, epoch, global_step, avg_loss,
                f"{cfg.checkpoint_dir}/phase2/best.pt"
            )
    
    # Final save
    save_checkpoint(
        model, optimizer, cfg.phase2_epochs, global_step, avg_loss,
        f"{cfg.checkpoint_dir}/phase2/final.pt"
    )
    
    print("\n‚úÖ Phase 2 Training Complete!")
    return logger

In [None]:
# Run Phase 2 Training
phase2_logger = train_phase2(model, llm_model, llm_tokenizer, encoders, cfg)

In [None]:
# Plot Phase 2 training curve
phase2_logger.plot("phase2/loss")

---

# üß™ Evaluation & Inference

---

In [None]:
# ============================================================
# INFERENCE FUNCTIONS
# ============================================================

@torch.no_grad()
def encode_image(image: Image.Image, model, encoders):
    """Encode a single image to aligned embedding."""
    model.eval()
    vision_feats = encoders.encode_images([image])
    z, _ = model.encode_vision(vision_feats)
    return F.normalize(z, dim=-1)


@torch.no_grad()
def encode_text_query(text: str, model, encoders):
    """Encode a single text to aligned embedding."""
    model.eval()
    text_feats, mask = encoders.encode_texts([text])
    z, _ = model.encode_text(text_feats, mask)
    return F.normalize(z, dim=-1)


@torch.no_grad()
def generate_caption(image: Image.Image, model, encoders, llm_model, llm_tokenizer, max_new_tokens=50):
    """Generate caption for an image."""
    model.eval()
    
    # Get LLM device/dtype
    llm_device = llm_model.model.embed_tokens.weight.device
    llm_dtype = llm_model.model.embed_tokens.weight.dtype
    
    # Encode image
    vision_feats = encoders.encode_images([image])
    _, latents = model.encode_vision(vision_feats)
    prefix_embeds = model.project_to_llm(latents)
    prefix_embeds = prefix_embeds.to(device=llm_device, dtype=llm_dtype)
    
    # Prompt
    prompt = "Describe this image:"
    prompt_tokens = llm_tokenizer(prompt, return_tensors="pt").to(llm_device)
    prompt_embeds = llm_model.get_input_embeddings()(prompt_tokens.input_ids)
    
    # Concatenate
    combined = torch.cat([prefix_embeds, prompt_embeds], dim=1)
    
    # Generate
    outputs = llm_model.generate(
        inputs_embeds=combined,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        pad_token_id=llm_tokenizer.eos_token_id,
    )
    
    return llm_tokenizer.decode(outputs[0], skip_special_tokens=True)

In [None]:
# ============================================================
# TEST RETRIEVAL
# ============================================================

print("\nüß™ Testing Image-Text Retrieval...")

# Create test data
test_images = [Image.new("RGB", (224, 224), color=c) for c in ["red", "green", "blue", "yellow"]]
test_texts = ["a red image", "a green image", "a blue image", "a yellow image"]

# Encode
model.eval()
vision_feats = encoders.encode_images(test_images)
text_feats, text_mask = encoders.encode_texts(test_texts)

z_vision, _ = model.encode_vision(vision_feats)
z_text, _ = model.encode_text(text_feats, text_mask)

z_vision = F.normalize(z_vision, dim=-1)
z_text = F.normalize(z_text, dim=-1)

# Compute similarity
sim_matrix = z_text @ z_vision.T
print("\nSimilarity Matrix (Text ‚Üí Image):")
print(sim_matrix.cpu().numpy().round(3))

# Retrieval accuracy
preds = sim_matrix.argmax(dim=-1)
targets = torch.arange(len(test_texts), device=device)
accuracy = (preds == targets).float().mean().item() * 100
print(f"\nRetrieval Accuracy: {accuracy:.1f}%")

In [None]:
# ============================================================
# TEST GENERATION
# ============================================================

print("\nüß™ Testing Image Captioning...")

# Test with a sample image
test_image = Image.new("RGB", (224, 224), color="blue")

try:
    caption = generate_caption(test_image, model, encoders, llm_model, llm_tokenizer)
    print(f"\nGenerated Caption: {caption}")
except Exception as e:
    print(f"‚ö†Ô∏è Generation error: {e}")

---

# üíæ Save Final Model

---

In [None]:
# ============================================================
# SAVE COMPLETE MODEL
# ============================================================

def save_complete_model(model, cfg, path):
    """Save the complete trained model with config."""
    checkpoint = {
        "config": asdict(cfg),
        "model_state_dict": model.state_dict(),
        "architecture": {
            "perceiver_dim": cfg.perceiver_dim,
            "num_latents": cfg.num_latents,
            "num_perceiver_layers": cfg.num_perceiver_layers,
            "d_align": cfg.d_align,
            "llm_hidden_size": cfg.llm_hidden_size,
        },
    }
    torch.save(checkpoint, path)
    print(f"\nüíæ Saved complete model: {path}")


# Save
save_complete_model(model, cfg, f"{cfg.checkpoint_dir}/multimodal_align_complete.pt")

In [None]:
# ============================================================
# LOAD MODEL FOR INFERENCE
# ============================================================

def load_complete_model(path, device="cuda"):
    """Load a complete trained model."""
    checkpoint = torch.load(path, map_location=device)
    
    # Recreate config
    cfg = TrainingConfig(**checkpoint["config"])
    
    # Recreate model
    model = MultimodalAlignmentModel(cfg).to(device)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    
    return model, cfg


# Test loading
loaded_model, loaded_cfg = load_complete_model(f"{cfg.checkpoint_dir}/multimodal_align_complete.pt")
print("‚úÖ Model loaded successfully!")

---

# üìä Training Summary

---

In [None]:
print("\n" + "="*70)
print("üìä TRAINING SUMMARY")
print("="*70)

print(f"""
Experiment: {cfg.experiment_name}

Architecture:
  ‚Ä¢ Perceiver: {cfg.num_latents} latents √ó {cfg.num_perceiver_layers} layers
  ‚Ä¢ Alignment dim: {cfg.d_align}
  ‚Ä¢ MRL dims: {cfg.mrl_dims}

Phase 1 (Alignment):
  ‚Ä¢ Epochs: {cfg.phase1_epochs}
  ‚Ä¢ Batch size: {cfg.phase1_batch_size}
  ‚Ä¢ Learning rate: {cfg.phase1_lr}
  ‚Ä¢ Samples: {cfg.phase1_max_samples}

Phase 2 (LLM Integration):
  ‚Ä¢ Epochs: {cfg.phase2_epochs}
  ‚Ä¢ Batch size: {cfg.phase2_batch_size}
  ‚Ä¢ Learning rate: {cfg.phase2_lr}
  ‚Ä¢ Samples: {cfg.phase2_max_samples}

Checkpoints saved to: {cfg.checkpoint_dir}/
  ‚Ä¢ phase1/best.pt - Best alignment model
  ‚Ä¢ phase2/best.pt - Best LLM-integrated model  
  ‚Ä¢ multimodal_align_complete.pt - Final complete model
""")

print("="*70)
print("‚úÖ Training Complete!")
print("="*70)