# SmolLMv3 + TRM Training

Train SmolLMv3-3B with Tiny Recursive Model (TRM) for enhanced reasoning.

**Architecture:**
- SmolLMv3-3B with LoRA adapters
- Perceiver-style latent attention compression (256x)
- TRM recursive reasoning (2 layers, effective depth 672)
- Sliding window output

**Based on:**
- [Less is More: Recursive Reasoning with Tiny Networks](https://arxiv.org/abs/2505.00000) by Alexia Jolicoeur-Martineau
- [SmolLM3 Blog](https://huggingface.co/blog/smollm3)

## 1. Setup Environment

In [None]:
# Check if running on Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running on Google Colab")
except:
    IN_COLAB = False
    print("Running locally")

# Install dependencies if on Colab
if IN_COLAB:
    !pip install -q torch>=2.0.0
    !pip install -q transformers>=4.30.0
    !pip install -q peft>=0.4.0
    !pip install -q pytorch-lightning>=2.0.0
    !pip install -q wandb>=0.15.0
    !pip install -q datasets>=2.14.0
    print("\nDependencies installed")

In [None]:
# Verify setup
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"MPS available: {torch.backends.mps.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    print("Device: Apple Silicon (MPS)")

In [None]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from typing import Tuple, Dict, List, Optional
import json

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup
from peft import LoraConfig, get_peft_model, TaskType

print("Imports successful")

## 2. TRM Core Components

From `src/models/trm.py`

In [None]:
class TransformerBlock(nn.Module):
    """Standard transformer block with self-attention"""

    def __init__(self, d_model: int, n_heads: int = 8, dropout: float = 0.0):
        super().__init__()

        self.norm1 = nn.RMSNorm(d_model)
        self.self_attn = nn.MultiheadAttention(
            d_model, n_heads, dropout=dropout, batch_first=True
        )

        self.norm2 = nn.RMSNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = self.norm1(x)
        x, _ = self.self_attn(x, x, x)
        x = x + residual

        residual = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = x + residual

        return x


class TinyRecursiveNetwork(nn.Module):
    """The core tiny network for recursive reasoning (2 layers optimal)"""

    def __init__(
        self,
        d_model: int,
        n_layers: int = 2,
        n_heads: int = 8,
        dropout: float = 0.0
    ):
        super().__init__()

        self.layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, dropout)
            for _ in range(n_layers)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x)
        return x


