# TRM-Enhanced Vision-Language Model with MoE

## A Scalable VLM using Tiny Recursive Models, Mixture of Experts, and Multi-GPU Training

This notebook implements a complete Vision-Language Model with:

### Architecture
- **Language Backbone**: Qwen2.5-3B with TRM recursive reasoning
- **MoE Enhancement**: Mixture of Experts layers for increased capacity
- **Vision Encoders**: Configurable SigLIP / DINOv2 (for ablation studies)
- **Projection**: 2-layer MLP connector

### Training
- **Dataset**: Pixmo-Cap (712K images, 196-word avg captions)
- **Infrastructure**: 2x H200 GPUs with FSDP
- **Comparison**: Qwen2.5-VL-7B benchmark

### Key TRM Concepts Applied
- Recursive reasoning with tiny networks
- Deep supervision across cycles
- Progressive answer refinement

---
## Part 1: Setup and Dependencies

In [None]:
# # Install dependencies
# ! uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 -q
# ! uv pip install transformers>=4.45.0 accelerate>=0.34.0 datasets -q
# ! uv pip install bitsandbytes peft flash-attn --no-build-isolation -q
# ! uv pip install einops timm sentencepiece tiktoken -q
# ! uv pip install wandb matplotlib seaborn tqdm pillow requests -q
# # ! uv pip install qwen-vl-utils -q  # For Qwen2.5-VL comparison

In [None]:
import os
import sys
import math
import json
import random
import logging
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional, Tuple, List, Dict, Any, Union
from enum import Enum
from io import BytesIO
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy, MixedPrecision
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import requests
from tqdm.auto import tqdm

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoProcessor,
    AutoModel,
    AutoImageProcessor,
    PreTrainedModel,
    PretrainedConfig,
    get_cosine_schedule_with_warmup,
)
from transformers.modeling_outputs import CausalLMOutputWithPast

from datasets import load_dataset
from accelerate import Accelerator
from accelerate.utils import set_seed

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set seeds
set_seed(42)

# Check GPU availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
    print(f"    Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.1f} GB")

---
## Part 2: Configuration Classes

In [None]:
class VisionEncoderType(Enum):
    """Supported vision encoder types for ablation studies"""
    SIGLIP = "siglip"
    DINOV2 = "dinov2"
    SIGLIP_DINOV2 = "siglip_dinov2"  # Dual encoder


class LLMBackboneType(Enum):
    """Supported LLM backbone types"""
    QWEN2_5_0_5B = "Qwen/Qwen2.5-0.5B"
    QWEN2_5_1_5B = "Qwen/Qwen2.5-1.5B"
    QWEN2_5_3B = "Qwen/Qwen2.5-3B"
    QWEN2_5_7B = "Qwen/Qwen2.5-7B"


@dataclass
class VisionEncoderConfig:
    """Configuration for vision encoder"""
    encoder_type: VisionEncoderType = VisionEncoderType.SIGLIP
    
    # SigLIP config
    siglip_model_name: str = "google/siglip-so400m-patch14-384"
    siglip_image_size: int = 384
    siglip_hidden_size: int = 1152
    siglip_num_patches: int = 729  # (384/14)^2
    
    # DINOv2 config
    dinov2_model_name: str = "facebook/dinov2-large"
    dinov2_image_size: int = 224
    dinov2_hidden_size: int = 1024
    dinov2_num_patches: int = 256  # (224/14)^2
    
    # Shared config
    freeze_vision: bool = True
    use_gradient_checkpointing: bool = True
    
    @property
    def output_dim(self) -> int:
        if self.encoder_type == VisionEncoderType.SIGLIP:
            return self.siglip_hidden_size
        elif self.encoder_type == VisionEncoderType.DINOV2:
            return self.dinov2_hidden_size
        else:  # Dual encoder
            return self.siglip_hidden_size + self.dinov2_hidden_size


@dataclass
class MoEConfig:
    """Configuration for Mixture of Experts"""
    enabled: bool = True
    num_experts: int = 8
    num_experts_per_token: int = 2  # top-k
    num_shared_experts: int = 1  # DeepSeek-style shared experts
    capacity_factor: float = 1.25  # For load balancing
    router_jitter_noise: float = 0.1  # Training stability
    load_balance_loss_weight: float = 0.01
    router_z_loss_weight: float = 0.001
    moe_layer_frequency: int = 2  # Apply MoE every N layers


@dataclass
class TRMConfig:
    """Configuration for TRM recursive reasoning"""
    enabled: bool = True
    n_recursions: int = 4  # Latent reasoning iterations
    t_cycles: int = 2  # Deep supervision cycles
    apply_to_layers: List[int] = field(default_factory=lambda: [-4, -3, -2, -1])  # Last N layers
    residual_scale: float = 0.1  # Scale for residual connections


@dataclass 
class ProjectorConfig:
    """Configuration for vision-language projector"""
    projector_type: str = "mlp"  # 'mlp', 'linear', 'resampler'
    num_layers: int = 2
    activation: str = "gelu"
    dropout: float = 0.0


@dataclass
class TRMVLMConfig:
    """Master configuration for TRM VLM"""
    # Model components
    llm_backbone: LLMBackboneType = LLMBackboneType.QWEN2_5_3B
    vision_config: VisionEncoderConfig = field(default_factory=VisionEncoderConfig)
    moe_config: MoEConfig = field(default_factory=MoEConfig)
    trm_config: TRMConfig = field(default_factory=TRMConfig)
    projector_config: ProjectorConfig = field(default_factory=ProjectorConfig)
    
    # LLM config (will be loaded from pretrained)
    llm_hidden_size: int = 2048  # Qwen2.5-3B default
    
    # Training config
    max_seq_length: int = 2048
    image_token_id: int = 151655  # <|image_pad|> in Qwen
    num_image_tokens: int = 256  # Number of visual tokens
    
    # Precision
    torch_dtype: str = "bfloat16"
    use_flash_attention: bool = True
    
    def __post_init__(self):
        # Set LLM hidden size based on backbone
        llm_hidden_sizes = {
            LLMBackboneType.QWEN2_5_0_5B: 896,
            LLMBackboneType.QWEN2_5_1_5B: 1536,
            LLMBackboneType.QWEN2_5_3B: 2048,
            LLMBackboneType.QWEN2_5_7B: 3584,
        }
        self.llm_hidden_size = llm_hidden_sizes.get(self.llm_backbone, 2048)


@dataclass
class TrainingConfig:
    """Training configuration"""
    # Basic training
    num_epochs: int = 3
    per_device_batch_size: int = 4
    gradient_accumulation_steps: int = 8
    
    # Learning rate
    learning_rate: float = 1e-4
    llm_learning_rate: float = 2e-5  # Lower LR for LLM backbone
    vision_learning_rate: float = 0.0  # Frozen by default
    warmup_ratio: float = 0.03
    weight_decay: float = 0.1
    max_grad_norm: float = 1.0
    
    # Scheduler
    lr_scheduler_type: str = "cosine"
    
    # Checkpointing
    save_steps: int = 500
    eval_steps: int = 250
    logging_steps: int = 10
    output_dir: str = "./outputs/trm_vlm"
    
    # Mixed precision
    bf16: bool = True
    
    # Distributed
    fsdp_sharding_strategy: str = "SHARD_GRAD_OP"  # Best for 2 GPUs
    
    # Wandb
    use_wandb: bool = True
    wandb_project: str = "trm-vlm"
    wandb_run_name: Optional[str] = None

In [None]:
# Create default configuration
model_config = TRMVLMConfig(
    llm_backbone=LLMBackboneType.QWEN2_5_3B,
    vision_config=VisionEncoderConfig(
        encoder_type=VisionEncoderType.SIGLIP,
        freeze_vision=True
    ),
    moe_config=MoEConfig(
        enabled=True,
        num_experts=8,
        num_experts_per_token=2
    ),
    trm_config=TRMConfig(
        enabled=True,
        n_recursions=4,
        t_cycles=2
    )
)

training_config = TrainingConfig(
    num_epochs=3,
    per_device_batch_size=4,
    gradient_accumulation_steps=8,
    learning_rate=1e-4
)

