# VLM Training with PixMo QA Dataset - FIXED VERSION

This notebook implements a **Vision-Language Model for Question Answering** using:
1. Pretrained aligned vision encoder (CLIP + projections + MRL) - FROZEN
2. **Pretrained Qwen2.5-7B-Instruct decoder with LoRA** - TRAINABLE
3. PixMo QA dataset with question-answer pairs
4. Optional TRM latent recursion on top of Qwen

## Key Fixes from Original:

✅ **Pretrained Decoder**: Uses Qwen2.5-7B instead of random 34M TRM

✅ **Proper Training**: Standard autoregressive loss with prefix_embeds

✅ **Baseline Mode**: Can disable TRM recursion for comparison

✅ **Debug Tools**: Text-only sanity check, parameter audit, first-batch logging

## Architecture:

```
Image (B, 3, 336, 336)
  ↓
Aligned Vision Encoder [FROZEN]
  ↓ (B, 577, 4096)
Vision Projection (4096 → d_qwen) [TRAINABLE]
  ↓ (B, 577, d_qwen)
Qwen2.5-7B Decoder + LoRA [TRAINABLE]
  ↓
Token Layout: [IMG_TOKENS] [QUESTION] [ANSWER]
Loss: Only on answer tokens (vision/question masked with -100)
```

Optional: TRM latent recursion wrapper on top of Qwen

## 1. Setup and Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path

# Add src to path
src_path = Path.cwd().parent / "src"
sys.path.insert(0, str(src_path))
print(f"Added to path: {src_path}")

Added to path: /storage/ice1/1/0/vchopra37/projects/edge_glass/edge_glass_modular/src


In [3]:
# Standard libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import wandb
from typing import Optional, Dict, List, Tuple
from collections import Counter
import string
import warnings
warnings.filterwarnings('ignore')

# Modular imports from edge_glass_modular
from config import load_config
from encoders.vision import VisionEncoder
from decoders.qwen import QwenDecoder
from decoders.trm import TRMConfig, TRMLayer, RMSNorm
from data.dataset_builder import PixmoQADataset
from data.transforms import get_image_transforms
from models.alignment import MultimodalAlignmentModel

# Set matplotlib style
%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

# Device info
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

PyTorch version: 2.9.0+cu128
CUDA available: True
GPU: NVIDIA H200
GPU Memory: 150.11 GB
Using device: cuda


## 2. Load Configuration

In [4]:
# Load experiment config
config_path = "../configs/trm_vlm_qa.yaml"
config = load_config(config_path)

print(f"Loaded config: {config.name}")
print(f"\nDataset:")
print(f"  Train: {config.dataset.train_parquet}")
print(f"  Val: {config.dataset.val_parquet}")
print(f"  Image size: {config.dataset.image_size}")
print(f"  Batch size: {config.dataset.batch_size}")

print(f"\nDecoder:")
print(f"  Model: {config.decoder.model_name}")
print(f"  Use LoRA: {config.decoder.use_lora}")
print(f"  Load in 8bit: {config.decoder.load_in_8bit}")

# Resolve checkpoint root (works from repo root or notebooks directory)
CKPT_ROOT_CANDIDATES = [
    Path.cwd() / 'checkpoints',
    Path.cwd().parent / 'checkpoints',
    Path.cwd() / 'edge_glass_modular/notebooks/checkpoints',
    Path.cwd().parent / 'edge_glass_modular/notebooks/checkpoints',
]
CKPT_ROOT = next((p for p in CKPT_ROOT_CANDIDATES if p.exists()), None)
if CKPT_ROOT is None:
    raise FileNotFoundError('No checkpoint directory found; expected one of: ' + ', '.join(str(p) for p in CKPT_ROOT_CANDIDATES))
print(f"Checkpoint root: {CKPT_ROOT}")


Loaded config: trm_vlm_qa

Dataset:
  Train: /home/hice1/vchopra37/scratch/projects/edge_glass/dataset/final_dataset/pixmo_alignment/pixmo_qa_mixed_train.parquet
  Val: /home/hice1/vchopra37/scratch/projects/edge_glass/dataset/final_dataset/pixmo_alignment/pixmo_qa_mixed_val.parquet
  Image size: 336
  Batch size: 16

Decoder:
  Model: Qwen/Qwen2.5-7B-Instruct
  Use LoRA: True
  Load in 8bit: False
Checkpoint root: /storage/ice1/1/0/vchopra37/projects/edge_glass/edge_glass_modular/notebooks/checkpoints


## 3. Load Pretrained Aligned Vision Encoder (FROZEN)

In [5]:
# Load alignment config
alignment_config_path = "../configs/pixmo_alignment.yaml"
alignment_config = load_config(alignment_config_path)

# Build aligned model
print("Loading aligned vision encoder...")
aligned_model = MultimodalAlignmentModel(alignment_config).to(device)

# Load checkpoint
checkpoint_path = CKPT_ROOT / 'pixmo_alignment/checkpoint_best.pt'
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
aligned_model.load_state_dict(checkpoint['model_state_dict'])
aligned_model.eval()

# FREEZE all parameters
for param in aligned_model.parameters():
    param.requires_grad = False

