# VLM Training with PixMo QA Dataset - FIXED VERSION

This notebook implements a **Vision-Language Model for Question Answering** using:
1. Pretrained aligned vision encoder (CLIP + projections + MRL) - FROZEN
2. **Pretrained Qwen2.5-7B-Instruct decoder with LoRA** - TRAINABLE
3. PixMo QA dataset with question-answer pairs
4. Optional TRM latent recursion on top of Qwen

## Key Fixes from Original:

‚úÖ **Pretrained Decoder**: Uses Qwen2.5-7B instead of random 34M TRM

‚úÖ **Proper Training**: Standard autoregressive loss with prefix_embeds

‚úÖ **Baseline Mode**: Can disable TRM recursion for comparison

‚úÖ **Debug Tools**: Text-only sanity check, parameter audit, first-batch logging

## Architecture:

```
Image (B, 3, 336, 336)
  ‚Üì
Aligned Vision Encoder [FROZEN]
  ‚Üì (B, 577, 4096)
Vision Projection (4096 ‚Üí d_qwen) [TRAINABLE]
  ‚Üì (B, 577, d_qwen)
Qwen2.5-7B Decoder + LoRA [TRAINABLE]
  ‚Üì
Token Layout: [IMG_TOKENS] [QUESTION] [ANSWER]
Loss: Only on answer tokens (vision/question masked with -100)
```

Optional: TRM latent recursion wrapper on top of Qwen

## 1. Setup and Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
from pathlib import Path

# Add src to path
src_path = Path.cwd().parent / "src"
sys.path.insert(0, str(Path.cwd().parent))
sys.path.insert(0, str(src_path))
print(f"Added to path: {src_path}")

In [None]:
# Standard libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import wandb
from typing import Optional, Dict, List, Tuple
from collections import Counter
import string
import warnings
warnings.filterwarnings('ignore')

In [None]:
# Modular imports from edge_glass_modular
from config import load_config
from encoders.vision import VisionEncoder
from decoders.qwen import QwenDecoder
from decoders.trm import TRMConfig, TRMLayer, RMSNorm
from data.dataset_builder import PixmoQADataset
from data.transforms import get_image_transforms
from models.alignment import MultimodalAlignmentModel


In [None]:
# Set matplotlib style
%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

# Device info
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Load Configuration

In [None]:
# Load experiment config
config_path = "../configs/trm_vlm_qa.yaml"
config = load_config(config_path)

print(f"Loaded config: {config.name}")
print(f"\nDataset:")
print(f"  Train: {config.dataset.train_parquet}")
print(f"  Val: {config.dataset.val_parquet}")
print(f"  Image size: {config.dataset.image_size}")
print(f"  Batch size: {config.dataset.batch_size}")

print(f"\nDecoder:")
print(f"  Model: {config.decoder.model_name}")
print(f"  Use LoRA: {config.decoder.use_lora}")
print(f"  Load in 8bit: {config.decoder.load_in_8bit}")

# Resolve checkpoint root (works from repo root or notebooks directory)
CKPT_ROOT_CANDIDATES = [
    Path.cwd() / 'checkpoints',
    Path.cwd().parent / 'checkpoints',
    Path.cwd() / 'edge_glass_modular/notebooks/checkpoints',
    Path.cwd().parent / 'edge_glass_modular/notebooks/checkpoints',
]
CKPT_ROOT = next((p for p in CKPT_ROOT_CANDIDATES if p.exists()), None)
if CKPT_ROOT is None:
    raise FileNotFoundError('No checkpoint directory found; expected one of: ' + ', '.join(str(p) for p in CKPT_ROOT_CANDIDATES))
print(f"Checkpoint root: {CKPT_ROOT}")


## 3. Load Pretrained Aligned Vision Encoder (FROZEN)

In [None]:
# Load alignment config
alignment_config_path = "../configs/pixmo_alignment.yaml"
alignment_config = load_config(alignment_config_path)

# Build aligned model
print("Loading aligned vision encoder...")
aligned_model = MultimodalAlignmentModel(alignment_config).to('cuda:1')

# Load checkpoint
checkpoint_path = CKPT_ROOT / 'pixmo_alignment/checkpoint_best.pt'
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
aligned_model.load_state_dict(checkpoint['model_state_dict'])
aligned_model.eval()