print("Model Configuration:")
print(f"  LLM Backbone: {model_config.llm_backbone.value}")
print(f"  Vision Encoder: {model_config.vision_config.encoder_type.value}")
print(f"  MoE Enabled: {model_config.moe_config.enabled}")
print(f"  TRM Enabled: {model_config.trm_config.enabled}")
print(f"  LLM Hidden Size: {model_config.llm_hidden_size}")

---
## Part 3: Vision Encoder Module

In [None]:
class SigLIPVisionEncoder(nn.Module):
    """SigLIP Vision Encoder wrapper"""
    
    def __init__(self, config: VisionEncoderConfig):
        super().__init__()
        self.config = config
        
        # Load SigLIP model
        self.model = AutoModel.from_pretrained(
            config.siglip_model_name,
            torch_dtype=torch.bfloat16
        ).vision_model
        
        self.processor = AutoProcessor.from_pretrained(config.siglip_model_name)
        
        # Freeze if specified
        if config.freeze_vision:
            for param in self.model.parameters():
                param.requires_grad = False
        
        # Enable gradient checkpointing
        if config.use_gradient_checkpointing:
            if hasattr(self.model, "gradient_checkpointing_enable"):
                self.model.gradient_checkpointing_enable()
            elif hasattr(self.model, "set_grad_checkpointing"):
                self.model.set_grad_checkpointing(True)
            else:
                logger.warning(
                    f"Requested gradient checkpointing for {self.__class__.__name__}, but the encoder backend does not expose a compatible hook."
                )
    
    def forward(self, images: torch.Tensor) -> torch.Tensor:
        """
        Args:
            images: Preprocessed images (B, C, H, W)
        Returns:
            Patch embeddings (B, num_patches, hidden_size)
        """
        outputs = self.model(pixel_values=images)
        # Return all patch embeddings (excluding CLS if present)
        hidden_states = outputs.last_hidden_state
        return hidden_states
    
    def preprocess(self, images: List[Image.Image]) -> torch.Tensor:
        """Preprocess PIL images"""
        inputs = self.processor(images=images, return_tensors="pt")
        return inputs.pixel_values


class DINOv2VisionEncoder(nn.Module):
    """DINOv2 Vision Encoder wrapper"""
    
    def __init__(self, config: VisionEncoderConfig):
        super().__init__()
        self.config = config
        
        # Load DINOv2 model
        self.model = AutoModel.from_pretrained(
            config.dinov2_model_name,
            torch_dtype=torch.bfloat16
        )
        
        self.processor = AutoImageProcessor.from_pretrained(config.dinov2_model_name)
        
        # Freeze if specified
        if config.freeze_vision:
            for param in self.model.parameters():
                param.requires_grad = False
        
        # Enable gradient checkpointing
        if config.use_gradient_checkpointing:
            if hasattr(self.model, "gradient_checkpointing_enable"):
                self.model.gradient_checkpointing_enable()
            elif hasattr(self.model, "set_grad_checkpointing"):
                self.model.set_grad_checkpointing(True)
            else:
                logger.warning(
                    f"Requested gradient checkpointing for {self.__class__.__name__}, but the encoder backend does not expose a compatible hook."
                )
    
    def forward(self, images: torch.Tensor) -> torch.Tensor:
        """
        Args:
            images: Preprocessed images (B, C, H, W)
        Returns:
            Patch embeddings (B, num_patches, hidden_size)
        """
        outputs = self.model(pixel_values=images)
        # Skip CLS token, return patch embeddings
        hidden_states = outputs.last_hidden_state[:, 1:, :]
        return hidden_states
    
    def preprocess(self, images: List[Image.Image]) -> torch.Tensor:
        """Preprocess PIL images"""
        inputs = self.processor(images=images, return_tensors="pt")
        return inputs.pixel_values


class DualVisionEncoder(nn.Module):
    """Dual SigLIP + DINOv2 encoder for enhanced representations"""
    
    def __init__(self, config: VisionEncoderConfig):
        super().__init__()
        self.config = config
        
        self.siglip = SigLIPVisionEncoder(config)
        self.dinov2 = DINOv2VisionEncoder(config)
        
        # Align spatial dimensions through adaptive pooling or interpolation
        # SigLIP: 729 patches, DINOv2: 256 patches -> use 256
        self.target_num_patches = config.dinov2_num_patches
    
    def forward(
        self, 
        siglip_images: torch.Tensor,
        dinov2_images: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            siglip_images: Images for SigLIP (B, C, 384, 384)
            dinov2_images: Images for DINOv2 (B, C, 224, 224)
        Returns:
            Concatenated features (B, num_patches, siglip_dim + dinov2_dim)
        """
        # Get features from both encoders
        siglip_features = self.siglip(siglip_images)  # (B, 729, 1152)
        dinov2_features = self.dinov2(dinov2_images)  # (B, 256, 1024)
        
        # Interpolate SigLIP features to match DINOv2 patch count
        B, N_sig, D_sig = siglip_features.shape
        H_sig = W_sig = int(math.sqrt(N_sig))
        
        siglip_features = siglip_features.view(B, H_sig, W_sig, D_sig).permute(0, 3, 1, 2)
        siglip_features = F.interpolate(
            siglip_features, 
            size=(16, 16),  # 256 patches
            mode='bilinear',
            align_corners=False
        )
        siglip_features = siglip_features.permute(0, 2, 3, 1).view(B, 256, D_sig)
        
        # Concatenate along feature dimension
        combined = torch.cat([siglip_features, dinov2_features], dim=-1)
        return combined


def create_vision_encoder(config: VisionEncoderConfig) -> nn.Module:
    """Factory function to create vision encoder based on config"""
    if config.encoder_type == VisionEncoderType.SIGLIP:
        return SigLIPVisionEncoder(config)
    elif config.encoder_type == VisionEncoderType.DINOV2:
        return DINOv2VisionEncoder(config)
    elif config.encoder_type == VisionEncoderType.SIGLIP_DINOV2:
        return DualVisionEncoder(config)
    else:
        raise ValueError(f"Unknown encoder type: {config.encoder_type}")

---
## Part 4: Mixture of Experts Layer

In [None]:
class MoERouter(nn.Module):
    """Router for Mixture of Experts with load balancing"""
    
    def __init__(
        self,
        hidden_size: int,
        num_experts: int,
        top_k: int = 2,
        jitter_noise: float = 0.1
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_experts = num_experts
        self.top_k = top_k
        self.jitter_noise = jitter_noise
        
        # Router linear layer
        self.gate = nn.Linear(hidden_size, num_experts, bias=False)
        
    def forward(
        self, 
        hidden_states: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            hidden_states: (batch, seq_len, hidden_size)
        Returns:
            router_probs: (batch, seq_len, top_k)
            expert_indices: (batch, seq_len, top_k)
            router_logits: (batch, seq_len, num_experts) - for aux loss
            expert_mask: (batch, seq_len, num_experts, top_k)
        """
        batch_size, seq_len, _ = hidden_states.shape
        
        # Compute router logits in float32 for stability
        router_logits = self.gate(hidden_states.float())
        
        # Add jitter noise during training
        if self.training and self.jitter_noise > 0:
            noise = torch.randn_like(router_logits) * self.jitter_noise
            router_logits = router_logits + noise
        
        # Get top-k experts
        router_probs = F.softmax(router_logits, dim=-1)
        top_k_probs, expert_indices = torch.topk(router_probs, self.top_k, dim=-1)
        
        # Normalize top-k probabilities
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
        
        # Create expert mask for dispatching
        expert_mask = F.one_hot(expert_indices, num_classes=self.num_experts)
        expert_mask = expert_mask.permute(0, 1, 3, 2)  # (batch, seq, num_experts, top_k)
        
        return top_k_probs, expert_indices, router_logits, expert_mask


class MoEExpert(nn.Module):
    """Single expert (SwiGLU FFN)"""
    
    def __init__(self, hidden_size: int, intermediate_size: int):
        super().__init__()
        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
        self.act = nn.SiLU()
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))


