# VLM Training with PixMo QA Dataset

This notebook implements a **Vision-Language Model for Question Answering** using:
1. Modular vision encoder (CLIP + Perceiver + MRL) from `edge_glass_modular/src/encoders`
2. Qwen decoder with LoRA from `edge_glass_modular/src/decoders`
3. PixMo QA dataset with question-answer pairs
4. Proper modular design following the edge_glass_modular architecture

## Architecture:

```
Image (B, 3, 336, 336)
  ↓
Vision Encoder (frozen aligned model)
  ↓ (B, num_latents, hidden_dim)
Projection to Qwen hidden dim
  ↓ (B, num_latents, qwen_dim)
Qwen Decoder with LoRA (trainable)
  ↓
Token Layout: [IMG_TOKENS] [QUESTION_TOKENS] [ANSWER_TOKENS]
  ↓
Loss on answer tokens only
```

## Key Features:
- Modular design using imports from `edge_glass_modular/src`
- Frozen aligned vision encoder
- Qwen2.5 decoder with LoRA fine-tuning
- Real QA dataset (not synthetic)
- Proper configuration management

## 1. Setup and Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path

# Add src to path
sys.path.insert(0, str(Path.cwd().parent / "src"))
Path.cwd().parent / "src"

PosixPath('/storage/ice1/1/0/vchopra37/projects/edge_glass/edge_glass_modular/src')

In [3]:
# Import standard libraries
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
import wandb
from pathlib import Path
from typing import Optional, Dict, List
from collections import Counter
import string
import warnings

# Import modular components from edge_glass_modular
from config import load_config
from encoders.vision import VisionEncoder
from decoders.qwen import QwenDecoder
from data.dataset_builder import PixmoParquetImageTextDataset
from data.transforms import get_image_transforms
from models.alignment import MultimodalAlignmentModel

from transformers import AutoTokenizer


# Set up matplotlib
%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")



KeyboardInterrupt: 

## 2. Load Configuration

Load the experiment configuration from YAML file.

In [None]:
# Load configuration
config_path = "../configs/trm_vlm_qa.yaml"
config = load_config(config_path)

print(f"Loaded config: {config.name}")
print(f"\nDataset:")
print(f"  Train parquet: {config.dataset.train_parquet}")
print(f"  Val parquet: {config.dataset.val_parquet}")
print(f"  Test parquet: {config.dataset.test_parquet}")
print(f"  Image size: {config.dataset.image_size}")
print(f"  Max question length: {config.dataset.max_question_length}")
print(f"  Max answer length: {config.dataset.max_answer_length}")
print(f"  Batch size: {config.dataset.batch_size}")

print(f"\nDecoder:")
print(f"  Type: {config.decoder.type}")
print(f"  Model: {config.decoder.model_name}")
print(f"  Use LoRA: {config.decoder.use_lora}")
print(f"  Load in 8bit: {config.decoder.load_in_8bit}")

## 3. Load Aligned Vision Encoder

Load the pretrained Perceiver+MRL alignment model and freeze it.

In [None]:
# Load alignment config
alignment_config_path = "../configs/pixmo_alignment.yaml"
alignment_config = load_config(alignment_config_path)

# Load aligned model
aligned_model = MultimodalAlignmentModel(alignment_config).to(device)

# Load checkpoint
checkpoint_path = "checkpoints/pixmo_alignment/checkpoint_best.pt"
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
aligned_model.load_state_dict(checkpoint['model_state_dict'])
aligned_model.eval()

# Freeze all parameters
for param in aligned_model.parameters():
    param.requires_grad = False

print(f"Loaded aligned model from {checkpoint_path}")
print(f"  Epoch: {checkpoint['epoch']}")
print(f"  Val loss: {checkpoint['best_val_loss']:.4f}")
print(f"  Vision encoder output: (B, 64, 4096)")

## 4. Extract Vision Encoder Method

Create a clean interface to get vision embeddings.

In [None]:
@torch.no_grad()
def encode_images(images: torch.Tensor) -> torch.Tensor:
    """Encode images to vision tokens.
    
    Args:
        images: (B, 3, H, W)
    
    Returns:
        vision_tokens: (B, num_latents, 4096)
    """
    vision_output = aligned_model.vision_encoder(images, return_sequence=True)
    if vision_output.sequence is None:
        raise ValueError("Vision encoder did not return sequence embeddings")
    # Get sequence output (B, num_latents, dim)
    return vision_output.sequence

# Test
test_img = torch.randn(2, 3, 336, 336).to(device)
test_vision_tokens = encode_images(test_img)
print(f"Vision tokens shape: {test_vision_tokens.shape}")
if getattr(aligned_model.vision_encoder, "use_perceiver", False):
    expected_tokens = aligned_model.config.vision_encoder.perceiver_num_latents
else:
    expected_tokens = test_vision_tokens.shape[1]
print(f"Expected: ({test_img.shape[0]}, {expected_tokens}, {aligned_model.config.vision_encoder.projection_dim})")

## 5. Implement Plain Tiny Decoder Baseline

First, implement a simple baseline decoder without TRM recursion.

In [None]:
from decoders.trm import TRMConfig, TRMDecoder

class TinyVLMDecoder(nn.Module):
    """Plain tiny decoder baseline for VLM.
    
    Architecture:
        - Projects vision tokens from 4096 -> d_dec
        - Token layout: [IMG_TOKENS] [QUESTION_TOKENS] [ANSWER_TOKENS]
        - Causal masking
        - Loss only on answer tokens
    """
    
    def __init__(
        self,
        vocab_size: int,
        hidden_dim: int = 512,
        num_layers: int = 4,
        num_heads: int = 8,
        vision_token_dim: int = 4096,
        max_seq_len: int = 256,
    ):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.vision_token_dim = vision_token_dim
        
        # Project vision tokens to decoder dim
        self.vision_proj = nn.Linear(vision_token_dim, hidden_dim)
        
        # TRM decoder
        trm_config = TRMConfig(
            vocab_size=vocab_size,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            num_heads=num_heads,
            max_seq_len=max_seq_len,
        )
        self.decoder = TRMDecoder(trm_config)
    
    def forward(
        self,
        vision_tokens: torch.Tensor,  # (B, K_img, 4096)
        question_ids: torch.Tensor,   # (B, L_q)
        answer_ids: torch.Tensor,     # (B, L_a)
    ):
        """Forward pass with proper token layout and loss masking.
        
        Token layout: [IMG_TOKENS] [QUESTION_TOKENS] [ANSWER_TOKENS]
        Loss only on answer tokens.
        """
        batch_size = vision_tokens.shape[0]
        num_img_tokens = vision_tokens.shape[1]
        
        # Project vision tokens
        vision_emb = self.vision_proj(vision_tokens)  # (B, K_img, d_dec)
        
        # Embed question and answer tokens
        question_emb = self.decoder.embed_tokens(question_ids)  # (B, L_q, d_dec)
        answer_emb = self.decoder.embed_tokens(answer_ids)      # (B, L_a, d_dec)
        
        # Concatenate: [vision | question | answer]
        full_sequence = torch.cat([vision_emb, question_emb, answer_emb], dim=1)
        
        # Create labels: -100 for image and question tokens, actual IDs for answer
        img_labels = torch.full(
            (batch_size, num_img_tokens),
            fill_value=-100,
            dtype=torch.long,
            device=vision_tokens.device
        )
        question_labels = torch.full_like(question_ids, fill_value=-100)
        answer_labels = answer_ids
        
        # Concatenate labels
        full_labels = torch.cat([img_labels, question_labels, answer_labels], dim=1)
        
        # Pass through decoder layers
        hidden_states = full_sequence
        for layer in self.decoder.layers:
            hidden_states = layer(hidden_states)
        
        hidden_states = self.decoder.norm(hidden_states)
        logits = self.decoder.lm_head(hidden_states)
        
        # Compute loss (shift for next-token prediction)
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = full_labels[:, 1:].contiguous()
        
        loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        
        return {
            'loss': loss,
            'logits': logits,
        }
    
    @torch.no_grad()
    def generate(
        self,
        vision_tokens: torch.Tensor,
        question_ids: torch.Tensor,
        max_new_tokens: int = 32,
        temperature: float = 0.7,
    ):
        """Generate answer tokens autoregressively."""
        batch_size = vision_tokens.shape[0]
        
        # Project vision
        vision_emb = self.vision_proj(vision_tokens)
        question_emb = self.decoder.embed_tokens(question_ids)
        
        # Start with image + question
        current_emb = torch.cat([vision_emb, question_emb], dim=1)
        generated_ids = []
        
        for _ in range(max_new_tokens):
            # Forward pass
            hidden = current_emb
            for layer in self.decoder.layers:
                hidden = layer(hidden)
            hidden = self.decoder.norm(hidden)
            logits = self.decoder.lm_head(hidden)
            
            # Sample next token
            next_token_logits = logits[:, -1, :] / temperature
            next_token = torch.argmax(next_token_logits, dim=-1)
            
            generated_ids.append(next_token)
            
            # Embed and append
            next_emb = self.decoder.embed_tokens(next_token.unsqueeze(1))
            current_emb = torch.cat([current_emb, next_emb], dim=1)
        
        return torch.stack(generated_ids, dim=1)