# FREEZE all parameters
for param in aligned_model.parameters():
    param.requires_grad = False

print(f"‚úì Loaded aligned model from {checkpoint_path}")
print(f"  Checkpoint epoch: {checkpoint.get('epoch', 'N/A')}")
print(f"  Val loss: {checkpoint.get('best_val_loss', 0.0):.4f}")

# Get vision encoder output dimension
vision_token_dim = alignment_config.vision_encoder.projection_dim
print(f"  Vision output: (B, num_tokens, {vision_token_dim})")

## 4. Vision Encoding Helper Function

In [None]:
@torch.no_grad()
def encode_images(images: torch.Tensor) -> torch.Tensor:
    """Encode images to vision tokens using frozen aligned encoder.
    
    Args:
        images: (B, 3, H, W)
    
    Returns:
        vision_tokens: (B, num_tokens, vision_token_dim)
    """
    # Ensure images are on same device as model
    device = next(aligned_model.parameters()).device
    images = images.to(device)
    vision_output = aligned_model.vision_encoder(images, return_sequence=True)
    if vision_output.sequence is None:
        raise ValueError("Vision encoder did not return sequence embeddings")
    return vision_output.sequence

# Test
test_img = torch.randn(2, 3, 336, 336).to(device)
test_vision_tokens = encode_images(test_img)
print(f"Vision tokens shape: {test_vision_tokens.shape}")
print(f"Expected: (2, num_tokens, {vision_token_dim})")

## 5. QwenVLM Model Class

**Main VLM wrapper with optional TRM recursion**

In [None]:
# Replaced inline class with imported refactored class
from models.trm_qwen_vlm import QwenVLM

print("‚úì QwenVLM class imported from src.models.trm_qwen_vlm")

## 6. Initialize Pretrained Qwen Decoder

In [None]:
# ========== CONFIGURATION ==========
USE_TRM_RECURSION = False  # Start with baseline, then try True
NUM_TRM_LAYERS = 4         # Only used if TRM recursion enabled
NUM_RECURSION_STEPS = 4
CONFIDENCE_THRESHOLD = 0.75
# ====================================

print("="*60)
print("INITIALIZING QWEN DECODER")
print("="*60)

# Load Qwen decoder
print(f"\nLoading: {config.decoder.model_name}")
print(f"  LoRA: {config.decoder.use_lora}")
print(f"  8-bit: {config.decoder.load_in_8bit}")

qwen_decoder = QwenDecoder(
    model_name=config.decoder.model_name,
    load_in_8bit=config.decoder.load_in_8bit,
    load_in_4bit=False,
    use_lora=config.decoder.use_lora,
    lora_r=config.decoder.get('lora_r', 32),
    lora_alpha=config.decoder.get('lora_alpha', 64),
    lora_dropout=config.decoder.get('lora_dropout', 0.1),
    device_map="balanced",
)

print(f"\n‚úì Qwen decoder loaded")
print(f"  Hidden dim: {qwen_decoder.hidden_dim}")
print(f"  Vocab size: {qwen_decoder.vocab_size}")

# Create QwenVLM wrapper
print(f"\nCreating QwenVLM wrapper")
print(f"  Vision token dim: {vision_token_dim}")
print(f"  Use TRM recursion: {USE_TRM_RECURSION}")

model = QwenVLM(
    qwen_decoder=qwen_decoder,
    vision_token_dim=vision_token_dim,
    use_trm_recursion=USE_TRM_RECURSION,
    num_trm_layers=NUM_TRM_LAYERS,
    num_recursion_steps=NUM_RECURSION_STEPS,
    confidence_threshold=CONFIDENCE_THRESHOLD,
)

print(f"\n‚úì QwenVLM model created")
print("="*60)

In [None]:
# 5. Dataset and Data Loader
print("\n" + "="*60)
print("INITIALIZING DATASETS")
print("="*60)

# Create datasets
# Define transforms globally
train_transforms = get_image_transforms(config.dataset.image_size, is_training=True)
val_transforms = get_image_transforms(config.dataset.image_size, is_training=False)

train_dataset = PixmoQADataset(
    parquet_path=config.dataset.train_parquet,
    tokenizer=qwen_decoder.tokenizer,
    image_transforms=get_image_transforms(config.dataset.image_size, is_training=True),
    max_question_length=128,
    max_answer_length=256,
)

