# Modular TRM Training for Sudoku 4x4

This notebook demonstrates a modular approach to training and evaluating a TRM neural network on 4x4 Sudoku puzzles using PyTorch. The code is organized for easy adaptation to other games and datasets.

## 1. Import Libraries and Set Up Environment
Import all required libraries, set random seeds, and configure device (CPU/GPU).

In [1]:
import os, math, random
from dataclasses import dataclass
from typing import Tuple, Any, List
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader
import sys
import os
sys.path.append(os.path.join("..", "src"))
from exploretinyrm.trm import TRM, TRMConfig
def set_seed(seed: int = 123):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
set_seed(123)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


## 2. AMP and EMA Utilities
Define automatic mixed precision (AMP) and exponential moving average (EMA) utility functions and classes.

In [2]:
# --- AMP setup (PyTorch 2.x forward-compatible) ---
try:
    from torch.amp import autocast as _autocast, GradScaler as _GradScaler
    _USE_TORCH_AMP = True
except ImportError:
    from torch.cuda.amp import autocast as _autocast, GradScaler as _GradScaler
    _USE_TORCH_AMP = False

def make_grad_scaler(is_cuda: bool):
    if _USE_TORCH_AMP:
        try:
            return _GradScaler("cuda", enabled=is_cuda)
        except TypeError:
            return _GradScaler(enabled=is_cuda)
    else:
        return _GradScaler(enabled=is_cuda)

def amp_autocast(is_cuda: bool, use_amp: bool):
    if _USE_TORCH_AMP:
        try:
            return _autocast(device_type="cuda", enabled=(is_cuda and use_amp))
        except TypeError:
            return _autocast(enabled=(is_cuda and use_amp))
    else:
        return _autocast(enabled=(is_cuda and use_amp))

# --- EMA utility (training-time only) ---
class EMA:
    def __init__(self, model: torch.nn.Module, decay: float = 0.999):
        self.decay = decay
        self.shadow = {
            name: param.detach().clone()
            for name, param in model.named_parameters()
            if param.requires_grad
        }

    def update(self, model: torch.nn.Module) -> None:
        d = self.decay
        with torch.no_grad():
            for name, param in model.named_parameters():
                if not param.requires_grad:
                    continue
                self.shadow[name].mul_(d).add_(param.detach(), alpha=1.0 - d)

    def copy_to(self, model: torch.nn.Module) -> None:
        with torch.no_grad():
            for name, param in model.named_parameters():
                if name in self.shadow:
                    param.copy_(self.shadow[name])

from contextlib import contextmanager

@contextmanager
def use_ema_weights(model: torch.nn.Module, ema: EMA):
    backup = {
        name: param.detach().clone()
        for name, param in model.named_parameters()
        if param.requires_grad
    }
    ema.copy_to(model)
    try:
        yield
    finally:
        with torch.no_grad():
            for name, param in model.named_parameters():
                if name in backup:
                    param.copy_(backup[name])

## 3. Sudoku 4x4 Dataset Preparation
Implement dataset generation, including solution permutation, puzzle masking, and PyTorch Dataset/DataLoader setup.

In [3]:
class GameDataset(Dataset):
    """Base class for game datasets. Subclass and implement _generate_sample."""
    def __init__(self, n_samples: int, seed: int = 0):
        self.rng = np.random.default_rng(seed)
        self.samples = [self._generate_sample() for _ in range(n_samples)]
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx): return self.samples[idx]
    def _generate_sample(self): raise NotImplementedError()

# --- Sudoku4x4 Dataset ---
SIDE = 4
BASE = 2
INPUT_PAD = 0               # 0 marks blank in the INPUT ONLY
INPUT_TOKENS = SIDE + 1     # {0..4} for inputs
OUTPUT_TOKENS = SIDE        # {0..3} for outputs (represents digits 1..4)

BASE_SOLUTION = np.array([
    [1, 2, 3, 4],
    [3, 4, 1, 2],
    [2, 1, 4, 3],
    [4, 3, 2, 1],
], dtype=np.int64)

