# VLM Training with PixMo QA Dataset

This notebook implements a **Vision-Language Model for Question Answering** using:
1. Modular vision encoder (CLIP + Perceiver + MRL) from `edge_glass_modular/src/encoders`
2. Qwen decoder with LoRA from `edge_glass_modular/src/decoders`
3. PixMo QA dataset with question-answer pairs
4. Proper modular design following the edge_glass_modular architecture

## Architecture:

```
Image (B, 3, 336, 336)
  ↓
Vision Encoder (frozen aligned model)
  ↓ (B, num_latents, hidden_dim)
Projection to Qwen hidden dim
  ↓ (B, num_latents, qwen_dim)
Qwen Decoder with LoRA (trainable)
  ↓
Token Layout: [IMG_TOKENS] [QUESTION_TOKENS] [ANSWER_TOKENS]
  ↓
Loss on answer tokens only
```

## Key Features:
- Modular design using imports from `edge_glass_modular/src`
- Frozen aligned vision encoder
- Qwen2.5 decoder with LoRA fine-tuning
- Real QA dataset (not synthetic)
- Proper configuration management

## 1. Setup and Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
from pathlib import Path

# Add src to path
sys.path.insert(0, str(Path.cwd().parent / "src"))
Path.cwd().parent / "src"

In [None]:
# Import standard libraries
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
import wandb
from pathlib import Path
from typing import Optional, Dict
import warnings

# Import modular components from edge_glass_modular
from config import load_config
from encoders.vision import VisionEncoder
from decoders.qwen import QwenDecoder
from data.dataset_builder import PixmoQADataset
from data.transforms import get_image_transforms

# Set up matplotlib
%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Load Configuration

Load the experiment configuration from YAML file.

In [None]:
# Load configuration
config_path = "../configs/trm_vlm_qa.yaml"
config = load_config(config_path)

print(f"Loaded config: {config.name}")
print(f"\nDataset:")
print(f"  Train parquet: {config.dataset.train_parquet}")
print(f"  Val parquet: {config.dataset.val_parquet}")
print(f"  Test parquet: {config.dataset.test_parquet}")
print(f"  Image size: {config.dataset.image_size}")
print(f"  Max question length: {config.dataset.max_question_length}")
print(f"  Max answer length: {config.dataset.max_answer_length}")
print(f"  Batch size: {config.dataset.batch_size}")

print(f"\nDecoder:")
print(f"  Type: {config.decoder.type}")
print(f"  Model: {config.decoder.model_name}")
print(f"  Use LoRA: {config.decoder.use_lora}")
print(f"  Load in 8bit: {config.decoder.load_in_8bit}")

## 3. Load Aligned Vision Encoder

Load the pretrained Perceiver+MRL alignment model and freeze it.

In [None]:
# Load alignment config
alignment_config_path = "../configs/pixmo_alignment.yaml"
alignment_config = load_config(alignment_config_path)

# Load aligned model
aligned_model = MultimodalAlignmentModel(alignment_config).to(device)

# Load checkpoint
checkpoint_path = "checkpoints/pixmo_alignment/checkpoint_best.pt"
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
aligned_model.load_state_dict(checkpoint['model_state_dict'])
aligned_model.eval()

# Freeze all parameters
for param in aligned_model.parameters():
    param.requires_grad = False

print(f"Loaded aligned model from {checkpoint_path}")
print(f"  Epoch: {checkpoint['epoch']}")
print(f"  Val loss: {checkpoint['best_val_loss']:.4f}")
print(f"  Vision encoder output: (B, 64, 4096)")

## 4. Extract Vision Encoder Method

Create a clean interface to get vision embeddings.

In [None]:
@torch.no_grad()
def encode_images(images: torch.Tensor) -> torch.Tensor:
    """Encode images to vision tokens.
    
    Args:
        images: (B, 3, H, W)
    
    Returns:
        vision_tokens: (B, num_latents, 4096)
    """
    vision_output = aligned_model.vision_encoder(images)
    # Get sequence output (B, num_latents, dim)
    return vision_output.sequence