print(f"✓ Loaded aligned model from {checkpoint_path}")
print(f"  Checkpoint epoch: {checkpoint.get('epoch', 'N/A')}")
print(f"  Val loss: {checkpoint.get('best_val_loss', 0.0):.4f}")

# Get vision encoder output dimension
vision_token_dim = alignment_config.vision_encoder.projection_dim
print(f"  Vision output: (B, num_tokens, {vision_token_dim})")

Loading aligned vision encoder...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

trainable params: 40,370,176 || all params: 7,655,986,688 || trainable%: 0.5273
✓ Loaded aligned model from /storage/ice1/1/0/vchopra37/projects/edge_glass/edge_glass_modular/notebooks/checkpoints/pixmo_alignment/checkpoint_best.pt
  Checkpoint epoch: 0
  Val loss: 0.0000
  Vision output: (B, num_tokens, 4096)


## 4. Vision Encoding Helper Function

In [6]:
@torch.no_grad()
def encode_images(images: torch.Tensor) -> torch.Tensor:
    """Encode images to vision tokens using frozen aligned encoder.
    
    Args:
        images: (B, 3, H, W)
    
    Returns:
        vision_tokens: (B, num_tokens, vision_token_dim)
    """
    vision_output = aligned_model.vision_encoder(images, return_sequence=True)
    if vision_output.sequence is None:
        raise ValueError("Vision encoder did not return sequence embeddings")
    return vision_output.sequence

# Test
test_img = torch.randn(2, 3, 336, 336).to(device)
test_vision_tokens = encode_images(test_img)
print(f"Vision tokens shape: {test_vision_tokens.shape}")
print(f"Expected: (2, num_tokens, {vision_token_dim})")

Vision tokens shape: torch.Size([2, 577, 4096])
Expected: (2, num_tokens, 4096)


## 5. QwenVLM Model Class

**Main VLM wrapper with optional TRM recursion**

