# Experiment 1: Vision-Text Alignment with Qwen Decoder

This notebook demonstrates:
1. Loading and configuring the Vision-Text alignment model
2. Training with multimodal alignment losses (contrastive + MRL)
3. Instruction tuning with Qwen-7B decoder
4. Evaluation and caption generation

**Hardware Requirements:**
- 1-2 H200 GPUs
- ~30-40GB GPU memory per GPU

**Dataset:**
- PixMo-Cap: 20K image-caption pairs
- Open-Orca: 50K instruction samples

## 1. Setup and Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path

# Add src to path
sys.path.insert(0, str(Path.cwd().parent / "src"))
Path.cwd().parent / "src"

PosixPath('/storage/ice1/1/0/vchopra37/projects/edge_glass/edge_glass_modular/src')

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from tqdm.auto import tqdm
import wandb
from datetime import datetime


In [4]:
# Import our modules
from config import load_config
from models import MultimodalAlignmentModel
from data import ImageTextDataset, get_image_transforms



In [5]:


# Set up matplotlib
%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

PyTorch version: 2.9.0+cu128
CUDA available: True
GPU: NVIDIA H200
GPU Memory: 150.11 GB


## 2. Load Configuration

In [6]:
# Load experiment configuration
config_path = "../configs/vision_text_qwen.yaml"
config = load_config(config_path)

print(f"Experiment: {config.name}")
print(f"Description: {config.description}")
print(f"\nConfiguration:")
print(f"  Vision Encoder: {config.vision_encoder.model_name}")
print(f"  Text Encoder: {config.text_encoder.model_name}")
print(f"  Decoder: {config.decoder.model_name}")
print(f"  Use Perceiver: {config.vision_encoder.use_perceiver}")
print(f"  Use MRL: {config.vision_encoder.use_mrl}")
print(f"  MRL Dimensions: {config.vision_encoder.mrl_dimensions}")
print(f"  Batch Size: {config.dataset.batch_size}")
print(f"  Learning Rate: {config.optimization.learning_rate}")
print(f"  Epochs: {config.training.num_epochs}")

Experiment: vision_text_qwen
Description: Vision-Text alignment with Qwen-7B decoder for instruction tuning

Configuration:
  Vision Encoder: openai/clip-vit-large-patch14
  Text Encoder: sentence-transformers/all-MiniLM-L6-v2
  Decoder: Qwen/Qwen2.5-7B-Instruct
  Use Perceiver: False
  Use MRL: True
  MRL Dimensions: [512, 256, 128]
  Batch Size: 32
  Learning Rate: 0.0002
  Epochs: 3


## 3. Create Model

In [7]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create model
print("\nCreating multimodal alignment model...")
model = MultimodalAlignmentModel(config)
model = model.to(device)

# Print parameter counts
print("\n" + "="*60)
model.print_parameter_counts()
print("="*60)

Using device: cuda

Creating multimodal alignment model...


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

trainable params: 20,185,088 || all params: 7,635,801,600 || trainable%: 0.2643

Component                            Trainable           Total
------------------------------------------------------------
Vision Encoder                       1,049,600     304,229,376
Text Encoder                           394,240      23,107,456
Decoder                             20,185,088   7,635,801,600
------------------------------------------------------------
TOTAL                              873,132,032   8,814,641,536


## 4. Prepare Dataset and DataLoader

In [21]:
DATA_DIR = Path.cwd().parent / "data"
DATA_DIR

PosixPath('/storage/ice1/1/0/vchopra37/projects/edge_glass/edge_glass_modular/data')

In [22]:
# Get image transforms
train_transforms = get_image_transforms(
    image_size=config.dataset.image_size,
    is_training=True
)

val_transforms = get_image_transforms(
    image_size=config.dataset.image_size,
    is_training=False
)

# Create datasets
print("Loading datasets...")
train_dataset = ImageTextDataset(
    metadata_path=f"{str(DATA_DIR)}/pixmo/metadata.json",
    image_transforms=train_transforms,
    max_text_length=config.dataset.max_text_length,
)