val_dataset = PixmoQADataset(
    parquet_path=config.dataset.val_parquet,
    tokenizer=qwen_decoder.tokenizer,
    image_transforms=get_image_transforms(config.dataset.image_size, is_training=False),
    max_question_length=128,
    max_answer_length=256,
)

print(f"\nTrain dataset: {len(train_dataset)} samples")
print(f"Val dataset: {len(val_dataset)} samples")

# Collate function
def collate_fn(batch):
    from torch.nn.utils.rnn import pad_sequence
    pad_idx = qwen_decoder.tokenizer.pad_token_id
    
    images = torch.stack([b['image'] for b in batch])
    
    q_padded = pad_sequence([b['question_ids'] for b in batch], batch_first=True, padding_value=pad_idx)
    a_padded = pad_sequence([b['answer_ids'] for b in batch], batch_first=True, padding_value=pad_idx)
    q_mask = pad_sequence([b['question_mask'] for b in batch], batch_first=True, padding_value=0)
    a_mask = pad_sequence([b['answer_mask'] for b in batch], batch_first=True, padding_value=0)
    
    return {
        'images': images,
        'question_ids': q_padded,
        'answer_ids': a_padded,
        'question_mask': q_mask,
        'answer_mask': a_mask,
        'answers': [b['answer'] for b in batch],
    }

# Data Loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.dataset.batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.dataset.batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn
)

print(f"Train loader: {len(train_loader)} batches")


## 7. Text-Only Sanity Check (BEFORE Training)

**Critical test**: Verify pretrained Qwen can generate coherent English

In [None]:
@torch.no_grad()
def text_only_sanity_check(decoder, prompts, max_tokens=20):
    """Test decoder on text-only prompts without vision."""
    print("\n" + "="*60)
    print("TEXT-ONLY SANITY CHECK (Before Training)")
    print("="*60)
    
    for prompt in prompts:
        # Encode
        inputs = decoder.tokenizer(prompt, return_tensors='pt').to(device)
        
        # Generate
        outputs = decoder.generate(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_new_tokens=max_tokens,
            temperature=0.7,
            do_sample=True,
        )
        
        # Decode
        generated = decoder.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        print(f"\nPrompt: {prompt}")
        print(f"Generated: {generated}")
        print("-" * 60)
    
    print("\n‚úì If coherent English ‚Üí Decoder works!")
    print("‚úó If garbage ‚Üí Decoder loading issue")
    print("="*60)

# Test prompts
test_prompts = [
    "Question: What is 2 + 2? Answer:",
    "Question: What color is the sky? Answer:",
    "The capital of France is",
]

text_only_sanity_check(qwen_decoder, test_prompts)

## 8. Parameter Audit

Verify which parameters are frozen vs trainable

In [None]:
print("="*60)
print("PARAMETER AUDIT")
print("="*60)

trainable_params = []
frozen_params = []

# Audit QwenVLM model
for name, param in model.named_parameters():
    if param.requires_grad:
        trainable_params.append((name, param.numel()))
    else:
        frozen_params.append((name, param.numel()))

# Audit aligned model
for name, param in aligned_model.named_parameters():
    full_name = f"aligned_model.{name}"
    if param.requires_grad:
        trainable_params.append((full_name, param.numel()))
    else:
        frozen_params.append((full_name, param.numel()))

print(f"\nüü¢ TRAINABLE ({len(trainable_params)} groups):")
total_trainable = 0
for name, count in trainable_params[:20]:
    print(f"  {name}: {count:,}")
    total_trainable += count

if len(trainable_params) > 20:
    print(f"  ... and {len(trainable_params) - 20} more")
    for _, count in trainable_params[20:]:
        total_trainable += count

print(f"\nüî¥ FROZEN ({len(frozen_params)} groups):")
total_frozen = 0
for name, count in frozen_params[:10]:
    print(f"  {name}: {count:,}")
    total_frozen += count

for _, count in frozen_params[10:]:
    total_frozen += count

if len(frozen_params) > 10:
    print(f"  ... and {len(frozen_params) - 10} more")