def permute_solution(board: np.ndarray, rng: np.random.Generator) -> np.ndarray:
    b = BASE; s = SIDE
    row_idx = []
    bands = [list(range(g*b, (g+1)*b)) for g in range(b)]
    for band in bands:
        rng.shuffle(band); row_idx.extend(band)
    board = board[row_idx, :]
    col_idx = []
    stacks = [list(range(g*b, (g+1)*b)) for g in range(b)]
    for stack in stacks:
        rng.shuffle(stack); col_idx.extend(stack)
    board = board[:, col_idx]
    band_order = list(range(b)); rng.shuffle(band_order)
    row_idx = []
    for g in band_order: row_idx.extend(list(range(g*b, (g+1)*b)))
    board = board[row_idx, :]
    stack_order = list(range(b)); rng.shuffle(stack_order)
    col_idx = []
    for g in stack_order: col_idx.extend(list(range(g*b, (g+1)*b)))
    board = board[:, col_idx]
    digits = np.arange(1, s+1); rng.shuffle(digits)
    mapping = {i+1: digits[i] for i in range(s)}
    return np.vectorize(lambda v: mapping[v])(board)

def make_puzzle(solution: np.ndarray, p_blank: float, rng: np.random.Generator) -> np.ndarray:
    mask = rng.random(solution.shape) < p_blank
    puzzle = solution.copy()
    puzzle[mask] = INPUT_PAD
    return puzzle

class Sudoku4x4(GameDataset):
    def __init__(self, n_samples: int, p_blank: float = 0.5, seed: int = 0):
        self.p_blank = p_blank
        super().__init__(n_samples, seed)
    def _generate_sample(self):
        sol = permute_solution(BASE_SOLUTION, self.rng)
        puz = make_puzzle(sol, p_blank=self.p_blank, rng=self.rng)
        x_tokens = puz.reshape(-1).astype(np.int64)         # [16], values in {0..4}
        y_digits = sol.reshape(-1).astype(np.int64)         # [16], values in {1..4}
        y_tokens = (y_digits - 1)                           # map to {0..3} for CE
        return torch.from_numpy(x_tokens), torch.from_numpy(y_tokens)

def get_loaders(n_train=512, n_val=128, batch_size=16, p_blank=0.45, seed=123):
    ds_tr = Sudoku4x4(n_train, p_blank=p_blank, seed=seed)
    ds_va = Sudoku4x4(n_val,   p_blank=p_blank, seed=seed+1)
    return (
        DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True),
        DataLoader(ds_va, batch_size=batch_size, shuffle=False, pin_memory=True)
    )

train_loader, val_loader = get_loaders(
    n_train=2048,
    n_val=512,
    batch_size=16,
    p_blank=0.50,
    seed=123
)

## 4. Model Configuration and Initialization
Configure TRM model parameters, instantiate the model, optimizer, scaler, and EMA.

In [4]:
D_MODEL = 128
SEQ_LEN = SIDE * SIDE
N_SUP   = 16
N       = 6
T       = 3
USE_ATT = False

cfg = TRMConfig(
    input_vocab_size=INPUT_TOKENS,
    output_vocab_size=OUTPUT_TOKENS,
    seq_len=SEQ_LEN,
    d_model=D_MODEL,
    n_layers=2,
    use_attention=USE_ATT,
    n_heads=8,
    dropout=0.0,
    mlp_ratio=4.0,
    token_mlp_ratio=2.0,
    n=N,
    T=T,
    k_last_ops=None,
    stabilize_input_sums=True
)

model = TRM(cfg).to(device)
print("Params (M):", sum(p.numel() for p in model.parameters())/1e6)

optimizer = torch.optim.AdamW(
    model.parameters(), lr=3e-4, weight_decay=0.0, betas=(0.9, 0.95)
)

scaler = make_grad_scaler(device.type == "cuda")
ema = EMA(model, decay=0.999)

NVIDIA GeForce RTX 5060 Ti with CUDA capability sm_120 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_50 sm_60 sm_61 sm_70 sm_75 sm_80 sm_86 sm_90 sm_37 compute_37.
If you want to use the NVIDIA GeForce RTX 5060 Ti GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/



Params (M): 0.397312


## 5. Sanity Checks on Data
Run assertions to verify input and label ranges for the Sudoku dataset.

In [5]:
bx, by = next(iter(train_loader))
assert bx.min().item() >= 0 and bx.max().item() <= SIDE
assert by.min().item() >= 0 and by.max().item() < SIDE
print("Sanity OK: x in [0,SIDE], y in [0,SIDE-1]")

Sanity OK: x in [0,SIDE], y in [0,SIDE-1]


## 6. Training and Evaluation Functions
Define training and evaluation functions, including loss calculation, metric reporting, and EMA evaluation.

