# SmolLMv3 + TRM Training on GSM8K

This notebook trains SmolLMv3 with Tiny Recursive Model (TRM) on the GSM8K (Grade School Math) dataset.

**Dataset Split:**
- **Training**: 80% of GSM8K train set (~5,978 examples)
- **Test**: 20% of GSM8K train set (~1,495 examples)
- **Final Evaluation**: Original GSM8K test set (1,319 examples) - kept separate

**Features:**
- Self-contained - no external imports needed
- Proper train/test separation
- PyTorch Lightning for clean training
- Weights & Biases logging
- LoRA for efficient fine-tuning
- Latent attention compression
- Automatic checkpointing

**Based on:** *Less is More: Recursive Reasoning with Tiny Networks* by Alexia Jolicoeur-Martineau

## 1. Setup Environment

In [1]:
# Check if running on Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running on Google Colab")
except:
    IN_COLAB = False
    print("Running locally")

# Install dependencies if on Colab
if IN_COLAB:
    !pip install -q torch>=2.0.0
    !pip install -q transformers>=4.30.0
    !pip install -q peft>=0.4.0
    !pip install -q pytorch-lightning>=2.0.0
    !pip install -q wandb>=0.15.0
    !pip install -q datasets>=2.14.0
    !pip install -q tqdm
    print("\nDependencies installed")

Running locally


In [2]:
# Verify PyTorch and GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"MPS available: {torch.backends.mps.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
elif torch.backends.mps.is_available():
    print("Device: Apple Silicon (MPS) - GPU acceleration enabled")
else:
    print("No GPU available - training will be slow")

PyTorch version: 2.8.0
CUDA available: False
MPS available: True
Device: Apple Silicon (MPS) - GPU acceleration enabled


## 2. Login to Weights & Biases

Get your API key from: https://wandb.ai/authorize

In [3]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mneosapien17[0m ([33manonx[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## 3. Import Core Dependencies

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from typing import Tuple, Dict, List, Optional
import json
import re

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset

print("Imports successful")

  from .autonotebook import tqdm as notebook_tqdm


Imports successful


## 4. Load and Split GSM8K Dataset

**Split Strategy:**
- GSM8K training set (7,473 examples) → 80% train / 20% test
- GSM8K test set (1,319 examples) → Reserved for final evaluation only

This gives us:
- **Train**: ~5,978 examples (80% of 7,473)
- **Test**: ~1,495 examples (20% of 7,473)
- **Final Eval**: 1,319 examples (original test set)

In [5]:
# Load GSM8K dataset from HuggingFace
print("Loading GSM8K dataset...")
gsm8k_full = load_dataset("gsm8k", "main", split="train")
gsm8k_final_test = load_dataset("gsm8k", "main", split="test")

print(f"Full training set: {len(gsm8k_full)} examples")
print(f"Final test set: {len(gsm8k_final_test)} examples (reserved for final evaluation)")

# Split the training set into train (80%) and test (20%)
from datasets import Dataset as HFDataset

# Calculate split index
total_examples = len(gsm8k_full)
test_size = int(0.2 * total_examples)
train_size = total_examples - test_size

print(f"\nCreating 80/20 split:")
print(f"  Train: {train_size} examples (80%)")
print(f"  Test: {test_size} examples (20%)")

# Create the split
gsm8k_train = gsm8k_full.select(range(train_size))
gsm8k_test = gsm8k_full.select(range(train_size, total_examples))

print(f"\nFinal splits:")
print(f"  Train: {len(gsm8k_train)}")
print(f"  Test: {len(gsm8k_test)}")
print(f"  Final Eval: {len(gsm8k_final_test)}")

# Show a sample
sample = gsm8k_train[0]
print("\nSample question:")
print(sample['question'])
print("\nSample answer:")
print(sample['answer'])

Loading GSM8K dataset...
Full training set: 7473 examples
Final test set: 1319 examples (reserved for final evaluation)

Creating 80/20 split:
  Train: 5979 examples (80%)
  Test: 1494 examples (20%)

Final splits:
  Train: 5979
  Test: 1494
  Final Eval: 1319

Sample question:
Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?

Sample answer:
Natalia sold 48/2 = <<48/2=24>>24 clips in May.
Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.
#### 72


## 5. TRM Core Components

### 5.1 Transformer Block and Tiny Recursive Network

In [6]:
class TransformerBlock(nn.Module):
    """Standard transformer block with self-attention"""
    
    def __init__(self, d_model: int, n_heads: int = 8, dropout: float = 0.0):
        super().__init__()
        
        self.norm1 = nn.RMSNorm(d_model)
        self.self_attn = nn.MultiheadAttention(
            d_model, n_heads, dropout=dropout, batch_first=True
        )
        
        self.norm2 = nn.RMSNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Self-attention with residual
        residual = x
        x = self.norm1(x)
        x, _ = self.self_attn(x, x, x)
        x = x + residual
        
        # MLP with residual
        residual = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = x + residual
        
        return x


class TinyRecursiveNetwork(nn.Module):
    """The core tiny network for recursive reasoning"""
    
    def __init__(
        self,
        d_model: int,
        n_layers: int = 2,
        n_heads: int = 8,
        dropout: float = 0.0
    ):
        super().__init__()
        
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, dropout) 
            for _ in range(n_layers)
        ])
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x)
        return x