class MoELayer(nn.Module):
    """Mixture of Experts layer replacing standard FFN"""
    
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        config: MoEConfig
    ):
        super().__init__()
        self.config = config
        self.hidden_size = hidden_size
        self.num_experts = config.num_experts
        self.top_k = config.num_experts_per_token
        
        # Router
        self.router = MoERouter(
            hidden_size=hidden_size,
            num_experts=config.num_experts,
            top_k=config.num_experts_per_token,
            jitter_noise=config.router_jitter_noise
        )
        
        # Experts
        self.experts = nn.ModuleList([
            MoEExpert(hidden_size, intermediate_size)
            for _ in range(config.num_experts)
        ])
        
        # Shared experts (DeepSeek-style)
        if config.num_shared_experts > 0:
            self.shared_experts = nn.ModuleList([
                MoEExpert(hidden_size, intermediate_size)
                for _ in range(config.num_shared_experts)
            ])
        else:
            self.shared_experts = None
    
    def forward(
        self, 
        hidden_states: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            hidden_states: (batch, seq_len, hidden_size)
        Returns:
            output: (batch, seq_len, hidden_size)
            router_logits: for computing aux loss
        """
        batch_size, seq_len, hidden_size = hidden_states.shape
        
        # Route tokens
        top_k_probs, expert_indices, router_logits, expert_mask = self.router(hidden_states)
        
        # Initialize output
        output = torch.zeros_like(hidden_states)
        
        # Process each expert
        for expert_idx, expert in enumerate(self.experts):
            # Find tokens routed to this expert
            # expert_mask: (batch, seq, num_experts, top_k)
            token_mask = expert_mask[:, :, expert_idx, :].any(dim=-1)  # (batch, seq)
            
            if token_mask.any():
                # Get the expert's weight for these tokens
                expert_probs = torch.zeros(batch_size, seq_len, device=hidden_states.device)
                for k in range(self.top_k):
                    mask = (expert_indices[:, :, k] == expert_idx)
                    expert_probs[mask] = top_k_probs[:, :, k][mask]
                
                # Compute expert output for all tokens (will mask later)
                expert_output = expert(hidden_states)
                
                # Weighted addition
                output = output + expert_output * expert_probs.unsqueeze(-1)
        
        # Add shared expert output
        if self.shared_experts is not None:
            for shared_expert in self.shared_experts:
                output = output + shared_expert(hidden_states) / len(self.shared_experts)
        
        return output, router_logits
    
    @staticmethod
    def compute_load_balancing_loss(
        router_logits: torch.Tensor,
        expert_indices: torch.Tensor,
        num_experts: int
    ) -> torch.Tensor:
        """Compute load balancing auxiliary loss"""
        # router_logits: (batch, seq, num_experts)
        router_probs = F.softmax(router_logits.float(), dim=-1)
        
        # Expert importance: mean probability assigned to each expert
        expert_importance = router_probs.mean(dim=[0, 1])  # (num_experts,)
        
        # Expert load: fraction of tokens assigned to each expert
        one_hot = F.one_hot(expert_indices, num_classes=num_experts).float()
        expert_load = one_hot.mean(dim=[0, 1, 2])  # (num_experts,)
        
        # Loss encourages uniform distribution
        return num_experts * (expert_importance * expert_load).sum()
    
    @staticmethod
    def compute_router_z_loss(router_logits: torch.Tensor) -> torch.Tensor:
        """Compute router z-loss for stability"""
        # Penalize large logits
        return torch.logsumexp(router_logits.float(), dim=-1).mean()

---
## Part 5: TRM Recursive Reasoning Module

In [None]:
class TRMReasoningBlock(nn.Module):
    """
    TRM Reasoning Block that can wrap any transformer layer.
    Applies recursive reasoning: updates latent z given (x, y, z).
    """
    
    def __init__(
        self,
        original_layer: nn.Module,
        hidden_size: int,
        config: TRMConfig
    ):
        super().__init__()
        self.original_layer = original_layer
        self.config = config
        self.hidden_size = hidden_size
        
        # Learnable initial state for z
        self.z_init = nn.Parameter(torch.randn(hidden_size) * 0.02)
        
        # Projection for combining x, y, z
        self.input_fusion = nn.Linear(hidden_size * 3, hidden_size, bias=False)
        
        # Output gate (controls how much TRM reasoning affects output)
        self.output_gate = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.Sigmoid()
        )
        
        # Layer norm for z updates
        self.z_norm = nn.RMSNorm(hidden_size, eps=1e-6)
    
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs
    ) -> Tuple[torch.Tensor, ...]:
        """
        Forward with TRM recursive reasoning.
        """
        batch_size, seq_len, _ = hidden_states.shape
        
        # Initialize z (latent reasoning state)
        z = self.z_init.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, -1)
        z = z.to(hidden_states.dtype)
        
        # Store x (input) and initialize y (answer draft)
        x = hidden_states
        y = hidden_states.clone()
        
        all_outputs = []
        
        # TRM recursive reasoning loop
        for t in range(self.config.t_cycles):
            # Determine if we need gradients (only last cycle)
            use_grad = (t == self.config.t_cycles - 1) or self.training
            
            with torch.set_grad_enabled(use_grad):
                # n recursions to update z
                for n in range(self.config.n_recursions):
                    # Combine x, y, z
                    combined = torch.cat([x, y, z], dim=-1)
                    fused = self.input_fusion(combined)
                    
                    # Pass through original transformer layer
                    layer_outputs = self.original_layer(
                        fused,
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                        past_key_value=None,  # Don't use cache during recursion
                        output_attentions=output_attentions,
                        use_cache=False,
                        **kwargs
                    )
                    
                    # Update z with residual
                    z_update = layer_outputs[0]
                    z = self.z_norm(z + self.config.residual_scale * z_update)
                
                # Update y (answer) based on refined z
                gate = self.output_gate(torch.cat([y, z], dim=-1))
                y = y + gate * (z - y)
                
                all_outputs.append(y)
            
            # Detach for next cycle
            if t < self.config.t_cycles - 1:
                z = z.detach()
                y = y.detach()
        
        # Final output
        final_output = all_outputs[-1]
        
        # Return in same format as original layer
        outputs = (final_output,)
        if output_attentions:
            outputs += (None,)  # Placeholder for attention weights
        if use_cache:
            outputs += (past_key_value,)  # Pass through cache
        
        return outputs

---
## Part 6: Vision-Language Projector

In [None]:
class VisionLanguageProjector(nn.Module):
    """Projects vision features to language model space"""
    
    def __init__(
        self,
        vision_dim: int,
        llm_dim: int,
        config: ProjectorConfig
    ):
        super().__init__()
        self.config = config
        
        if config.projector_type == "linear":
            self.projector = nn.Linear(vision_dim, llm_dim)
        
        elif config.projector_type == "mlp":
            layers = []
            in_dim = vision_dim
            
            for i in range(config.num_layers):
                out_dim = llm_dim
                layers.append(nn.Linear(in_dim, out_dim))
                
                if i < config.num_layers - 1:  # No activation on last layer
                    if config.activation == "gelu":
                        layers.append(nn.GELU())
                    elif config.activation == "silu":
                        layers.append(nn.SiLU())
                    elif config.activation == "relu":
                        layers.append(nn.ReLU())
                    
                    if config.dropout > 0:
                        layers.append(nn.Dropout(config.dropout))
                
                in_dim = out_dim
            
            self.projector = nn.Sequential(*layers)
        
        elif config.projector_type == "resampler":
            # Perceiver-style resampler
            self.projector = PerceiverResampler(
                dim=vision_dim,
                output_dim=llm_dim,
                num_latents=64,
                depth=2
            )
        else:
            raise ValueError(f"Unknown projector type: {config.projector_type}")
    
    def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            vision_features: (batch, num_patches, vision_dim)
        Returns:
            projected: (batch, num_patches, llm_dim)
        """
        return self.projector(vision_features)


class PerceiverResampler(nn.Module):
    """Perceiver-style resampler for flexible token compression"""
    
    def __init__(
        self,
        dim: int,
        output_dim: int,
        num_latents: int = 64,
        depth: int = 2,
        num_heads: int = 8
    ):
        super().__init__()
        self.latents = nn.Parameter(torch.randn(num_latents, dim) * 0.02)
        
        self.layers = nn.ModuleList([
            nn.ModuleDict({
                'cross_attn': nn.MultiheadAttention(dim, num_heads, batch_first=True),
                'cross_norm': nn.LayerNorm(dim),
                'ff': nn.Sequential(
                    nn.Linear(dim, dim * 4),
                    nn.GELU(),
                    nn.Linear(dim * 4, dim)
                ),
                'ff_norm': nn.LayerNorm(dim)
            })
            for _ in range(depth)
        ])
        
        self.output_proj = nn.Linear(dim, output_dim)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.shape[0]
        latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1)
        
        for layer in self.layers:
            # Cross attention
            attn_out, _ = layer['cross_attn'](latents, x, x)
            latents = layer['cross_norm'](latents + attn_out)
            
            # Feed-forward
            ff_out = layer['ff'](latents)
            latents = layer['ff_norm'](latents + ff_out)
        
        return self.output_proj(latents)

