# Pixmo Vision-Text Alignment with 4096-dim Embeddings

This notebook demonstrates the improved alignment training with:
1. **Pixmo Parquet Dataset** with embedded image bytes
2. **4096-dim embeddings** with MRL dimensions [2048, 1024, 512, 256, 128]
3. **Learnable attention pooling** instead of CLS/mean pooling
4. **Updated loss weights**: MRL=1.0, CLIP=0.25
5. **Text dropout** for better image reliance
6. **Improved training** with warmup+cosine LR, checkpointing, crash recovery
7. **Comprehensive logging** and visualization

**Hardware Requirements:**
- 1-2 H200 GPUs
- ~40-50GB GPU memory per GPU

## 1. Setup and Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path

# Add src to path
sys.path.insert(0, str(Path.cwd().parent / "src"))
Path.cwd().parent / "src"

PosixPath('/storage/ice1/1/0/vchopra37/projects/edge_glass/edge_glass_modular/src')

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from tqdm.auto import tqdm
import wandb
from datetime import datetime

In [None]:
# Import our modules
from config import load_config
from models import MultimodalAlignmentModel
from data.dataset_builder import build_image_datasets_from_parquet
from data.transforms import get_image_transforms
from training.improved_trainer import ImprovedMultimodalTrainer
from utils.visualization import TrainingVisualizer

In [None]:
# Set up matplotlib
%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Load Configuration

In [None]:
# Load experiment configuration
config_path = "../configs/pixmo_alignment.yaml"
config = load_config(config_path)

In [None]:


print(f"Experiment: {config.name}")
print(f"\nDataset Configuration:")
print(f"  Train Parquet: {config.dataset.train_parquet}")
print(f"  Val Parquet: {config.dataset.val_parquet}")
print(f"  Batch Size: {config.dataset.batch_size}")
print(f"  Text Dropout: {config.dataset.text_dropout_prob}")

print(f"\nModel Configuration:")
print(f"  Vision Encoder: {config.vision_encoder.model_name}")
print(f"  Projection Dim: {config.vision_encoder.projection_dim}")
print(f"  MRL Dimensions: {config.vision_encoder.mrl_dimensions}")
print(f"  Attention Pooling: {config.vision_encoder.use_attention_pooling}")
print(f"  Pooling Type: {config.vision_encoder.pooling_type}")

print(f"\nLoss Configuration:")
print(f"  CLIP Weight: {config.losses.contrastive}")
print(f"  MRL Weight: {config.losses.mrl}")
print(f"  Sample Single MRL Dim: {config.losses.sample_single_mrl_dim}")

print(f"\nOptimization Configuration:")
print(f"  Learning Rate: {config.optimization.lr}")
print(f"  Weight Decay: {config.optimization.weight_decay}")
print(f"  Max Grad Norm: {config.optimization.max_grad_norm}")
print(f"  Warmup Ratio: {config.optimization.warmup_ratio}")

## 3. Create Model

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create model
print("\nCreating multimodal alignment model with 4096-dim embeddings...")
model = MultimodalAlignmentModel(config)
model = model.to(device)

# Print parameter counts
print("\n" + "="*60)
model.print_parameter_counts()
print("="*60)

## 4. Load Pixmo Parquet Dataset

In [None]:
# Get image transforms
train_transforms = get_image_transforms(
    image_size=config.dataset.image_size,
    is_training=True
)

val_transforms = get_image_transforms(
    image_size=config.dataset.image_size,
    is_training=False
)

# Build datasets from parquet files
print("Loading Pixmo datasets from parquet files...")
datasets = build_image_datasets_from_parquet(
    cfg=config,
    train_parquet_path=config.dataset.train_parquet,
    val_parquet_path=config.dataset.val_parquet,
    test_parquet_path=config.dataset.test_parquet,
    train_transforms=train_transforms,
    val_transforms=val_transforms,
    max_text_length=config.dataset.max_text_length,
    text_dropout_prob=config.dataset.text_dropout_prob,
)

train_dataset = datasets['train']
val_dataset = datasets['val']

print(f"\nTrain samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
if 'test' in datasets:
    print(f"Test samples: {len(datasets['test'])}")

In [None]:
# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.dataset.batch_size,
    shuffle=True,
    num_workers=config.dataset.num_workers,
    pin_memory=config.dataset.pin_memory,
    persistent_workers=config.dataset.persistent_workers if config.dataset.num_workers > 0 else False,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.dataset.batch_size,
    shuffle=False,
    num_workers=config.dataset.num_workers,
    pin_memory=config.dataset.pin_memory,
)