# Test
test_img = torch.randn(2, 3, 336, 336).to(device)
test_vision_tokens = encode_images(test_img)
print(f"Vision tokens shape: {test_vision_tokens.shape}")
print(f"Expected: (2, 64, 4096)")

## 5. Implement Plain Tiny Decoder Baseline

First, implement a simple baseline decoder without TRM recursion.

In [None]:
from decoders.trm import TRMConfig, TRMDecoder

class TinyVLMDecoder(nn.Module):
    """Plain tiny decoder baseline for VLM.
    
    Architecture:
        - Projects vision tokens from 4096 -> d_dec
        - Token layout: [IMG_TOKENS] [QUESTION_TOKENS] [ANSWER_TOKENS]
        - Causal masking
        - Loss only on answer tokens
    """
    
    def __init__(
        self,
        vocab_size: int,
        hidden_dim: int = 512,
        num_layers: int = 4,
        num_heads: int = 8,
        vision_token_dim: int = 4096,
        max_seq_len: int = 256,
    ):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.vision_token_dim = vision_token_dim
        
        # Project vision tokens to decoder dim
        self.vision_proj = nn.Linear(vision_token_dim, hidden_dim)
        
        # TRM decoder
        trm_config = TRMConfig(
            vocab_size=vocab_size,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            num_heads=num_heads,
            max_seq_len=max_seq_len,
        )
        self.decoder = TRMDecoder(trm_config)
    
    def forward(
        self,
        vision_tokens: torch.Tensor,  # (B, K_img, 4096)
        question_ids: torch.Tensor,   # (B, L_q)
        answer_ids: torch.Tensor,     # (B, L_a)
    ):
        """Forward pass with proper token layout and loss masking.
        
        Token layout: [IMG_TOKENS] [QUESTION_TOKENS] [ANSWER_TOKENS]
        Loss only on answer tokens.
        """
        batch_size = vision_tokens.shape[0]
        num_img_tokens = vision_tokens.shape[1]
        
        # Project vision tokens
        vision_emb = self.vision_proj(vision_tokens)  # (B, K_img, d_dec)
        
        # Embed question and answer tokens
        question_emb = self.decoder.embed_tokens(question_ids)  # (B, L_q, d_dec)
        answer_emb = self.decoder.embed_tokens(answer_ids)      # (B, L_a, d_dec)
        
        # Concatenate: [vision | question | answer]
        full_sequence = torch.cat([vision_emb, question_emb, answer_emb], dim=1)
        
        # Create labels: -100 for image and question tokens, actual IDs for answer
        img_labels = torch.full(
            (batch_size, num_img_tokens),
            fill_value=-100,
            dtype=torch.long,
            device=vision_tokens.device
        )
        question_labels = torch.full_like(question_ids, fill_value=-100)
        answer_labels = answer_ids
        
        # Concatenate labels
        full_labels = torch.cat([img_labels, question_labels, answer_labels], dim=1)
        
        # Pass through decoder layers
        hidden_states = full_sequence
        for layer in self.decoder.layers:
            hidden_states = layer(hidden_states)
        
        hidden_states = self.decoder.norm(hidden_states)
        logits = self.decoder.lm_head(hidden_states)
        
        # Compute loss (shift for next-token prediction)
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = full_labels[:, 1:].contiguous()
        
        loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        
        return {
            'loss': loss,
            'logits': logits,
        }
    
    @torch.no_grad()
    def generate(
        self,
        vision_tokens: torch.Tensor,
        question_ids: torch.Tensor,
        max_new_tokens: int = 32,
        temperature: float = 0.7,
    ):
        """Generate answer tokens autoregressively."""
        batch_size = vision_tokens.shape[0]
        
        # Project vision
        vision_emb = self.vision_proj(vision_tokens)
        question_emb = self.decoder.embed_tokens(question_ids)
        
        # Start with image + question
        current_emb = torch.cat([vision_emb, question_emb], dim=1)
        generated_ids = []
        
        for _ in range(max_new_tokens):
            # Forward pass
            hidden = current_emb
            for layer in self.decoder.layers:
                hidden = layer(hidden)
            hidden = self.decoder.norm(hidden)
            logits = self.decoder.lm_head(hidden)
            
            # Sample next token
            next_token_logits = logits[:, -1, :] / temperature
            next_token = torch.argmax(next_token_logits, dim=-1)
            
            generated_ids.append(next_token)
            
            # Embed and append
            next_emb = self.decoder.embed_tokens(next_token.unsqueeze(1))
            current_emb = torch.cat([current_emb, next_emb], dim=1)
        
        return torch.stack(generated_ids, dim=1)