print("Transformer and Tiny Recursive Network defined")

Transformer and Tiny Recursive Network defined


### 5.2 Latent Attention Compressor

In [7]:
class LatentAttentionCompressor(nn.Module):
    """Attention-based sequence compression (Perceiver-style)"""
    
    def __init__(
        self,
        hidden_size: int,
        num_latents: int,
        n_heads: int = 8,
        dropout: float = 0.1
    ):
        super().__init__()
        
        self.hidden_size = hidden_size
        self.num_latents = num_latents
        self.n_heads = n_heads
        
        # Learned latent queries
        self.latent_queries = nn.Parameter(torch.randn(num_latents, hidden_size))
        
        # Cross-attention
        self.compress_attn = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=n_heads,
            dropout=dropout,
            batch_first=True
        )
        self.compress_norm = nn.LayerNorm(hidden_size)
        self.compress_ff = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size * 4, hidden_size),
            nn.Dropout(dropout)
        )
        self.compress_ff_norm = nn.LayerNorm(hidden_size)
    
    def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch_size, seq_len, _ = x.shape
        
        # Expand latent queries
        latents = self.latent_queries.unsqueeze(0).expand(batch_size, -1, -1)
        
        # Cross-attention
        # Convert attention_mask to key_padding_mask format (True for positions to ignore)
        key_padding_mask = None

        attention_mask = attention_mask.bool()
        if attention_mask is not None:
            key_padding_mask = (attention_mask == 0)  # True for padding positions
        
        attn_out, _ = self.compress_attn(
            query=latents,
            key=x,
            value=x,
            key_padding_mask=key_padding_mask
        )
        latents = self.compress_norm(latents + attn_out)
        
        # Feed-forward
        ff_out = self.compress_ff(latents)
        latents = self.compress_ff_norm(latents + ff_out)
        
        return latents

print("Latent Attention Compressor defined")

Latent Attention Compressor defined


### 5.3 Recursive Reasoning Base Class