---
## Part 7: Complete TRM-VLM Model

In [None]:
class TRMVLM(nn.Module):
    """
    Complete TRM Vision-Language Model.
    
    Components:
    - Vision Encoder: SigLIP / DINOv2 / Dual
    - Projector: MLP/Resampler
    - LLM: Qwen2.5 with optional MoE and TRM
    """
    
    def __init__(self, config: TRMVLMConfig):
        super().__init__()
        self.config = config
        
        # 1. Vision Encoder
        logger.info(f"Loading vision encoder: {config.vision_config.encoder_type.value}")
        self.vision_encoder = create_vision_encoder(config.vision_config)
        
        # 2. Vision-Language Projector
        self.projector = VisionLanguageProjector(
            vision_dim=config.vision_config.output_dim,
            llm_dim=config.llm_hidden_size,
            config=config.projector_config
        )
        
        # 3. Load LLM backbone
        logger.info(f"Loading LLM backbone: {config.llm_backbone.value}")
        self.llm = AutoModelForCausalLM.from_pretrained(
            config.llm_backbone.value,
            torch_dtype=getattr(torch, config.torch_dtype),
            # attn_implementation="flash_attention_2" if config.use_flash_attention else "eager",
            trust_remote_code=True
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            config.llm_backbone.value,
            trust_remote_code=True
        )
        
        # 4. Apply MoE modifications if enabled
        if config.moe_config.enabled:
            self._apply_moe_modifications()
        
        # 5. Apply TRM modifications if enabled
        if config.trm_config.enabled:
            self._apply_trm_modifications()
        
        # 6. Enable gradient checkpointing
        self.llm.gradient_checkpointing_enable()
        
        # 7. Store auxiliary losses
        self.aux_losses = {}
    
    def _apply_moe_modifications(self):
        """Replace FFN layers with MoE layers"""
        moe_config = self.config.moe_config
        
        # Get model layers
        layers = self.llm.model.layers
        
        for i, layer in enumerate(layers):
            # Apply MoE every N layers
            if i % moe_config.moe_layer_frequency == 0:
                # Get original FFN dimensions
                original_mlp = layer.mlp
                hidden_size = self.config.llm_hidden_size
                
                # Infer intermediate size from original MLP
                if hasattr(original_mlp, 'gate_proj'):
                    intermediate_size = original_mlp.gate_proj.out_features
                else:
                    intermediate_size = hidden_size * 4
                
                # Create MoE layer
                moe_layer = MoELayer(
                    hidden_size=hidden_size,
                    intermediate_size=intermediate_size,
                    config=moe_config
                )
                
                # Initialize experts from original FFN (sparse upcycling)
                with torch.no_grad():
                    for expert in moe_layer.experts:
                        if hasattr(original_mlp, 'gate_proj'):
                            expert.gate_proj.weight.copy_(original_mlp.gate_proj.weight)
                            expert.up_proj.weight.copy_(original_mlp.up_proj.weight)
                            expert.down_proj.weight.copy_(original_mlp.down_proj.weight)
                
                # Replace MLP with MoE
                layer.mlp = moe_layer
                logger.info(f"Replaced layer {i} FFN with MoE ({moe_config.num_experts} experts)")
    
    def _apply_trm_modifications(self):
        """Wrap specified layers with TRM reasoning"""
        trm_config = self.config.trm_config
        layers = self.llm.model.layers
        num_layers = len(layers)
        
        # Convert negative indices to positive
        layer_indices = [
            idx if idx >= 0 else num_layers + idx
            for idx in trm_config.apply_to_layers
        ]
        
        for idx in layer_indices:
            if 0 <= idx < num_layers:
                original_layer = layers[idx]
                trm_layer = TRMReasoningBlock(
                    original_layer=original_layer,
                    hidden_size=self.config.llm_hidden_size,
                    config=trm_config
                )
                layers[idx] = trm_layer
                logger.info(f"Wrapped layer {idx} with TRM reasoning block")
    
    def encode_images(self, images: torch.Tensor) -> torch.Tensor:
        """
        Encode images to LLM embedding space.
        
        Args:
            images: Preprocessed images (B, C, H, W)
        Returns:
            image_embeds: (B, num_tokens, llm_dim)
        """
        # Get vision features
        with torch.no_grad() if self.config.vision_config.freeze_vision else torch.enable_grad():
            vision_features = self.vision_encoder(images)
        
        # Project to LLM space
        image_embeds = self.projector(vision_features)
        
        return image_embeds
    
    def prepare_inputs_for_generation(
        self,
        input_ids: torch.Tensor,
        images: Optional[torch.Tensor] = None,
        image_positions: Optional[torch.Tensor] = None
    ) -> Dict[str, torch.Tensor]:
        """
        Prepare inputs by inserting image embeddings into text sequence.
        """
        # Get text embeddings
        text_embeds = self.llm.get_input_embeddings()(input_ids)
        
        if images is not None:
            # Get image embeddings
            image_embeds = self.encode_images(images)
            
            # Replace image token positions with image embeddings
            batch_size, seq_len, hidden_dim = text_embeds.shape
            num_image_tokens = image_embeds.shape[1]
            
            # Find image token positions
            image_token_mask = (input_ids == self.config.image_token_id)
            
            # Insert image embeddings
            for b in range(batch_size):
                image_positions = image_token_mask[b].nonzero(as_tuple=True)[0]
                if len(image_positions) > 0:
                    start_pos = image_positions[0].item()
                    text_embeds[b, start_pos:start_pos + num_image_tokens] = image_embeds[b]
        
        return {
            'inputs_embeds': text_embeds,
            'attention_mask': (input_ids != self.tokenizer.pad_token_id).long()
        }
    
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        images: Optional[torch.Tensor] = None,
        return_dict: bool = True,
        **kwargs
    ) -> CausalLMOutputWithPast:
        """
        Forward pass for training.
        """
        # Prepare inputs with image embeddings
        if images is not None:
            inputs = self.prepare_inputs_for_generation(input_ids, images)
            inputs_embeds = inputs['inputs_embeds']
            
            # Forward through LLM with embeddings
            outputs = self.llm(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                labels=labels,
                return_dict=True,
                **kwargs
            )
        else:
            # Text-only forward
            outputs = self.llm(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
                return_dict=True,
                **kwargs
            )
        
        # Compute auxiliary losses (MoE load balancing)
        if self.config.moe_config.enabled and self.training:
            total_aux_loss = self._compute_moe_aux_losses()
            if outputs.loss is not None:
                outputs.loss = outputs.loss + total_aux_loss
        
        return outputs
    
    def _compute_moe_aux_losses(self) -> torch.Tensor:
        """Compute and accumulate MoE auxiliary losses"""
        total_loss = 0.0
        moe_config = self.config.moe_config
        
        for name, module in self.llm.named_modules():
            if isinstance(module, MoELayer):
                # Note: Would need to store router_logits during forward
                # This is a simplified version
                pass
        
        return torch.tensor(total_loss, device=next(self.parameters()).device)
    
    @torch.no_grad()
    def generate(
        self,
        input_ids: torch.Tensor,
        images: Optional[torch.Tensor] = None,
        max_new_tokens: int = 256,
        temperature: float = 0.7,
        top_p: float = 0.9,
        **kwargs
    ) -> torch.Tensor:
        """
        Generate text from the model.
        """
        self.eval()
        
        if images is not None:
            inputs = self.prepare_inputs_for_generation(input_ids, images)
            
            return self.llm.generate(
                inputs_embeds=inputs['inputs_embeds'],
                attention_mask=inputs['attention_mask'],
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                do_sample=True,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
                **kwargs
            )
        else:
            return self.llm.generate(
                input_ids=input_ids,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                do_sample=True,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
                **kwargs
            )
    
    def save_pretrained(self, save_path: str):
        """Save model weights and config"""
        os.makedirs(save_path, exist_ok=True)
        
        # Save config
        import pickle
        with open(os.path.join(save_path, "config.pkl"), "wb") as f:
            pickle.dump(self.config, f)
        
        # Save model weights
        torch.save(self.state_dict(), os.path.join(save_path, "model.pt"))
        
        # Save tokenizer
        self.tokenizer.save_pretrained(save_path)
        
        logger.info(f"Model saved to {save_path}")
    
    @classmethod
    def from_pretrained(cls, load_path: str):
        """Load model from saved weights"""
        import pickle
        
        # Load config
        with open(os.path.join(load_path, "config.pkl"), "rb") as f:
            config = pickle.load(f)
        
        # Create model
        model = cls(config)
        
        # Load weights
        state_dict = torch.load(os.path.join(load_path, "model.pt"))
        model.load_state_dict(state_dict)
        
        return model