class RecursiveReasoningBase(nn.Module):
    """Base class with core recursion logic from TRM paper"""

    def latent_recursion(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        z: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Update z n times, then update y once"""
        for _ in range(self.n_latent_steps):
            combined = x + y + z
            z = self.net(combined)

        combined = y + z
        y = self.net(combined)

        return y, z

    def run_deep_recursion(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        z: torch.Tensor,
        with_gradients: bool = True
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """T-1 recursions without gradients, 1 with gradients"""
        if self.n_deep_recursions > 1:
            with torch.no_grad():
                for _ in range(self.n_deep_recursions - 1):
                    y, z = self.latent_recursion(x, y, z)

        if with_gradients:
            y, z = self.latent_recursion(x, y, z)
        else:
            with torch.no_grad():
                y, z = self.latent_recursion(x, y, z)

        return y, z

    def compute_halt_probability(self, y: torch.Tensor) -> torch.Tensor:
        """ACT halting mechanism"""
        halt_logits = self.halt_head(y.mean(dim=1))
        return torch.sigmoid(halt_logits)

print("TRM core components defined")

## 3. Latent Attention Compressor

From `src/models/compression.py` - Perceiver-style compression

In [None]:
class LatentAttentionCompressor(nn.Module):
    """Perceiver-style compression: [B, L, D] -> [B, M, D]"""

    def __init__(
        self,
        hidden_size: int,
        num_latents: int,
        n_heads: int = 8,
        dropout: float = 0.1
    ):
        super().__init__()

        self.hidden_size = hidden_size
        self.num_latents = num_latents
        self.n_heads = n_heads

        # Learned latent queries
        self.latent_queries = nn.Parameter(torch.randn(num_latents, hidden_size))

        # Cross-attention
        self.compress_attn = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=n_heads,
            dropout=dropout,
            batch_first=True
        )
        self.compress_norm = nn.LayerNorm(hidden_size)
        self.compress_ff = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size * 4, hidden_size),
            nn.Dropout(dropout)
        )
        self.compress_ff_norm = nn.LayerNorm(hidden_size)

    def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch_size, seq_len, _ = x.shape

        latents = self.latent_queries.unsqueeze(0).expand(batch_size, -1, -1)

        key_padding_mask = None
        if attention_mask is not None:
            key_padding_mask = (attention_mask == 0)

        attn_out, _ = self.compress_attn(
            query=latents,
            key=x,
            value=x,
            key_padding_mask=key_padding_mask
        )
        latents = self.compress_norm(latents + attn_out)

        ff_out = self.compress_ff(latents)
        latents = self.compress_ff_norm(latents + ff_out)

        return latents

print("Latent Attention Compressor defined")

## 4. Hidden State TRM

From `src/models/smollm.py` - TRM for LLM hidden states with sliding window

In [None]:
class HiddenStateTRM(RecursiveReasoningBase):
    """TRM for LLM hidden states with sliding window output"""

    def __init__(
        self,
        hidden_size: int = 3072,
        num_latents: int = 256,
        n_layers: int = 2,
        n_heads: int = 8,
        compression_heads: int = 8,
        n_latent_steps: int = 6,
        n_deep_recursions: int = 3,
        n_supervision_steps: int = 8,
        dropout: float = 0.1
    ):
        super().__init__()

        self.hidden_size = hidden_size
        self.num_latents = num_latents
        self.n_latent_steps = n_latent_steps
        self.n_deep_recursions = n_deep_recursions
        self.n_supervision_steps = n_supervision_steps

        self.compressor = LatentAttentionCompressor(
            hidden_size=hidden_size,
            num_latents=num_latents,
            n_heads=compression_heads,
            dropout=dropout
        )

        self.net = TinyRecursiveNetwork(
            d_model=hidden_size,
            n_layers=n_layers,
            n_heads=n_heads,
            dropout=dropout
        )

        self.halt_head = nn.Linear(hidden_size, 1)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        return_all_steps: bool = False
    ) -> torch.Tensor:
        batch_size, seq_len, _ = hidden_states.shape

        # Compress: [B, L, D] -> [B, M, D]
        x_compressed = self.compressor(hidden_states, attention_mask=attention_mask)

        y = torch.zeros_like(x_compressed)
        z = torch.zeros_like(x_compressed)

        all_outputs = []

        for step in range(self.n_supervision_steps):
            y, z = self.run_deep_recursion(x_compressed, y, z, with_gradients=True)

            if return_all_steps:
                shifted = torch.cat([
                    hidden_states[:, self.num_latents:, :],
                    y
                ], dim=1)
                all_outputs.append(shifted)

            if not self.training:
                halt_prob = self.compute_halt_probability(y)
                if halt_prob.mean() > 0.5:
                    break

            y = y.detach()
            z = z.detach()

        # Sliding window: drop first M, append M TRM states
        shifted_states = torch.cat([
            hidden_states[:, self.num_latents:, :],
            y
        ], dim=1)

        if return_all_steps:
            return all_outputs
        return shifted_states

print("Hidden State TRM defined")

## 5. SmolLMv3 + TRM Integration

Full integration with LoRA adapters

In [None]:
class SmolLMv3WithTRM(nn.Module):
    """SmolLMv3 with TRM for enhanced reasoning"""

    def __init__(
        self,
        model_name: str = "HuggingFaceTB/SmolLM3-3B",
        use_lora: bool = True,
        lora_r: int = 16,
        lora_alpha: int = 32,
        lora_dropout: float = 0.1,
        num_latents: int = 256,
        trm_kwargs: Optional[dict] = None
    ):
        super().__init__()

        print(f"Loading {model_name}...")
        self.base_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )

        if use_lora:
            lora_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM,
                r=lora_r,
                lora_alpha=lora_alpha,
                lora_dropout=lora_dropout,
                target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
                bias="none"
            )
            self.base_model = get_peft_model(self.base_model, lora_config)
            print("\nLoRA adapters applied:")
            self.base_model.print_trainable_parameters()

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        special_tokens = {"additional_special_tokens": ["<think>"]}
        num_added = self.tokenizer.add_special_tokens(special_tokens)
        if num_added > 0:
            self.base_model.resize_token_embeddings(len(self.tokenizer))

        self.think_token_id = self.tokenizer.convert_tokens_to_ids("<think>")

        config = self.base_model.config
        hidden_size = config.hidden_size

        trm_kwargs = trm_kwargs or {}
        print(f"\nInitializing TRM with {num_latents} latents...")
        self.trm = HiddenStateTRM(
            hidden_size=hidden_size,
            num_latents=num_latents,
            **trm_kwargs
        )

        if not use_lora:
            for param in self.base_model.parameters():
                param.requires_grad = False

        print(f"\nModel initialized")
        print(f"  <think> token ID: {self.think_token_id}")
        print(f"  TRM parameters: {sum(p.numel() for p in self.trm.parameters())/1e6:.2f}M")

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        use_trm: bool = True
    ):
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_hidden_states=True,
            return_dict=True
        )

        if not use_trm or not self.training:
            return outputs

        think_positions = (input_ids == self.think_token_id).nonzero(as_tuple=True)

        if len(think_positions[0]) == 0:
            return outputs

        hidden_states = outputs.hidden_states[-1]
        shifted_states = self.trm(hidden_states, attention_mask=attention_mask)

        trm_logits = self.base_model.lm_head(shifted_states)

        if labels is not None:
            shifted_labels = labels[:, self.trm.num_latents:]
            if trm_logits.size(1) != shifted_labels.size(1):
                trm_logits = trm_logits[:, :shifted_labels.size(1), :]

            loss_fct = nn.CrossEntropyLoss()
            trm_loss = loss_fct(
                trm_logits.reshape(-1, trm_logits.size(-1)),
                shifted_labels.reshape(-1)
            )
            outputs.loss = outputs.loss + 0.3 * trm_loss

        return outputs

print("SmolLMv3WithTRM defined")

## 6. Dataset and Training Utilities

In [None]:
class ReasoningDataset(Dataset):
    """Dataset with <think> token for TRM reasoning"""

    def __init__(
        self,
        data: List[Dict],
        tokenizer,
        max_length: int = 512,
        add_think_token: bool = True
    ):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.add_think_token = add_think_token

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = self.data[idx]

        question = item['question']
        answer = item['answer']

        if self.add_think_token:
            text = f"Question: {question}\nAnswer: <think> {answer}"
        else:
            text = f"Question: {question}\nAnswer: {answer}"

        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        input_ids = encoding["input_ids"].squeeze(0)
        attention_mask = encoding["attention_mask"].squeeze(0)

        labels = input_ids.clone()
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }


def create_sample_dataset() -> List[Dict]:
    """Create sample math problems for testing"""
    return [
        {"question": "What is 15 x 23?", "answer": "345"},
        {"question": "What is 48 + 76?", "answer": "124"},
        {"question": "What is 100 - 37?", "answer": "63"},
        {"question": "What is 12 x 12?", "answer": "144"},
        {"question": "What is 256 / 8?", "answer": "32"},
    ]

print("Dataset utilities defined")

## 7. PyTorch Lightning Module

In [None]:
class SmolLMTRMLightningModule(pl.LightningModule):
    """PyTorch Lightning module for training"""

    def __init__(
        self,
        model_name: str = "HuggingFaceTB/SmolLM3-3B",
        use_lora: bool = True,
        lora_r: int = 16,
        lora_alpha: int = 32,
        num_latents: int = 256,
        learning_rate: float = 2e-4,
        weight_decay: float = 0.01,
        warmup_steps: int = 100,
        trm_kwargs: Optional[Dict] = None
    ):
        super().__init__()
        self.save_hyperparameters()

        self.model = SmolLMv3WithTRM(
            model_name=model_name,
            use_lora=use_lora,
            lora_r=lora_r,
            lora_alpha=lora_alpha,
            num_latents=num_latents,
            trm_kwargs=trm_kwargs or {}
        )

        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.warmup_steps = warmup_steps

    def forward(self, input_ids, attention_mask, labels):
        return self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            use_trm=True
        )

    def training_step(self, batch, batch_idx):
        outputs = self(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"]
        )

        loss = outputs.loss
        self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train/perplexity", torch.exp(loss), on_step=False, on_epoch=True)

        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"]
        )

        loss = outputs.loss
        self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/perplexity", torch.exp(loss), on_step=False, on_epoch=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, self.parameters()),
            lr=self.learning_rate,
            weight_decay=self.weight_decay
        )

        total_steps = self.trainer.estimated_stepping_batches

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.warmup_steps,
            num_training_steps=total_steps
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1
            }
        }

print("Lightning Module defined")

## 8. Training Configuration

Configure and run training (update paths as needed)

In [None]:
# Training configuration
config = {
    "model_name": "HuggingFaceTB/SmolLM3-3B",
    "batch_size": 2,
    "num_epochs": 3,
    "learning_rate": 2e-4,
    "num_latents": 256,
    "accumulate_grad_batches": 4,
    "precision": "bf16-mixed",
    "output_dir": "./checkpoints",
}

print("Training Configuration:")
for k, v in config.items():
    print(f"  {k}: {v}")

In [None]:
# Create sample data (replace with your dataset)
sample_data = create_sample_dataset()
print(f"Sample data: {len(sample_data)} examples")
print(f"Example: {sample_data[0]}")

In [None]:
# Uncomment to run training
# import wandb
# wandb.login()
#
# pl_module = SmolLMTRMLightningModule(
#     model_name=config["model_name"],
#     num_latents=config["num_latents"],
#     learning_rate=config["learning_rate"],
#     trm_kwargs={
#         "n_layers": 2,
#         "n_latent_steps": 6,
#         "n_deep_recursions": 3,
#         "n_supervision_steps": 8,
#     }
# )
#
# train_dataset = ReasoningDataset(sample_data, pl_module.model.tokenizer)
# train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
#
# trainer = pl.Trainer(
#     max_epochs=config["num_epochs"],
#     accelerator="auto",
#     precision=config["precision"],
#     accumulate_grad_batches=config["accumulate_grad_batches"],
#     gradient_clip_val=1.0,
# )
#
# trainer.fit(pl_module, train_dataloaders=train_loader)

print("Training code ready - uncomment to run")

## 9. Notes

### Training Phases (see src/train/)

1. **Phase 1**: Compressor pretraining (identity + CoT)
2. **Phase 2**: TRM iteration training (hidden_pre -> hidden_post)
3. **Phase 3**: GRPO training (freeze LLM, train TRM + compressor)

### Key Hyperparameters (from TRM paper)

- `n_layers=2` (more layers â†’ overfitting)
- `n_latent_steps=6` (n in paper)
- `n_deep_recursions=3` (T in paper)
- `n_supervision_steps=16` (N_sup in paper)
- `ema_decay=0.999` (critical for stability)

### References

- TRM Paper: papers/less-is-more-TRM/paper.tex
- SmolLM3: https://huggingface.co/blog/smollm3