In [8]:
class RecursiveReasoningBase(nn.Module):
    """Base class with core recursion logic"""
    
    def latent_recursion(
        self, 
        x: torch.Tensor, 
        y: torch.Tensor, 
        z: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Update latent z
        for _ in range(self.n_latent_steps):
            combined = x + y + z
            z = self.net(combined)
        
        # Update prediction y
        combined = y + z
        y = self.net(combined)
        
        return y, z
    
    def run_deep_recursion(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        z: torch.Tensor,
        with_gradients: bool = True
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Run T-1 recursions without gradients
        if self.n_deep_recursions > 1:
            with torch.no_grad():
                for _ in range(self.n_deep_recursions - 1):
                    y, z = self.latent_recursion(x, y, z)
        
        # Final recursion with gradients
        if with_gradients:
            y, z = self.latent_recursion(x, y, z)
        else:
            with torch.no_grad():
                y, z = self.latent_recursion(x, y, z)
        
        return y, z
    
    def compute_halt_probability(self, y: torch.Tensor) -> torch.Tensor:
        halt_logits = self.halt_head(y.mean(dim=1))
        return torch.sigmoid(halt_logits)

print("Recursive Reasoning Base defined")

Recursive Reasoning Base defined


### 5.4 Hidden State TRM

In [9]:
class HiddenStateTRM(RecursiveReasoningBase):
    """TRM for processing LLM hidden states with sliding window"""
    
    def __init__(
        self,
        hidden_size: int = 3072,
        num_latents: int = 256,
        n_layers: int = 2,
        n_heads: int = 8,
        compression_heads: int = 8,
        n_latent_steps: int = 6,
        n_deep_recursions: int = 3,
        n_supervision_steps: int = 8,
        dropout: float = 0.1
    ):
        super().__init__()
        
        self.hidden_size = hidden_size
        self.num_latents = num_latents
        self.n_latent_steps = n_latent_steps
        self.n_deep_recursions = n_deep_recursions
        self.n_supervision_steps = n_supervision_steps
        
        self.compressor = LatentAttentionCompressor(
            hidden_size=hidden_size,
            num_latents=num_latents,
            n_heads=compression_heads,
            dropout=dropout
        )
        
        self.net = TinyRecursiveNetwork(
            d_model=hidden_size,
            n_layers=n_layers,
            n_heads=n_heads,
            dropout=dropout
        )
        
        self.halt_head = nn.Linear(hidden_size, 1)
    
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        return_all_steps: bool = False
    ) -> torch.Tensor:
        batch_size, seq_len, _ = hidden_states.shape
        
        # Compress
        x_compressed = self.compressor(hidden_states, attention_mask=attention_mask)
        
        # Initialize
        y = torch.zeros_like(x_compressed)
        z = torch.zeros_like(x_compressed)
        
        all_outputs = []
        
        # Deep supervision loop
        for step in range(self.n_supervision_steps):
            y, z = self.run_deep_recursion(x_compressed, y, z, with_gradients=True)
            
            if return_all_steps:
                shifted = torch.cat([
                    hidden_states[:, self.num_latents:, :],
                    y
                ], dim=1)
                all_outputs.append(shifted)
            
            if not self.training:
                halt_prob = self.compute_halt_probability(y)
                if halt_prob.mean() > 0.5:
                    break
            
            y = y.detach()
            z = z.detach()
        
        # Sliding window
        shifted_states = torch.cat([
            hidden_states[:, self.num_latents:, :],
            y
        ], dim=1)
        
        if return_all_steps:
            return all_outputs
        return shifted_states

print("Hidden State TRM defined")

Hidden State TRM defined


## 6. SmolLMv3 + TRM Integration

In [10]:
class SmolLMv3WithTRM(nn.Module):
    """SmolLMv3 with TRM for enhanced reasoning"""
    
    def __init__(
        self,
        model_name: str = "HuggingFaceTB/SmolLM3-3B",
        use_lora: bool = True,
        lora_r: int = 16,
        lora_alpha: int = 32,
        lora_dropout: float = 0.1,
        num_latents: int = 256,
        trm_kwargs: Optional[dict] = None
    ):
        super().__init__()
        
        print(f"Loading {model_name}...")
        self.base_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        
        if use_lora:
            lora_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM,
                r=lora_r,
                lora_alpha=lora_alpha,
                lora_dropout=lora_dropout,
                target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
                bias="none"
            )
            self.base_model = get_peft_model(self.base_model, lora_config)
            print("\nLoRA adapters applied:")
            self.base_model.print_trainable_parameters()
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        # Add <think> token
        special_tokens = {"additional_special_tokens": ["<think>"]}
        num_added = self.tokenizer.add_special_tokens(special_tokens)
        if num_added > 0:
            self.base_model.resize_token_embeddings(len(self.tokenizer))
        
        self.think_token_id = self.tokenizer.convert_tokens_to_ids("<think>")
        
        config = self.base_model.config
        hidden_size = config.hidden_size
        
        trm_kwargs = trm_kwargs or {}
        print(f"\nInitializing TRM with {num_latents} latents...")
        self.trm = HiddenStateTRM(
            hidden_size=hidden_size,
            num_latents=num_latents,
            **trm_kwargs
        )
        
        if not use_lora:
            for param in self.base_model.parameters():
                param.requires_grad = False
        
        print(f"\nModel initialized")
        print(f"  <think> token ID: {self.think_token_id}")
        print(f"  TRM parameters: {sum(p.numel() for p in self.trm.parameters())/1e6:.2f}M")
    
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        use_trm: bool = True
    ):
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_hidden_states=True,
            return_dict=True
        )
        
        if not use_trm or not self.training:
            return outputs
        
        think_positions = (input_ids == self.think_token_id).nonzero(as_tuple=True)
        
        if len(think_positions[0]) == 0:
            return outputs
        
        hidden_states = outputs.hidden_states[-1]
        shifted_states = self.trm(hidden_states, attention_mask=attention_mask)
        
        trm_logits = self.base_model.lm_head(shifted_states)
        
        if labels is not None:
            # Ensure dimensions match: TRM logits should match the shifted labels
            shifted_labels = labels[:, self.trm.num_latents:]
            # Make sure TRM logits and shifted labels have the same sequence length
            if trm_logits.size(1) != shifted_labels.size(1):
                # If dimensions don't match, adjust TRM logits to match shifted labels
                trm_logits = trm_logits[:, :shifted_labels.size(1), :]
            
            loss_fct = nn.CrossEntropyLoss()
            trm_loss = loss_fct(
                trm_logits.reshape(-1, trm_logits.size(-1)),
                shifted_labels.reshape(-1)
            )
            outputs.loss = outputs.loss + 0.3 * trm_loss
        
        return outputs
    
    def generate_with_thinking(
        self,
        prompt: str,
        max_new_tokens: int = 256,
        temperature: float = 0.7,
        do_sample: bool = True
    ) -> str:
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.base_model.device)
        
        with torch.no_grad():
            outputs = self.base_model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=do_sample,
                pad_token_id=self.tokenizer.eos_token_id
            )
        
        return self.tokenizer.decode(outputs[0], skip_special_tokens=False)