## 6. Implement TRM-Style Recursive Decoder

Now implement the TRM version with latent recursion.

In [None]:
class TRMVLMWithConfidence(nn.Module):
    """TRM VLM decoder with confidence-based recursive refinement.
    
    Key features:
    - Uses aligned vision encoder (frozen)
    - Projects vision tokens to TRM hidden dim
    - Implements latent recursion for reasoning
    - Confidence-based early stopping
    """
    
    def __init__(
        self,
        vocab_size: int,
        vision_token_dim: int = 4096,
        hidden_dim: int = 512,
        num_layers: int = 2,
        num_heads: int = 8,
        num_inner_steps: int = 4,
        confidence_threshold: float = 0.8,
    ):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.num_inner_steps = num_inner_steps
        self.confidence_threshold = confidence_threshold
        
        # Project vision tokens to decoder dim
        self.vision_proj = nn.Linear(vision_token_dim, hidden_dim)
        
        # Token embeddings
        self.embed_tokens = nn.Embedding(vocab_size, hidden_dim)
        
        # Tiny transformer for recursion (use TRM components)
        from decoders.trm import TRMConfig, TRMLayer
        
        trm_config = TRMConfig(
            vocab_size=vocab_size,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            num_heads=num_heads,
            max_seq_len=1024,
        )
        
        self.tiny_transformer = nn.ModuleList([
            TRMLayer(trm_config) for _ in range(num_layers)
        ])
        
        from decoders.trm import RMSNorm
        self.norm = RMSNorm(hidden_dim)
        
        # LM head
        self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
        self.lm_head.weight = self.embed_tokens.weight  # Tie weights
        
        # Learned initial z state
        self.z_init = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
    
    def latent_recursion(
        self,
        x: torch.Tensor,  # Context: (B, L_ctx, d)
        y: torch.Tensor,  # Answer: (B, L_ans, d)
        z: torch.Tensor,  # Latent: (B, L_ans, d)
    ):
        """Single step of latent recursion."""
        # Concatenate along sequence: [x, y, z]
        concat = torch.cat([x, y, z], dim=1)  # (B, L_ctx + 2*L_ans, d)
        
        # Pass through tiny transformer
        hidden = concat
        for layer in self.tiny_transformer:
            hidden = layer(hidden)
        
        # Split back
        L_ctx = x.shape[1]
        L_ans = y.shape[1]
        
        x_out = hidden[:, :L_ctx, :]
        y_out = hidden[:, L_ctx:L_ctx+L_ans, :]
        z_out = hidden[:, L_ctx+L_ans:, :]
        
        return x_out, y_out, z_out
    
    def compute_confidence(self, logits: torch.Tensor) -> torch.Tensor:
        """Compute confidence score from logits.
        
        Args:
            logits: (B, L, vocab_size)
        
        Returns:
            confidence: (B,) - mean max softmax probability
        """
        probs = torch.softmax(logits, dim=-1)
        max_probs = torch.max(probs, dim=-1)[0]  # (B, L)
        confidence = torch.mean(max_probs, dim=-1)  # (B,)
        return confidence
    
    def forward(
        self,
        vision_tokens: torch.Tensor,  # (B, K_img, 4096)
        question_ids: torch.Tensor,   # (B, L_q)
        answer_ids: torch.Tensor,     # (B, L_a)
        num_recursion_steps: Optional[int] = None,
    ):
        """Forward pass with TRM recursion."""
        batch_size = vision_tokens.shape[0]
        L_ans = answer_ids.shape[0]
        
        # Use default or override recursion steps
        n_steps = num_recursion_steps if num_recursion_steps is not None else self.num_inner_steps
        
        # Project vision tokens
        vision_emb = self.vision_proj(vision_tokens)  # (B, K_img, d)
        
        # Embed question
        question_emb = self.embed_tokens(question_ids)  # (B, L_q, d)
        
        # Context x = [vision | question]
        x = torch.cat([vision_emb, question_emb], dim=1)  # (B, L_ctx, d)
        
        # Teacher-forced answer embeddings
        y = self.embed_tokens(answer_ids)  # (B, L_ans, d)
        
        # Initialize latent z
        z = self.z_init.expand(batch_size, answer_ids.shape[1], -1)  # (B, L_ans, d)
        
        # Inner recursion (n steps)
        for _ in range(n_steps):
            x, y, z = self.latent_recursion(x, y, z)
        
        # Final answer from y
        y = self.norm(y)
        logits = self.lm_head(y)  # (B, L_ans, vocab_size)
        
        # Compute loss (standard next-token prediction on answer)
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = answer_ids[:, 1:].contiguous()
        
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        
        # Compute confidence
        confidence = self.compute_confidence(logits)
        
        return {
            'loss': loss,
            'logits': logits,
            'confidence': confidence,
        }
    
    @torch.no_grad()
    def generate(
        self,
        vision_tokens: torch.Tensor,
        question_ids: torch.Tensor,
        max_new_tokens: int = 32,
        temperature: float = 0.7,
        use_confidence: bool = True,
    ):
        """Generate answer with confidence-based recursion."""
        batch_size = vision_tokens.shape[0]
        
        # Context
        vision_emb = self.vision_proj(vision_tokens)
        question_emb = self.embed_tokens(question_ids)
        x = torch.cat([vision_emb, question_emb], dim=1)
        
        # Start with empty answer
        generated_ids = []
        confidence_scores = []
        recursion_steps_used = []
        
        # Generate autoregressively
        for step in range(max_new_tokens):
            # Current y from generated so far
            if len(generated_ids) == 0:
                # First token: use learned blank
                y = torch.zeros(batch_size, 1, self.hidden_dim, device=x.device)
                z = self.z_init.expand(batch_size, 1, -1)
            else:
                y = self.embed_tokens(torch.stack(generated_ids, dim=1))
                z = self.z_init.expand(batch_size, len(generated_ids), -1)
            
            # Adaptive recursion based on confidence
            if use_confidence and len(generated_ids) > 0:
                # Run once to check confidence
                x_temp, y_temp, z_temp = self.latent_recursion(x, y, z)
                y_normed = self.norm(y_temp)
                logits_temp = self.lm_head(y_normed[:, -1, :])
                conf = self.compute_confidence(logits_temp.unsqueeze(1))
                
                if conf.mean() >= self.confidence_threshold:
                    # High confidence - use result
                    y = y_temp
                    num_steps = 1
                else:
                    # Low confidence - run more recursion
                    y = y_temp
                    z = z_temp
                    for _ in range(self.num_inner_steps - 1):
                        x_temp, y, z = self.latent_recursion(x, y, z)
                    num_steps = self.num_inner_steps
                
                recursion_steps_used.append(num_steps)
            else:
                # Fixed recursion
                for _ in range(self.num_inner_steps):
                    x_temp, y, z = self.latent_recursion(x, y, z)
                recursion_steps_used.append(self.num_inner_steps)
            
            # Get logits for last position
            y = self.norm(y)
            logits = self.lm_head(y[:, -1, :]) / temperature
            next_token = torch.argmax(logits, dim=-1)
            
            # Track confidence
            conf = self.compute_confidence(logits.unsqueeze(1))
            confidence_scores.append(conf.mean().item())
            
            generated_ids.append(next_token)
            
            # Stop at EOS (optional)
            # if (next_token == tokenizer.eos_token_id).all():
            #     break
        
        return {
            'ids': torch.stack(generated_ids, dim=1),
            'confidence': confidence_scores,
            'recursion_steps': recursion_steps_used,
        }