## 6. Implement TRM-Style Recursive Decoder

Now implement the TRM version with latent recursion.

In [None]:
class TRMVLMDecoder(nn.Module):
    """TRM-style recursive decoder for VLM.
    
    Uses latent recursion for reasoning:
        x = context ([IMG_TOKENS] + [QUESTION_TOKENS])
        y = answer embeddings (learned or teacher-forced)
        z = latent reasoning state
    
    Inner recursion: Repeat n times
        concat = [x, y, z]
        concat' = TinyTransformer(concat)
        x', y', z' = split(concat')
        y, z = y', z'
    """
    
    def __init__(
        self,
        vocab_size: int,
        hidden_dim: int = 512,
        num_layers: int = 2,  # Small for tiny network
        num_heads: int = 8,
        vision_token_dim: int = 4096,
        max_seq_len: int = 256,
        num_inner_steps: int = 4,  # n = number of inner recursion steps
    ):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.num_inner_steps = num_inner_steps
        
        # Project vision tokens to decoder dim
        self.vision_proj = nn.Linear(vision_token_dim, hidden_dim)
        
        # Token embeddings
        self.embed_tokens = nn.Embedding(vocab_size, hidden_dim)
        
        # Tiny transformer for recursion
        trm_config = TRMConfig(
            vocab_size=vocab_size,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            num_heads=num_heads,
            max_seq_len=max_seq_len,
        )
        self.tiny_transformer = nn.ModuleList([
            TRMDecoder(trm_config).layers[i] for i in range(num_layers)
        ])
        self.norm = TRMDecoder(trm_config).norm
        
        # LM head
        self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
        self.lm_head.weight = self.embed_tokens.weight  # Tie weights
        
        # Learned initial z state
        self.z_init = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
    
    def latent_recursion(
        self,
        x: torch.Tensor,  # Context: (B, L_ctx, d)
        y: torch.Tensor,  # Answer: (B, L_ans, d)
        z: torch.Tensor,  # Latent: (B, L_ans, d)
    ):
        """Single step of latent recursion."""
        # Concatenate along sequence: [x, y, z]
        concat = torch.cat([x, y, z], dim=1)  # (B, L_ctx + 2*L_ans, d)
        
        # Pass through tiny transformer
        hidden = concat
        for layer in self.tiny_transformer:
            hidden = layer(hidden)
        
        # Split back
        L_ctx = x.shape[1]
        L_ans = y.shape[1]
        
        x_out = hidden[:, :L_ctx, :]
        y_out = hidden[:, L_ctx:L_ctx+L_ans, :]
        z_out = hidden[:, L_ctx+L_ans:, :]
        
        return x_out, y_out, z_out
    
    def forward(
        self,
        vision_tokens: torch.Tensor,  # (B, K_img, 4096)
        question_ids: torch.Tensor,   # (B, L_q)
        answer_ids: torch.Tensor,     # (B, L_a)
    ):
        """Forward pass with TRM recursion."""
        batch_size = vision_tokens.shape[0]
        L_ans = answer_ids.shape[1]
        
        # Project vision tokens
        vision_emb = self.vision_proj(vision_tokens)  # (B, K_img, d)
        
        # Embed question
        question_emb = self.embed_tokens(question_ids)  # (B, L_q, d)
        
        # Context x = [vision | question]
        x = torch.cat([vision_emb, question_emb], dim=1)  # (B, L_ctx, d)
        
        # Teacher-forced answer embeddings
        y = self.embed_tokens(answer_ids)  # (B, L_ans, d)
        
        # Initialize latent z
        z = self.z_init.expand(batch_size, L_ans, -1)  # (B, L_ans, d)
        
        # Inner recursion (n steps)
        for _ in range(self.num_inner_steps):
            x, y, z = self.latent_recursion(x, y, z)
        
        # Final answer from y
        y = self.norm(y)
        logits = self.lm_head(y)  # (B, L_ans, vocab_size)
        
        # Compute loss (standard next-token prediction on answer)
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = answer_ids[:, 1:].contiguous()
        
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        
        return {
            'loss': loss,
            'logits': logits,
        }
    
    @torch.no_grad()
    def generate(
        self,
        vision_tokens: torch.Tensor,
        question_ids: torch.Tensor,
        max_new_tokens: int = 32,
        temperature: float = 0.7,
    ):
        """Generate answer with TRM recursion."""
        batch_size = vision_tokens.shape[0]
        
        # Context
        vision_emb = self.vision_proj(vision_tokens)
        question_emb = self.embed_tokens(question_ids)
        x = torch.cat([vision_emb, question_emb], dim=1)
        
        # Start with empty answer (or learned start token)
        generated_ids = []
        
        # Generate autoregressively
        for step in range(max_new_tokens):
            # Current y from generated so far
            if len(generated_ids) == 0:
                # First token: use learned blank
                y = torch.zeros(batch_size, 1, self.hidden_dim, device=x.device)
                z = self.z_init.expand(batch_size, 1, -1)
            else:
                y = self.embed_tokens(torch.stack(generated_ids, dim=1))
                z = self.z_init.expand(batch_size, len(generated_ids), -1)
            
            # Run recursion
            for _ in range(self.num_inner_steps):
                x_temp, y, z = self.latent_recursion(x, y, z)
            
            # Get logits for last position
            y = self.norm(y)
            logits = self.lm_head(y[:, -1, :]) / temperature
            next_token = torch.argmax(logits, dim=-1)
            
            generated_ids.append(next_token)
        
        return torch.stack(generated_ids, dim=1)