print("SmolLMv3WithTRM defined")

SmolLMv3WithTRM defined


## 7. GSM8K Dataset Class

In [11]:
class GSM8KDataset(Dataset):
    """Dataset for GSM8K with <think> token support"""
    
    def __init__(
        self,
        dataset,
        tokenizer,
        max_length: int = 512,
        add_think_token: bool = True
    ):
        self.data = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.add_think_token = add_think_token
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = self.data[idx]
        
        question = item['question']
        answer = item['answer']
        
        # Format with <think> token to trigger TRM reasoning
        if self.add_think_token:
            text = f"Question: {question}\nAnswer: <think> {answer}"
        else:
            text = f"Question: {question}\nAnswer: {answer}"
        
        # Tokenize
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        input_ids = encoding["input_ids"].squeeze(0)
        attention_mask = encoding["attention_mask"].squeeze(0)
        
        # Create labels
        labels = input_ids.clone()
        labels[labels == self.tokenizer.pad_token_id] = -100
        
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

print("GSM8K Dataset class defined")

GSM8K Dataset class defined


## 8. PyTorch Lightning Module

In [12]:
class SmolLMTRMLightningModule(pl.LightningModule):
    """PyTorch Lightning module for training"""
    
    def __init__(
        self,
        model_name: str = "HuggingFaceTB/SmolLM3-3B",
        use_lora: bool = True,
        lora_r: int = 16,
        lora_alpha: int = 32,
        num_latents: int = 256,
        learning_rate: float = 2e-4,
        weight_decay: float = 0.01,
        warmup_steps: int = 100,
        trm_kwargs: Optional[Dict] = None
    ):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = SmolLMv3WithTRM(
            model_name=model_name,
            use_lora=use_lora,
            lora_r=lora_r,
            lora_alpha=lora_alpha,
            num_latents=num_latents,
            trm_kwargs=trm_kwargs or {}
        )
        
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.warmup_steps = warmup_steps
        self.total_steps = None
    
    def forward(self, input_ids, attention_mask, labels):
        return self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            use_trm=True
        )
    
    def training_step(self, batch, batch_idx):
        outputs = self(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"]
        )
        
        loss = outputs.loss
        self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train/perplexity", torch.exp(loss), on_step=False, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        outputs = self(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"]
        )
        
        loss = outputs.loss
        self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log("val/perplexity", torch.exp(loss), on_step=False, on_epoch=True, sync_dist=True)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, self.parameters()),
            lr=self.learning_rate,
            weight_decay=self.weight_decay
        )
        
        if self.total_steps is None:
            self.total_steps = self.trainer.estimated_stepping_batches
        
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.warmup_steps,
            num_training_steps=self.total_steps
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1
            }
        }