---
## Part 8: Pixmo-Cap Dataset

In [None]:
class PixmoCapDataset(Dataset):
    """
    Pixmo-Cap dataset for VLM training.
    
    Features:
    - 712K images with dense captions (avg 196 words)
    - URL-based image loading with caching
    - Multiple transcripts per image
    """
    
    def __init__(
        self,
        split: str = "train",
        max_samples: Optional[int] = None,
        image_processor = None,
        tokenizer = None,
        max_length: int = 2048,
        image_size: int = 384,
        cache_dir: str = "./cache/pixmo_images",
        num_image_tokens: int = 256
    ):
        self.split = split
        self.max_length = max_length
        self.image_size = image_size
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        self.num_image_tokens = num_image_tokens
        
        self.image_processor = image_processor
        self.tokenizer = tokenizer
        
        # Load dataset
        logger.info("Loading Pixmo-Cap dataset...")
        self.dataset = load_dataset("allenai/pixmo-cap", split=split)
        
        if max_samples is not None:
            self.dataset = self.dataset.select(range(min(max_samples, len(self.dataset))))
        
        logger.info(f"Loaded {len(self.dataset)} samples from Pixmo-Cap {split}")
        
        # Define prompt template
        self.prompt_template = """<|im_start|>system
You are a helpful assistant that describes images in detail.<|im_end|>
<|im_start|>user
<image>
Describe this image in detail.<|im_end|>
<|im_start|>assistant
{caption}<|im_end|>"""
    
    def __len__(self) -> int:
        return len(self.dataset)
    
    def _download_image(self, url: str, idx: int) -> Optional[Image.Image]:
        """Download image with caching"""
        cache_path = self.cache_dir / f"{idx}.jpg"
        
        # Check cache first
        if cache_path.exists():
            try:
                return Image.open(cache_path).convert("RGB")
            except Exception:
                cache_path.unlink(missing_ok=True)
        
        # Download image
        try:
            response = requests.get(url, timeout=15)
            response.raise_for_status()
            image = Image.open(BytesIO(response.content)).convert("RGB")
            
            # Cache the image
            image.save(cache_path, "JPEG", quality=95)
            
            return image
        except Exception as e:
            logger.warning(f"Failed to download image {idx}: {e}")
            return None
    
    def _process_image(self, image: Image.Image) -> torch.Tensor:
        """Process image for model input"""
        if self.image_processor is not None:
            processed = self.image_processor(images=image, return_tensors="pt")
            return processed.pixel_values.squeeze(0)
        else:
            # Default preprocessing
            from torchvision import transforms
            transform = transforms.Compose([
                transforms.Resize((self.image_size, self.image_size)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ])
            return transform(image)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = self.dataset[idx]
        
        # Get image
        image_url = item.get('image_url', '')
        image = self._download_image(image_url, idx)
        
        if image is None:
            # Return a dummy sample if image download fails
            return self._get_dummy_sample()
        
        # Process image
        pixel_values = self._process_image(image)
        
        # Get caption
        caption = item.get('caption', '')
        
        # Create full text with prompt
        full_text = self.prompt_template.format(caption=caption)
        
        # Tokenize
        if self.tokenizer is not None:
            # Replace <image> placeholder with image tokens
            image_placeholder = "<image>"
            image_tokens = "<|image_pad|>" * self.num_image_tokens
            full_text = full_text.replace(image_placeholder, image_tokens)
            
            encoding = self.tokenizer(
                full_text,
                max_length=self.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt"
            )
            
            input_ids = encoding.input_ids.squeeze(0)
            attention_mask = encoding.attention_mask.squeeze(0)
            
            # Create labels (mask prompt, only train on completion)
            labels = input_ids.clone()
            
            # Find where assistant response starts
            assistant_token = self.tokenizer.encode("<|im_start|>assistant", add_special_tokens=False)
            for i in range(len(input_ids) - len(assistant_token)):
                if input_ids[i:i+len(assistant_token)].tolist() == assistant_token:
                    # Mask everything before the response
                    labels[:i+len(assistant_token)] = -100
                    break
            
            return {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'labels': labels,
                'pixel_values': pixel_values
            }
        else:
            return {
                'text': full_text,
                'pixel_values': pixel_values
            }
    
    def _get_dummy_sample(self) -> Dict[str, torch.Tensor]:
        """Return a dummy sample for failed image downloads"""
        return {
            'input_ids': torch.zeros(self.max_length, dtype=torch.long),
            'attention_mask': torch.zeros(self.max_length, dtype=torch.long),
            'labels': torch.full((self.max_length,), -100, dtype=torch.long),
            'pixel_values': torch.zeros(3, self.image_size, self.image_size)
        }


def create_dataloader(
    dataset: Dataset,
    batch_size: int,
    shuffle: bool = True,
    num_workers: int = 4,
    distributed: bool = False
) -> DataLoader:
    """Create dataloader with optional distributed sampling"""
    sampler = None
    if distributed:
        sampler = DistributedSampler(dataset, shuffle=shuffle)
        shuffle = False
    
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True
    )

---
## Part 9: Training Loop with FSDP