# Create validation split (last 10% of data)
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    train_dataset, [train_size, val_size]
)

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.dataset.batch_size,
    shuffle=True,
    num_workers=config.dataset.num_workers,
    pin_memory=config.dataset.pin_memory,
    persistent_workers=config.dataset.persistent_workers,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.dataset.batch_size,
    shuffle=False,
    num_workers=config.dataset.num_workers,
    pin_memory=config.dataset.pin_memory,
)

print(f"\nTrain batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

Loading datasets...
Train samples: 15898
Validation samples: 1767

Train batches: 497
Validation batches: 56


## 5. Visualize Sample Data

In [23]:
# Get a sample batch
sample_batch = next(iter(train_loader))
print(f"Batch keys: {sample_batch.keys()}")
print(f"Image tensor shape: {sample_batch['image'].shape}")
print(f"Number of captions: {len(sample_batch['text'])}")

# Visualize first 4 images with captions
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

for idx in range(min(4, len(sample_batch['image']))):
    # Denormalize image
    img = sample_batch['image'][idx].cpu()
    img = img * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img = img + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    img = torch.clamp(img, 0, 1)
    img = img.permute(1, 2, 0).numpy()
    
    axes[idx].imshow(img)
    axes[idx].set_title(sample_batch['text'][idx][:60] + "...", fontsize=10)
    axes[idx].axis('off')

plt.tight_layout()
plt.show()

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/hice1/vchopra37/scratch/projects/edge_glass/edge_glass_env/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/hice1/vchopra37/scratch/projects/edge_glass/edge_glass_env/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 50, in fetch
    data = self.dataset.__getitems__(possibly_batched_index)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hice1/vchopra37/scratch/projects/edge_glass/edge_glass_env/lib/python3.12/site-packages/torch/utils/data/dataset.py", line 416, in __getitems__
    return [self.dataset[self.indices[idx]] for idx in indices]
            ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
  File "/storage/ice1/1/0/vchopra37/projects/edge_glass/edge_glass_modular/src/data/dataset_builder.py", line 44, in __getitem__
    image = read_image(item["image_path"]).float() / 255.0
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hice1/vchopra37/scratch/projects/edge_glass/edge_glass_env/lib/python3.12/site-packages/torchvision/io/image.py", line 336, in read_image
    data = read_file(path)
           ^^^^^^^^^^^^^^^
  File "/home/hice1/vchopra37/scratch/projects/edge_glass/edge_glass_env/lib/python3.12/site-packages/torchvision/io/image.py", line 64, in read_file
    data = torch.ops.image.read_file(str(path))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hice1/vchopra37/scratch/projects/edge_glass/edge_glass_env/lib/python3.12/site-packages/torch/_ops.py", line 1255, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: [Errno 2] No such file or directory: 'data/pixmo/pixmo_0008667.jpg'


## 6. Setup Training

In [None]:
# Create optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config.optimization.learning_rate,
    weight_decay=config.optimization.weight_decay,
    betas=config.optimization.betas,
)

# Create learning rate scheduler
num_training_steps = len(train_loader) * config.training.num_epochs
num_warmup_steps = config.optimization.warmup_steps

from torch.optim.lr_scheduler import LambdaLR

def lr_lambda(current_step):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
    return max(0.0, 0.5 * (1.0 + np.cos(np.pi * progress)))

scheduler = LambdaLR(optimizer, lr_lambda)

# Mixed precision scaler
scaler = torch.cuda.amp.GradScaler() if config.optimization.mixed_precision != "no" else None

# Initialize WandB (optional)
use_wandb = False  # Set to True if you want to use WandB
if use_wandb:
    wandb.init(
        project=config.training.wandb_project,
        name=config.training.wandb_run_name,
        config=config.to_dict(),
    )

print(f"Optimizer: {optimizer.__class__.__name__}")
print(f"Learning rate: {config.optimization.learning_rate}")
print(f"Warmup steps: {num_warmup_steps}")
print(f"Total training steps: {num_training_steps}")
print(f"Mixed precision: {config.optimization.mixed_precision}")

## 7. Training Functions