print("Lightning Module defined")

Lightning Module defined


## 9. Training Function

In [13]:
def train_gsm8k(
    train_dataset,
    test_dataset,
    model_name: str = "HuggingFaceTB/SmolLM3-3B",
    output_dir: str = "./checkpoints",
    batch_size: int = 2,
    num_epochs: int = 3,
    learning_rate: float = 2e-4,
    num_latents: int = 256,
    accumulate_grad_batches: int = 4,
    val_check_interval: float = 0.25,
    precision: str = "bf16-mixed",
    devices: int = 1,
    wandb_project: str = "smollm-trm-gsm8k",
    wandb_name: Optional[str] = None,
):
    """Train on GSM8K dataset with provided train/test split"""
    
    print(f"Train examples: {len(train_dataset)}")
    print(f"Test examples: {len(test_dataset)}")
    
    # Initialize model
    pl_module = SmolLMTRMLightningModule(
        model_name=model_name,
        use_lora=True,
        lora_r=16,
        lora_alpha=32,
        num_latents=num_latents,
        learning_rate=learning_rate,
        warmup_steps=100,
        trm_kwargs={
            "n_layers": 2,
            "n_latent_steps": 4,
            "n_deep_recursions": 2,
            "n_supervision_steps": 4,
            "compression_heads": 8
        }
    )
    
    # Create datasets
    train_dataset_wrapped = GSM8KDataset(
        train_dataset,
        pl_module.model.tokenizer,
        max_length=512,
        add_think_token=True
    )
    
    test_dataset_wrapped = GSM8KDataset(
        test_dataset,
        pl_module.model.tokenizer,
        max_length=512,
        add_think_token=True
    )
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset_wrapped,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,  # Set to 0 to avoid multiprocessing issues in Jupyter
        pin_memory=False  # Disable pin_memory for MPS compatibility
    )
    
    test_loader = DataLoader(
        test_dataset_wrapped,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,  # Set to 0 to avoid multiprocessing issues in Jupyter
        pin_memory=False  # Disable pin_memory for MPS compatibility
    )
    
    # Setup wandb logger
    wandb_logger = WandbLogger(
        project=wandb_project,
        name=wandb_name,
        log_model=True
    )
    
    # Setup callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath=output_dir,
        filename="smollm-trm-gsm8k-{epoch:02d}-{val/loss:.4f}",
        monitor="val/loss",
        mode="min",
        save_top_k=3,
        save_last=True,
        verbose=True
    )
    
    early_stop_callback = EarlyStopping(
        monitor="val/loss",
        patience=3,
        mode="min",
        verbose=True
    )
    
    lr_monitor = LearningRateMonitor(logging_interval="step")
    
    # Create trainer
    trainer = pl.Trainer(
        max_epochs=num_epochs,
        accelerator="auto",
        devices=devices,
        precision=precision,
        accumulate_grad_batches=accumulate_grad_batches,
        gradient_clip_val=1.0,
        val_check_interval=val_check_interval,
        logger=wandb_logger,
        callbacks=[checkpoint_callback, early_stop_callback, lr_monitor],
        log_every_n_steps=10,
        enable_progress_bar=True,
        enable_model_summary=True
    )
    
    # Train
    print("\n" + "="*70)
    print("Starting training on GSM8K...")
    print("="*70 + "\n")
    
    trainer.fit(
        pl_module,
        train_dataloaders=train_loader,
        val_dataloaders=test_loader
    )
    
    print(f"\n" + "="*70)
    print("Training complete!")
    print("="*70)
    print(f"Best checkpoint: {checkpoint_callback.best_model_path}")
    print(f"Best test loss: {checkpoint_callback.best_model_score:.4f}")
    
    return trainer, pl_module

print("Training function defined")

Training function defined


## 10. Evaluation Functions

Extract numerical answers and compute accuracy.

