# Training Learned Projection Heads for Embedding Alignment

This notebook trains projection heads to align CLIP, CLAP, and text embeddings in a shared semantic space.

**Purpose**: Improve cross-modal similarity by learning to map different embedding spaces together.

**Requirements**:
- GPU runtime (T4 or better)
- ~2-4 hours training time
- Training data: aligned (text, image, audio) triplets

## 1. Setup

In [None]:
# Install dependencies (Colab)
!pip install torch sentence-transformers transformers laion-clap

In [None]:
# Clone repository (if running on Colab)
# !git clone https://github.com/your-repo/MultiModal-Coherence-AI.git
# %cd MultiModal-Coherence-AI

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
from pathlib import Path

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Load Training Data

We need aligned (text, image, audio) triplets. Options:
1. Use generated experiment data from runs/
2. Use AudioCaps + LAION with synthetic pairing
3. Create custom aligned dataset

In [None]:
# Option 1: Load from experiment runs
import json
from glob import glob

def load_triplets_from_runs(runs_dir: str, max_samples: int = 5000):
    """Load triplets from experiment bundles."""
    bundles = glob(f"{runs_dir}/**/bundle.json", recursive=True)
    
    texts, images, audios = [], [], []
    
    for bundle_path in bundles[:max_samples]:
        try:
            with open(bundle_path) as f:
                bundle = json.load(f)
            
            text = bundle.get('outputs', {}).get('text', '')
            img_path = Path(bundle_path).parent / 'image' / 'output.png'
            aud_path = Path(bundle_path).parent / 'audio' / 'output.wav'
            
            if text and img_path.exists() and aud_path.exists():
                texts.append(text)
                images.append(str(img_path))
                audios.append(str(aud_path))
        except Exception as e:
            continue
    
    print(f"Loaded {len(texts)} triplets from runs")
    return texts, images, audios

# Uncomment to use:
# texts, images, audios = load_triplets_from_runs('../runs')

In [None]:
# Option 2: Create synthetic triplets (demonstration)
# In practice, you'd load real aligned data

def create_synthetic_demo_data(n_samples: int = 1000):
    """Create random embeddings for demonstration."""
    np.random.seed(42)
    
    # Simulate embeddings with some alignment
    base = np.random.randn(n_samples, 256)
    
    text_emb = base + np.random.randn(n_samples, 256) * 0.3
    image_emb = base + np.random.randn(n_samples, 256) * 0.3
    audio_emb = base + np.random.randn(n_samples, 256) * 0.3
    
    # Pad to 512 dim (typical embedding size)
    text_emb = np.pad(text_emb, ((0, 0), (0, 256)))
    image_emb = np.pad(image_emb, ((0, 0), (0, 256)))
    audio_emb = np.pad(audio_emb, ((0, 0), (0, 256)))
    
    return text_emb, image_emb, audio_emb

# For demonstration
text_emb, image_emb, audio_emb = create_synthetic_demo_data(2000)
print(f"Shapes: text={text_emb.shape}, image={image_emb.shape}, audio={audio_emb.shape}")

## 3. Create Dataset and Model

In [None]:
from src.training.learned_projection import LearnedProjection
from src.training.contrastive_trainer import (
    MultimodalTripletDataset,
    ContrastiveTrainer,
    TrainingConfig,
)

# Create datasets
n_train = int(len(text_emb) * 0.9)

train_dataset = MultimodalTripletDataset(
    text_embeddings=text_emb[:n_train],
    image_embeddings=image_emb[:n_train],
    audio_embeddings=audio_emb[:n_train],
)

val_dataset = MultimodalTripletDataset(
    text_embeddings=text_emb[n_train:],
    image_embeddings=image_emb[n_train:],
    audio_embeddings=audio_emb[n_train:],
)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

In [None]:
# Create model
model = LearnedProjection(
    text_dim=512,
    image_dim=512,
    audio_dim=512,
    shared_dim=256,
    hidden_dim=384,
    dropout=0.1,
)

# Count parameters
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")

## 4. Training

In [None]:
# Training configuration
config = TrainingConfig(
    batch_size=64,
    learning_rate=1e-4,
    weight_decay=1e-5,
    n_epochs=10,
    temperature=0.07,
    eval_every=100,
    save_every=500,
    device="cuda" if torch.cuda.is_available() else "cpu",
)

print(f"Training on: {config.device}")

In [None]:
# Create trainer
trainer = ContrastiveTrainer(
    model=model,
    config=config,
    output_dir=Path("../models/projection"),
)

In [None]:
# Train!
trained_model = trainer.train(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
)

## 5. Evaluation

In [None]:
import matplotlib.pyplot as plt

# Plot training history
if trainer.history:
    steps = [h['step'] for h in trainer.history]
    losses = [h['val_loss'] for h in trainer.history]
    
    plt.figure(figsize=(10, 4))
    plt.plot(steps, losses)
    plt.xlabel('Step')
    plt.ylabel('Validation Loss')
    plt.title('Training Progress')
    plt.grid(True)
    plt.show()

In [None]:
# Test projection quality
trained_model.eval()

with torch.no_grad():
    # Get some test samples
    test_text = torch.tensor(text_emb[n_train:n_train+100], dtype=torch.float32)
    test_image = torch.tensor(image_emb[n_train:n_train+100], dtype=torch.float32)
    test_audio = torch.tensor(audio_emb[n_train:n_train+100], dtype=torch.float32)
    
    # Project
    projected = trained_model(test_text, test_image, test_audio)
    
    # Compute similarities after projection
    p_text = projected['text']
    p_image = projected['image']
    p_audio = projected['audio']
    
    # Diagonal = positive pairs
    ti_sim = torch.sum(p_text * p_image, dim=-1).mean()
    ta_sim = torch.sum(p_text * p_audio, dim=-1).mean()
    ia_sim = torch.sum(p_image * p_audio, dim=-1).mean()
    
    print(f"\nPost-projection similarities:")
    print(f"  Text-Image: {ti_sim:.4f}")
    print(f"  Text-Audio: {ta_sim:.4f}")
    print(f"  Image-Audio: {ia_sim:.4f}")

## 6. Save Model

In [None]:
# Save final model
save_path = Path("../models/projection/learned_projection.pt")
trained_model.save(save_path)
print(f"Model saved to: {save_path}")

In [None]:
# Test loading
loaded_model = LearnedProjection.load(save_path)
print(f"Model loaded successfully. Config: {loaded_model.config}")

## 7. Integration with MSCI

To use the trained projections with MSCI:

In [None]:
# Example: Using ProjectedEmbedder
from src.training.learned_projection import ProjectedEmbedder

# Load your base embedder
# from src.embeddings.aligned_embeddings import AlignedEmbedder
# base_embedder = AlignedEmbedder()

# Wrap with projection
# projected_embedder = ProjectedEmbedder(base_embedder, loaded_model)

# Use for MSCI computation
# text_emb = projected_embedder.embed_text("A peaceful forest")
# image_emb = projected_embedder.embed_image("path/to/image.png")
# audio_emb = projected_embedder.embed_audio("path/to/audio.wav")

print("See src/training/learned_projection.py for ProjectedEmbedder usage")

## Next Steps

1. Train on real aligned data (AudioCaps, LAION, experiment runs)
2. Evaluate improvement in MSCI-human correlation
3. Tune hyperparameters (temperature, learning rate, architecture)
4. Compare MSCI v1 vs v2 with projected embeddings