print("TRM VLM with Confidence-Based Recursion defined.")

## 7. Evaluation Metrics (EM and F1)

In [None]:
def normalize_answer(s: str) -> str:
    """Normalize answer text for evaluation."""
    # Remove punctuation
    s = ''.join(ch for ch in s if ch not in string.punctuation)
    # Lowercase and strip
    s = s.lower().strip()
    # Remove articles
    s = ' '.join([w for w in s.split() if w not in {'a', 'an', 'the'}])
    return s

def compute_exact_match(pred: str, target: str) -> float:
    """Compute exact match score."""
    return float(normalize_answer(pred) == normalize_answer(target))

def compute_f1(pred: str, target: str) -> float:
    """Compute token-level F1 score."""
    pred_tokens = normalize_answer(pred).split()
    target_tokens = normalize_answer(target).split()
    
    if len(pred_tokens) == 0 or len(target_tokens) == 0:
        return float(pred_tokens == target_tokens)
    
    common = Counter(pred_tokens) & Counter(target_tokens)
    num_common = sum(common.values())
    
    if num_common == 0:
        return 0.0
    
    precision = num_common / len(pred_tokens)
    recall = num_common / len(target_tokens)
    f1 = 2 * precision * recall / (precision + recall)
    
    return f1

def evaluate_qa(
    predictions: List[str],
    targets: List[str],
) -> Dict[str, float]:
    """Evaluate QA predictions."""
    em_scores = [compute_exact_match(p, t) for p, t in zip(predictions, targets)]
    f1_scores = [compute_f1(p, t) for p, t in zip(predictions, targets)]
    
    return {
        'em': np.mean(em_scores) * 100,
        'f1': np.mean(f1_scores) * 100,
    }

## 8. Setup Dataset and Dataloaders

In [None]:
# Initialize tokenizer (use GPT2 tokenizer)
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

print(f"Tokenizer vocab size: {tokenizer.vocab_size}")
print(f"Pad token: {tokenizer.pad_token}")
print(f"EOS token: {tokenizer.eos_token}")

# Get transforms
train_transforms = get_image_transforms(image_size=336, is_training=True)
val_transforms = get_image_transforms(image_size=336, is_training=False)

In [None]:
# Import QA dataset
from data.dataset_builder import PixmoQADataset

# Define collate function for QA batches
def collate_qa_batch(batch):
    """Collate QA batch with padding."""
    # Stack images
    images = torch.stack([item['image'] for item in batch])
    
    # Pad question_ids
    question_ids = [item['question_ids'] for item in batch]
    max_q_len = max(q.shape[0] for q in question_ids)
    question_ids_padded = torch.stack([
        torch.cat([q, torch.full((max_q_len - q.shape[0],), tokenizer.pad_token_id, dtype=torch.long)])
        for q in question_ids
    ])
    
    # Pad answer_ids
    answer_ids = [item['answer_ids'] for item in batch]
    max_a_len = max(a.shape[0] for a in answer_ids)
    answer_ids_padded = torch.stack([
        torch.cat([a, torch.full((max_a_len - a.shape[0],), tokenizer.pad_token_id, dtype=torch.long)])
        for a in answer_ids
    ])
    
    # Get raw text
    questions = [item['question'] for item in batch]
    answers = [item['answer'] for item in batch]
    
    return {
        'images': images,
        'question_ids': question_ids_padded,
        'answer_ids': answer_ids_padded,
        'questions': questions,
        'answers': answers,
    }

# Create datasets
train_dataset = PixmoQADataset(
    parquet_path=config.dataset.train_parquet,
    tokenizer=tokenizer,
    image_transforms=train_transforms,
    max_question_length=128,
    max_answer_length=32,
)

val_dataset = PixmoQADataset(
    parquet_path=config.dataset.val_parquet,
    tokenizer=tokenizer,
    image_transforms=val_transforms,
    max_question_length=128,
    max_answer_length=32,
)

# Create dataloaders
batch_size = 32
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_qa_batch,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    collate_fn=collate_qa_batch,
    pin_memory=True,
)

print(f"\nDataset sizes:")
print(f"  Train: {len(train_dataset):,}")
print(f"  Val: {len(val_dataset):,}")
print(f"\nDataLoader info:")
print(f"  Batch size: {batch_size}")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")

## 9. Visualize Sample Batch

In [None]:
# Get sample batch
sample_batch = next(iter(train_loader))

print("Batch contents:")
print(f"  Images: {sample_batch['images'].shape}")
print(f"  Question IDs: {sample_batch['question_ids'].shape}")
print(f"  Answer IDs: {sample_batch['answer_ids'].shape}")

print("\nSample QA pairs:")
for i in range(min(3, len(sample_batch['questions']))):
    print(f"\n  [{i+1}]")
    print(f"    Q: {sample_batch['questions'][i]}")
    print(f"    A: {sample_batch['answers'][i][:80]}...")

## 10. Initialize Decoder Model

Choose between TinyVLMDecoder (baseline) or TRMVLMDecoder (recursive).

In [None]:
# Configuration
USE_TRM = True  # Always use TRM for this notebook
HIDDEN_DIM = 512
NUM_LAYERS = 2  # Small for TRM
NUM_HEADS = 8
NUM_INNER_STEPS = 4
CONFIDENCE_THRESHOLD = 0.75

# Get vision token dimension from aligned model
vision_token_dim = aligned_model.vision_encoder.projector.out_features  # Should be 4096

# Initialize TRM VLM model
model = TRMVLMWithConfidence(
    vocab_size=tokenizer.vocab_size,
    vision_token_dim=vision_token_dim,
    hidden_dim=HIDDEN_DIM,
    num_layers=NUM_LAYERS,
    num_heads=NUM_HEADS,
    num_inner_steps=NUM_INNER_STEPS,
    confidence_threshold=CONFIDENCE_THRESHOLD,
).to(device)

print(f"Initialized TRM VLM Decoder with Confidence-Based Recursion")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel parameters:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")
print(f"  Hidden dim: {HIDDEN_DIM}")
print(f"  Num layers: {NUM_LAYERS}")
print(f"  Num heads: {NUM_HEADS}")
print(f"  Inner recursion steps: {NUM_INNER_STEPS}")
print(f"  Confidence threshold: {CONFIDENCE_THRESHOLD}")

In [None]:
print("="*60)
print("PARAMETER AUDIT")
print("="*60)

trainable_params = []
frozen_params = []

# Audit all model parameters
for name, param in model.named_parameters():
    if param.requires_grad:
        trainable_params.append((name, param.numel()))
    else:
        frozen_params.append((name, param.numel()))

# Audit aligned model (should all be frozen)
for name, param in aligned_model.named_parameters():
    full_name = f"aligned_model.{name}"
    if param.requires_grad:
        trainable_params.append((full_name, param.numel()))
    else:
        frozen_params.append((full_name, param.numel()))

print(f"\n🟢 TRAINABLE PARAMETERS ({len(trainable_params)} groups):")
total_trainable = 0
for name, count in trainable_params[:20]:  # Show first 20
    print(f"  {name}: {count:,}")
    total_trainable += count

if len(trainable_params) > 20:
    print(f"  ... and {len(trainable_params) - 20} more")
    for _, count in trainable_params[20:]:
        total_trainable += count