In [None]:
class TRMVLMTrainer:
    """
    Trainer for TRM-VLM with FSDP support.
    
    Features:
    - FSDP for multi-GPU training
    - Mixed precision (bf16)
    - Gradient checkpointing
    - Learning rate scheduling
    - Wandb logging
    """
    
    def __init__(
        self,
        model: TRMVLM,
        train_dataset: Dataset,
        eval_dataset: Optional[Dataset],
        training_config: TrainingConfig,
        model_config: TRMVLMConfig
    ):
        self.model = model
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.training_config = training_config
        self.model_config = model_config
        
        # Initialize accelerator for distributed training
        self.accelerator = Accelerator(
            mixed_precision="bf16" if training_config.bf16 else None,
            gradient_accumulation_steps=training_config.gradient_accumulation_steps,
            log_with="wandb" if training_config.use_wandb else None,
            project_dir=training_config.output_dir
        )
        
        # Setup training
        self._setup_training()
    
    def _setup_training(self):
        """Setup optimizer, scheduler, and dataloaders"""
        config = self.training_config
        
        # Create parameter groups with different learning rates
        param_groups = [
            {
                'params': [p for n, p in self.model.named_parameters() 
                          if 'vision_encoder' in n and p.requires_grad],
                'lr': config.vision_learning_rate,
                'name': 'vision_encoder'
            },
            {
                'params': [p for n, p in self.model.named_parameters() 
                          if 'llm' in n and p.requires_grad],
                'lr': config.llm_learning_rate,
                'name': 'llm'
            },
            {
                'params': [p for n, p in self.model.named_parameters() 
                          if 'projector' in n and p.requires_grad],
                'lr': config.learning_rate,
                'name': 'projector'
            },
        ]
        
        # Filter out empty groups
        param_groups = [g for g in param_groups if len(list(g['params'])) > 0]
        
        # Optimizer
        self.optimizer = torch.optim.AdamW(
            param_groups,
            lr=config.learning_rate,
            weight_decay=config.weight_decay,
            betas=(0.9, 0.95)
        )
        
        # Dataloader
        self.train_dataloader = create_dataloader(
            self.train_dataset,
            batch_size=config.per_device_batch_size,
            shuffle=True,
            num_workers=4
        )
        
        if self.eval_dataset is not None:
            self.eval_dataloader = create_dataloader(
                self.eval_dataset,
                batch_size=config.per_device_batch_size,
                shuffle=False,
                num_workers=4
            )
        else:
            self.eval_dataloader = None
        
        # Calculate total steps
        num_update_steps_per_epoch = len(self.train_dataloader) // config.gradient_accumulation_steps
        self.total_steps = num_update_steps_per_epoch * config.num_epochs
        warmup_steps = int(self.total_steps * config.warmup_ratio)
        
        # Scheduler
        self.scheduler = get_cosine_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=self.total_steps
        )
        
        # Prepare for distributed training
        self.model, self.optimizer, self.train_dataloader, self.scheduler = self.accelerator.prepare(
            self.model, self.optimizer, self.train_dataloader, self.scheduler
        )
        
        if self.eval_dataloader is not None:
            self.eval_dataloader = self.accelerator.prepare(self.eval_dataloader)
        
        # Initialize wandb
        if self.training_config.use_wandb and self.accelerator.is_main_process:
            self.accelerator.init_trackers(
                project_name=config.wandb_project,
                config={
                    'model_config': str(self.model_config),
                    'training_config': str(config)
                }
            )
    
    def train(self):
        """Main training loop"""
        config = self.training_config
        
        self.model.train()
        global_step = 0
        best_eval_loss = float('inf')
        
        for epoch in range(config.num_epochs):
            epoch_loss = 0.0
            num_batches = 0
            
            pbar = tqdm(
                self.train_dataloader,
                desc=f"Epoch {epoch + 1}/{config.num_epochs}",
                disable=not self.accelerator.is_main_process
            )
            
            for batch_idx, batch in enumerate(pbar):
                with self.accelerator.accumulate(self.model):
                    # Forward pass
                    outputs = self.model(
                        input_ids=batch['input_ids'],
                        attention_mask=batch['attention_mask'],
                        labels=batch['labels'],
                        images=batch.get('pixel_values')
                    )
                    
                    loss = outputs.loss
                    
                    # Backward pass
                    self.accelerator.backward(loss)
                    
                    # Gradient clipping
                    if self.accelerator.sync_gradients:
                        self.accelerator.clip_grad_norm_(
                            self.model.parameters(),
                            config.max_grad_norm
                        )
                    
                    self.optimizer.step()
                    self.scheduler.step()
                    self.optimizer.zero_grad()
                
                epoch_loss += loss.item()
                num_batches += 1
                
                # Logging
                if self.accelerator.sync_gradients:
                    global_step += 1
                    
                    if global_step % config.logging_steps == 0:
                        avg_loss = epoch_loss / num_batches
                        lr = self.scheduler.get_last_lr()[0]
                        
                        pbar.set_postfix({
                            'loss': f'{avg_loss:.4f}',
                            'lr': f'{lr:.2e}'
                        })
                        
                        if self.training_config.use_wandb:
                            self.accelerator.log({
                                'train/loss': avg_loss,
                                'train/lr': lr,
                                'train/epoch': epoch,
                                'train/global_step': global_step
                            })
                    
                    # Evaluation
                    if global_step % config.eval_steps == 0 and self.eval_dataloader is not None:
                        eval_loss = self.evaluate()
                        
                        if eval_loss < best_eval_loss:
                            best_eval_loss = eval_loss
                            self.save_checkpoint(f"best_model")
                        
                        self.model.train()
                    
                    # Save checkpoint
                    if global_step % config.save_steps == 0:
                        self.save_checkpoint(f"checkpoint-{global_step}")
            
            # End of epoch logging
            avg_epoch_loss = epoch_loss / num_batches
            logger.info(f"Epoch {epoch + 1} completed. Average loss: {avg_epoch_loss:.4f}")
        
        # Save final model
        self.save_checkpoint("final_model")
        
        if self.training_config.use_wandb:
            self.accelerator.end_training()
    
    @torch.no_grad()
    def evaluate(self) -> float:
        """Evaluate on validation set"""
        self.model.eval()
        total_loss = 0.0
        num_batches = 0
        
        for batch in tqdm(self.eval_dataloader, desc="Evaluating", 
                         disable=not self.accelerator.is_main_process):
            outputs = self.model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                labels=batch['labels'],
                images=batch.get('pixel_values')
            )
            
            total_loss += outputs.loss.item()
            num_batches += 1
        
        avg_loss = total_loss / num_batches
        
        if self.training_config.use_wandb:
            self.accelerator.log({'eval/loss': avg_loss})
        
        logger.info(f"Evaluation loss: {avg_loss:.4f}")
        return avg_loss
    
    def save_checkpoint(self, name: str):
        """Save training checkpoint"""
        save_path = os.path.join(self.training_config.output_dir, name)
        
        self.accelerator.wait_for_everyone()
        
        if self.accelerator.is_main_process:
            unwrapped_model = self.accelerator.unwrap_model(self.model)
            unwrapped_model.save_pretrained(save_path)
            logger.info(f"Checkpoint saved to {save_path}")

---
## Part 10: Comparison with Qwen2.5-VL-7B