In [7]:
class QwenVLM(nn.Module):
    """Qwen-based VLM with optional TRM latent recursion.
    
    Features:
    - Pretrained Qwen2.5 decoder with LoRA
    - Vision token projection to Qwen hidden dim
    - Proper prefix_embeds integration
    - Standard autoregressive loss with -100 masking
    - Optional TRM latent recursion on top
    """
    
    def __init__(
        self,
        qwen_decoder: QwenDecoder,
        vision_token_dim: int = 4096,
        use_trm_recursion: bool = False,
        num_trm_layers: int = 2,
        num_recursion_steps: int = 4,
        confidence_threshold: float = 0.75,
    ):
        super().__init__()
        
        self.qwen = qwen_decoder
        self.hidden_dim = qwen_decoder.hidden_dim
        self.use_trm_recursion = use_trm_recursion
        self.num_recursion_steps = num_recursion_steps
        self.confidence_threshold = confidence_threshold
        
        # Vision projection: 4096 → Qwen hidden dim
        self.vision_proj = nn.Linear(vision_token_dim, self.hidden_dim)
        
        # Optional TRM recursion components
        if use_trm_recursion:
            trm_config = TRMConfig(
                vocab_size=qwen_decoder.vocab_size,
                hidden_dim=self.hidden_dim,
                num_layers=num_trm_layers,
                num_heads=8,
                max_seq_len=2048,
            )
            
            self.trm_layers = nn.ModuleList([
                TRMLayer(trm_config) for _ in range(num_trm_layers)
            ])
            self.trm_norm = RMSNorm(self.hidden_dim)
            self.z_init = nn.Parameter(torch.randn(1, 1, self.hidden_dim) * 0.02)
            
            print(f"  ✓ TRM recursion enabled ({num_trm_layers} layers, {num_recursion_steps} steps)")
        else:
            print(f"  ✓ Baseline mode (no TRM recursion)")
    
    def forward(
        self,
        vision_tokens: torch.Tensor,  # (B, K_img, vision_dim)
        question_ids: torch.Tensor,   # (B, L_q)
        answer_ids: torch.Tensor,     # (B, L_a)
        answer_mask: Optional[torch.Tensor] = None,
    ) -> Dict[str, torch.Tensor]:
        """Forward pass with proper autoregressive training.
        
        Token layout: [IMG_TOKENS] [QUESTION_TOKENS] [ANSWER_TOKENS]
        Labels: [-100 for img/question] [answer_ids]
        """
        batch_size = vision_tokens.shape[0]
        
        # Project vision tokens
        device = vision_tokens.device
        qwen_dtype = self.qwen.model.dtype

        # Align aux modules and inputs with Qwen sharded dtype/device
        if self.vision_proj.weight.device != device or self.vision_proj.weight.dtype != qwen_dtype:
            self.vision_proj = self.vision_proj.to(device=device, dtype=qwen_dtype)
        if self.use_trm_recursion:
            self.trm_layers.to(device=device, dtype=qwen_dtype)
            self.trm_norm.to(device=device, dtype=qwen_dtype)
            if self.z_init.device != device or self.z_init.dtype != qwen_dtype:
                self.z_init.data = self.z_init.data.to(device=device, dtype=qwen_dtype)

        vision_tokens = vision_tokens.to(device=device, dtype=qwen_dtype)

        vision_emb = self.vision_proj(vision_tokens)  # (B, K_img, d_qwen)
        
        # Prepare input_ids and labels
        input_ids = torch.cat([question_ids, answer_ids], dim=1)  # (B, L_q + L_a)
        
        # Labels: -100 for question, real IDs for answer
        question_labels = torch.full_like(question_ids, fill_value=-100)
        labels = torch.cat([question_labels, answer_ids], dim=1)  # (B, L_q + L_a)
        
        if not self.use_trm_recursion:
            # ===== BASELINE: Standard Qwen forward =====
            outputs = self.qwen(
                input_ids=input_ids,
                prefix_embeds=vision_emb,  # Vision tokens as prefix
                labels=labels,
            )
            
            return {
                'loss': outputs.loss,
                'logits': outputs.logits,
                'confidence': None,
            }
        
        else:
            # ===== TRM MODE: Latent recursion on answer embeddings =====
            # Get text embeddings
            text_emb = self.qwen.model.get_input_embeddings()(input_ids)
            
            # Split into question and answer
            L_q = question_ids.shape[1]
            L_a = answer_ids.shape[1]
            question_emb = text_emb[:, :L_q, :]
            answer_emb = text_emb[:, L_q:, :]
            
            # Context: [vision, question]
            x = torch.cat([vision_emb, question_emb], dim=1)  # (B, K_img + L_q, d)
            
            # Answer to be refined
            y = answer_emb  # (B, L_a, d)
            
            # Initialize latent state
            z = self.z_init.expand(batch_size, L_a, -1)  # (B, L_a, d)
            
            # Latent recursion
            for _ in range(self.num_recursion_steps):
                x, y, z = self.latent_recursion(x, y, z)
            
            # Final logits from refined answer
            y = self.trm_norm(y)
            lm_head = self.qwen.model.lm_head
            logits_answer = lm_head(y)  # (B, L_a, vocab)
            
            # Compute loss on answer only
            shift_logits = logits_answer[:, :-1, :].contiguous()
            shift_labels = answer_ids[:, 1:].contiguous()
            
            loss_fct = nn.CrossEntropyLoss()
            # Align labels to logits device in sharded setups
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )
            
            # Compute confidence
            probs = torch.softmax(logits_answer, dim=-1)
            max_probs = torch.max(probs, dim=-1)[0]
            confidence = torch.mean(max_probs, dim=-1)
            
            return {
                'loss': loss,
                'logits': logits_answer,
                'confidence': confidence,
            }
    
    def latent_recursion(
        self,
        x: torch.Tensor,  # Context: (B, L_ctx, d)
        y: torch.Tensor,  # Answer: (B, L_ans, d)
        z: torch.Tensor,  # Latent: (B, L_ans, d)
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Single TRM latent recursion step."""
        # Concatenate [x, y, z]
        concat = torch.cat([x, y, z], dim=1)
        
        # Pass through TRM layers
        hidden = concat
        for layer in self.trm_layers:
            hidden = layer(hidden)
        
        # Split back
        L_ctx = x.shape[1]
        L_ans = y.shape[1]
        
        x_out = hidden[:, :L_ctx, :]
        y_out = hidden[:, L_ctx:L_ctx+L_ans, :]
        z_out = hidden[:, L_ctx+L_ans:, :]
        
        return x_out, y_out, z_out
    
    @torch.no_grad()
    def generate(
        self,
        vision_tokens: torch.Tensor,
        question_ids: torch.Tensor,
        max_new_tokens: int = 32,
        temperature: float = 0.0,
        use_confidence: bool = False,
        return_stats: bool = False,
    ):
        """Generate answers with optional TRM recursion."""
        # Project vision
        device = vision_tokens.device
        target_dtype = self.qwen.model.dtype

        # Align vision projection with Qwen sharded dtype/device
        if self.vision_proj.weight.device != device or self.vision_proj.weight.dtype != target_dtype:
            self.vision_proj = self.vision_proj.to(device=device, dtype=target_dtype)

        vision_tokens = vision_tokens.to(
            device=self.vision_proj.weight.device,
            dtype=self.vision_proj.weight.dtype,
        )

        vision_emb = self.vision_proj(vision_tokens)
        
        if not self.use_trm_recursion:
            # Baseline: standard Qwen generation
            outputs = self.qwen.generate(
                input_ids=question_ids,
                prefix_embeds=vision_emb,
                max_new_tokens=max_new_tokens,
                temperature=max(temperature, 0.01) if temperature > 0 else 1.0,
                do_sample=temperature > 0,
            )
            
            # Extract only generated part
            prompt_len = question_ids.shape[1] + vision_emb.shape[1]
            generated_ids = outputs[:, prompt_len:]
            
            if return_stats:
                return {
                    'predictions': generated_ids,
                    'confidences': torch.ones_like(generated_ids, dtype=torch.float),
                    'recursion_steps': torch.ones_like(generated_ids, dtype=torch.float),
                }
            return generated_ids
        
        else:
            # TRM mode: autoregressive with recursion
            batch_size = vision_tokens.shape[0]
            generated_ids = []
            confidences = []
            recursion_steps_used = []
            
            for step in range(max_new_tokens):
                # Current answer sequence
                if len(generated_ids) == 0:
                    answer_ids = torch.tensor(
                        [[self.qwen.tokenizer.pad_token_id]],
                        device=vision_tokens.device
                    ).expand(batch_size, 1)
                else:
                    answer_ids = torch.stack(generated_ids, dim=1)
                
                # Forward
                outputs = self.forward(vision_tokens, question_ids, answer_ids)
                logits = outputs['logits'][:, -1, :]  # Last position
                
                if temperature > 0:
                    logits = logits / temperature
                    probs = torch.softmax(logits, dim=-1)
                    next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
                else:
                    next_token = torch.argmax(logits, dim=-1)
                
                generated_ids.append(next_token)
                
                # Track confidence
                probs = torch.softmax(logits, dim=-1)
                conf = torch.max(probs, dim=-1)[0]
                confidences.append(conf)
                
                recursion_steps_used.append(
                    self.num_recursion_steps if self.use_trm_recursion else 1
                )
            
            generated_ids = torch.stack(generated_ids, dim=1)
            
            if return_stats:
                return {
                    'predictions': generated_ids,
                    'confidences': torch.stack(confidences, dim=1),
                    'recursion_steps': torch.tensor(
                        recursion_steps_used
                    ).unsqueeze(0).expand(batch_size, -1).float(),
                }
            return generated_ids

print("✓ QwenVLM class defined")

✓ QwenVLM class defined


## 6. Initialize Pretrained Qwen Decoder

In [8]:
# ========== CONFIGURATION ==========
USE_TRM_RECURSION = False  # Start with baseline, then try True
NUM_TRM_LAYERS = 4         # Only used if TRM recursion enabled
NUM_RECURSION_STEPS = 4
CONFIDENCE_THRESHOLD = 0.75
# ====================================

print("="*60)
print("INITIALIZING QWEN DECODER")
print("="*60)

# Load Qwen decoder
print(f"\nLoading: {config.decoder.model_name}")
print(f"  LoRA: {config.decoder.use_lora}")
print(f"  8-bit: {config.decoder.load_in_8bit}")

qwen_decoder = QwenDecoder(
    model_name=config.decoder.model_name,
    load_in_8bit=config.decoder.load_in_8bit,
    load_in_4bit=False,
    use_lora=config.decoder.use_lora,
    lora_r=config.decoder.get('lora_r', 32),
    lora_alpha=config.decoder.get('lora_alpha', 64),
    lora_dropout=config.decoder.get('lora_dropout', 0.1),
    device_map="auto",
)

print(f"\n✓ Qwen decoder loaded")
print(f"  Hidden dim: {qwen_decoder.hidden_dim}")
print(f"  Vocab size: {qwen_decoder.vocab_size}")

# Create QwenVLM wrapper
print(f"\nCreating QwenVLM wrapper")
print(f"  Vision token dim: {vision_token_dim}")
print(f"  Use TRM recursion: {USE_TRM_RECURSION}")

model = QwenVLM(
    qwen_decoder=qwen_decoder,
    vision_token_dim=vision_token_dim,
    use_trm_recursion=USE_TRM_RECURSION,
    num_trm_layers=NUM_TRM_LAYERS,
    num_recursion_steps=NUM_RECURSION_STEPS,
    confidence_threshold=CONFIDENCE_THRESHOLD,
)

print(f"\n✓ QwenVLM model created")
print("="*60)

INITIALIZING QWEN DECODER

Loading: Qwen/Qwen2.5-7B-Instruct
  LoRA: True
  8-bit: False


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

trainable params: 40,370,176 || all params: 7,655,986,688 || trainable%: 0.5273

✓ Qwen decoder loaded
  Hidden dim: 3584
  Vocab size: 152064

Creating QwenVLM wrapper
  Vision token dim: 4096
  Use TRM recursion: False
  ✓ Baseline mode (no TRM recursion)

✓ QwenVLM model created


## 7. Text-Only Sanity Check (BEFORE Training)

**Critical test**: Verify pretrained Qwen can generate coherent English

In [9]:
@torch.no_grad()
def text_only_sanity_check(decoder, prompts, max_tokens=20):
    """Test decoder on text-only prompts without vision."""
    print("\n" + "="*60)
    print("TEXT-ONLY SANITY CHECK (Before Training)")
    print("="*60)
    
    for prompt in prompts:
        # Encode
        inputs = decoder.tokenizer(prompt, return_tensors='pt').to(device)
        
        # Generate
        outputs = decoder.generate(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_new_tokens=max_tokens,
            temperature=0.7,
            do_sample=True,
        )
        
        # Decode
        generated = decoder.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        print(f"\nPrompt: {prompt}")
        print(f"Generated: {generated}")
        print("-" * 60)
    
    print("\n✓ If coherent English → Decoder works!")
    print("✗ If garbage → Decoder loading issue")
    print("="*60)

# Test prompts
test_prompts = [
    "Question: What is 2 + 2? Answer:",
    "Question: What color is the sky? Answer:",
    "The capital of France is",
]

text_only_sanity_check(qwen_decoder, test_prompts)


TEXT-ONLY SANITY CHECK (Before Training)

Prompt: Question: What is 2 + 2? Answer:
Generated:  2 + 2 is equal to 4. This is a basic arithmetic operation in mathematics where
------------------------------------------------------------

Prompt: Question: What color is the sky? Answer:
Generated:  The sky is typically blue during the day, except at sunrise, sunset, and on overcast days
------------------------------------------------------------

Prompt: The capital of France is
Generated:  Paris, and the capital of Germany is Berlin. Which of the following statements is true?
A.
------------------------------------------------------------

✓ If coherent English → Decoder works!
✗ If garbage → Decoder loading issue


## 8. Parameter Audit

Verify which parameters are frozen vs trainable

In [10]:
print("="*60)
print("PARAMETER AUDIT")
print("="*60)

trainable_params = []
frozen_params = []

# Audit QwenVLM model
for name, param in model.named_parameters():
    if param.requires_grad:
        trainable_params.append((name, param.numel()))
    else:
        frozen_params.append((name, param.numel()))

# Audit aligned model
for name, param in aligned_model.named_parameters():
    full_name = f"aligned_model.{name}"
    if param.requires_grad:
        trainable_params.append((full_name, param.numel()))
    else:
        frozen_params.append((full_name, param.numel()))

print(f"\n🟢 TRAINABLE ({len(trainable_params)} groups):")
total_trainable = 0
for name, count in trainable_params[:20]:
    print(f"  {name}: {count:,}")
    total_trainable += count

if len(trainable_params) > 20:
    print(f"  ... and {len(trainable_params) - 20} more")
    for _, count in trainable_params[20:]:
        total_trainable += count

print(f"\n🔴 FROZEN ({len(frozen_params)} groups):")
total_frozen = 0
for name, count in frozen_params[:10]:
    print(f"  {name}: {count:,}")
    total_frozen += count

for _, count in frozen_params[10:]:
    total_frozen += count

if len(frozen_params) > 10:
    print(f"  ... and {len(frozen_params) - 10} more")

total = total_trainable + total_frozen
print(f"\n📊 SUMMARY:")
print(f"  Total: {total:,}")
print(f"  Trainable: {total_trainable:,} ({100*total_trainable/total:.2f}%)")
print(f"  Frozen: {total_frozen:,} ({100*total_frozen/total:.2f}%)")

print(f"\n✓ VERIFICATION:")
print(f"  Aligned model frozen: {all(not p.requires_grad for p in aligned_model.parameters())}")
print(f"  Vision proj trainable: {'vision_proj' in str([n for n, _ in trainable_params])}")

if USE_TRM_RECURSION:
    print(f"  TRM layers trainable: {'trm_layers' in str(trainable_params)}")
    print(f"  z_init trainable: {'z_init' in str([n for n, _ in trainable_params])}")
else:
    print(f"  Qwen LoRA trainable: {'lora' in str(trainable_params).lower()}")

print("="*60)

PARAMETER AUDIT

🟢 TRAINABLE (226 groups):
  qwen.model.base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight: 229,376
  qwen.model.base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight: 229,376
  qwen.model.base_model.model.model.layers.0.self_attn.k_proj.lora_A.default.weight: 229,376
  qwen.model.base_model.model.model.layers.0.self_attn.k_proj.lora_B.default.weight: 32,768
  qwen.model.base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight: 229,376
  qwen.model.base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight: 32,768
  qwen.model.base_model.model.model.layers.0.self_attn.o_proj.lora_A.default.weight: 229,376
  qwen.model.base_model.model.model.layers.0.self_attn.o_proj.lora_B.default.weight: 229,376
  qwen.model.base_model.model.model.layers.1.self_attn.q_proj.lora_A.default.weight: 229,376
  qwen.model.base_model.model.model.layers.1.self_attn.q_proj.lora_B.default.weight: 229,376
  qwen.model.base_model.model

## 9. Setup Dataset and Dataloaders

In [11]:
def collate_qa_batch(batch):
    images = [item["image"] for item in batch]
    questions = [item["question_ids"] for item in batch]
    answers = [item["answer_ids"] for item in batch]
    
    # Process images
    # Stack if already tensors (from dataset transform)
    if isinstance(images[0], torch.Tensor):
        images = torch.stack(images)
    
    # Pad sequences
    # We need to pad questions (left) and answers (right)
    # Get max lengths
    max_q_len = max([q.size(0) for q in questions])
    max_a_len = max([a.size(0) for a in answers])
    
    # Create padded tensors
    bs = len(batch)
    padded_questions = torch.full((bs, max_q_len), tokenizer.pad_token_id, dtype=torch.long)
    padded_answers = torch.full((bs, max_a_len), tokenizer.pad_token_id, dtype=torch.long)
    
    # Create answer mask (1 for valid, 0 for pad)
    answer_mask = torch.zeros((bs, max_a_len), dtype=torch.long)
    
    for i in range(bs):
        # Left pad question? Or right? Usually left for generation, but here we are training.
        # Right padding is standard for training with attention masks.
        q_len = questions[i].size(0)
        padded_questions[i, :q_len] = questions[i]
        
        a_len = answers[i].size(0)
        padded_answers[i, :a_len] = answers[i]
        
        # Set mask for valid answer tokens
        answer_mask[i, :a_len] = 1
        
    return {
        "images": images,
        "question_ids": padded_questions,
        "answer_ids": padded_answers,
        "answer_mask": answer_mask # Return mask
    }


Loaded 8400 samples from /home/hice1/vchopra37/scratch/projects/edge_glass/dataset/final_dataset/pixmo_alignment/pixmo_qa_mixed_train.parquet
Loaded 1800 samples from /home/hice1/vchopra37/scratch/projects/edge_glass/dataset/final_dataset/pixmo_alignment/pixmo_qa_mixed_val.parquet
Dataset sizes:
  Train: 8,400
  Val: 1,800

DataLoader:
  Batch size: 16
  Train batches: 525
  Val batches: 113


## 10. Evaluation Metrics

In [12]:
def normalize_answer(s: str) -> str:
    """Normalize answer for evaluation."""
    s = ''.join(ch for ch in s if ch not in string.punctuation)
    s = s.lower().strip()
    s = ' '.join([w for w in s.split() if w not in {'a', 'an', 'the'}])
    return s

def compute_exact_match(pred: str, target: str) -> float:
    """Compute exact match score."""
    return float(normalize_answer(pred) == normalize_answer(target))

def compute_f1(pred: str, target: str) -> float:
    """Compute token-level F1."""
    pred_tokens = normalize_answer(pred).split()
    target_tokens = normalize_answer(target).split()
    
    if len(pred_tokens) == 0 or len(target_tokens) == 0:
        return float(pred_tokens == target_tokens)
    
    common = Counter(pred_tokens) & Counter(target_tokens)
    num_common = sum(common.values())
    
    if num_common == 0:
        return 0.0
    
    precision = num_common / len(pred_tokens)
    recall = num_common / len(target_tokens)
    f1 = 2 * precision * recall / (precision + recall)
    
    return f1

def evaluate_qa(predictions: List[str], targets: List[str]) -> Dict[str, float]:
    """Evaluate QA predictions."""
    em_scores = [compute_exact_match(p, t) for p, t in zip(predictions, targets)]
    f1_scores = [compute_f1(p, t) for p, t in zip(predictions, targets)]
    
    return {
        'em': np.mean(em_scores) * 100,
        'f1': np.mean(f1_scores) * 100,
    }

print("✓ Evaluation metrics defined")

✓ Evaluation metrics defined


## 11. Debug Helper Functions

In [13]:
def debug_first_batch(batch, vision_tokens, outputs, tokenizer):
    """Debug first training batch."""
    print("\n" + "="*60)
    print("FIRST BATCH DEBUG")
    print("="*60)
    
    print(f"\n📦 Shapes:")
    print(f"  Images: {batch['images'].shape}")
    print(f"  Vision tokens: {vision_tokens.shape}")
    print(f"  Question IDs: {batch['question_ids'].shape}")
    print(f"  Answer IDs: {batch['answer_ids'].shape}")
    
    print(f"\n📝 First example:")
    print(f"  Q: {batch['questions'][0][:100]}...")
    print(f"  A: {batch['answers'][0][:100]}...")
    
    print(f"\n🔤 Decoded tokens:")
    q_dec = tokenizer.decode(batch['question_ids'][0], skip_special_tokens=False)
    a_dec = tokenizer.decode(batch['answer_ids'][0], skip_special_tokens=False)
    print(f"  Question: {q_dec[:150]}...")
    print(f"  Answer: {a_dec[:100]}...")
    
    print(f"\n📊 Outputs:")
    print(f"  Loss: {outputs['loss'].item():.4f}")
    if outputs.get('logits') is not None:
        print(f"  Logits: {outputs['logits'].shape}")
    if outputs.get('confidence') is not None and outputs['confidence'] is not None:
        print(f"  Confidence: {outputs['confidence'].mean().item():.3f}")
    
    # Token counts
    num_q = (batch['question_ids'][0] != tokenizer.pad_token_id).sum().item()
    num_a = (batch['answer_ids'][0] != tokenizer.pad_token_id).sum().item()
    num_img = vision_tokens.shape[1]
    
    print(f"\n📏 Token counts:")
    print(f"  Vision: {num_img}")
    print(f"  Question (non-pad): {num_q}")
    print(f"  Answer (non-pad): {num_a}")
    print(f"  Supervised: {num_a} ({100*num_a/(num_img+num_q+num_a):.1f}%)")
    
    print(f"\n✓ Expected:")
    print(f"  - Loss < 10 (pretrained range)")
    print(f"  - Supervised = answer only")
    print("="*60 + "\n")

print("✓ Debug helpers defined")

✓ Debug helpers defined


## 12. Training Setup

In [14]:
# Training config
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.05
MAX_GRAD_NORM = 1.0
LOG_EVERY = 20

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    betas=(0.9, 0.95),
    weight_decay=WEIGHT_DECAY,
)

# Scheduler
total_steps = len(train_loader) * NUM_EPOCHS
warmup_steps = int(total_steps * WARMUP_RATIO)

def get_lr_scheduler(optimizer, warmup_steps, total_steps):
    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return 0.1 + 0.9 * 0.5 * (1 + np.cos(np.pi * progress))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

scheduler = get_lr_scheduler(optimizer, warmup_steps, total_steps)

print(f"Training config:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  LR: {LEARNING_RATE}")
print(f"  Total steps: {total_steps}")
print(f"  Warmup steps: {warmup_steps}")

Training config:
  Epochs: 10
  LR: 0.0001
  Total steps: 5250
  Warmup steps: 262


## 13. Training Loop

In [15]:
# Initialize wandb
USE_WANDB = True
if USE_WANDB:
    run_name = f"qwen_vlm_{'trm' if USE_TRM_RECURSION else 'baseline'}"
    wandb.init(
        project="edge_glass_qwen_vlm",
        name=run_name,
        config={
            'use_trm_recursion': USE_TRM_RECURSION,
            'num_trm_layers': NUM_TRM_LAYERS if USE_TRM_RECURSION else 0,
            'num_recursion_steps': NUM_RECURSION_STEPS if USE_TRM_RECURSION else 0,
            'learning_rate': LEARNING_RATE,
            'batch_size': batch_size,
            'decoder': config.decoder.model_name,
        }
    )

# Training state
global_step = 0
best_val_loss = float('inf')
history = {'train_loss': [], 'val_loss': []}
first_batch_debugged = False

# Checkpoint dir
ckpt_dir = CKPT_ROOT / f"qwen_vlm_qa_{'trm' if USE_TRM_RECURSION else 'baseline'}"
ckpt_dir.mkdir(parents=True, exist_ok=True)



[34m[1mwandb[0m: Currently logged in as: [33mvedaangchopra[0m ([33mvedaangchopra_gatech[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [16]:
# 4. Training Loop (Uncommented and Fixed)
from tqdm.auto import tqdm

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
steps = 0
max_steps = 100 # Short run for verification

model.train()
progress_bar = tqdm(range(max_steps))

for batch in train_loader:
    images = batch["images"].to(device)
    question_ids = batch["question_ids"].to(device)
    answer_ids = batch["answer_ids"].to(device)
    answer_mask = batch["answer_mask"].to(device) # Get answer mask
    
    # Forward pass
    outputs = model(
        images=images, 
        question_ids=question_ids, 
        answer_ids=answer_ids,
        answer_mask=answer_mask # Pass answer mask
    )
    loss = outputs.loss
    
    # Backward
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    progress_bar.set_description(f"Loss: {loss.item():.4f}")
    progress_bar.update(1)
    
    steps += 1
    if steps >= max_steps:
        break


## 14. Full Evaluation

In [17]:
# Load best checkpoint
best_ckpt = torch.load(ckpt_dir / 'checkpoint_best.pt', map_location=device, weights_only=False)
model.load_state_dict(best_ckpt['model_state_dict'])
model.eval()

QwenVLM(
  (qwen): QwenDecoder(
    (model): PeftModelForCausalLM(
      (base_model): LoraModel(
        (model): Qwen2ForCausalLM(
          (model): Qwen2Model(
            (embed_tokens): Embedding(152064, 3584)
            (layers): ModuleList(
              (0-27): 28 x Qwen2DecoderLayer(
                (self_attn): Qwen2Attention(
                  (q_proj): lora.Linear(
                    (base_layer): Linear(in_features=3584, out_features=3584, bias=True)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.05, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=3584, out_features=64, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default): Linear(in_features=64, out_features=3584, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
                    (lora_embedding_B

In [18]:
# Complex VLM question probe on the curve image
from pathlib import Path
from PIL import Image
import torchvision.transforms as T
import torch

model.eval()
probe_image_path = Path("/home/hice1/vchopra37/scratch/projects/edge_glass/complex_curve.png")
probe_question = (
    "Provide a short narrative interpretation of the plotted curve: describe its overall trend, where the slope changes,"    " and what that implies about acceleration, saturation, or decay in the underlying relationship."
)
display(probe_image_path)

PosixPath('/home/hice1/vchopra37/scratch/projects/edge_glass/complex_curve.png')

In [19]:
probe_image = Image.open(probe_image_path).convert("RGB")
probe_tensor = val_transforms(T.ToTensor()(probe_image)).unsqueeze(0).to(device)
probe_tokens = encode_images(probe_tensor)

probe_question_ids = qwen_decoder.tokenizer(
    probe_question, return_tensors='pt', add_special_tokens=True
).input_ids.to(device)

probe_gen = model.generate(
    vision_tokens=probe_tokens,
    question_ids=probe_question_ids,
    max_new_tokens=96,
    temperature=0.7,
    use_confidence=True,
    return_stats=True,
)

if isinstance(probe_gen, torch.Tensor):
    gen_ids = probe_gen[0].detach().cpu().tolist()
    confidence_trace = None
    recursion_steps = None
else:
    gen_ids = probe_gen["predictions"][0].detach().cpu().tolist()
    confidence_trace = probe_gen.get("confidences")
    if confidence_trace is not None:
        confidence_trace = probe_gen["confidences"][0].detach().cpu().tolist()
    recursion_steps = probe_gen.get("recursion_steps")
    if recursion_steps is not None:
        recursion_steps = probe_gen["recursion_steps"][0].detach().cpu().tolist()

probe_answer = qwen_decoder.tokenizer.decode(
    gen_ids,
    skip_special_tokens=True,
).strip()

print("Question:", probe_question)
print("Answer:", probe_answer)
if confidence_trace is not None:
    print("Confidence trace:", [round(float(c), 3) for c in confidence_trace])
if recursion_steps is not None:
    print("Recursion steps per token:", recursion_steps)


Question: Provide a short narrative interpretation of the plotted curve: describe its overall trend, where the slope changes, and what that implies about acceleration, saturation, or decay in the underlying relationship.
Answer: 
Confidence trace: []
Recursion steps per token: []


In [None]:
print("\n" + "="*60)
print("FULL EVALUATION")
print("="*60)
print(f"Mode: {'TRM' if USE_TRM_RECURSION else 'Baseline'}")
print("="*60)

all_predictions = []
all_targets = []
all_confidences = []
all_recursion_steps = []

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Generating"):
        images = batch['images'].to(device)
        question_ids = batch['question_ids'].to(device)
        answers = batch['answers']
        
        vision_tokens = encode_images(images)
        
        gen_outputs = model.generate(
            vision_tokens,
            question_ids,
            max_new_tokens=32,
            temperature=0.0,
            return_stats=True,
        )
        
        if isinstance(gen_outputs, dict):
            generated_ids = gen_outputs['predictions']
            if gen_outputs.get('confidences') is not None:
                all_confidences.extend(gen_outputs['confidences'].mean(dim=1).cpu().tolist())
            if gen_outputs.get('recursion_steps') is not None:
                all_recursion_steps.extend(gen_outputs['recursion_steps'].mean(dim=1).cpu().tolist())
        else:
            generated_ids = gen_outputs
        
        predictions = qwen_decoder.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        
        all_predictions.extend(predictions)
        all_targets.extend(answers)

# Metrics
metrics = evaluate_qa(all_predictions, all_targets)

print("\n" + "="*60)
print("RESULTS")
print("="*60)
print(f"EM: {metrics['em']:.2f}%")
print(f"F1: {metrics['f1']:.2f}%")

if all_confidences:
    print(f"\nConfidence: {np.mean(all_confidences):.3f} ± {np.std(all_confidences):.3f}")

if all_recursion_steps:
    print(f"\nRecursion: {np.mean(all_recursion_steps):.2f} steps/seq")

print("="*60)

if USE_WANDB:
    wandb.log({'val/em_final': metrics['em'], 'val/f1_final': metrics['f1']})

# Sample predictions
print("\n" + "="*60)
print("SAMPLE PREDICTIONS")
print("="*60)

for i in range(min(10, len(all_predictions))):
    print(f"\n[{i+1}]")
    print(f"  Q: {val_dataset.df.iloc[i]['question'][:80]}...")
    print(f"  Target: {all_targets[i][:80]}...")
    print(f"  Predicted: {all_predictions[i][:80]}...")
    print(f"  EM: {compute_exact_match(all_predictions[i], all_targets[i])}")
    print(f"  F1: {compute_f1(all_predictions[i], all_targets[i]):.3f}")

print("\n" + "="*60)


FULL EVALUATION
Mode: Baseline


Generating:   0%|          | 0/113 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


## 15. Summary

In [None]:
print("\n" + "="*60)
print("SUMMARY")
print("="*60)

print(f"\n🔧 Config:")
print(f"  Decoder: {config.decoder.model_name}")
print(f"  LoRA: {config.decoder.use_lora}")
print(f"  TRM Recursion: {USE_TRM_RECURSION}")

print(f"\n📊 Training:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Best val loss: {best_val_loss:.4f}")

print(f"\n📈 Eval:")
print(f"  EM: {metrics['em']:.2f}%")
print(f"  F1: {metrics['f1']:.2f}%")

print(f"\n📦 Dataset:")
print(f"  Train: {len(train_dataset):,}")
print(f"  Val: {len(val_dataset):,}")

print(f"\n💾 Output:")
print(f"  {ckpt_dir / 'checkpoint_best.pt'}")

print("\n" + "="*60)
print("KEY IMPROVEMENTS")
print("="*60)
print("✅ Pretrained Qwen decoder (vs random 34M TRM)")
print("✅ Proper autoregressive training (prefix_embeds + -100 masking)")
print("✅ Baseline mode for comparison")
print("✅ Debug instrumentation (sanity check, param audit, first-batch)")

print("\n" + "="*60)
print("NEXT STEPS")
print("="*60)
print("1. If EM/F1 still low:")
print("   - Check text-only generation")
print("   - Check first-batch debug")
print("   - Increase epochs/LR")
print("\n2. If baseline works:")
print("   - Set USE_TRM_RECURSION = True")
print("   - Compare vs baseline")
print("\n3. Expected baseline: EM > 15%, F1 > 25%")
print("="*60)

print("\n✓ Complete!")