print(f"\n🔴 FROZEN PARAMETERS ({len(frozen_params)} groups):")
total_frozen = 0
sample_frozen = frozen_params[:10]
for name, count in sample_frozen:
    print(f"  {name}: {count:,}")
    total_frozen += count

for _, count in frozen_params[10:]:
    total_frozen += count

if len(frozen_params) > 10:
    print(f"  ... and {len(frozen_params) - 10} more")

print(f"\n📊 SUMMARY:")
total = total_trainable + total_frozen
print(f"  Total parameters: {total:,}")
print(f"  Trainable: {total_trainable:,} ({100*total_trainable/total:.2f}%)")
print(f"  Frozen: {total_frozen:,} ({100*total_frozen/total:.2f}%)")

print(f"\n✓ VERIFICATION:")
print(f"  ✓ Aligned model frozen: {all(not p.requires_grad for p in aligned_model.parameters())}")
print(f"  ✓ Vision projection trainable: {'vision_proj' in [n for n, _ in trainable_params]}")

if USE_TRM_RECURSION:
    print(f"  ✓ TRM layers trainable: {'trm_layers' in str(trainable_params)}")
    print(f"  ✓ z_init trainable: {'z_init' in [n for n, _ in trainable_params]}")
else:
    print(f"  ✓ Qwen LoRA trainable: {'lora' in str(trainable_params).lower()}")

print("="*60)

## 📊 Parameter Audit

Verify which parameters are frozen vs trainable.

In [None]:
@torch.no_grad()
def text_only_sanity_check(decoder, prompts, max_tokens=20):
    """Test decoder on text-only prompts without vision."""
    print("\n" + "="*60)
    print("TEXT-ONLY SANITY CHECK (Before Training)")
    print("="*60)
    
    for prompt in prompts:
        # Encode prompt
        inputs = decoder.tokenizer(prompt, return_tensors='pt').to(device)
        
        # Generate
        outputs = decoder.generate(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_new_tokens=max_tokens,
            temperature=0.7,
            do_sample=True,
        )
        
        # Decode
        generated_text = decoder.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        print(f"\nPrompt: {prompt}")
        print(f"Generated: {generated_text}")
        print("-" * 60)
    
    print("\n✓ If generation is coherent English → Decoder works!")
    print("✗ If generation is garbage → Decoder loading issue")
    print("="*60)

# Test prompts
test_prompts = [
    "Question: What is 2 + 2? Answer:",
    "Question: What color is the sky? Answer:",
    "The capital of France is",
]

text_only_sanity_check(qwen_decoder, test_prompts)

In [None]:
def debug_first_batch(batch, vision_tokens, outputs, qwen_tokenizer):
    """Debug first training batch to verify labels, masks, and loss."""
    print("\n" + "="*60)
    print("FIRST BATCH DEBUG")
    print("="*60)
    
    print(f"\n📦 Batch shapes:")
    print(f"  Images: {batch['images'].shape}")
    print(f"  Vision tokens: {vision_tokens.shape}")
    print(f"  Question IDs: {batch['question_ids'].shape}")
    print(f"  Answer IDs: {batch['answer_ids'].shape}")
    
    # Analyze first example
    print(f"\n📝 First example:")
    print(f"  Question: {batch['questions'][0][:100]}...")
    print(f"  Answer: {batch['answers'][0][:100]}...")
    
    # Decode tokens
    print(f"\n🔤 Decoded tokens (first example):")
    q_decoded = qwen_tokenizer.decode(batch['question_ids'][0], skip_special_tokens=False)
    a_decoded = qwen_tokenizer.decode(batch['answer_ids'][0], skip_special_tokens=False)
    print(f"  Question tokens: {q_decoded[:150]}...")
    print(f"  Answer tokens: {a_decoded[:100]}...")
    
    # Check loss and logits
    print(f"\n📊 Training outputs:")
    print(f"  Loss: {outputs['loss'].item():.4f}")
    if outputs.get('logits') is not None:
        print(f"  Logits shape: {outputs['logits'].shape}")
    if outputs.get('confidence') is not None and outputs['confidence'] is not None:
        print(f"  Confidence: {outputs['confidence'].mean().item():.3f}")
    
    # Token counts
    num_q_tokens = (batch['question_ids'][0] != qwen_tokenizer.pad_token_id).sum().item()
    num_a_tokens = (batch['answer_ids'][0] != qwen_tokenizer.pad_token_id).sum().item()
    num_img_tokens = vision_tokens.shape[1]
    
    print(f"\n📏 Token counts (first example):")
    print(f"  Vision tokens: {num_img_tokens}")
    print(f"  Question tokens (non-pad): {num_q_tokens}")
    print(f"  Answer tokens (non-pad): {num_a_tokens}")
    print(f"  Total context: {num_img_tokens + num_q_tokens}")
    print(f"  Supervised tokens (answer): {num_a_tokens}")
    
    total_seq_len = num_img_tokens + num_q_tokens + num_a_tokens
    print(f"\n  Total sequence length: {total_seq_len}")
    print(f"  Supervised %: {100 * num_a_tokens / total_seq_len:.1f}%")
    
    print("\n✓ Expected behavior:")
    print("  - Loss should be < 10 (pretrained LM range)")
    print("  - Supervised tokens should be answer only (~10-30 tokens)")
    print("  - Vision + question tokens should NOT contribute to loss")
    
    print("="*60 + "\n")

print("✓ Debug helper functions defined")

# Initialize wandb
USE_WANDB = True
if USE_WANDB:
    run_name = f"qwen_vlm_{'trm' if USE_TRM_RECURSION else 'baseline'}"
    wandb.init(
        project="edge_glass_qwen_vlm",
        name=run_name,
        config={
            'use_trm_recursion': USE_TRM_RECURSION,
            'num_trm_layers': NUM_TRM_LAYERS if USE_TRM_RECURSION else 0,
            'num_recursion_steps': NUM_RECURSION_STEPS if USE_TRM_RECURSION else 0,
            'confidence_threshold': CONFIDENCE_THRESHOLD,
            'learning_rate': LEARNING_RATE,
            'batch_size': batch_size,
            'decoder': config.decoder.model_name,
            'use_lora': config.decoder.use_lora,
        }
    )

# Training state
global_step = 0
best_val_loss = float('inf')
history = {'train_loss': [], 'val_loss': [], 'val_em': [], 'val_f1': []}
first_batch_debugged = False  # Flag for first-batch debug

# Checkpoint directory
ckpt_dir = Path(f"checkpoints/qwen_vlm_qa_{'trm' if USE_TRM_RECURSION else 'baseline'}")
ckpt_dir.mkdir(parents=True, exist_ok=True)

print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)
print(f"Mode: {'TRM Recursion' if USE_TRM_RECURSION else 'Baseline (No Recursion)'}")
print(f"Checkpoint dir: {ckpt_dir}")
print("="*60)