total = total_trainable + total_frozen
print(f"\nüìä SUMMARY:")
print(f"  Total: {total:,}")
print(f"  Trainable: {total_trainable:,} ({100*total_trainable/total:.2f}%)")
print(f"  Frozen: {total_frozen:,} ({100*total_frozen/total:.2f}%)")

print(f"\n‚úì VERIFICATION:")
print(f"  Aligned model frozen: {all(not p.requires_grad for p in aligned_model.parameters())}")
print(f"  Vision proj trainable: {'vision_proj' in str([n for n, _ in trainable_params])}")

if USE_TRM_RECURSION:
    print(f"  TRM layers trainable: {'trm_layers' in str(trainable_params)}")
    print(f"  z_init trainable: {'z_init' in str([n for n, _ in trainable_params])}")
else:
    print(f"  Qwen LoRA trainable: {'lora' in str(trainable_params).lower()}")

print("="*60)

## 9. Setup Dataset and Dataloaders

In [None]:
def collate_qa_batch(batch):
    images = [item["image"] for item in batch]
    questions = [item["question_ids"] for item in batch]
    answers = [item["answer_ids"] for item in batch]
    
    # Process images
    # Stack if already tensors (from dataset transform)
    if isinstance(images[0], torch.Tensor):
        images = torch.stack(images)
    
    # Pad sequences
    # We need to pad questions (left) and answers (right)
    # Get max lengths
    max_q_len = max([q.size(0) for q in questions])
    max_a_len = max([a.size(0) for a in answers])
    
    # Create padded tensors
    bs = len(batch)
    padded_questions = torch.full((bs, max_q_len), tokenizer.pad_token_id, dtype=torch.long)
    padded_answers = torch.full((bs, max_a_len), tokenizer.pad_token_id, dtype=torch.long)
    
    # Create answer mask (1 for valid, 0 for pad)
    answer_mask = torch.zeros((bs, max_a_len), dtype=torch.long)
    
    for i in range(bs):
        # Left pad question? Or right? Usually left for generation, but here we are training.
        # Right padding is standard for training with attention masks.
        q_len = questions[i].size(0)
        padded_questions[i, :q_len] = questions[i]
        
        a_len = answers[i].size(0)
        padded_answers[i, :a_len] = answers[i]
        
        # Set mask for valid answer tokens
        answer_mask[i, :a_len] = 1
        
    return {
        "images": images,
        "question_ids": padded_questions,
        "answer_ids": padded_answers,
        "answer_mask": answer_mask # Return mask
    }

## 10. Evaluation Metrics

In [None]:
def normalize_answer(s: str) -> str:
    """Normalize answer for evaluation."""
    s = ''.join(ch for ch in s if ch not in string.punctuation)
    s = s.lower().strip()
    s = ' '.join([w for w in s.split() if w not in {'a', 'an', 'the'}])
    return s

def compute_exact_match(pred: str, target: str) -> float:
    """Compute exact match score."""
    return float(normalize_answer(pred) == normalize_answer(target))

def compute_f1(pred: str, target: str) -> float:
    """Compute token-level F1."""
    pred_tokens = normalize_answer(pred).split()
    target_tokens = normalize_answer(target).split()
    
    if len(pred_tokens) == 0 or len(target_tokens) == 0:
        return float(pred_tokens == target_tokens)
    
    common = Counter(pred_tokens) & Counter(target_tokens)
    num_common = sum(common.values())
    
    if num_common == 0:
        return 0.0
    
    precision = num_common / len(pred_tokens)
    recall = num_common / len(target_tokens)
    f1 = 2 * precision * recall / (precision + recall)
    
    return f1

def evaluate_qa(predictions: List[str], targets: List[str]) -> Dict[str, float]:
    """Evaluate QA predictions."""
    em_scores = [compute_exact_match(p, t) for p, t in zip(predictions, targets)]
    f1_scores = [compute_f1(p, t) for p, t in zip(predictions, targets)]
    
    return {'em': np.mean(em_scores) * 100, 'f1': np.mean(f1_scores) * 100 ,}

print("‚úì Evaluation metrics defined")

## 11. Debug Helper Functions

