# Multimodal Alignment Training

This notebook demonstrates a clean, modular implementation of:

1. **Phase 1**: Vision-Text Alignment (CLIP-style contrastive learning)
2. **Phase 2**: Connect aligned embeddings to LLM decoder

## Architecture Overview

```
Phase 1: Alignment
==================
Vision Encoder (CLIP, frozen) → MLP Adapter (trainable) → z_vision
Text Encoder (frozen)         → MLP Adapter (trainable) → z_text
                                      ↓
                              Contrastive Loss (MRL + CLIP)

Phase 2: LLM Integration
========================
z_vision → Vision-to-LLM Projector (trainable) → prefix tokens
                                                      ↓
                                               LLM Decoder → Generated text
```

In [1]:
# Auto-reload modules during development
%load_ext autoreload
%autoreload 2

In [4]:
import sys
sys.path.append("..")
import torch
import torch.nn as nn
from pathlib import Path

# Add module path if needed
# sys.path.insert(0, str(Path.cwd()))

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.9.0+cu128
CUDA available: True
CUDA device: NVIDIA H200


---
## Phase 0: Setup & Configuration

In [6]:
from imports.core import AlignmentConfig, get_device, set_seed

# Create configuration
cfg = AlignmentConfig(
    # Models
    vision_model_name="openai/clip-vit-base-patch32",
    text_model_name="sentence-transformers/all-MiniLM-L6-v2",
    llm_model_name="Qwen/Qwen2.5-1.5B-Instruct",  # Use smaller for testing
    
    # Architecture
    d_align=512,
    adapter_hidden_factor=2.0,
    dropout=0.1,
    
    # Training
    batch_size=32,
    learning_rate=1e-4,
    weight_decay=0.01,
    num_epochs=3,
    warmup_ratio=0.1,
    
    # Loss
    mrl_dims=(128, 256, 512),
    mrl_temperature=0.07,
    clip_temperature=0.07,
    
    # Misc
    seed=42,
    log_every=20,
)

# Set device and dtype
cfg.device = get_device()
cfg.dtype = torch.float32  # Use float32 for stability (float16 for GPU if memory is tight)

# Set seed for reproducibility
set_seed(cfg.seed)

print(f"Device: {cfg.device}")
print(f"Dtype: {cfg.dtype}")



Device: cuda
Dtype: torch.float32


---
## Phase 1: Vision-Text Alignment

### 1.1 Create the Alignment Model

In [7]:
from imports.core import VisionTextAligner, count_parameters

# Create alignment model
model = VisionTextAligner(cfg)

# Count parameters
params = count_parameters(model)
print(f"\nParameter count:")
print(f"  Total: {params['total']:,}")
print(f"  Trainable: {params['trainable']:,}")
print(f"  Frozen: {params['frozen']:,}")

[VisionEncoder] Loaded openai/clip-vit-base-patch32, hidden_size=768
[TextEncoder] Loaded sentence-transformers/all-MiniLM-L6-v2, hidden_size=384
[VisionTextAligner] d_vision=768, d_text=384, d_align=512

Parameter count:
  Total: 112,829,056
  Trainable: 2,659,840
  Frozen: 110,169,216


### 1.2 Test with a Quick Forward Pass

In [8]:
from PIL import Image
import requests
from io import BytesIO

# Load a test image
def load_test_image(url: str) -> Image.Image:
    response = requests.get(url, timeout=10)
    return Image.open(BytesIO(response.content)).convert("RGB")

# Test images
test_urls = [
    "https://picsum.photos/id/1/200/1",
    "https://picsum.photos/id/2/200/2",
]

test_images = [load_test_image(url) for url in test_urls]
test_texts = [
    "A cat sitting and looking at the camera",
    "A yellow labrador retriever dog",
]

print(f"Loaded {len(test_images)} test images")
print(f"Image sizes: {[img.size for img in test_images]}")

Loaded 2 test images
Image sizes: [(200, 1), (200, 2)]


In [9]:
# Quick forward pass test
with torch.no_grad():
    output = model(test_images, test_texts)

print(f"Loss: {output['loss'].item():.4f}")
print(f"MRL Loss: {output['loss_mrl']:.4f}")
print(f"CLIP Loss: {output['loss_clip']:.4f}")
print(f"z_vision shape: {output['z_vision'].shape}")
print(f"z_text shape: {output['z_text'].shape}")