## 7. Evaluation Metrics (EM and F1)

In [None]:
def normalize_answer(s: str) -> str:
    """Normalize answer text for evaluation."""
    # Remove punctuation
    s = ''.join(ch for ch in s if ch not in string.punctuation)
    # Lowercase and strip
    s = s.lower().strip()
    # Remove articles
    s = ' '.join([w for w in s.split() if w not in {'a', 'an', 'the'}])
    return s

def compute_exact_match(pred: str, target: str) -> float:
    """Compute exact match score."""
    return float(normalize_answer(pred) == normalize_answer(target))

def compute_f1(pred: str, target: str) -> float:
    """Compute token-level F1 score."""
    pred_tokens = normalize_answer(pred).split()
    target_tokens = normalize_answer(target).split()
    
    if len(pred_tokens) == 0 or len(target_tokens) == 0:
        return float(pred_tokens == target_tokens)
    
    common = Counter(pred_tokens) & Counter(target_tokens)
    num_common = sum(common.values())
    
    if num_common == 0:
        return 0.0
    
    precision = num_common / len(pred_tokens)
    recall = num_common / len(target_tokens)
    f1 = 2 * precision * recall / (precision + recall)
    
    return f1

def evaluate_qa(
    predictions: List[str],
    targets: List[str],
) -> Dict[str, float]:
    """Evaluate QA predictions."""
    em_scores = [compute_exact_match(p, t) for p, t in zip(predictions, targets)]
    f1_scores = [compute_f1(p, t) for p, t in zip(predictions, targets)]
    
    return {
        'em': np.mean(em_scores) * 100,
        'f1': np.mean(f1_scores) * 100,
    }