model.train()

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    epoch_losses = []
    
    for batch_idx, batch in enumerate(pbar):
        # Move to device
        images = batch['images'].to(device)
        question_ids = batch['question_ids'].to(device)
        answer_ids = batch['answer_ids'].to(device)
        
        # Encode images (frozen)
        vision_tokens = encode_images(images)
        
        # Forward pass
        outputs = model(vision_tokens, question_ids, answer_ids)
        loss = outputs['loss']
        
        # 🐛 DEBUG FIRST BATCH
        if epoch == 0 and batch_idx == 0 and not first_batch_debugged:
            debug_first_batch(batch, vision_tokens, outputs, qwen_decoder.tokenizer)
            first_batch_debugged = True
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
        optimizer.step()
        scheduler.step()
        
        # Log
        epoch_losses.append(loss.item())
        global_step += 1
        
        if global_step % LOG_EVERY == 0:
            avg_loss = np.mean(epoch_losses[-LOG_EVERY:])
            pbar_dict = {
                'loss': f'{avg_loss:.4f}',
                'lr': f'{scheduler.get_last_lr()[0]:.2e}'
            }
            
            if outputs.get('confidence') is not None and outputs['confidence'] is not None:
                avg_conf = outputs['confidence'].mean().item()
                pbar_dict['conf'] = f'{avg_conf:.3f}'
            
            pbar.set_postfix(pbar_dict)
            
            if USE_WANDB:
                log_dict = {
                    'train/loss': avg_loss,
                    'train/lr': scheduler.get_last_lr()[0],
                    'step': global_step,
                }
                if outputs.get('confidence') is not None and outputs['confidence'] is not None:
                    log_dict['train/confidence'] = avg_conf
                wandb.log(log_dict)
    
    # Epoch-end evaluation
    print(f"\n  Epoch {epoch+1} average loss: {np.mean(epoch_losses):.4f}")
    history['train_loss'].append(np.mean(epoch_losses))
    
    # Validation (simple loss for now, full eval is expensive)
    model.eval()
    val_losses = []
    val_confidences = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            images = batch['images'].to(device)
            question_ids = batch['question_ids'].to(device)
            answer_ids = batch['answer_ids'].to(device)
            
            vision_tokens = encode_images(images)
            outputs = model(vision_tokens, question_ids, answer_ids)
            val_losses.append(outputs['loss'].item())
            
            if outputs.get('confidence') is not None and outputs['confidence'] is not None:
                val_confidences.append(outputs['confidence'].mean().item())
    
    val_loss = np.mean(val_losses)
    history['val_loss'].append(val_loss)
    
    print(f"  Validation loss: {val_loss:.4f}")
    if val_confidences:
        val_conf = np.mean(val_confidences)
        print(f"  Validation confidence: {val_conf:.3f}")
    
    if USE_WANDB:
        log_dict = {
            'val/loss': val_loss,
            'epoch': epoch + 1,
        }
        if val_confidences:
            log_dict['val/confidence'] = val_conf
        wandb.log(log_dict)
    
    # Save checkpoint
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_loss': best_val_loss,
            'config': {
                'use_trm_recursion': USE_TRM_RECURSION,
                'num_trm_layers': NUM_TRM_LAYERS,
                'num_recursion_steps': NUM_RECURSION_STEPS,
                'confidence_threshold': CONFIDENCE_THRESHOLD,
            },
        }, ckpt_dir / "checkpoint_best.pt")
        print(f"  ✓ Saved best checkpoint (val_loss: {best_val_loss:.4f})")
    
    model.train()

print("\n" + "="*60)
print("TRAINING COMPLETED")
print("="*60)
print(f"Best validation loss: {best_val_loss:.4f}")

## ✅ Text-Only Sanity Check (BEFORE Training)

**CRITICAL TEST**: Verify that the pretrained Qwen decoder can generate coherent English BEFORE multimodal training.

If this fails → decoder loading issue
If this passes → decoder is working, ready for multimodal training

In [None]:
# Load best checkpoint
best_ckpt = torch.load(ckpt_dir / "checkpoint_best.pt", map_location=device, weights_only=False)
model.load_state_dict(best_ckpt['model_state_dict'])
model.eval()

print("\n" + "="*60)
print("FULL EVALUATION ON VALIDATION SET")
print("="*60)
print(f"Mode: {'TRM Recursion' if USE_TRM_RECURSION else 'Baseline (No Recursion)'}")
print("="*60)

all_predictions = []
all_targets = []
all_confidence_scores = []
all_recursion_steps = []

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Generating answers"):
        images = batch['images'].to(device)
        question_ids = batch['question_ids'].to(device)
        answers = batch['answers']
        
        # Encode images
        vision_tokens = encode_images(images)
        
        # Generate answers
        gen_outputs = model.generate(
            vision_tokens,
            question_ids,
            max_new_tokens=32,
            temperature=0.0,  # Greedy for evaluation
            use_confidence=USE_TRM_RECURSION,  # Only for TRM mode
            return_stats=True,
        )
        
        # Handle different output formats
        if isinstance(gen_outputs, dict):
            generated_ids = gen_outputs['predictions']
            if gen_outputs.get('confidences') is not None:
                all_confidence_scores.extend(gen_outputs['confidences'].mean(dim=1).cpu().tolist())
            if gen_outputs.get('recursion_steps') is not None:
                all_recursion_steps.extend(gen_outputs['recursion_steps'].mean(dim=1).cpu().tolist())
        else:
            generated_ids = gen_outputs
        
        # Decode
        predictions = qwen_decoder.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        
        all_predictions.extend(predictions)
        all_targets.extend(answers)

# Compute metrics
metrics = evaluate_qa(all_predictions, all_targets)

print("\n" + "="*60)
print("EVALUATION RESULTS")
print("="*60)
print(f"Exact Match (EM): {metrics['em']:.2f}%")
print(f"Token F1: {metrics['f1']:.2f}%")

if all_confidence_scores:
    print(f"\n📊 Confidence Statistics:")
    print(f"  Mean: {np.mean(all_confidence_scores):.3f}")
    print(f"  Std: {np.std(all_confidence_scores):.3f}")
    print(f"  Min: {np.min(all_confidence_scores):.3f}")
    print(f"  Max: {np.max(all_confidence_scores):.3f}")

if all_recursion_steps:
    print(f"\n🔁 Recursion Statistics:")
    print(f"  Avg steps per sequence: {np.mean(all_recursion_steps):.2f}")
    print(f"  Std: {np.std(all_recursion_steps):.2f}")
    if USE_TRM_RECURSION:
        recursion_triggered = sum(1 for s in all_recursion_steps if s > 1) / len(all_recursion_steps) * 100
        print(f"  Recursion triggered: {recursion_triggered:.1f}% of sequences")

print("="*60)

if USE_WANDB:
    wandb.log({
        'val/em_final': metrics['em'],
        'val/f1_final': metrics['f1'],
    })

# Show sample predictions
print("\n" + "="*60)
print("SAMPLE PREDICTIONS")
print("="*60)

for i in range(min(10, len(all_predictions))):
    print(f"\n[{i+1}]")
    print(f"  Q: {val_dataset.df.iloc[i]['question'][:80]}...")
    print(f"  Target: {all_targets[i][:80]}...")
    print(f"  Predicted: {all_predictions[i][:80]}...")
    print(f"  EM: {compute_exact_match(all_predictions[i], all_targets[i])}")
    print(f"  F1: {compute_f1(all_predictions[i], all_targets[i]):.3f}")

print("\n" + "="*60)

## 🔧 Initialize Pretrained Qwen Decoder

Load Qwen2.5-7B-Instruct with LoRA for efficient fine-tuning.