In [14]:
def extract_answer(text: str) -> Optional[str]:
    """Extract numerical answer from GSM8K format"""
    # GSM8K answers are in format "#### 123"
    match = re.search(r'####\s*([\d,\.]+)', text)
    if match:
        return match.group(1).replace(',', '')
    
    # Also try to find last number in text
    numbers = re.findall(r'[\d,]+\.?\d*', text)
    if numbers:
        return numbers[-1].replace(',', '')
    
    return None


def evaluate_gsm8k(model, test_dataset, max_samples: Optional[int] = None, batch_size: int = 1):
    """
    Evaluate model on GSM8K test set.
    
    Args:
        model: The trained SmolLMv3WithTRM model
        test_dataset: GSM8K test dataset
        max_samples: Number of samples to evaluate (None for all)
        batch_size: Batch size for generation
    
    Returns:
        Dictionary with accuracy and examples
    """
    model.eval()
    model.base_model.eval()
    
    correct = 0
    total = 0
    examples = []
    
    num_to_eval = len(test_dataset) if max_samples is None else min(max_samples, len(test_dataset))
    print(f"Evaluating on {num_to_eval} examples...")
    
    for i in range(num_to_eval):
        item = test_dataset[i]
        question = item['question']
        true_answer = extract_answer(item['answer'])
        
        # Create prompt with <think> token
        prompt = f"Question: {question}\nAnswer: <think>"
        
        # Generate answer
        generated = model.generate_with_thinking(
            prompt,
            max_new_tokens=200,
            temperature=0.1,  # Low temperature for more deterministic answers
            do_sample=True
        )
        
        # Extract predicted answer
        pred_answer = extract_answer(generated)
        
        # Check if correct
        is_correct = (pred_answer == true_answer) if (pred_answer and true_answer) else False
        
        if is_correct:
            correct += 1
        total += 1
        
        # Save first 10 examples
        if i < 10:
            examples.append({
                'question': question,
                'true_answer': true_answer,
                'predicted_answer': pred_answer,
                'generated_text': generated,
                'correct': is_correct
            })
        
        # Progress
        if (i + 1) % 10 == 0:
            acc = correct / total * 100
            print(f"  {i+1}/{num_to_eval} - Accuracy: {acc:.2f}%")
    
    accuracy = correct / total * 100
    
    return {
        'accuracy': accuracy,
        'correct': correct,
        'total': total,
        'examples': examples
    }

print("Evaluation functions defined")

Evaluation functions defined


## 11. Configure and Start Training

The training will use the 80/20 split we created earlier.

In [15]:
# Training configuration
config = {
    "train_dataset": gsm8k_train,  # 80% of GSM8K train set
    "test_dataset": gsm8k_test,    # 20% of GSM8K train set
    "model_name": "HuggingFaceTB/SmolLM3-3B",
    "batch_size": 2,
    "num_epochs": 3,
    "learning_rate": 2e-4,
    "num_latents": 256,
    "accumulate_grad_batches": 4,
    "precision": "bf16-mixed",
    "devices": 1,
    "wandb_project": "smollm-trm-gsm8k",
    "wandb_name": "gsm8k-80-20-split",
    "output_dir": "./checkpoints",
}

print("Training Configuration:")
print(f"  Train size: {len(config['train_dataset'])} examples")
print(f"  Test size: {len(config['test_dataset'])} examples")
print(f"  Epochs: {config['num_epochs']}")
print(f"  Batch size: {config['batch_size']}")
print(f"  Learning rate: {config['learning_rate']}")
print("\nReady to train! Run the next cell to start.")

Training Configuration:
  Train size: 5979 examples
  Test size: 1494 examples
  Epochs: 3
  Batch size: 2
  Learning rate: 0.0002

Ready to train! Run the next cell to start.


In [None]:
# Start training on GSM8K
trainer, model = train_gsm8k(**config)

Train examples: 5979
Test examples: 1494
Loading HuggingFaceTB/SmolLM3-3B...