## 8. Setup Dataset and Dataloaders

In [None]:
# Initialize tokenizer (use GPT2 tokenizer for simplicity)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Get transforms
train_transforms = get_image_transforms(image_size=336, is_training=True)
val_transforms = get_image_transforms(image_size=336, is_training=False)

# Create datasets
train_dataset = PixMoQADataset(
    parquet_path="/home/hice1/vchopra37/scratch/projects/edge_glass/dataset/final_dataset/pixmo/pixmo_train.parquet",
    tokenizer=tokenizer,
    image_transforms=train_transforms,
    max_answer_length=32,
)

val_dataset = PixMoQADataset(
    parquet_path="/home/hice1/vchopra37/scratch/projects/edge_glass/dataset/final_dataset/pixmo/pixmo_val.parquet",
    tokenizer=tokenizer,
    image_transforms=val_transforms,
    max_answer_length=32,
)

# Create dataloaders
batch_size = 32
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_qa_batch,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    collate_fn=collate_qa_batch,
    pin_memory=True,
)

print(f"\nDataset sizes:")
print(f"  Train: {len(train_dataset):,}")
print(f"  Val: {len(val_dataset):,}")
print(f"\nDataLoader info:")
print(f"  Batch size: {batch_size}")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")

## 9. Visualize Sample Batch

In [None]:
# Get sample batch
sample_batch = next(iter(train_loader))

print("Batch contents:")
print(f"  Images: {sample_batch['images'].shape}")
print(f"  Question IDs: {sample_batch['question_ids'].shape}")
print(f"  Answer IDs: {sample_batch['answer_ids'].shape}")

print("\nSample QA pairs:")
for i in range(min(3, len(sample_batch['questions']))):
    print(f"\n  [{i+1}]")
    print(f"    Q: {sample_batch['questions'][i]}")
    print(f"    A: {sample_batch['answers'][i][:80]}...")

## 10. Initialize Decoder Model

Choose between TinyVLMDecoder (baseline) or TRMVLMDecoder (recursive).

In [None]:
# Configuration
USE_TRM = True  # Set to False for baseline
HIDDEN_DIM = 512
NUM_LAYERS = 4 if not USE_TRM else 2  # TRM uses smaller network
NUM_HEADS = 8
NUM_INNER_STEPS = 4  # Only for TRM

# Initialize model
if USE_TRM:
    model = TRMVLMDecoder(
        vocab_size=tokenizer.vocab_size,
        hidden_dim=HIDDEN_DIM,
        num_layers=NUM_LAYERS,
        num_heads=NUM_HEADS,
        num_inner_steps=NUM_INNER_STEPS,
    ).to(device)
    print(f"Initialized TRM VLM Decoder (n={NUM_INNER_STEPS} inner steps)")
else:
    model = TinyVLMDecoder(
        vocab_size=tokenizer.vocab_size,
        hidden_dim=HIDDEN_DIM,
        num_layers=NUM_LAYERS,
        num_heads=NUM_HEADS,
    ).to(device)
    print("Initialized Plain Tiny Decoder (baseline)")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel parameters:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")
print(f"  Hidden dim: {HIDDEN_DIM}")
print(f"  Num layers: {NUM_LAYERS}")
print(f"  Num heads: {NUM_HEADS}")

## 11. Training Setup

In [None]:
# Training configuration
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.05
MAX_GRAD_NORM = 1.0
EVAL_EVERY = 100
LOG_EVERY = 20

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    betas=(0.9, 0.95),
    weight_decay=WEIGHT_DECAY,
)

# Scheduler
total_steps = len(train_loader) * NUM_EPOCHS
warmup_steps = int(total_steps * WARMUP_RATIO)

def get_lr_scheduler(optimizer, warmup_steps, total_steps):
    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps
        else:
            progress = (step - warmup_steps) / (total_steps - warmup_steps)
            return 0.1 + 0.9 * 0.5 * (1 + np.cos(np.pi * progress))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