In [None]:
class QwenVLM(nn.Module):
    """Qwen-based VLM with optional TRM recursion.
    
    This is the FIXED implementation that uses pretrained Qwen instead of random init.
    
    Features:
    - Pretrained Qwen2.5-7B-Instruct with LoRA
    - Vision token projection to Qwen's hidden dim
    - Proper prefix_embeds integration
    - Standard autoregressive loss with -100 masking
    - Optional TRM latent recursion on top (for comparison)
    """
    
    def __init__(
        self,
        qwen_decoder: QwenDecoder,
        vision_token_dim: int = 4096,
        use_trm_recursion: bool = False,
        num_trm_layers: int = 2,
        num_recursion_steps: int = 4,
        confidence_threshold: float = 0.75,
    ):
        super().__init__()
        
        self.qwen = qwen_decoder
        self.hidden_dim = qwen_decoder.hidden_dim
        self.use_trm_recursion = use_trm_recursion
        self.num_recursion_steps = num_recursion_steps
        self.confidence_threshold = confidence_threshold
        
        # Project vision tokens from 4096 -> Qwen hidden dim
        self.vision_proj = nn.Linear(vision_token_dim, self.hidden_dim)
        
        # Optional: TRM recursion components
        if use_trm_recursion:
            from decoders.trm import TRMConfig, TRMLayer, RMSNorm
            
            trm_config = TRMConfig(
                vocab_size=qwen_decoder.vocab_size,
                hidden_dim=self.hidden_dim,
                num_layers=num_trm_layers,
                num_heads=8,
                max_seq_len=2048,
            )
            
            self.trm_layers = nn.ModuleList([
                TRMLayer(trm_config) for _ in range(num_trm_layers)
            ])
            self.trm_norm = RMSNorm(self.hidden_dim)
            self.z_init = nn.Parameter(torch.randn(1, 1, self.hidden_dim) * 0.02)
            
            print(f"  ✓ TRM recursion enabled ({num_trm_layers} layers, {num_recursion_steps} steps)")
        else:
            print(f"  ✓ Baseline mode (no TRM recursion)")
    
    def forward(
        self,
        vision_tokens: torch.Tensor,  # (B, K_img, 4096)
        question_ids: torch.Tensor,   # (B, L_q)
        answer_ids: torch.Tensor,     # (B, L_a)
    ):
        """Forward pass with proper autoregressive training.
        
        Token layout: [IMG_TOKENS] [QUESTION_TOKENS] [ANSWER_TOKENS]
        Loss only on answer tokens via -100 masking.
        """
        batch_size = vision_tokens.shape[0]
        num_img_tokens = vision_tokens.shape[1]
        
        # Project vision tokens to Qwen hidden dim
        vision_emb = self.vision_proj(vision_tokens)  # (B, K_img, d_qwen)
        
        # Prepare input_ids and labels for standard autoregressive training
        # Input: [question_ids, answer_ids]
        # Labels: [-100 for question, answer_ids]
        
        input_ids = torch.cat([question_ids, answer_ids], dim=1)  # (B, L_q + L_a)
        
        # Create labels: -100 for question, real IDs for answer
        question_labels = torch.full_like(question_ids, fill_value=-100)
        labels = torch.cat([question_labels, answer_ids], dim=1)  # (B, L_q + L_a)
        
        if not self.use_trm_recursion:
            # ===== BASELINE MODE: Standard Qwen forward =====
            # Use prefix_embeds for vision tokens
            outputs = self.qwen(
                input_ids=input_ids,
                prefix_embeds=vision_emb,
                labels=labels,
            )
            
            return {
                'loss': outputs.loss,
                'logits': outputs.logits,
                'confidence': None,
            }
        
        else:
            # ===== TRM MODE: Latent recursion on top of Qwen =====
            # Get embeddings
            text_emb = self.qwen.model.get_input_embeddings()(input_ids)  # (B, L_q+L_a, d)
            
            # Split into question and answer
            L_q = question_ids.shape[1]
            L_a = answer_ids.shape[1]
            question_emb = text_emb[:, :L_q, :]
            answer_emb = text_emb[:, L_q:, :]
            
            # Context: [vision, question]
            x = torch.cat([vision_emb, question_emb], dim=1)  # (B, K_img + L_q, d)
            
            # Answer representation (to be refined by TRM)
            y = answer_emb  # (B, L_a, d)
            
            # Initialize latent state
            z = self.z_init.expand(batch_size, L_a, -1)  # (B, L_a, d)
            
            # Latent recursion
            for _ in range(self.num_recursion_steps):
                x, y, z = self.latent_recursion(x, y, z)
            
            # Final logits from refined y
            y = self.trm_norm(y)
            # Use Qwen's LM head
            lm_head = self.qwen.model.lm_head
            logits_answer = lm_head(y)  # (B, L_a, vocab)
            
            # Compute loss only on answer tokens
            shift_logits = logits_answer[:, :-1, :].contiguous()
            shift_labels = answer_ids[:, 1:].contiguous()
            
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            
            # Compute confidence
            probs = torch.softmax(logits_answer, dim=-1)
            max_probs = torch.max(probs, dim=-1)[0]
            confidence = torch.mean(max_probs, dim=-1)
            
            return {
                'loss': loss,
                'logits': logits_answer,
                'confidence': confidence,
            }
    
    def latent_recursion(self, x, y, z):
        """TRM latent recursion step."""
        # Concatenate [x, y, z]
        concat = torch.cat([x, y, z], dim=1)
        
        # Pass through TRM layers
        hidden = concat
        for layer in self.trm_layers:
            hidden = layer(hidden)
        
        # Split back
        L_ctx = x.shape[1]
        L_ans = y.shape[1]
        
        x_out = hidden[:, :L_ctx, :]
        y_out = hidden[:, L_ctx:L_ctx+L_ans, :]
        z_out = hidden[:, L_ctx+L_ans:, :]
        
        return x_out, y_out, z_out
    
    @torch.no_grad()
    def generate(
        self,
        vision_tokens: torch.Tensor,
        question_ids: torch.Tensor,
        max_new_tokens: int = 32,
        temperature: float = 0.7,
        use_confidence: bool = True,
        return_stats: bool = False,
    ):
        """Generate answers with optional TRM recursion."""
        # Project vision
        vision_emb = self.vision_proj(vision_tokens)
        
        if not self.use_trm_recursion:
            # Baseline: standard Qwen generation
            outputs = self.qwen.generate(
                input_ids=question_ids,
                prefix_embeds=vision_emb,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=False,  # Use greedy for eval
            )
            
            # Extract only the generated part (remove input)
            prompt_len = question_ids.shape[1] + vision_emb.shape[1]
            generated_ids = outputs[:, prompt_len:]
            
            if return_stats:
                return {
                    'predictions': generated_ids,
                    'confidences': torch.ones(generated_ids.shape),  # Dummy
                    'recursion_steps': torch.ones(generated_ids.shape),  # No recursion
                }
            return generated_ids
        
        else:
            # TRM mode: autoregressive generation with recursion
            batch_size = vision_tokens.shape[0]
            generated_ids = []
            confidences = []
            recursion_steps_used = []
            
            for step in range(max_new_tokens):
                # Current answer so far
                if len(generated_ids) == 0:
                    answer_ids = torch.tensor(
                        [[self.qwen.tokenizer.pad_token_id]],
                        device=vision_tokens.device
                    ).expand(batch_size, 1)
                else:
                    answer_ids = torch.stack(generated_ids, dim=1)
                
                # Forward pass
                outputs = self.forward(vision_tokens, question_ids, answer_ids)
                logits = outputs['logits'][:, -1, :] / temperature
                
                # Sample next token
                next_token = torch.argmax(logits, dim=-1)
                generated_ids.append(next_token)
                
                # Track confidence
                probs = torch.softmax(logits, dim=-1)
                conf = torch.max(probs, dim=-1)[0]
                confidences.append(conf.cpu())
                
                recursion_steps_used.append(
                    self.num_recursion_steps if self.use_trm_recursion else 1
                )
            
            generated_ids = torch.stack(generated_ids, dim=1)
            
            if return_stats:
                return {
                    'predictions': generated_ids,
                    'confidences': torch.stack(confidences, dim=1),
                    'recursion_steps': torch.tensor(recursion_steps_used).unsqueeze(0).expand(batch_size, -1),
                }
            return generated_ids

print("✓ QwenVLM class defined (with baseline and TRM modes)")

## 🔧 NEW: Qwen-Based VLM Decoder (Pretrained)

**CRITICAL FIX**: The original TRM decoder was randomly initialized, attempting to learn language from scratch. 

This new implementation uses **pretrained Qwen2.5-7B** with LoRA, giving us:
- ✅ Pretrained language knowledge (no need to learn grammar/vocabulary)
- ✅ Efficient fine-tuning with LoRA (only ~40M trainable params)
- ✅ Proper multimodal prefix support
- ✅ Standard autoregressive training with -100 label masking

We'll implement:
1. **Baseline mode**: Standard Qwen decoder without TRM recursion
2. **TRM mode**: Qwen + TRM latent recursion on top

## 11. Training Setup

In [None]:
print("\n" + "="*60)
print("EXPERIMENT SUMMARY")
print("="*60)