In [None]:
def debug_first_batch(batch, vision_tokens, outputs, tokenizer):
    """Debug first training batch."""
    print("\n" + "="*60)
    print("FIRST BATCH DEBUG")
    print("="*60)
    
    print(f"\nüì¶ Shapes:")
    print(f"  Images: {batch['images'].shape}")
    print(f"  Vision tokens: {vision_tokens.shape}")
    print(f"  Question IDs: {batch['question_ids'].shape}")
    print(f"  Answer IDs: {batch['answer_ids'].shape}")
    
    print(f"\nüìù First example:")
    print(f"  Q: {batch['questions'][0][:100]}...")
    print(f"  A: {batch['answers'][0][:100]}...")
    
    print(f"\nüî§ Decoded tokens:")
    q_dec = tokenizer.decode(batch['question_ids'][0], skip_special_tokens=False)
    a_dec = tokenizer.decode(batch['answer_ids'][0], skip_special_tokens=False)
    print(f"  Question: {q_dec[:150]}...")
    print(f"  Answer: {a_dec[:100]}...")
    
    print(f"\nüìä Outputs:")
    print(f"  Loss: {outputs['loss'].item():.4f}")
    if outputs.get('logits') is not None:
        print(f"  Logits: {outputs['logits'].shape}")
    if outputs.get('confidence') is not None and outputs['confidence'] is not None:
        print(f"  Confidence: {outputs['confidence'].mean().item():.3f}")
    
    # Token counts
    num_q = (batch['question_ids'][0] != tokenizer.pad_token_id).sum().item()
    num_a = (batch['answer_ids'][0] != tokenizer.pad_token_id).sum().item()
    num_img = vision_tokens.shape[1]
    
    print(f"\nüìè Token counts:")
    print(f"  Vision: {num_img}")
    print(f"  Question (non-pad): {num_q}")
    print(f"  Answer (non-pad): {num_a}")
    print(f"  Supervised: {num_a} ({100*num_a/(num_img+num_q+num_a):.1f}%)")
    
    print(f"\n‚úì Expected:")
    print(f"  - Loss < 10 (pretrained range)")
    print(f"  - Supervised = answer only")
    print("="*60 + "\n")

print("‚úì Debug helpers defined")

## 12. Training Setup

In [None]:
# Training config
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.05
MAX_GRAD_NORM = 1.0
LOG_EVERY = 20

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    betas=(0.9, 0.95),
    weight_decay=WEIGHT_DECAY,
)

# Scheduler
total_steps = len(train_loader) * NUM_EPOCHS
warmup_steps = int(total_steps * WARMUP_RATIO)

def get_lr_scheduler(optimizer, warmup_steps, total_steps):
    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return 0.1 + 0.9 * 0.5 * (1 + np.cos(np.pi * progress))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

scheduler = get_lr_scheduler(optimizer, warmup_steps, total_steps)

print(f"Training config:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  LR: {LEARNING_RATE}")
print(f"  Total steps: {total_steps}")
print(f"  Warmup steps: {warmup_steps}")

## 13. Training Loop

In [None]:
# Initialize wandb
USE_WANDB = True
if USE_WANDB:
    run_name = f"qwen_vlm_{'trm' if USE_TRM_RECURSION else 'baseline'}"
    wandb.init(
        project="edge_glass_qwen_vlm",
        name=run_name,
        config={
            'use_trm_recursion': USE_TRM_RECURSION,
            'num_trm_layers': NUM_TRM_LAYERS if USE_TRM_RECURSION else 0,
            'num_recursion_steps': NUM_RECURSION_STEPS if USE_TRM_RECURSION else 0,
            'learning_rate': LEARNING_RATE,
            'batch_size': config.dataset.batch_size,
            'decoder': config.decoder.model_name,
        }
    )

# Training state
global_step = 0
best_val_loss = float('inf')
history = {'train_loss': [], 'val_loss': []}
first_batch_debugged = False

# Checkpoint dir
ckpt_dir = CKPT_ROOT / f"qwen_vlm_qa_{'trm' if USE_TRM_RECURSION else 'baseline'}"
ckpt_dir.mkdir(parents=True, exist_ok=True)

In [None]:
# 4. Training Loop (Uncommented and Fixed)
from tqdm.auto import tqdm
import torch

# Initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
steps = 0
max_steps = 100 # Short run for verification

model.train()
progress_bar = tqdm(range(max_steps))
best_loss = float('inf')