In [None]:
def train_epoch(model, dataloader, optimizer, scheduler, scaler, epoch, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    total_contrastive_loss = 0
    total_mrl_loss = 0
    
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}")
    
    for batch_idx, batch in enumerate(pbar):
        # Move to device
        images = batch['image'].to(device)
        texts = batch['text']
        
        # Forward pass
        if scaler is not None:
            with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                outputs = model(images=images, texts=texts)
                loss = outputs.loss
        else:
            outputs = model(images=images, texts=texts)
            loss = outputs.loss
        
        # Backward pass
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimization.max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimization.max_grad_norm)
            optimizer.step()
        
        optimizer.zero_grad()
        scheduler.step()
        
        # Track losses
        total_loss += loss.item()
        if outputs.losses is not None:
            if 'contrastive_loss' in outputs.losses:
                total_contrastive_loss += outputs.losses['contrastive_loss'].item()
            if 'mrl_loss_512' in outputs.losses:
                total_mrl_loss += outputs.losses['mrl_loss_512'].item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'lr': f"{scheduler.get_last_lr()[0]:.2e}"
        })
        
        # Log to WandB
        if use_wandb and batch_idx % config.training.logging_steps == 0:
            wandb.log({
                'train/loss': loss.item(),
                'train/lr': scheduler.get_last_lr()[0],
                'train/epoch': epoch,
            })
    
    avg_loss = total_loss / len(dataloader)
    avg_contrastive = total_contrastive_loss / len(dataloader)
    avg_mrl = total_mrl_loss / len(dataloader)
    
    return {
        'loss': avg_loss,
        'contrastive_loss': avg_contrastive,
        'mrl_loss': avg_mrl,
    }


@torch.no_grad()
def evaluate(model, dataloader, device):
    """Evaluate model."""
    model.eval()
    total_loss = 0
    
    # Collect embeddings for retrieval metrics
    vision_embeddings = []
    text_embeddings = []
    
    pbar = tqdm(dataloader, desc="Evaluating")
    
    for batch in pbar:
        images = batch['image'].to(device)
        texts = batch['text']
        
        outputs = model(images=images, texts=texts, return_embeddings=True)
        
        if outputs.loss is not None:
            total_loss += outputs.loss.item()
        
        # Collect embeddings
        if outputs.vision_emb is not None:
            vision_embeddings.append(outputs.vision_emb.cpu())
        if outputs.text_emb is not None:
            text_embeddings.append(outputs.text_emb.cpu())
    
    avg_loss = total_loss / len(dataloader)
    
    # Compute retrieval metrics
    metrics = {'loss': avg_loss}
    
    if vision_embeddings and text_embeddings:
        vision_embs = torch.cat(vision_embeddings, dim=0)
        text_embs = torch.cat(text_embeddings, dim=0)
        
        # Compute similarity matrix
        similarity = torch.matmul(vision_embs, text_embs.t())
        
        # Image-to-text retrieval
        ranks = torch.argsort(similarity, dim=1, descending=True)
        correct_indices = torch.arange(len(vision_embs)).unsqueeze(1)
        
        # R@1, R@5, R@10
        r1 = (ranks[:, :1] == correct_indices).any(dim=1).float().mean().item()
        r5 = (ranks[:, :5] == correct_indices).any(dim=1).float().mean().item()
        r10 = (ranks[:, :10] == correct_indices).any(dim=1).float().mean().item()
        
        metrics.update({
            'i2t_r1': r1,
            'i2t_r5': r5,
            'i2t_r10': r10,
        })
    
    return metrics

## 8. Train the Model

In [None]:
# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'val_i2t_r1': [],
    'val_i2t_r5': [],
}