scheduler = get_lr_scheduler(optimizer, warmup_steps, total_steps)

print(f"Training configuration:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Total steps: {total_steps}")
print(f"  Warmup steps: {warmup_steps}")
print(f"  Eval every: {EVAL_EVERY} steps")

## 12. Training Loop

In [None]:
# Initialize wandb
USE_WANDB = True
if USE_WANDB:
    wandb.init(
        project="edge_glass_trm_vlm",
        name=f"{'trm' if USE_TRM else 'baseline'}_decoder_d{HIDDEN_DIM}_l{NUM_LAYERS}",
        config={
            'use_trm': USE_TRM,
            'hidden_dim': HIDDEN_DIM,
            'num_layers': NUM_LAYERS,
            'num_heads': NUM_HEADS,
            'num_inner_steps': NUM_INNER_STEPS if USE_TRM else None,
            'learning_rate': LEARNING_RATE,
            'batch_size': batch_size,
            'total_params': total_params,
        }
    )

# Training state
global_step = 0
best_val_loss = float('inf')
history = {'train_loss': [], 'val_loss': [], 'val_em': [], 'val_f1': []}

# Checkpoint directory
ckpt_dir = Path(f"checkpoints/trm_vlm_{'trm' if USE_TRM else 'baseline'}")
ckpt_dir.mkdir(parents=True, exist_ok=True)

print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)

model.train()

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    epoch_losses = []
    
    for batch_idx, batch in enumerate(pbar):
        # Move to device
        images = batch['images'].to(device)
        question_ids = batch['question_ids'].to(device)
        answer_ids = batch['answer_ids'].to(device)
        
        # Encode images (frozen)
        vision_tokens = encode_images(images)
        
        # Forward pass
        outputs = model(vision_tokens, question_ids, answer_ids)
        loss = outputs['loss']
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
        optimizer.step()
        scheduler.step()
        
        # Log
        epoch_losses.append(loss.item())
        global_step += 1
        
        if global_step % LOG_EVERY == 0:
            avg_loss = np.mean(epoch_losses[-LOG_EVERY:])
            pbar.set_postfix({'loss': f'{avg_loss:.4f}', 'lr': f'{scheduler.get_last_lr()[0]:.2e}'})
            
            if USE_WANDB:
                wandb.log({
                    'train/loss': avg_loss,
                    'train/lr': scheduler.get_last_lr()[0],
                    'step': global_step,
                })
    
    # Epoch-end evaluation
    print(f"\n  Epoch {epoch+1} average loss: {np.mean(epoch_losses):.4f}")
    history['train_loss'].append(np.mean(epoch_losses))
    
    # Validation (simple loss for now, full eval is expensive)
    model.eval()
    val_losses = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            images = batch['images'].to(device)
            question_ids = batch['question_ids'].to(device)
            answer_ids = batch['answer_ids'].to(device)
            
            vision_tokens = encode_images(images)
            outputs = model(vision_tokens, question_ids, answer_ids)
            val_losses.append(outputs['loss'].item())
    
    val_loss = np.mean(val_losses)
    history['val_loss'].append(val_loss)
    
    print(f"  Validation loss: {val_loss:.4f}")
    
    if USE_WANDB:
        wandb.log({
            'val/loss': val_loss,
            'epoch': epoch + 1,
        })
    
    # Save checkpoint
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_loss': best_val_loss,
            'config': {
                'use_trm': USE_TRM,
                'hidden_dim': HIDDEN_DIM,
                'num_layers': NUM_LAYERS,
                'num_heads': NUM_HEADS,
                'num_inner_steps': NUM_INNER_STEPS if USE_TRM else None,
            },
        }, ckpt_dir / "checkpoint_best.pt")
        print(f"  ✓ Saved best checkpoint (val_loss: {best_val_loss:.4f})")
    
    model.train()