Loss: 1.3732
MRL Loss: 0.6922
CLIP Loss: 0.6811
z_vision shape: torch.Size([2, 512])
z_text shape: torch.Size([2, 512])


### 1.3 Load Dataset

You can use:
- **Option A**: Pre-extracted features (faster, uses FeatureDataset)
- **Option B**: On-the-fly loading from HuggingFace (more flexible)

In [10]:
# Option B: Load from HuggingFace dataset (on-the-fly)
from datasets import load_dataset
from imports.data import ImageTextDataset, create_dataloader, collate_images

print("Loading PIXMO dataset...")

# Load a small subset for quick testing
# You can use other datasets like: allenai/pixmo-cap, conceptual_captions, etc.
try:
    hf_dataset = load_dataset(
        "allenai/pixmo-cap",
        split="train",
        # trust_remote_code=True,
    )
    
    # Take a subset for quick testing
    subset_size = 1000
    if len(hf_dataset) > subset_size:
        hf_dataset = hf_dataset.shuffle(seed=cfg.seed).select(range(subset_size))
    
    print(f"Dataset size: {len(hf_dataset)}")
    print(f"Columns: {hf_dataset.column_names}")
    
    # Create dataset wrapper
    train_dataset = ImageTextDataset(
        hf_dataset,
        image_column="image_url",
        text_column="caption",  # or "caption" depending on dataset
    )
    
    USE_HF_DATASET = True
except Exception as e:
    print(f"Could not load HF dataset: {e}")
    print("\nFalling back to synthetic data for demo...")
    USE_HF_DATASET = False

Loading PIXMO dataset...
Dataset size: 1000
Columns: ['image_url', 'caption', 'transcripts']
[ImageTextDataset] Using columns: image=image_url, text=caption


In [None]:
# # Fallback: Create synthetic dataset for demo
# if not USE_HF_DATASET:
#     from data import SimpleImageTextDataset
#     import numpy as np
    
#     # Create random images and dummy captions
#     n_samples = 200
    
#     synthetic_images = [
#         Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8))
#         for _ in range(n_samples)
#     ]
#     synthetic_texts = [f"This is synthetic image number {i}" for i in range(n_samples)]
    
#     train_dataset = SimpleImageTextDataset(synthetic_images, synthetic_texts)
#     print(f"Created synthetic dataset with {len(train_dataset)} samples")

In [11]:
# Create DataLoader
train_loader = create_dataloader(
    train_dataset,
    batch_size=cfg.batch_size,
    shuffle=True,
    num_workers=0,  # Use 0 for debugging, increase for speed
    collate_fn=collate_images,
    drop_last=True,
)

print(f"Train batches: {len(train_loader)}")

Train batches: 31


### 1.4 Train Phase 1 (Alignment)

In [None]:
from imports.train import train_alignment

# Train!
history = train_alignment(
    model=model,
    train_loader=train_loader,
    val_loader=None,  # Add validation loader if you have one
    num_epochs=cfg.num_epochs,
    log_every=cfg.log_every,
    save_dir="./checkpoints/phase1",
    use_features=False,  # We're using on-the-fly images
)


Starting Training
  Epochs: 3
  Train batches: 31
  LR: 0.0001
  Device: cuda



Epoch 0:   0%|          | 0/31 [00:00<?, ?it/s]




Epoch 0 - Train Loss: 6.2982


Epoch 1:   0%|          | 0/31 [00:00<?, ?it/s]

In [None]:
# Plot training curves
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.plot(history["train_loss"], marker='o')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.grid(True, alpha=0.3)

if history["R@1"]:
    plt.subplot(1, 2, 2)
    plt.plot(history["R@1"], label="R@1", marker='o')
    plt.plot(history["R@5"], label="R@5", marker='s')
    plt.plot(history["R@10"], label="R@10", marker='^')
    plt.xlabel("Epoch")
    plt.ylabel("Recall")
    plt.title("Retrieval Accuracy")
    plt.legend()
    plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### 1.5 Test Retrieval

In [None]:
from core import compute_retrieval_metrics, l2_normalize

# Encode our test images and texts
with torch.no_grad():
    z_img = model.encode_vision(test_images)
    z_txt = model.encode_text(test_texts)