`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 2/2 [00:28<00:00, 14.10s/it]



LoRA adapters applied:
trainable params: 7,667,712 || all params: 3,082,766,336 || trainable%: 0.2487

Initializing TRM with 256 latents...


Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs



Model initialized
  <think> token ID: 128002
  TRM parameters: 151.59M

Starting training on GSM8K...



Loading `train_dataloader` to estimate number of stepping batches.
/Users/neosapien/Development/llm-trm/.venv/lib/python3.13/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.
/Users/neosapien/Development/llm-trm/.venv/lib/python3.13/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision bf16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name  | Type            | Params | Mode 
--------------------------------------------------
0 | model | SmolLMv3WithTRM | 3.2 B  | train
--------------------------------------------------
159 M     Trainable params
3.1 B     Non-trainable params
3.2 B     Total params
12,937.437Total estimated model params size (MB)
1476   

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/Users/neosapien/Development/llm-trm/.venv/lib/python3.13/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


                                                                           



Epoch 0:   0%|          | 1/2990 [00:48<40:22:16,  0.02it/s, v_num=eixo, train/loss_step=nan.0]

wandb-core(77045) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(77051) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(77052) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(77057) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(77061) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(77067) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(77068) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(77075) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(77080) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(77086) MallocStackLogging: can't turn off malloc stack logging because 

## 12. Evaluate on Test Set (20% split)

Test the trained model on our 20% test split.

In [None]:
# Evaluate on our 20% test split
results = evaluate_gsm8k(
    model.model,
    gsm8k_test,
    max_samples=100,  # Evaluate on 100 examples (test set has ~1,495)
)

print("\n" + "="*70)
print("Test Set Results (20% split)")
print("="*70)
print(f"Accuracy: {results['accuracy']:.2f}%")
print(f"Correct: {results['correct']}/{results['total']}")
print("\nExample Predictions:")
print("="*70)

for i, ex in enumerate(results['examples'][:5]):
    print(f"\nExample {i+1}:")
    print(f"Question: {ex['question']}")
    print(f"True Answer: {ex['true_answer']}")
    print(f"Predicted Answer: {ex['predicted_answer']}")
    print(f"Correct: {ex['correct']}")
    print("-"*70)

## 13. Final Evaluation on GSM8K Official Test Set (Optional)

Evaluate on the official GSM8K test set (1,319 examples) - use this for final benchmarking only.

In [None]:
# Evaluate on official GSM8K test set (1,319 examples)
# This is the standard benchmark test set - use sparingly to avoid overfitting

final_results = evaluate_gsm8k(
    model.model,
    gsm8k_final_test,
    max_samples=100,  # Test on 100 examples (full set has 1,319)
)

print("\n" + "="*70)
print("Final Evaluation on Official GSM8K Test Set")
print("="*70)
print(f"Accuracy: {final_results['accuracy']:.2f}%")
print(f"Correct: {final_results['correct']}/{final_results['total']}")
print("\nNote: This is evaluated on the official GSM8K test set")
print("Use this for final benchmarking only to avoid overfitting")

## 14. Test Individual Problems

Try the model on specific math problems.

In [None]:
# Test on custom problems
test_problems = [
    "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?",
    "A restaurant served 9 pizzas during lunch and 6 during dinner today. How many pizzas were served in total?",
    "If John has 5 apples and buys 7 more, then gives away 3, how many apples does he have left?",
]

print("Testing on custom problems:\n")
for i, problem in enumerate(test_problems):
    prompt = f"Question: {problem}\nAnswer: <think>"
    
    response = model.model.generate_with_thinking(
        prompt,
        max_new_tokens=150,
        temperature=0.1
    )
    
    # Extract just the answer part
    answer_part = response.split("<think>")[-1] if "<think>" in response else response
    
    print(f"Problem {i+1}:")
    print(f"Q: {problem}")
    print(f"A: {answer_part.strip()}")
    print("-"*70)

## 15. Save Trained Model

Save the model for later use.

In [None]:
# Save the model
save_path = "./gsm8k_model"

# Save base model with LoRA
model.model.base_model.save_pretrained(save_path)

# Save TRM weights
torch.save(model.model.trm.state_dict(), f"{save_path}/trm_weights.pt")

# Save tokenizer
model.model.tokenizer.save_pretrained(save_path)

print(f"Model saved to {save_path}")
print("\nTo load later:")
print("  model = SmolLMv3WithTRM(model_name=save_path)")
print(f"  model.trm.load_state_dict(torch.load('{save_path}/trm_weights.pt'))")