print(f"Train batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

## 5. Visualize Sample Data

In [None]:
# Get a sample batch
sample_batch = next(iter(train_loader))
print(f"Batch keys: {sample_batch.keys()}")
print(f"Image tensor shape: {sample_batch['image'].shape}")
print(f"Number of captions: {len(sample_batch['text'])}")

# Visualize first 4 images with captions
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

for idx in range(min(4, len(sample_batch['image']))):
    # Denormalize image
    img = sample_batch['image'][idx].cpu()
    img = img * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img = img + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    img = torch.clamp(img, 0, 1)
    img = img.permute(1, 2, 0).numpy()
    
    axes[idx].imshow(img)
    caption_text = sample_batch['text'][idx] if sample_batch['text'][idx] else "[DROPPED TEXT]"
    axes[idx].set_title(caption_text[:60] + "...", fontsize=10)
    axes[idx].axis('off')

plt.tight_layout()
plt.show()

print("\nNote: Some captions may show '[DROPPED TEXT]' due to text dropout.")

## 6. Initialize Trainer

In [None]:
# Initialize trainer with improved features
use_wandb = False  # Set to True if you want to use WandB

trainer = ImprovedMultimodalTrainer(
    cfg=config,
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    use_wandb=use_wandb,
)

print("\nTrainer initialized with:")
print(f"  Checkpoint directory: {trainer.ckpt_dir}")
print(f"  Effective batch size: {trainer.effective_batch_size}")
print(f"  World size (GPUs): {trainer.world_size}")
print(f"  Starting epoch: {trainer.state.epoch}")
print(f"  Starting step: {trainer.state.global_step}")
print(f"  Best val loss: {trainer.state.best_val_loss}")

## 7. Train the Model

In [None]:
print("\nStarting training...")
print("="*60)

# Train
history = trainer.train()

print("\n" + "="*60)
print("Training completed!")
print(f"Best validation loss: {trainer.state.best_val_loss:.4f}")
print("="*60)

## 8. Visualize Training Results

In [None]:
# Initialize visualizer
output_dir = Path(config.trainer.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
visualizer = TrainingVisualizer(save_dir=output_dir)

# Plot training curves
visualizer.plot_training_curves(history)
print(f"Training curves saved to {output_dir / 'training_curves.png'}")

# Plot loss components
visualizer.plot_loss_components(history)
print(f"Loss components saved to {output_dir / 'loss_components.png'}")

# Plot LR schedule
if history['lr']:
    visualizer.plot_lr_schedule(history['lr'])
    print(f"LR schedule saved to {output_dir / 'lr_schedule.png'}")

# Display training curves
from IPython.display import Image as IPImage, display
display(IPImage(filename=str(output_dir / 'training_curves.png')))

## 9. Evaluate and Visualize Embeddings

In [None]:
# Load best model
best_checkpoint_path = trainer.ckpt_dir / "checkpoint_best.pt"
if best_checkpoint_path.exists():
    checkpoint = torch.load(best_checkpoint_path, map_location=device)
    model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
    model_to_load.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from {best_checkpoint_path}")
    print(f"Best epoch: {checkpoint['epoch']}")
    print(f"Best val loss: {checkpoint['best_val_loss']:.4f}")
else:
    print("No best checkpoint found, using current model state.")

model.eval()

In [None]:
# Collect embeddings from validation set
print("Collecting embeddings for visualization...")
vision_embs = []
text_embs = []
captions = []

with torch.no_grad():
    for batch in tqdm(list(val_loader)[:10]):  # First 10 batches
        images = batch['image'].to(device)
        texts = batch['text']
        
        outputs = model(images=images, texts=texts, return_embeddings=True)
        
        vision_embs.append(outputs.vision_emb.cpu().numpy())
        text_embs.append(outputs.text_emb.cpu().numpy())
        captions.extend(texts)

vision_embs = np.vstack(vision_embs)
text_embs = np.vstack(text_embs)

print(f"Vision embeddings: {vision_embs.shape}")
print(f"Text embeddings: {text_embs.shape}")

In [None]:
# Visualize embedding space
visualizer.plot_embedding_space(
    vision_embs=vision_embs,
    text_embs=text_embs,
    method="pca",
    n_samples=500,
)
print(f"Embedding space saved to {output_dir / 'embedding_space.png'}")

display(IPImage(filename=str(output_dir / 'embedding_space.png')))

In [None]:
# Plot similarity matrix
visualizer.plot_similarity_matrix(
    vision_embs=vision_embs,
    text_embs=text_embs,
    n_samples=50,
)
print(f"Similarity matrix saved to {output_dir / 'similarity_matrix.png'}")

display(IPImage(filename=str(output_dir / 'similarity_matrix.png')))

## 10. Test Image Retrieval

In [None]:
# Test image-to-text retrieval
def retrieve_top_k(query_emb, database_embs, k=5):
    """Retrieve top-k nearest neighbors."""
    similarities = np.dot(database_embs, query_emb)
    top_indices = np.argsort(similarities)[::-1][:k]
    return top_indices, similarities[top_indices]

# Test on a few examples
n_examples = 3
for i in range(n_examples):
    print(f"\n{'='*60}")
    print(f"Example {i+1}:")
    print(f"Ground Truth Caption: {captions[i]}")
    
    # Image-to-text retrieval
    top_indices, scores = retrieve_top_k(vision_embs[i], text_embs, k=5)
    
    print("\nTop 5 Retrieved Captions:")
    for rank, (idx, score) in enumerate(zip(top_indices, scores), 1):
        match = "✓ MATCH" if idx == i else ""
        print(f"  {rank}. [{score:.3f}] {captions[idx][:80]}... {match}")

## 11. Summary and Key Metrics

In [None]:
print("\n" + "="*60)
print("EXPERIMENT SUMMARY")
print("="*60)

print(f"\nConfiguration:")
print(f"  Model: {config.name}")
print(f"  Vision Encoder: {config.vision_encoder.model_name}")
print(f"  Embedding Dimension: {config.vision_encoder.projection_dim}")
print(f"  MRL Dimensions: {config.vision_encoder.mrl_dimensions}")
print(f"  Attention Pooling: {config.vision_encoder.use_attention_pooling}")
print(f"  Text Dropout: {config.dataset.text_dropout_prob}")

print(f"\nLoss Weights:")
print(f"  CLIP Loss: {config.losses.contrastive}")
print(f"  MRL Loss: {config.losses.mrl}")

print(f"\nTraining:")
print(f"  Epochs: {config.trainer.epochs}")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")
print(f"  Best validation loss: {trainer.state.best_val_loss:.4f}")

if history['val_i2t_r1']:
    best_r1 = max(history['val_i2t_r1']) * 100
    best_r5 = max(history['val_i2t_r5']) * 100
    best_r10 = max(history['val_i2t_r10']) * 100
    print(f"  Best R@1: {best_r1:.2f}%")
    print(f"  Best R@5: {best_r5:.2f}%")
    print(f"  Best R@10: {best_r10:.2f}%")

print(f"\nModel Parameters:")
print(f"  Trainable: {model.num_trainable_parameters:,}")
print(f"  Total: {model.num_total_parameters:,}")
print(f"  Trainable %: {100*model.num_trainable_parameters/model.num_total_parameters:.2f}%")

print(f"\nOutput Files:")
print(f"  Best model: {trainer.ckpt_dir / 'checkpoint_best.pt'}")
print(f"  Latest checkpoint: {trainer.ckpt_dir / 'checkpoint_latest.pt'}")
print(f"  Training history: {trainer.ckpt_dir / 'training_history.json'}")
print(f"  Visualizations: {output_dir}")

print("\n" + "="*60)
print("Next Steps:")
print("="*60)
print("1. Test MRL performance at different dimensions")
print("2. Evaluate on test set")
print("3. Try different attention pooling strategies")
print("4. Fine-tune with decoder for instruction following")
print("5. Export model for deployment")
print("="*60)

## 12. Save Final Metrics

In [None]:
# Save metrics table
if history['val_i2t_r1']:
    final_metrics = {
        'best_val_loss': trainer.state.best_val_loss,
        'best_r1': max(history['val_i2t_r1']),
        'best_r5': max(history['val_i2t_r5']),
        'best_r10': max(history['val_i2t_r10']),
        'final_train_loss': history['train_loss'][-1] if history['train_loss'] else 0,
        'final_val_loss': history['val_loss'][-1] if history['val_loss'] else 0,
        'embedding_dim': config.vision_encoder.projection_dim,
        'mrl_dims': str(config.vision_encoder.mrl_dimensions),
        'clip_weight': config.losses.contrastive,
        'mrl_weight': config.losses.mrl,
    }
    
    visualizer.save_metrics_table(final_metrics)
    print(f"Final metrics saved to {output_dir / 'metrics.csv'}")

print("\n✓ Notebook complete!")