# Compute similarity
z_img_norm = l2_normalize(z_img)
z_txt_norm = l2_normalize(z_txt)
sims = z_img_norm @ z_txt_norm.T

print("Similarity matrix (images × texts):")
print(sims.cpu().numpy().round(3))
print("\nExpected: High values on diagonal (matching pairs)")

---
## Phase 2: LLM Integration

Now we connect the aligned vision embeddings to an LLM decoder.

In [None]:
from imports.llm_integration import LLMConfig, MultimodalLLM

# Configure LLM integration
llm_cfg = LLMConfig(
    model_name="Qwen/Qwen2.5-1.5B-Instruct",  # Use smaller for testing
    max_new_tokens=128,
    temperature=0.7,
    num_prefix_tokens=8,
    freeze_llm=True,  # Phase 2a: only train projector
)

# Create multimodal model
mm_model = MultimodalLLM(
    aligner=model,
    llm_config=llm_cfg,
)

# Count trainable params
trainable = sum(p.numel() for p in mm_model.get_trainable_params())
print(f"\nPhase 2 trainable parameters: {trainable:,}")

### 2.1 Test Generation (Before Training)

In [None]:
# Test generation before any Phase 2 training
print("Testing generation (before Phase 2 training)...\n")

for i, img in enumerate(test_images):
    output = mm_model.generate(
        images=[img],
        prompt="Describe this image in detail:",
        max_new_tokens=64,
        temperature=0.7,
    )
    print(f"Image {i+1}:")
    print(f"  Ground truth: {test_texts[i]}")
    print(f"  Generated: {output}")
    print()

### 2.2 Phase 2 Training (Caption Generation)

In [None]:
from torch.optim import AdamW
from imports.llm_integration import create_caption_labels, train_multimodal_step
from tqdm.auto import tqdm

# Create optimizer for Phase 2 (projector only)
optimizer_p2 = AdamW(
    mm_model.get_trainable_params(),
    lr=1e-4,
    weight_decay=0.01,
)

# Simple training loop
num_steps = 100  # Adjust based on your data
mm_model.projector.train()

print("Phase 2 Training (projector only)...")
running_loss = 0.0

for step in tqdm(range(num_steps)):
    # Sample a batch (using our test data for demo)
    batch_images = test_images
    batch_texts = test_texts
    
    # Create labels for caption training
    labels = create_caption_labels(
        mm_model.llm.tokenizer,
        batch_texts,
        max_length=128,
        device=cfg.device,
    )
    
    batch = {
        "images": batch_images,
        **labels,
    }
    
    metrics = train_multimodal_step(mm_model, batch, optimizer_p2)
    running_loss += metrics["loss"]
    
    if (step + 1) % 20 == 0:
        avg_loss = running_loss / 20
        print(f"Step {step+1}: loss = {avg_loss:.4f}")
        running_loss = 0.0

print("\nPhase 2 training complete!")

### 2.3 Test Generation (After Training)

In [None]:
# Test generation after Phase 2 training
mm_model.projector.eval()

print("Testing generation (after Phase 2 training)...\n")

for i, img in enumerate(test_images):
    output = mm_model.generate(
        images=[img],
        prompt="Describe this image:",
        max_new_tokens=64,
        temperature=0.7,
    )
    print(f"Image {i+1}:")
    print(f"  Ground truth: {test_texts[i]}")
    print(f"  Generated: {output}")
    print()

---
## Summary

### What we built:

1. **Phase 1 - Alignment**
   - Froze CLIP vision encoder
   - Froze text encoder (sentence-transformers)
   - Trained MLP adapters to align vision & text embeddings
   - Used MRL + CLIP contrastive loss

2. **Phase 2 - LLM Integration**
   - Froze alignment model (from Phase 1)
   - Added Vision-to-LLM projector
   - Connected to Qwen LLM
   - Trained projector for caption generation

### Next steps:

- Add audio encoder (Whisper) for multimodal
- Add Perceiver for better sequence handling
- Use LoRA for efficient LLM fine-tuning
- Scale up to larger datasets

In [None]:
# Save final model
torch.save({
    "config": cfg,
    "llm_config": llm_cfg,
    "vision_adapter": model.vision_adapter.state_dict(),
    "text_adapter": model.text_adapter.state_dict(),
    "projector": mm_model.projector.state_dict(),
}, "./checkpoints/multimodal_final.pt")

print("Model saved!")