print(f"Training started for {max_steps} steps...")

for batch in train_loader:
    images = batch["images"].to(device)
    with torch.no_grad():
        vision_tokens = encode_images(images)
    question_ids = batch["question_ids"].to(device)
    answer_ids = batch["answer_ids"].to(device)
    answer_mask = batch["answer_mask"].to(device) # Get answer mask
    
    # Forward pass
    outputs = model(
        vision_tokens=vision_tokens, 
        question_ids=question_ids, 
        answer_ids=answer_ids,
        answer_mask=answer_mask # Pass answer mask
    )
    loss = outputs.loss
    
    # Backward
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    # Checkpointing Logic (Save Best and Last)
    current_loss = loss.item()
    
    # Save best
    if current_loss < best_loss:
        best_loss = current_loss
        torch.save({
            'step': steps,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_loss,
        }, ckpt_dir / 'checkpoint_best.pt')
        
    # Save last periodically (every 50 steps)
    if steps > 0 and steps % 50 == 0:
         torch.save({
            'step': steps,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': current_loss,
        }, ckpt_dir / 'checkpoint_last.pt')
    
    progress_bar.set_description(f"Loss: {current_loss:.4f} | Best: {best_loss:.4f}")
    progress_bar.update(1)
    
    steps += 1
    if steps >= max_steps:
        break
        
# Final save
torch.save({
    'step': steps,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': current_loss,
}, ckpt_dir / 'checkpoint_last.pt')
print(f"Training complete. Best loss: {best_loss:.4f}")


## 14. Full Evaluation

In [None]:
# Load best checkpoint
try:
    ckpt_path = ckpt_dir / 'checkpoint_best.pt'
    if ckpt_path.exists():
        print(f"Loading best checkpoint from {ckpt_path}")
        best_ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
        model.load_state_dict(best_ckpt['model_state_dict'])
        print("‚úì Loaded best checkpoint")
    else:
        print(f"‚ö† Checkpoint not found at {ckpt_path}. Using current model state.")
except RuntimeError as e:
    print(f"‚ö† Error loading checkpoint: {e}")
    print("This is likely due to a size mismatch from an old checkpoint.")
    print("Continuing with current model state...")

model.eval()


In [None]:
# Complex VLM question probe on the curve image
from pathlib import Path
from PIL import Image
import torchvision.transforms as T
import torch

model.eval()
probe_image_path = Path("/home/hice1/vchopra37/scratch/projects/edge_glass/complex_curve.png")
probe_question = (
    "Provide a short narrative interpretation of the plotted curve: describe its overall trend, where the slope changes,"    " and what that implies about acceleration, saturation, or decay in the underlying relationship."
)
display(probe_image_path)

In [None]:
probe_image = Image.open(probe_image_path).convert("RGB")
probe_tensor = val_transforms(T.ToTensor()(probe_image)).unsqueeze(0).to(device)
probe_tokens = encode_images(probe_tensor)

probe_question_ids = qwen_decoder.tokenizer(
    probe_question, return_tensors='pt', add_special_tokens=True
).input_ids.to(device)

probe_gen = model.generate(
    vision_tokens=probe_tokens,
    question_ids=probe_question_ids,
    max_new_tokens=96,
    temperature=0.7,
    use_confidence=True,
    return_stats=True,
)

if isinstance(probe_gen, torch.Tensor):
    gen_ids = probe_gen[0].detach().cpu().tolist()
    confidence_trace = None
    recursion_steps = None
else:
    gen_ids = probe_gen["predictions"][0].detach().cpu().tolist()
    confidence_trace = probe_gen.get("confidences")
    if confidence_trace is not None:
        confidence_trace = probe_gen["confidences"][0].detach().cpu().tolist()
    recursion_steps = probe_gen.get("recursion_steps")
    if recursion_steps is not None:
        recursion_steps = probe_gen["recursion_steps"][0].detach().cpu().tolist()

probe_answer = qwen_decoder.tokenizer.decode(
    gen_ids,
    skip_special_tokens=True,
).strip()

print("Question:", probe_question)
print("Answer:", probe_answer)
if confidence_trace is not None:
    print("Confidence trace:", [round(float(c), 3) for c in confidence_trace])
if recursion_steps is not None:
    print("Recursion steps per token:", recursion_steps)