print("\n" + "="*60)
print("TRAINING COMPLETED")
print("="*60)
print(f"Best validation loss: {best_val_loss:.4f}")

## 13. Full Evaluation with EM and F1

In [None]:
# Load best checkpoint
best_ckpt = torch.load(ckpt_dir / "checkpoint_best.pt", map_location=device)
model.load_state_dict(best_ckpt['model_state_dict'])
model.eval()

print("\nRunning full evaluation on validation set...")

all_predictions = []
all_targets = []

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Generating answers"):
        images = batch['images'].to(device)
        question_ids = batch['question_ids'].to(device)
        answers = batch['answers']
        
        # Encode images
        vision_tokens = encode_images(images)
        
        # Generate answers
        generated_ids = model.generate(
            vision_tokens,
            question_ids,
            max_new_tokens=32,
            temperature=0.7,
        )
        
        # Decode
        predictions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        
        all_predictions.extend(predictions)
        all_targets.extend(answers)

# Compute metrics
metrics = evaluate_qa(all_predictions, all_targets)

print("\n" + "="*60)
print("EVALUATION RESULTS")
print("="*60)
print(f"Exact Match (EM): {metrics['em']:.2f}%")
print(f"Token F1: {metrics['f1']:.2f}%")
print("="*60)

if USE_WANDB:
    wandb.log({
        'val/em': metrics['em'],
        'val/f1': metrics['f1'],
    })

# Show some examples
print("\nSample predictions:")
for i in range(min(10, len(all_predictions))):
    print(f"\n[{i+1}]")
    print(f"  Target: {all_targets[i][:80]}...")
    print(f"  Predicted: {all_predictions[i][:80]}...")
    print(f"  EM: {compute_exact_match(all_predictions[i], all_targets[i])}")
    print(f"  F1: {compute_f1(all_predictions[i], all_targets[i]):.3f}")

## 14. Visualize Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Metrics (if available)
if history.get('val_em'):
    axes[1].plot(history['val_em'], label='Exact Match', linewidth=2, marker='o')
if history.get('val_f1'):
    axes[1].plot(history['val_f1'], label='Token F1', linewidth=2, marker='s')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Score (%)')
axes[1].set_title('Validation Metrics')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(ckpt_dir / "training_curves.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"\nTraining curves saved to {ckpt_dir / 'training_curves.png'}")

## 15. Summary and Next Steps

In [None]:
print("\n" + "="*60)
print("EXPERIMENT SUMMARY")
print("="*60)

print(f"\nModel Configuration:")
print(f"  Type: {'TRM Recursive Decoder' if USE_TRM else 'Plain Tiny Decoder (Baseline)'}")
print(f"  Hidden dim: {HIDDEN_DIM}")
print(f"  Num layers: {NUM_LAYERS}")
print(f"  Num heads: {NUM_HEADS}")
if USE_TRM:
    print(f"  Inner recursion steps: {NUM_INNER_STEPS}")
print(f"  Total parameters: {total_params:,}")

print(f"\nTraining:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Best val loss: {best_val_loss:.4f}")

print(f"\nEvaluation Results:")
print(f"  Exact Match (EM): {metrics['em']:.2f}%")
print(f"  Token F1: {metrics['f1']:.2f}%")

print(f"\nOutput Files:")
print(f"  Checkpoint: {ckpt_dir / 'checkpoint_best.pt'}")
print(f"  Training curves: {ckpt_dir / 'training_curves.png'}")

print("\n" + "="*60)
print("NEXT STEPS")
print("="*60)
print("1. Run ablation: Switch USE_TRM flag and compare baseline vs TRM")
print("2. Try different recursion depths (num_inner_steps = {2, 4, 6, 8})")
print("3. Experiment with hidden_dim = {256, 512, 1024}")
print("4. Add outer deep recursion (T > 1)")
print("5. Evaluate on text-only baseline (no vision tokens)")
print("="*60)

print("\n✓ Notebook complete!")