In [6]:
def exact_match_from_logits(logits: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    preds = logits.argmax(dim=-1)
    return (preds == y_true).all(dim=1).float()

def token_ce_loss(logits: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    B, L, V = logits.shape
    return F.cross_entropy(logits.reshape(B*L, V), y_true.reshape(B*L))

def train_one_epoch(
    model: TRM,
    loader: DataLoader,
    optimizer,
    scaler,
    epoch: int,
    use_amp: bool = True,
    ema: "EMA | None" = None
):
    model.train()
    total_ce, total_halt, total_em, total_steps = 0.0, 0.0, 0.0, 0
    for x_tokens, y_true in loader:
        x_tokens = x_tokens.to(device, non_blocking=True)
        y_true   = y_true.to(device,   non_blocking=True)
        y_state, z_state = model.init_state(batch_size=x_tokens.size(0), device=device)
        for _ in range(N_SUP):
            optimizer.zero_grad(set_to_none=True)
            y_state, z_state, logits, halt_logit = model.forward_step(
                x_tokens, y=y_state, z=z_state, n=N, T=T, k_last_ops=None
            )
            loss_ce = F.cross_entropy(logits.float().reshape(-1, OUTPUT_TOKENS), y_true.reshape(-1))
            with torch.no_grad():
                em = exact_match_from_logits(logits, y_true)
            loss_halt = F.binary_cross_entropy_with_logits(halt_logit.float(), em)
            loss = loss_ce + loss_halt
            if use_amp:
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                if ema is not None:
                    ema.update(model)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                if ema is not None:
                    ema.update(model)
            total_ce   += loss_ce.detach().item()
            total_halt += loss_halt.detach().item()
            total_em   += em.mean().item()
            total_steps += 1
    print(f"Epoch {epoch:02d} | CE {total_ce/max(1,total_steps):.4f} | HaltBCE {total_halt/max(1,total_steps):.4f} | Exact-match {total_em/max(1,total_steps):.3f}")

@torch.no_grad()
def evaluate(model: TRM, loader: DataLoader, n_sup_eval: int = N_SUP):
    model.eval()
    em_list, cell_acc_list = [], []
    for x_tokens, y_true in loader:
        x_tokens = x_tokens.to(device)
        y_true   = y_true.to(device)
        y_state, z_state = model.init_state(batch_size=x_tokens.size(0), device=device)
        for _ in range(n_sup_eval):
            y_state, z_state, logits, halt_logit = model.forward_step(
                x_tokens, y=y_state, z=z_state, n=N, T=T, k_last_ops=None
            )
        preds = logits.argmax(dim=-1)
        em = (preds == y_true).all(dim=1).float()
        cell_acc = (preds == y_true).float().mean(dim=1)
        em_list.append(em); cell_acc_list.append(cell_acc)
    em = torch.cat(em_list).mean().item()
    cell_acc = torch.cat(cell_acc_list).mean().item()
    print(f"Validation | Exact-match {em:.3f} | Cell accuracy {cell_acc:.3f}")
    return em, cell_acc

@torch.no_grad()
def evaluate_with_ema(model: TRM, ema: EMA, loader: DataLoader, n_sup_eval: int = N_SUP):
    with use_ema_weights(model, ema):
        return evaluate(model, loader, n_sup_eval=n_sup_eval)

## 7. Single Batch Forward and Training Step Check
Perform a forward-only check and a single training step to verify model and gradient finiteness.

In [7]:
x_tokens, y_true = next(iter(train_loader))
x_tokens = x_tokens.to(device)
y_true   = y_true.to(device)

# forward-only check (no training, no AMP)
model.eval()
with torch.no_grad():
    y0, z0 = model.init_state(batch_size=x_tokens.size(0), device=device)
    y1, z1, logits, halt_logit = model.forward_step(x_tokens, y=y0, z=z0, n=N, T=T, k_last_ops=None)
print("Forward-only finiteness:",
      "y1", torch.isfinite(y1).all().item(),
      "z1", torch.isfinite(z1).all().item(),
      "logits", torch.isfinite(logits).all().item(),
      "halt_logit", torch.isfinite(halt_logit).all().item())

# single training step in full FP32 (no AMP, tiny LR, no weight decay)
model.train()
opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.0)
opt.zero_grad(set_to_none=True)
y0, z0 = model.init_state(batch_size=x_tokens.size(0), device=device)
y1, z1, logits, halt_logit = model.forward_step(x_tokens, y=y0, z=z0, n=N, T=T, k_last_ops=None)
loss_ce = F.cross_entropy(logits.float().reshape(-1, OUTPUT_TOKENS), y_true.reshape(-1))
em = (logits.argmax(dim=-1) == y_true).all(dim=1).float()
loss_halt = F.binary_cross_entropy_with_logits(halt_logit.float(), em)
loss = loss_ce + loss_halt
print("Pre-backward finiteness:",
      "loss", torch.isfinite(loss).item(),
      "loss_ce", torch.isfinite(loss_ce).item(),
      "loss_halt", torch.isfinite(loss_halt).item())
loss.backward()
all_grads_finite = True
for n, p in model.named_parameters():
    if p.grad is None:
        continue
    if not torch.isfinite(p.grad).all():
        print("Non-finite grad in:", n)
        all_grads_finite = False
        break
print("Gradients finite:", all_grads_finite)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
model.eval()
with torch.no_grad():
    y0, z0 = model.init_state(batch_size=x_tokens.size(0), device=device)
    y2, z2, logits2, halt2 = model.forward_step(x_tokens, y=y0, z=z0, n=N, T=T, k_last_ops=None)
print("Post-step forward finiteness:",
      "y2", torch.isfinite(y2).all().item(),
      "z2", torch.isfinite(z2).all().item(),
      "logits2", torch.isfinite(logits2).all().item(),
      "halt2", torch.isfinite(halt2).all().item())

Forward-only finiteness: y1 True z1 True logits True halt_logit True
Pre-backward finiteness: loss True loss_ce True loss_halt True
Gradients finite: True
Post-step forward finiteness: y2 True z2 True logits2 True halt2 True


## 8. Training Loop
Run the main training loop for several epochs, reporting metrics for both raw and EMA weights.

In [8]:
EPOCHS = 2
for epoch in range(1, EPOCHS+1):
    train_one_epoch(model, train_loader, optimizer, scaler, epoch, use_amp=False, ema=ema)
    em_raw, cell_raw = evaluate(model, val_loader, n_sup_eval=N_SUP)
    em_ema, cell_ema = evaluate_with_ema(model, ema, val_loader)
    print(f"Validation (raw) | EM {em_raw:.3f} | Cell {cell_raw:.3f}")
    print(f"Validation (EMA) | EM {em_ema:.3f} | Cell {cell_ema:.3f}")

Epoch 01 | CE 0.4201 | HaltBCE 0.3554 | Exact-match 0.218
Validation | Exact-match 0.373 | Cell accuracy 0.902
Validation | Exact-match 0.373 | Cell accuracy 0.902
Validation | Exact-match 0.209 | Cell accuracy 0.874
Validation (raw) | EM 0.373 | Cell 0.902
Validation (EMA) | EM 0.209 | Cell 0.874
Validation | Exact-match 0.209 | Cell accuracy 0.874
Validation (raw) | EM 0.373 | Cell 0.902
Validation (EMA) | EM 0.209 | Cell 0.874
Epoch 02 | CE 0.2328 | HaltBCE 0.4928 | Exact-match 0.477
Epoch 02 | CE 0.2328 | HaltBCE 0.4928 | Exact-match 0.477
Validation | Exact-match 0.473 | Cell accuracy 0.918
Validation | Exact-match 0.473 | Cell accuracy 0.918
Validation | Exact-match 0.543 | Cell accuracy 0.930
Validation (raw) | EM 0.473 | Cell 0.918
Validation (EMA) | EM 0.543 | Cell 0.930
Validation | Exact-match 0.543 | Cell accuracy 0.930
Validation (raw) | EM 0.473 | Cell 0.918
Validation (EMA) | EM 0.543 | Cell 0.930


## 9. Model Inference and Visualization
Implement a function to solve and display Sudoku puzzles using the trained model.

In [9]:
@torch.no_grad()
def solve_and_show(model: TRM, loader: DataLoader, n_batches: int = 1):
    model.eval()
    shown = 0
    for x_tokens, y_true in loader:
        x_tokens = x_tokens.to(device)
        y_true   = y_true.to(device)
        y_state, z_state = model.init_state(batch_size=x_tokens.size(0), device=device)
        for _ in range(N_SUP):
            y_state, z_state, logits, halt_logit = model.forward_step(
                x_tokens, y=y_state, z=z_state, n=N, T=T, k_last_ops=None
            )
        preds_tok = logits.argmax(dim=-1).cpu().numpy()
        xs = x_tokens.cpu().numpy()
        ys_tok = y_true.cpu().numpy()
        for i in range(min(4, xs.shape[0])):
            print(f"\nPuzzle {shown+i}:")
            print(xs[i].reshape(4,4))
            print("Pred:")
            print((preds_tok[i] + 1).reshape(4,4))   # tokens -> digits
            print("True:")
            print((ys_tok[i] + 1).reshape(4,4))
        shown += 1
        if shown >= n_batches:
            break

solve_and_show(model, val_loader, n_batches=1)


Puzzle 0:
[[3 0 1 0]
 [0 0 0 2]
 [0 1 2 3]
 [2 0 0 1]]
Pred:
[[3 2 1 4]
 [1 2 3 2]
 [4 1 2 3]
 [2 3 4 1]]
True:
[[3 2 1 4]
 [1 4 3 2]
 [4 1 2 3]
 [2 3 4 1]]

Puzzle 1:
[[0 0 0 4]
 [4 0 3 1]
 [0 4 1 3]
 [0 1 4 2]]
Pred:
[[1 3 2 4]
 [4 2 3 1]
 [2 4 1 3]
 [3 1 4 2]]
True:
[[1 3 2 4]
 [4 2 3 1]
 [2 4 1 3]
 [3 1 4 2]]

Puzzle 2:
[[0 4 0 1]
 [0 0 3 4]
 [0 2 4 3]
 [4 3 1 0]]
Pred:
[[3 4 2 1]
 [2 1 3 4]
 [1 2 4 3]
 [4 3 1 2]]
True:
[[3 4 2 1]
 [2 1 3 4]
 [1 2 4 3]
 [4 3 1 2]]

Puzzle 3:
[[0 0 4 2]
 [0 4 3 1]
 [4 0 1 3]
 [3 0 0 4]]
Pred:
[[1 3 4 2]
 [2 4 3 1]
 [4 2 1 3]
 [3 2 2 4]]
True:
[[1 3 4 2]
 [2 4 3 1]
 [4 2 1 3]
 [3 1 2 4]]


## 10. Forward Finiteness Probe
Probe the model for non-finite values during forward passes to ensure numerical stability.

In [10]:
@torch.no_grad()
def forward_finiteness_probe(model: TRM, x_tokens: torch.Tensor):
    model.eval()
    y, z = model.init_state(batch_size=x_tokens.size(0), device=x_tokens.device)
    x_h = model.embed_input(x_tokens)
    def check(tag, t):
        if not torch.isfinite(t).all():
            raise RuntimeError(f"Non-finite values detected at {tag}")
    for t in range(T):
        for i in range(N):
            h_z = (x_h + y + z) if not model.cfg.stabilize_input_sums else (x_h + y + z) / math.sqrt(3.0)
            check(f"T{t}-hz{i}", h_z)
            z = model._net(h_z)
            check(f"T{t}-z{i}", z)
        h_y = (y + z) if not model.cfg.stabilize_input_sums else (y + z) / math.sqrt(2.0)
        check(f"T{t}-hy", h_y)
        y = model._net(h_y)
        check(f"T{t}-y", y)
    logits = model.output_head(y)
    halt_logit = model.halt_head(y)
    check("logits", logits)
    check("halt_logit", halt_logit)
    print("Forward finiteness probe passed.")

x_tokens, _ = next(iter(train_loader))
forward_finiteness_probe(model, x_tokens.to(device))

Forward finiteness probe passed.


## 11. Embedding and Input Checks
Check input token ranges, embedding matrix finiteness, and embedding lookup results.

In [11]:
x_tokens, _ = next(iter(train_loader))
x_tokens = x_tokens.to(device)
print("x_tokens range:", int(x_tokens.min()), int(x_tokens.max()), x_tokens.dtype)
w = model.input_emb.weight.data
print("embedding finite?", torch.isfinite(w).all().item(), "max|w|:", float(w.abs().max()))
x_h = model.embed_input(x_tokens)
print("x_h finite?", torch.isfinite(x_h).all().item())

x_tokens range: 0 4 torch.int64
embedding finite? True max|w|: 3.2094664573669434
x_h finite? True