In [None]:
print("\n" + "="*60)
print("FULL EVALUATION")
print("="*60)
print(f"Mode: {'TRM' if USE_TRM_RECURSION else 'Baseline'}")
print("="*60)

all_predictions = []
all_targets = []
all_confidences = []
all_recursion_steps = []

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Generating"):
        images = batch['images'].to(device)
        question_ids = batch['question_ids'].to(device)
        answers = batch['answers']
        
        vision_tokens = encode_images(images)
        
        gen_outputs = model.generate(
            vision_tokens,
            question_ids,
            max_new_tokens=32,
            temperature=0.0,
            return_stats=True,
        )
        
        if isinstance(gen_outputs, dict):
            generated_ids = gen_outputs['predictions']
            if gen_outputs.get('confidences') is not None:
                all_confidences.extend(gen_outputs['confidences'].mean(dim=1).cpu().tolist())
            if gen_outputs.get('recursion_steps') is not None:
                all_recursion_steps.extend(gen_outputs['recursion_steps'].mean(dim=1).cpu().tolist())
        else:
            generated_ids = gen_outputs
        
        predictions = qwen_decoder.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        
        all_predictions.extend(predictions)
        all_targets.extend(answers)

# Metrics
metrics = evaluate_qa(all_predictions, all_targets)

print("\n" + "="*60)
print("RESULTS")
print("="*60)
print(f"EM: {metrics['em']:.2f}%")
print(f"F1: {metrics['f1']:.2f}%")

if all_confidences:
    print(f"\nConfidence: {np.mean(all_confidences):.3f} ¬± {np.std(all_confidences):.3f}")

if all_recursion_steps:
    print(f"\nRecursion: {np.mean(all_recursion_steps):.2f} steps/seq")

print("="*60)

if USE_WANDB:
    wandb.log({'val/em_final': metrics['em'], 'val/f1_final': metrics['f1']})

# Sample predictions
print("\n" + "="*60)
print("SAMPLE PREDICTIONS")
print("="*60)

for i in range(min(10, len(all_predictions))):
    print(f"\n[{i+1}]")
    print(f"  Q: {val_dataset.df.iloc[i]['question'][:80]}...")
    print(f"  Target: {all_targets[i][:80]}...")
    print(f"  Predicted: {all_predictions[i][:80]}...")
    print(f"  EM: {compute_exact_match(all_predictions[i], all_targets[i])}")
    print(f"  F1: {compute_f1(all_predictions[i], all_targets[i]):.3f}")

print("\n" + "="*60)

## 15. Summary

In [None]:
print("\n" + "="*60)
print("SUMMARY")
print("="*60)

print(f"\nüîß Config:")
print(f"  Decoder: {config.decoder.model_name}")
print(f"  LoRA: {config.decoder.use_lora}")
print(f"  TRM Recursion: {USE_TRM_RECURSION}")

print(f"\nüìä Training:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Best val loss: {best_val_loss:.4f}")

print(f"\nüìà Eval:")
print(f"  EM: {metrics['em']:.2f}%")
print(f"  F1: {metrics['f1']:.2f}%")

print(f"\nüì¶ Dataset:")
print(f"  Train: {len(train_dataset):,}")
print(f"  Val: {len(val_dataset):,}")

print(f"\nüíæ Output:")
print(f"  {ckpt_dir / 'checkpoint_best.pt'}")

print("\n" + "="*60)
print("KEY IMPROVEMENTS")
print("="*60)
print("‚úÖ Pretrained Qwen decoder (vs random 34M TRM)")
print("‚úÖ Proper autoregressive training (prefix_embeds + -100 masking)")
print("‚úÖ Baseline mode for comparison")
print("‚úÖ Debug instrumentation (sanity check, param audit, first-batch)")

print("\n" + "="*60)
print("NEXT STEPS")
print("="*60)
print("1. If EM/F1 still low:")
print("   - Check text-only generation")
print("   - Check first-batch debug")
print("   - Increase epochs/LR")
print("\n2. If baseline works:")
print("   - Set USE_TRM_RECURSION = True")
print("   - Compare vs baseline")
print("\n3. Expected baseline: EM > 15%, F1 > 25%")
print("="*60)

print("\n‚úì Complete!")