best_val_loss = float('inf')
output_dir = Path(config.training.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

print("\nStarting training...")
print(f"Output directory: {output_dir}")
print("="*60)

for epoch in range(config.training.num_epochs):
    print(f"\nEpoch {epoch + 1}/{config.training.num_epochs}")
    
    # Train
    train_metrics = train_epoch(model, train_loader, optimizer, scheduler, scaler, epoch, device)
    print(f"Train Loss: {train_metrics['loss']:.4f}")
    print(f"  - Contrastive: {train_metrics['contrastive_loss']:.4f}")
    print(f"  - MRL: {train_metrics['mrl_loss']:.4f}")
    
    # Evaluate
    val_metrics = evaluate(model, val_loader, device)
    print(f"\nValidation Loss: {val_metrics['loss']:.4f}")
    if 'i2t_r1' in val_metrics:
        print(f"  - Image→Text R@1:  {val_metrics['i2t_r1']*100:.2f}%")
        print(f"  - Image→Text R@5:  {val_metrics['i2t_r5']*100:.2f}%")
        print(f"  - Image→Text R@10: {val_metrics['i2t_r10']*100:.2f}%")
    
    # Update history
    history['train_loss'].append(train_metrics['loss'])
    history['val_loss'].append(val_metrics['loss'])
    if 'i2t_r1' in val_metrics:
        history['val_i2t_r1'].append(val_metrics['i2t_r1'])
        history['val_i2t_r5'].append(val_metrics['i2t_r5'])
    
    # Log to WandB
    if use_wandb:
        wandb.log({
            'epoch': epoch,
            'val/loss': val_metrics['loss'],
            'val/i2t_r1': val_metrics.get('i2t_r1', 0),
            'val/i2t_r5': val_metrics.get('i2t_r5', 0),
        })
    
    # Save best model
    if val_metrics['loss'] < best_val_loss:
        best_val_loss = val_metrics['loss']
        checkpoint_path = output_dir / "best_model.pt"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': val_metrics['loss'],
            'config': config,
        }, checkpoint_path)
        print(f"\n✓ Saved best model to {checkpoint_path}")
    
    # Save checkpoint
    if (epoch + 1) % (config.training.save_steps // len(train_loader)) == 0:
        checkpoint_path = output_dir / f"checkpoint_epoch_{epoch+1}.pt"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': val_metrics['loss'],
            'config': config,
        }, checkpoint_path)
        print(f"✓ Saved checkpoint to {checkpoint_path}")

print("\n" + "="*60)
print("Training completed!")
print(f"Best validation loss: {best_val_loss:.4f}")
print("="*60)

## 9. Plot Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True)

# Retrieval metrics
if history['val_i2t_r1']:
    axes[1].plot([r*100 for r in history['val_i2t_r1']], label='R@1', marker='o')
    axes[1].plot([r*100 for r in history['val_i2t_r5']], label='R@5', marker='s')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Recall (%)')
    axes[1].set_title('Image→Text Retrieval')
    axes[1].legend()
    axes[1].grid(True)