print(f"\n🔧 Model Configuration:")
print(f"  Decoder: {config.decoder.model_name}")
print(f"  Use LoRA: {config.decoder.use_lora}")
print(f"  Use TRM Recursion: {USE_TRM_RECURSION}")

if USE_TRM_RECURSION:
    print(f"  TRM layers: {NUM_TRM_LAYERS}")
    print(f"  Recursion steps: {NUM_RECURSION_STEPS}")
    print(f"  Confidence threshold: {CONFIDENCE_THRESHOLD}")

print(f"\n📊 Training:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Best val loss: {best_val_loss:.4f}")

print(f"\n📈 Evaluation Results:")
print(f"  Exact Match (EM): {metrics['em']:.2f}%")
print(f"  Token F1: {metrics['f1']:.2f}%")

if all_recursion_steps and USE_TRM_RECURSION:
    print(f"  Avg recursion steps: {np.mean(all_recursion_steps):.2f}")
    print(f"  Recursion triggered: {recursion_triggered:.1f}% of sequences")

print(f"\n📦 Dataset:")
print(f"  Train samples: {len(train_dataset):,}")
print(f"  Val samples: {len(val_dataset):,}")

print(f"\n💾 Output Files:")
print(f"  Best checkpoint: {ckpt_dir / 'checkpoint_best.pt'}")

print("\n" + "="*60)
print("KEY IMPROVEMENTS FROM ORIGINAL")
print("="*60)
print("✅ 1. Pretrained Qwen Decoder:")
print("   - Replaced random TRM (34M params) with Qwen-7B + LoRA")
print("   - Model now has language prior (no need to learn from scratch)")
print("   - Text-only sanity check validates decoder works")
print()
print("✅ 2. Proper Autoregressive Training:")
print("   - Uses prefix_embeds for vision tokens")
print("   - Standard -100 label masking for vision/question tokens")
print("   - Loss only computed on answer tokens")
print()
print("✅ 3. Baseline Mode:")
print("   - Can disable TRM recursion for comparison")
print("   - Isolates decoder issues from recursion issues")
print("   - Establishes performance floor")
print()
print("✅ 4. Debug Instrumentation:")
print("   - Text-only generation test before training")
print("   - Parameter audit (frozen vs trainable)")
print("   - First-batch debug logging")
print("   - Detailed recursion statistics")
print()
print("=" * 60)
print("📋 NEXT STEPS")
print("="*60)
print("1. If EM/F1 are still low with baseline:")
print("   - Check first-batch debug output")
print("   - Verify text-only generation works")
print("   - Increase training epochs or learning rate")
print()
print("2. If baseline works well:")
print("   - Set USE_TRM_RECURSION = True")
print("   - Compare TRM vs baseline performance")
print("   - Tune confidence threshold and recursion steps")
print()
print("3. Expected performance:")
print("   - Baseline: EM > 15%, F1 > 25% (if Qwen works)")
print("   - With TRM: EM >= baseline (should match or exceed)")
print("="*60)

print("\n✓ Notebook complete!")

## 12. Training Loop

In [None]:
# Initialize wandb
USE_WANDB = True
if USE_WANDB:
    wandb.init(
        project="edge_glass_trm_vlm",
        name=f"trm_vlm_d{HIDDEN_DIM}_l{NUM_LAYERS}_n{NUM_INNER_STEPS}_conf{CONFIDENCE_THRESHOLD}",
        config={
            'hidden_dim': HIDDEN_DIM,
            'num_layers': NUM_LAYERS,
            'num_heads': NUM_HEADS,
            'num_inner_steps': NUM_INNER_STEPS,
            'confidence_threshold': CONFIDENCE_THRESHOLD,
            'learning_rate': LEARNING_RATE,
            'batch_size': batch_size,
            'total_params': total_params,
        }
    )

# Training state
global_step = 0
best_val_loss = float('inf')
history = {'train_loss': [], 'val_loss': [], 'val_em': [], 'val_f1': []}

# Checkpoint directory
ckpt_dir = Path(f"checkpoints/trm_vlm_qa")
ckpt_dir.mkdir(parents=True, exist_ok=True)

print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)

model.train()

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    epoch_losses = []
    
    for batch_idx, batch in enumerate(pbar):
        # Move to device
        images = batch['images'].to(device)
        question_ids = batch['question_ids'].to(device)
        answer_ids = batch['answer_ids'].to(device)
        
        # Encode images (frozen)
        vision_tokens = encode_images(images)
        
        # Forward pass
        outputs = model(vision_tokens, question_ids, answer_ids)
        loss = outputs['loss']
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
        optimizer.step()
        scheduler.step()
        
        # Log
        epoch_losses.append(loss.item())
        global_step += 1
        
        if global_step % LOG_EVERY == 0:
            avg_loss = np.mean(epoch_losses[-LOG_EVERY:])
            avg_conf = outputs['confidence'].mean().item()
            pbar.set_postfix({
                'loss': f'{avg_loss:.4f}',
                'conf': f'{avg_conf:.3f}',
                'lr': f'{scheduler.get_last_lr()[0]:.2e}'
            })
            
            if USE_WANDB:
                wandb.log({
                    'train/loss': avg_loss,
                    'train/confidence': avg_conf,
                    'train/lr': scheduler.get_last_lr()[0],
                    'step': global_step,
                })
    
    # Epoch-end evaluation
    print(f"\n  Epoch {epoch+1} average loss: {np.mean(epoch_losses):.4f}")
    history['train_loss'].append(np.mean(epoch_losses))
    
    # Validation (simple loss for now, full eval is expensive)
    model.eval()
    val_losses = []
    val_confidences = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            images = batch['images'].to(device)
            question_ids = batch['question_ids'].to(device)
            answer_ids = batch['answer_ids'].to(device)
            
            vision_tokens = encode_images(images)
            outputs = model(vision_tokens, question_ids, answer_ids)
            val_losses.append(outputs['loss'].item())
            val_confidences.append(outputs['confidence'].mean().item())
    
    val_loss = np.mean(val_losses)
    val_conf = np.mean(val_confidences)
    history['val_loss'].append(val_loss)
    
    print(f"  Validation loss: {val_loss:.4f}")
    print(f"  Validation confidence: {val_conf:.3f}")
    
    if USE_WANDB:
        wandb.log({
            'val/loss': val_loss,
            'val/confidence': val_conf,
            'epoch': epoch + 1,
        })
    
    # Save checkpoint
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_loss': best_val_loss,
            'config': {
                'hidden_dim': HIDDEN_DIM,
                'num_layers': NUM_LAYERS,
                'num_heads': NUM_HEADS,
                'num_inner_steps': NUM_INNER_STEPS,
                'confidence_threshold': CONFIDENCE_THRESHOLD,
            },
        }, ckpt_dir / "checkpoint_best.pt")
        print(f"  ✓ Saved best checkpoint (val_loss: {best_val_loss:.4f})")
    
    model.train()

print("\n" + "="*60)
print("TRAINING COMPLETED")
print("="*60)
print(f"Best validation loss: {best_val_loss:.4f}")

## 13. Full Evaluation with EM and F1

In [None]:
# Load best checkpoint
best_ckpt = torch.load(ckpt_dir / "checkpoint_best.pt", map_location=device, weights_only=False)
model.load_state_dict(best_ckpt['model_state_dict'])
model.eval()

print("\\nRunning full evaluation on validation set...")
print("Using confidence-based adaptive recursion during generation")

all_predictions = []
all_targets = []
all_confidence_scores = []
all_recursion_steps = []

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Generating answers"):
        images = batch['images'].to(device)
        question_ids = batch['question_ids'].to(device)
        answers = batch['answers']
        
        # Encode images
        vision_tokens = encode_images(images)
        
        # Generate answers with confidence-based recursion
        gen_outputs = model.generate(
            vision_tokens,
            question_ids,
            max_new_tokens=32,
            temperature=0.7,
            use_confidence=True,  # Enable adaptive recursion
        )
        
        generated_ids = gen_outputs['ids']
        
        # Decode
        predictions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        
        all_predictions.extend(predictions)
        all_targets.extend(answers)
        all_confidence_scores.extend(gen_outputs['confidence'])
        all_recursion_steps.append(gen_outputs['recursion_steps'])