In [None]:
class Qwen2VLComparison:
    """
    Comparison utilities for benchmarking against Qwen2.5-VL-7B.
    """
    
    def __init__(self, device: str = "cuda"):
        self.device = device
        self.qwen_vl_model = None
        self.qwen_vl_processor = None
    
    def load_qwen_vl(self):
        """Load Qwen2.5-VL-7B model"""
        from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
        
        logger.info("Loading Qwen2.5-VL-7B...")
        
        self.qwen_vl_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            "Qwen/Qwen2.5-VL-7B-Instruct",
            torch_dtype=torch.bfloat16,
            # attn_implementation="flash_attention_2",
            device_map="auto"
        )
        
        self.qwen_vl_processor = AutoProcessor.from_pretrained(
            "Qwen/Qwen2.5-VL-7B-Instruct"
        )
        
        logger.info("Qwen2.5-VL-7B loaded successfully")
    
    @torch.no_grad()
    def generate_caption_qwen_vl(
        self, 
        image: Image.Image,
        prompt: str = "Describe this image in detail."
    ) -> str:
        """Generate caption using Qwen2.5-VL-7B"""
        if self.qwen_vl_model is None:
            self.load_qwen_vl()
        
        messages = [{
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt}
            ]
        }]
        
        text = self.qwen_vl_processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        
        inputs = self.qwen_vl_processor(
            text=[text],
            images=[image],
            return_tensors="pt"
        ).to(self.device)
        
        output_ids = self.qwen_vl_model.generate(
            **inputs,
            max_new_tokens=256,
            do_sample=True,
            temperature=0.7
        )
        
        # Decode output
        generated_ids = output_ids[:, inputs['input_ids'].shape[1]:]
        caption = self.qwen_vl_processor.batch_decode(
            generated_ids, skip_special_tokens=True
        )[0]
        
        return caption
    
    @torch.no_grad()
    def generate_caption_trm_vlm(
        self,
        model: TRMVLM,
        image: Image.Image,
        prompt: str = "Describe this image in detail."
    ) -> str:
        """Generate caption using TRM-VLM"""
        model.eval()
        
        # Process image
        if hasattr(model.vision_encoder, 'processor'):
            pixel_values = model.vision_encoder.preprocess([image]).to(self.device)
        else:
            from torchvision import transforms
            transform = transforms.Compose([
                transforms.Resize((384, 384)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
            pixel_values = transform(image).unsqueeze(0).to(self.device)
        
        # Create prompt
        full_prompt = f"""<|im_start|>system
You are a helpful assistant that describes images in detail.<|im_end|>
<|im_start|>user
{"<|image_pad|>" * model.config.num_image_tokens}
{prompt}<|im_end|>
<|im_start|>assistant
"""
        
        # Tokenize
        inputs = model.tokenizer(
            full_prompt,
            return_tensors="pt"
        ).to(self.device)
        
        # Generate
        output_ids = model.generate(
            input_ids=inputs['input_ids'],
            images=pixel_values,
            max_new_tokens=256,
            temperature=0.7
        )
        
        # Decode
        caption = model.tokenizer.decode(
            output_ids[0, inputs['input_ids'].shape[1]:],
            skip_special_tokens=True
        )
        
        return caption
    
    def compare_models(
        self,
        trm_vlm_model: TRMVLM,
        test_images: List[Image.Image],
        prompts: Optional[List[str]] = None
    ) -> Dict[str, List[str]]:
        """Compare outputs from both models on test images"""
        if prompts is None:
            prompts = ["Describe this image in detail."] * len(test_images)
        
        results = {
            'trm_vlm': [],
            'qwen_vl': []
        }
        
        for image, prompt in tqdm(zip(test_images, prompts), total=len(test_images)):
            # TRM-VLM caption
            trm_caption = self.generate_caption_trm_vlm(trm_vlm_model, image, prompt)
            results['trm_vlm'].append(trm_caption)
            
            # Qwen-VL caption
            qwen_caption = self.generate_caption_qwen_vl(image, prompt)
            results['qwen_vl'].append(qwen_caption)
        
        return results
    
    def compute_metrics(
        self,
        predictions: List[str],
        references: List[str]
    ) -> Dict[str, float]:
        """Compute evaluation metrics"""
        try:
            from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
            from rouge_score import rouge_scorer
        except ImportError:
            logger.warning("nltk or rouge_score not installed. Skipping metrics.")
            return {}
        
        # BLEU scores
        smoothing = SmoothingFunction().method1
        bleu_scores = []
        
        for pred, ref in zip(predictions, references):
            ref_tokens = ref.split()
            pred_tokens = pred.split()
            score = sentence_bleu([ref_tokens], pred_tokens, smoothing_function=smoothing)
            bleu_scores.append(score)
        
        # ROUGE scores
        scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
        rouge_scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}
        
        for pred, ref in zip(predictions, references):
            scores = scorer.score(ref, pred)
            for key in rouge_scores:
                rouge_scores[key].append(scores[key].fmeasure)
        
        return {
            'bleu': np.mean(bleu_scores),
            'rouge1': np.mean(rouge_scores['rouge1']),
            'rouge2': np.mean(rouge_scores['rouge2']),
            'rougeL': np.mean(rouge_scores['rougeL'])
        }

---
## Part 11: Ablation Study Framework

In [None]:
class AblationStudy:
    """
    Framework for running ablation studies on TRM-VLM.
    
    Ablations:
    1. Vision Encoder: SigLIP vs DINOv2 vs Dual
    2. MoE: Enabled vs Disabled
    3. TRM: Enabled vs Disabled
    4. LLM Backbone: 0.5B vs 1.5B vs 3B
    """
    
    def __init__(self, base_training_config: TrainingConfig):
        self.base_training_config = base_training_config
        self.results = {}
    
    def create_ablation_configs(self) -> Dict[str, TRMVLMConfig]:
        """Create configurations for each ablation"""
        configs = {}
        
        # 1. Vision Encoder Ablations
        for encoder_type in VisionEncoderType:
            name = f"vision_{encoder_type.value}"
            configs[name] = TRMVLMConfig(
                llm_backbone=LLMBackboneType.QWEN2_5_3B,
                vision_config=VisionEncoderConfig(encoder_type=encoder_type),
                moe_config=MoEConfig(enabled=True),
                trm_config=TRMConfig(enabled=True)
            )
        
        # 2. MoE Ablation
        configs["no_moe"] = TRMVLMConfig(
            llm_backbone=LLMBackboneType.QWEN2_5_3B,
            vision_config=VisionEncoderConfig(encoder_type=VisionEncoderType.SIGLIP),
            moe_config=MoEConfig(enabled=False),
            trm_config=TRMConfig(enabled=True)
        )
        
        # 3. TRM Ablation
        configs["no_trm"] = TRMVLMConfig(
            llm_backbone=LLMBackboneType.QWEN2_5_3B,
            vision_config=VisionEncoderConfig(encoder_type=VisionEncoderType.SIGLIP),
            moe_config=MoEConfig(enabled=True),
            trm_config=TRMConfig(enabled=False)
        )
        
        # 4. LLM Backbone Ablations
        for backbone in [LLMBackboneType.QWEN2_5_0_5B, LLMBackboneType.QWEN2_5_1_5B]:
            name = f"backbone_{backbone.value.split('/')[-1].lower()}"
            configs[name] = TRMVLMConfig(
                llm_backbone=backbone,
                vision_config=VisionEncoderConfig(encoder_type=VisionEncoderType.SIGLIP),
                moe_config=MoEConfig(enabled=True),
                trm_config=TRMConfig(enabled=True)
            )
        
        # 5. Baseline (full model)
        configs["full_model"] = TRMVLMConfig(
            llm_backbone=LLMBackboneType.QWEN2_5_3B,
            vision_config=VisionEncoderConfig(encoder_type=VisionEncoderType.SIGLIP),
            moe_config=MoEConfig(enabled=True),
            trm_config=TRMConfig(enabled=True)
        )
        
        return configs
    
    def run_single_ablation(
        self,
        name: str,
        config: TRMVLMConfig,
        train_dataset: Dataset,
        eval_dataset: Dataset
    ) -> Dict[str, float]:
        """Run a single ablation experiment"""
        logger.info(f"Running ablation: {name}")
        
        # Create model
        model = TRMVLM(config)
        
        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        logger.info(f"  Total params: {total_params:,}")
        logger.info(f"  Trainable params: {trainable_params:,}")
        
        # Update training config
        training_config = TrainingConfig(
            **{**self.base_training_config.__dict__,
               'output_dir': f"{self.base_training_config.output_dir}/{name}",
               'wandb_run_name': name}
        )
        
        # Train
        trainer = TRMVLMTrainer(
            model=model,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            training_config=training_config,
            model_config=config
        )
        
        trainer.train()
        
        # Final evaluation
        final_loss = trainer.evaluate()
        
        return {
            'name': name,
            'final_loss': final_loss,
            'total_params': total_params,
            'trainable_params': trainable_params
        }
    
    def run_all_ablations(
        self,
        train_dataset: Dataset,
        eval_dataset: Dataset,
        ablations_to_run: Optional[List[str]] = None
    ):
        """Run all ablation studies"""
        configs = self.create_ablation_configs()
        
        if ablations_to_run is not None:
            configs = {k: v for k, v in configs.items() if k in ablations_to_run}
        
        for name, config in configs.items():
            try:
                result = self.run_single_ablation(
                    name, config, train_dataset, eval_dataset
                )
                self.results[name] = result
            except Exception as e:
                logger.error(f"Ablation {name} failed: {e}")
                self.results[name] = {'error': str(e)}
        
        return self.results
    
    def summarize_results(self) -> str:
        """Generate summary of ablation results"""
        summary = "\n" + "="*60 + "\n"
        summary += "ABLATION STUDY RESULTS\n"
        summary += "="*60 + "\n\n"
        
        # Sort by loss
        sorted_results = sorted(
            [(k, v) for k, v in self.results.items() if 'error' not in v],
            key=lambda x: x[1].get('final_loss', float('inf'))
        )
        
        for name, result in sorted_results:
            summary += f"{name}:\n"
            summary += f"  Loss: {result['final_loss']:.4f}\n"
            summary += f"  Params: {result['total_params']:,}\n"
            summary += f"  Trainable: {result['trainable_params']:,}\n\n"
        
        return summary

---
## Part 12: Main Execution

In [None]:
def main():
    """
    Main execution function for TRM-VLM training.
    """
    # 1. Create configurations
    model_config = TRMVLMConfig(
        llm_backbone=LLMBackboneType.QWEN2_5_3B,
        vision_config=VisionEncoderConfig(
            encoder_type=VisionEncoderType.SIGLIP,
            freeze_vision=True
        ),
        moe_config=MoEConfig(
            enabled=True,
            num_experts=8,
            num_experts_per_token=2,
            num_shared_experts=1
        ),
        trm_config=TRMConfig(
            enabled=True,
            n_recursions=4,
            t_cycles=2
        )
    )
    
    training_config = TrainingConfig(
        num_epochs=3,
        per_device_batch_size=4,
        gradient_accumulation_steps=8,
        learning_rate=1e-4,
        llm_learning_rate=2e-5,
        warmup_ratio=0.03,
        output_dir="./outputs/trm_vlm_full",
        use_wandb=True,
        wandb_project="trm-vlm"
    )
    
    print("Configuration Summary:")
    print(f"  LLM: {model_config.llm_backbone.value}")
    print(f"  Vision: {model_config.vision_config.encoder_type.value}")
    print(f"  MoE: {model_config.moe_config.enabled} ({model_config.moe_config.num_experts} experts)")
    print(f"  TRM: {model_config.trm_config.enabled}")
    print(f"  Effective batch size: {training_config.per_device_batch_size * training_config.gradient_accumulation_steps * 2}")
    
    return model_config, training_config

model_config, training_config = main()

In [None]:
# Create model
print("Creating TRM-VLM model...")
model = TRMVLM(model_config)

# Print model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Summary:")
print(f"  Total parameters: {total_params:,} ({total_params/1e9:.2f}B)")
print(f"  Trainable parameters: {trainable_params:,} ({trainable_params/1e6:.2f}M)")
print(f"  Frozen parameters: {total_params - trainable_params:,}")

In [None]:
# Create datasets
print("Creating Pixmo-Cap datasets...")

# Use smaller subset for testing
MAX_TRAIN_SAMPLES = 10000  # Set to None for full dataset
MAX_EVAL_SAMPLES = 500

train_dataset = PixmoCapDataset(
    split="train",
    max_samples=MAX_TRAIN_SAMPLES,
    tokenizer=model.tokenizer,
    max_length=model_config.max_seq_length,
    num_image_tokens=model_config.num_image_tokens
)

# For evaluation, use a subset
eval_dataset = PixmoCapDataset(
    split="train",  # Using train split subset for eval since no explicit eval split
    max_samples=MAX_EVAL_SAMPLES,
    tokenizer=model.tokenizer,
    max_length=model_config.max_seq_length,
    num_image_tokens=model_config.num_image_tokens
)

print(f"Train dataset: {len(train_dataset)} samples")
print(f"Eval dataset: {len(eval_dataset)} samples")

In [None]:
# Initialize trainer
trainer = TRMVLMTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    training_config=training_config,
    model_config=model_config
)

print("Trainer initialized successfully!")
print(f"Total training steps: {trainer.total_steps}")

In [None]:
# Start training
print("Starting training...")
trainer.train()

---
## Part 13: Evaluation and Comparison

In [None]:
# Load test images for comparison
def load_test_images(num_images: int = 10) -> List[Image.Image]:
    """Load test images from the dataset"""
    test_dataset = PixmoCapDataset(
        split="train",
        max_samples=num_images * 2  # Load extra in case some fail
    )
    
    images = []
    for i in range(len(test_dataset)):
        item = test_dataset.dataset[i]
        image_url = item.get('image_url', '')
        image = test_dataset._download_image(image_url, i + 100000)  # Different cache
        if image is not None:
            images.append(image)
        if len(images) >= num_images:
            break
    
    return images

# Load test images
test_images = load_test_images(10)
print(f"Loaded {len(test_images)} test images")

In [None]:
# Compare models
comparison = Qwen2VLComparison()

# Generate captions from both models
results = comparison.compare_models(
    trm_vlm_model=model,
    test_images=test_images
)

# Display results
print("\n" + "="*60)
print("CAPTION COMPARISON")
print("="*60)

for i, (trm_cap, qwen_cap) in enumerate(zip(results['trm_vlm'], results['qwen_vl'])):
    print(f"\nImage {i+1}:")
    print(f"  TRM-VLM: {trm_cap[:200]}..." if len(trm_cap) > 200 else f"  TRM-VLM: {trm_cap}")
    print(f"  Qwen-VL: {qwen_cap[:200]}..." if len(qwen_cap) > 200 else f"  Qwen-VL: {qwen_cap}")

In [None]:
# Visualize comparison
fig, axes = plt.subplots(2, 5, figsize=(20, 8))

for i, (ax, image) in enumerate(zip(axes.flat, test_images[:10])):
    ax.imshow(image)
    ax.set_title(f"Image {i+1}")
    ax.axis('off')

plt.suptitle("Test Images for Comparison", fontsize=14)
plt.tight_layout()
plt.show()

---
## Part 14: Run Ablation Studies (Optional)

In [None]:
# Run ablation studies (this will take a long time)
RUN_ABLATIONS = False  # Set to True to run ablations

if RUN_ABLATIONS:
    ablation_training_config = TrainingConfig(
        num_epochs=1,  # Reduced for ablation
        per_device_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=1e-4,
        output_dir="./outputs/ablations",
        use_wandb=True,
        wandb_project="trm-vlm-ablations"
    )
    
    ablation_study = AblationStudy(ablation_training_config)
    
    # Run specific ablations
    ablations_to_run = [
        "vision_siglip",
        "vision_dinov2",
        "no_moe",
        "no_trm",
        "full_model"
    ]
    
    # Create smaller datasets for ablations
    ablation_train = PixmoCapDataset(
        split="train",
        max_samples=5000,
        max_length=1024
    )
    ablation_eval = PixmoCapDataset(
        split="train",
        max_samples=500,
        max_length=1024
    )
    
    # Run ablations
    results = ablation_study.run_all_ablations(
        train_dataset=ablation_train,
        eval_dataset=ablation_eval,
        ablations_to_run=ablations_to_run
    )
    
    # Print summary
    print(ablation_study.summarize_results())

---
## Part 15: Save Final Model

In [None]:
# Save the trained model
save_path = "./outputs/trm_vlm_final"
model.save_pretrained(save_path)
print(f"Model saved to {save_path}")

# Save training metrics
import json

metrics = {
    'model_config': str(model_config),
    'training_config': str(training_config),
    'total_params': total_params,
    'trainable_params': trainable_params
}

with open(f"{save_path}/metrics.json", "w") as f:
    json.dump(metrics, f, indent=2)

print("Training complete!")

---
## Summary

This notebook implements a complete TRM-enhanced Vision-Language Model with:

### Architecture Components
1. **Vision Encoders**: SigLIP, DINOv2, or Dual (configurable)
2. **Projector**: 2-layer MLP or Perceiver Resampler
3. **LLM Backbone**: Qwen2.5 (0.5B/1.5B/3B/7B configurable)
4. **MoE Enhancement**: 8 experts with top-2 routing + shared experts
5. **TRM Reasoning**: Recursive refinement on last N layers

### Training Features
- FSDP for 2x H200 GPUs
- Mixed precision (bf16)
- Gradient checkpointing
- Wandb logging
- Pixmo-Cap dataset (712K images)

### Evaluation
- Comparison with Qwen2.5-VL-7B
- BLEU/ROUGE metrics
- Ablation study framework

### Key Hyperparameters
| Parameter | Value |
|-----------|-------|
| Effective Batch Size | 64 |
| Learning Rate | 1e-4 (projector), 2e-5 (LLM) |
| MoE Experts | 8 (top-2) |
| TRM Recursions | 4 |
| TRM Cycles | 2 |