plt.tight_layout()
plt.savefig(output_dir / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nTraining curves saved to {output_dir / 'training_curves.png'}")

## 10. Test Generation on Sample Images

In [None]:
# Load best model
checkpoint = torch.load(output_dir / "best_model.pt")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print("Loaded best model for inference")
print(f"Best epoch: {checkpoint['epoch']}")
print(f"Best val loss: {checkpoint['val_loss']:.4f}")

In [None]:
# Get sample images
sample_batch = next(iter(val_loader))
images = sample_batch['image'][:4].to(device)
ground_truth = sample_batch['text'][:4]

# Generate captions
print("Generating captions...\n")
with torch.no_grad():
    generated_captions = model.generate(
        images=images,
        prompt="Describe this image in detail:",
        max_new_tokens=50,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
    )

# Visualize results
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
axes = axes.flatten()

for idx in range(4):
    # Denormalize image
    img = images[idx].cpu()
    img = img * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img = img + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    img = torch.clamp(img, 0, 1)
    img = img.permute(1, 2, 0).numpy()
    
    axes[idx].imshow(img)
    axes[idx].axis('off')
    
    # Add captions
    title = f"Ground Truth: {ground_truth[idx][:80]}...\n\n"
    title += f"Generated: {generated_captions[idx][:80]}..."
    axes[idx].set_title(title, fontsize=9, wrap=True)

plt.tight_layout()
plt.savefig(output_dir / 'generated_captions.png', dpi=150, bbox_inches='tight')
plt.show()

# Print full captions
print("\n" + "="*60)
print("Full Generated Captions:")
print("="*60)
for idx, (gt, gen) in enumerate(zip(ground_truth, generated_captions)):
    print(f"\nImage {idx+1}:")
    print(f"  Ground Truth: {gt}")
    print(f"  Generated:    {gen}")
    print("-"*60)

## 11. Embedding Space Visualization

In [None]:
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

# Collect embeddings from validation set
print("Collecting embeddings for visualization...")
vision_embs = []
text_embs = []
captions = []

with torch.no_grad():
    for batch in tqdm(list(val_loader)[:10]):  # First 10 batches
        images = batch['image'].to(device)
        texts = batch['text']
        
        outputs = model(images=images, texts=texts, return_embeddings=True)
        
        vision_embs.append(outputs.vision_emb.cpu())
        text_embs.append(outputs.text_emb.cpu())
        captions.extend(texts)

vision_embs = torch.cat(vision_embs, dim=0).numpy()
text_embs = torch.cat(text_embs, dim=0).numpy()

print(f"Vision embeddings: {vision_embs.shape}")
print(f"Text embeddings: {text_embs.shape}")

# Apply PCA
all_embs = np.vstack([vision_embs, text_embs])
pca = PCA(n_components=2)
all_embs_2d = pca.fit_transform(all_embs)

vision_embs_2d = all_embs_2d[:len(vision_embs)]
text_embs_2d = all_embs_2d[len(vision_embs):]

# Plot
plt.figure(figsize=(12, 10))
plt.scatter(vision_embs_2d[:, 0], vision_embs_2d[:, 1], 
           c='blue', alpha=0.6, s=50, label='Vision')
plt.scatter(text_embs_2d[:, 0], text_embs_2d[:, 1], 
           c='red', alpha=0.6, s=50, label='Text')

# Draw lines connecting matching pairs
for i in range(min(50, len(vision_embs_2d))):  # First 50 pairs
    plt.plot([vision_embs_2d[i, 0], text_embs_2d[i, 0]], 
            [vision_embs_2d[i, 1], text_embs_2d[i, 1]], 
            'gray', alpha=0.2, linewidth=0.5)

plt.xlabel('PCA Component 1')
plt.ylabel('PCA Component 2')
plt.title('Vision-Text Embedding Space (PCA Projection)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(output_dir / 'embedding_space.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nEmbedding space visualization saved to {output_dir / 'embedding_space.png'}")

## 12. Summary and Next Steps

In [None]:
print("\n" + "="*60)
print("EXPERIMENT SUMMARY")
print("="*60)
print(f"\nConfiguration:")
print(f"  - Model: {config.name}")
print(f"  - Vision Encoder: {config.vision_encoder.model_name}")
print(f"  - Text Encoder: {config.text_encoder.model_name}")
print(f"  - Decoder: {config.decoder.model_name}")
print(f"  - Use MRL: {config.vision_encoder.use_mrl}")
print(f"  - Use Perceiver: {config.vision_encoder.use_perceiver}")

print(f"\nTraining:")
print(f"  - Epochs: {config.training.num_epochs}")
print(f"  - Training samples: {len(train_dataset)}")
print(f"  - Validation samples: {len(val_dataset)}")
print(f"  - Best validation loss: {best_val_loss:.4f}")
if history['val_i2t_r1']:
    print(f"  - Best R@1: {max(history['val_i2t_r1'])*100:.2f}%")
    print(f"  - Best R@5: {max(history['val_i2t_r5'])*100:.2f}%")

print(f"\nModel Parameters:")
print(f"  - Trainable: {model.num_trainable_parameters:,}")
print(f"  - Total: {model.num_total_parameters:,}")
print(f"  - Trainable %: {100*model.num_trainable_parameters/model.num_total_parameters:.2f}%")

print(f"\nOutput Files:")
print(f"  - Best model: {output_dir / 'best_model.pt'}")
print(f"  - Training curves: {output_dir / 'training_curves.png'}")
print(f"  - Generated captions: {output_dir / 'generated_captions.png'}")
print(f"  - Embedding space: {output_dir / 'embedding_space.png'}")

print("\n" + "="*60)
print("Next Steps:")
print("="*60)
print("1. Experiment with different MRL dimensions")
print("2. Try enabling Perceiver resampler")
print("3. Test on your own images")
print("4. Run tri-modal experiment (Notebook 2)")
print("5. Compare with TRM decoder (Notebook 3)")
print("="*60)

## 13. Save Final Results

In [None]:
import json

# Save training history
history_path = output_dir / 'training_history.json'
with open(history_path, 'w') as f:
    json.dump(history, f, indent=2)

print(f"Training history saved to {history_path}")

# Close WandB
if use_wandb:
    wandb.finish()

print("\n✓ Notebook complete!")