In [None]:
# Compute metrics
metrics = evaluate_qa(all_predictions, all_targets)

# Analyze recursion statistics
flat_recursion_steps = [step for batch_steps in all_recursion_steps for step in batch_steps]
avg_recursion_steps = np.mean(flat_recursion_steps)
recursion_triggered = sum(1 for s in flat_recursion_steps if s > 1) / len(flat_recursion_steps) * 100

print("\\n" + "="*60)
print("EVALUATION RESULTS")
print("="*60)
print(f"Exact Match (EM): {metrics['em']:.2f}%")
print(f"Token F1: {metrics['f1']:.2f}%")
print(f"\\nRecursion Statistics:")
print(f"  Average steps per token: {avg_recursion_steps:.2f}")
print(f"  Recursion triggered: {recursion_triggered:.1f}% of tokens")
print(f"  (Threshold: {CONFIDENCE_THRESHOLD})")
print("="*60)

if USE_WANDB:
    wandb.log({
        'val/em': metrics['em'],
        'val/f1': metrics['f1'],
        'val/avg_recursion_steps': avg_recursion_steps,
        'val/recursion_triggered_pct': recursion_triggered,
    })

# Show some examples
print("\\nSample predictions with recursion analysis:")
for i in range(min(10, len(all_predictions))):
    print(f"\\n[{i+1}]")
    print(f"  Question: {val_dataset.df.iloc[i]['question'][:60]}...")
    print(f"  Target: {all_targets[i][:60]}...")
    print(f"  Predicted: {all_predictions[i][:60]}...")
    print(f"  EM: {compute_exact_match(all_predictions[i], all_targets[i])}")
    print(f"  F1: {compute_f1(all_predictions[i], all_targets[i]):.3f}")
    
    # Show first few recursion steps for this sample
    if i < len(all_recursion_steps):
        sample_steps = all_recursion_steps[i][:5]  # First 5 tokens
        print(f"  Recursion steps (first 5 tokens): {sample_steps}")

## 14. Visualize Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Metrics (if available)
if history.get('val_em'):
    axes[1].plot(history['val_em'], label='Exact Match', linewidth=2, marker='o')
if history.get('val_f1'):
    axes[1].plot(history['val_f1'], label='Token F1', linewidth=2, marker='s')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Score (%)')
axes[1].set_title('Validation Metrics')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(ckpt_dir / "training_curves.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"\nTraining curves saved to {ckpt_dir / 'training_curves.png'}")

## 15. Complex VLM Qualitative Check

Run a manual question on the complex curve plot to sanity check end-to-end VLM reasoning before the final summary.

In [None]:
# Quick VLM spot check on a complex curve image
from pathlib import Path
from PIL import Image
import torchvision.transforms as T

model.eval()
complex_curve_path = Path("/home/hice1/vchopra37/scratch/projects/edge_glass/complex_curve.png")
raw_image = Image.open(complex_curve_path).convert("RGB")

fig, ax = plt.subplots(figsize=(6, 4))
ax.imshow(raw_image)
ax.axis("off")
ax.set_title("Complex curve prompt image")

image_tensor = val_transforms(T.ToTensor()(raw_image)).unsqueeze(0).to(device)
vision_tokens = encode_images(image_tensor)

question = "Interpret the curve on this image. What does its shape suggest about the underlying relationship?"
question_ids = tokenizer(
    question,
    return_tensors='pt',
    add_special_tokens=True,
).input_ids.to(device)

generation = model.generate(
    vision_tokens=vision_tokens,
    question_ids=question_ids,
    max_new_tokens=64,
    temperature=0.7,
    use_confidence=True,
)

decoded_answer = tokenizer.decode(
    generation['ids'][0].detach().cpu().tolist(),
    skip_special_tokens=True,
).strip()

print("Question:", question)
print("Predicted answer:", decoded_answer)
print("Confidence trace:", [round(c, 3) for c in generation["confidence"]])
print("Recursion steps per token:", generation["recursion_steps"])


## 16. Complex VLM Question Probe

Push a harder, multi-part interpretation query on the curve plot before the final summary.


In [None]:
# Complex VLM question probe on the curve image
probe_image_path = Path("/home/hice1/vchopra37/scratch/projects/edge_glass/complex_curve.png")
probe_question = (
    "Provide a short narrative interpretation of the plotted curve: describe its overall trend, where the slope changes,"
    " and what that implies about acceleration, saturation, or decay in the underlying relationship."
)

probe_image = Image.open(probe_image_path).convert("RGB")
probe_tensor = val_transforms(T.ToTensor()(probe_image)).unsqueeze(0).to(device)
probe_tokens = encode_images(probe_tensor)

probe_question_ids = tokenizer(
    probe_question, return_tensors='pt', add_special_tokens=True
).input_ids.to(device)

probe_gen = model.generate(
    vision_tokens=probe_tokens,
    question_ids=probe_question_ids,
    max_new_tokens=96,
    temperature=0.7,
    use_confidence=True,
)

probe_answer = tokenizer.decode(
    probe_gen['ids'][0].detach().cpu().tolist(),
    skip_special_tokens=True,
).strip()

print("Question:", probe_question)
print("Answer:", probe_answer)
print("Confidence trace:", [round(c, 3) for c in probe_gen['confidence']])
print("Recursion steps per token:", probe_gen['recursion_steps'])


## 17. Summary and Next Steps

In [None]:
print("\\n" + "="*60)
print("EXPERIMENT SUMMARY")
print("="*60)

print(f"\\nModel Configuration:")
print(f"  Type: TRM VLM with Confidence-Based Recursive Refinement")
print(f"  Hidden dim: {HIDDEN_DIM}")
print(f"  Num layers: {NUM_LAYERS}")
print(f"  Num heads: {NUM_HEADS}")
print(f"  Inner recursion steps: {NUM_INNER_STEPS}")
print(f"  Confidence threshold: {CONFIDENCE_THRESHOLD}")
print(f"  Total parameters: {total_params:,}")

print(f"\\nTraining:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Best val loss: {best_val_loss:.4f}")

print(f"\\nEvaluation Results:")
print(f"  Exact Match (EM): {metrics['em']:.2f}%")
print(f"  Token F1: {metrics['f1']:.2f}%")
print(f"  Avg recursion steps: {avg_recursion_steps:.2f}")
print(f"  Recursion triggered: {recursion_triggered:.1f}% of tokens")

print(f"\\nDataset:")
print(f"  Train samples: {len(train_dataset):,}")
print(f"  Val samples: {len(val_dataset):,}")

print(f"\\nOutput Files:")
print(f"  Best checkpoint: {ckpt_dir / 'checkpoint_best.pt'}")
print(f"  Training curves: {ckpt_dir / 'training_curves.png'} (if saved)")

print("\\n" + "="*60)
print("KEY INSIGHTS")
print("="*60)
print("1. TRM Recursion vs Baseline:")
print("   - TRM uses latent recursion (n inner steps) for reasoning")
print("   - Confidence-based early stopping saves computation")
print("   - Higher confidence → skip extra recursion (efficient)")
print("   - Lower confidence → trigger full recursion (quality)")
print()
print("2. Confidence Threshold Analysis:")
print(f"   - Set to {CONFIDENCE_THRESHOLD}")
print(f"   - Triggered on {recursion_triggered:.1f}% of tokens")
print("   - Can tune this for speed/quality tradeoff")
print()
print("3. Next Experiments:")
print("   - Ablation: Run with use_confidence=False (fixed recursion)")
print("   - Sweep confidence thresholds: {0.5, 0.6, 0.7, 0.8, 0.9}")
print("   - Try different inner steps: {2, 4, 6, 8}")
print("   - Add outer deep recursion (T > 1)")
print("   - Compare to baseline Tiny VLM (no TRM)")
print("="*60)

print("\\n✓ Notebook complete!")