In [18]:
import exploretinyrm as m
print(m.__version__)

%load_ext autoreload
%autoreload 2

import torch
from exploretinyrm.utils import compute_tensor_summary  


0.1.0
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:

import os, math, random
from dataclasses import dataclass
from typing import Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader


import sys
sys.path.append("src")

from exploretinyrm.trm import TRM, TRMConfig

def set_seed(seed: int = 123):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(123)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cuda


In [20]:
# dataset (puzzle, solution)

SIDE = 4
BASE = 2
INPUT_PAD = 0               # 0 marks blank in the INPUT ONLY
INPUT_TOKENS = SIDE + 1     # {0..4} for inputs
OUTPUT_TOKENS = SIDE        # {0..3} for outputs (represents digits 1..4)

BASE_SOLUTION = np.array([
    [1, 2, 3, 4],
    [3, 4, 1, 2],
    [2, 1, 4, 3],
    [4, 3, 2, 1],
], dtype=np.int64)

def permute_solution(board: np.ndarray, rng: np.random.Generator) -> np.ndarray:
    """random legal permutations: rows/cols within bands/stacks, swap bands/stacks, digit perm"""
    b = BASE; s = SIDE
    # rows within bands
    row_idx = []
    bands = [list(range(g*b, (g+1)*b)) for g in range(b)]
    for band in bands:
        rng.shuffle(band); row_idx.extend(band)
    board = board[row_idx, :]
    # cols within stacks
    col_idx = []
    stacks = [list(range(g*b, (g+1)*b)) for g in range(b)]
    for stack in stacks:
        rng.shuffle(stack); col_idx.extend(stack)
    board = board[:, col_idx]
    # swap bands
    band_order = list(range(b)); rng.shuffle(band_order)
    row_idx = []
    for g in band_order: row_idx.extend(list(range(g*b, (g+1)*b)))
    board = board[row_idx, :]
    # swap stacks
    stack_order = list(range(b)); rng.shuffle(stack_order)
    col_idx = []
    for g in stack_order: col_idx.extend(list(range(g*b, (g+1)*b)))
    board = board[:, col_idx]
    # permute digits 1..SIDE
    digits = np.arange(1, s+1); rng.shuffle(digits)
    mapping = {i+1: digits[i] for i in range(s)}
    return np.vectorize(lambda v: mapping[v])(board)

def make_puzzle(solution: np.ndarray, p_blank: float, rng: np.random.Generator) -> np.ndarray:
    """mask entries with probability p_blank to form the puzzle"""
    mask = rng.random(solution.shape) < p_blank
    puzzle = solution.copy()
    puzzle[mask] = INPUT_PAD
    return puzzle


class Sudoku4x4(Dataset):
    def __init__(self, n_samples: int, p_blank: float = 0.5, seed: int = 0):
        self.rng = np.random.default_rng(seed)
        self.samples = []
        for _ in range(n_samples):
            sol = permute_solution(BASE_SOLUTION, self.rng)     # digits in 1..4
            puz = make_puzzle(sol, p_blank=p_blank, rng=self.rng)
            x_tokens = puz.reshape(-1).astype(np.int64)         # [16], values in {0..4}
            y_digits = sol.reshape(-1).astype(np.int64)         # [16], values in {1..4}
            y_tokens = (y_digits - 1)                           # map to {0..3} for CE
            self.samples.append((x_tokens, y_tokens))

    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        x, y = self.samples[idx]
        return torch.from_numpy(x), torch.from_numpy(y)

def get_loaders(n_train=512, n_val=128, batch_size=16, p_blank=0.45, seed=123):
    # slightly easier p_blank for the first sanity run; you can return to 0.50 after it learns
    ds_tr = Sudoku4x4(n_train, p_blank=p_blank, seed=seed)
    ds_va = Sudoku4x4(n_val,   p_blank=p_blank, seed=seed+1)
    return (
        DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True),
        DataLoader(ds_va, batch_size=batch_size, shuffle=False, pin_memory=True)
    )

train_loader, val_loader = get_loaders(
    n_train=2048, #4096
    n_val=512,
    batch_size=16,
    p_blank=0.50,
    seed=123
)


In [21]:

# ensure labels are in 0 to 3 and inputs are in 0to4
bx, by = next(iter(train_loader))
assert bx.min().item() >= 0 and bx.max().item() <= SIDE
assert by.min().item() >= 0 and by.max().item() < SIDE
print("Sanity OK: x in [0,SIDE], y in [0,SIDE-1]")


Sanity OK: x in [0,SIDE], y in [0,SIDE-1]


In [22]:
# keep SIDE, SEQ_LEN, USE_ATT as before
D_MODEL = 14-28
SEQ_LEN = SIDE * SIDE
N_SUP   = 16          # paper uses "up to 16" for Sudoku
N       = 6  # of z-updates inside a recursion process
T       = 3  # of full recursion processes per supervision step
USE_ATT = False # false for SUDUKU ONLY !!!!!!!


cfg = TRMConfig(
    input_vocab_size=INPUT_TOKENS,  # 5
    output_vocab_size=OUTPUT_TOKENS,  # 4
    seq_len=SEQ_LEN,
    d_model=D_MODEL,
    n_layers=2,   
    use_attention=USE_ATT, # token MLP for SUDUKU
    n_heads=8,        # ignored when use_attention=False; harmless but see note below
    dropout=0.0,
    mlp_ratio=4.0,
    token_mlp_ratio=2.0,
    n=N,
    T=T,
    k_last_ops=None,
    stabilize_input_sums = True #False to be like the paper but here produces NaN !!!
)

model = TRM(cfg).to(device)
print("Params (M):", sum(p.numel() for p in model.parameters())/1e6)



RuntimeError: Trying to create tensor with negative dimension -14: [5, -14]

In [23]:


from torch.amp import autocast, GradScaler
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_


def exact_match_from_logits(logits: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    # logits: [B, L, V], y_true: [B, L]
    preds = logits.argmax(dim=-1)                      # [B, L]
    return (preds == y_true).all(dim=1).float()        # [B]

def token_ce_loss(logits: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    # flatten to (B*L, V) vs (B*L)
    B, L, V = logits.shape
    return F.cross_entropy(logits.reshape(B*L, V), y_true.reshape(B*L))


def train_one_epoch(model: TRM, loader: DataLoader, optimizer, scaler, epoch: int, use_amp: bool = True):
    model.train()
    total_ce, total_halt, total_em, total_steps = 0.0, 0.0, 0.0, 0

    for x_tokens, y_true in loader:
        x_tokens = x_tokens.to(device, non_blocking=True)
        y_true   = y_true.to(device,   non_blocking=True)   # tokens in {0..3}

        y_state, z_state = model.init_state(batch_size=x_tokens.size(0), device=device)

        for _ in range(N_SUP):
            optimizer.zero_grad(set_to_none=True)
            #with autocast("cuda", enabled=(device.type == "cuda" and use_amp)):
            y_state, z_state, logits, halt_logit = model.forward_step(
                x_tokens, y=y_state, z=z_state, n=N, T=T, k_last_ops=None  # truncation helps early training (paper said it is failed idea)
            )
            # stable losses in fp32
            loss_ce   = token_ce_loss(logits.float(), y_true)
            with torch.no_grad():
                em = exact_match_from_logits(logits, y_true)

            loss_halt = F.binary_cross_entropy_with_logits(halt_logit.float(), em)
            loss = loss_ce + loss_halt

            if use_amp:
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer) # unscale before clipping
                clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()

            total_ce   += loss_ce.detach().item()
            total_halt += loss_halt.detach().item()
            total_em   += em.mean().item()
            total_steps += 1

            # if (halt_logit > 0).all():
            #    break

    print(f"Epoch {epoch:02d} | CE {total_ce/max(1,total_steps):.4f} | HaltBCE {total_halt/max(1,total_steps):.4f} | Exact-match {total_em/max(1,total_steps):.3f}")




@torch.no_grad()
def evaluate(model: TRM, loader: DataLoader, n_sup_eval: int = N_SUP):
    model.eval()
    em_list, cell_acc_list = [], []

    for x_tokens, y_true in loader:
        x_tokens = x_tokens.to(device)
        y_true   = y_true.to(device)
        y_state, z_state = model.init_state(batch_size=x_tokens.size(0), device=device)

        for _ in range(n_sup_eval):
            y_state, z_state, logits, halt_logit = model.forward_step(
                x_tokens, y=y_state, z=z_state, n=N, T=T, k_last_ops=None  # make eval consistent
            )

        preds = logits.argmax(dim=-1)
        em = (preds == y_true).all(dim=1).float()
        cell_acc = (preds == y_true).float().mean(dim=1)
        em_list.append(em); cell_acc_list.append(cell_acc)

    em = torch.cat(em_list).mean().item()
    cell_acc = torch.cat(cell_acc_list).mean().item()
    print(f"Validation | Exact-match {em:.3f} | Cell accuracy {cell_acc:.3f}")
    return em, cell_acc




optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.0, betas=(0.9, 0.95))
scaler = GradScaler("cuda", enabled=(device.type == "cuda"))



In [24]:

# one batch
x_tokens, y_true = next(iter(train_loader))
x_tokens = x_tokens.to(device)
y_true   = y_true.to(device)

# forward-only check (no training, no AMP)
model.eval()
with torch.no_grad():
    y0, z0 = model.init_state(batch_size=x_tokens.size(0), device=device)
    y1, z1, logits, halt_logit = model.forward_step(x_tokens, y=y0, z=z0, n=N, T=T, k_last_ops=None)
print("Forward-only finiteness:",
      "y1", torch.isfinite(y1).all().item(),
      "z1", torch.isfinite(z1).all().item(),
      "logits", torch.isfinite(logits).all().item(),
      "halt_logit", torch.isfinite(halt_logit).all().item())

# single training step in full FP32 (no AMP, tiny LR, no weight decay)
model.train()
opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.0)
opt.zero_grad(set_to_none=True)

y0, z0 = model.init_state(batch_size=x_tokens.size(0), device=device)
y1, z1, logits, halt_logit = model.forward_step(x_tokens, y=y0, z=z0, n=N, T=T, k_last_ops=None)

# compute losses in fp32
loss_ce = F.cross_entropy(logits.float().reshape(-1, OUTPUT_TOKENS), y_true.reshape(-1))
em = (logits.argmax(dim=-1) == y_true).all(dim=1).float()
loss_halt = F.binary_cross_entropy_with_logits(halt_logit.float(), em)
loss = loss_ce + loss_halt

print("Pre-backward finiteness:",
      "loss", torch.isfinite(loss).item(),
      "loss_ce", torch.isfinite(loss_ce).item(),
      "loss_halt", torch.isfinite(loss_halt).item())

loss.backward()

# check grads
all_grads_finite = True
for n, p in model.named_parameters():
    if p.grad is None:
        continue
    if not torch.isfinite(p.grad).all():
        print("Non-finite grad in:", n)
        all_grads_finite = False
        break
print("Gradients finite:", all_grads_finite)

# clip and step
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()

# forward again after one small step
model.eval()
with torch.no_grad():
    y0, z0 = model.init_state(batch_size=x_tokens.size(0), device=device)
    y2, z2, logits2, halt2 = model.forward_step(x_tokens, y=y0, z=z0, n=N, T=T, k_last_ops=None)
print("Post-step forward finiteness:",
      "y2", torch.isfinite(y2).all().item(),
      "z2", torch.isfinite(z2).all().item(),
      "logits2", torch.isfinite(logits2).all().item(),
      "halt2", torch.isfinite(halt2).all().item())



Forward-only finiteness: y1 True z1 True logits True halt_logit True
Pre-backward finiteness: loss True loss_ce True loss_halt True
Gradients finite: True
Post-step forward finiteness: y2 True z2 True logits2 True halt2 True


In [25]:
# train+eval
EPOCHS = 5
for epoch in range(1, EPOCHS+1):
    train_one_epoch(model, train_loader, optimizer, scaler, epoch, use_amp=False)
    evaluate(model, val_loader, n_sup_eval=N_SUP)



Epoch 01 | CE 0.2018 | HaltBCE 0.4367 | Exact-match 0.621
Validation | Exact-match 0.609 | Cell accuracy 0.928
Epoch 02 | CE 0.1962 | HaltBCE 0.4472 | Exact-match 0.610
Validation | Exact-match 0.637 | Cell accuracy 0.929
Epoch 03 | CE 0.1807 | HaltBCE 0.4153 | Exact-match 0.641
Validation | Exact-match 0.621 | Cell accuracy 0.927
Epoch 04 | CE 0.2008 | HaltBCE 0.4212 | Exact-match 0.619
Validation | Exact-match 0.650 | Cell accuracy 0.933
Epoch 05 | CE 0.2081 | HaltBCE 0.4337 | Exact-match 0.624
Validation | Exact-match 0.553 | Cell accuracy 0.917


In [26]:
@torch.no_grad()
def solve_and_show(model: TRM, loader: DataLoader, n_batches: int = 1):
    model.eval()
    shown = 0
    for x_tokens, y_true in loader:
        x_tokens = x_tokens.to(device)
        y_true   = y_true.to(device)
        y_state, z_state = model.init_state(batch_size=x_tokens.size(0), device=device)
        for _ in range(N_SUP):
            y_state, z_state, logits, halt_logit = model.forward_step(
                x_tokens, y=y_state, z=z_state, n=N, T=T, k_last_ops=None
            )
        preds_tok = logits.argmax(dim=-1).cpu().numpy()
        xs = x_tokens.cpu().numpy()
        ys_tok = y_true.cpu().numpy()
        for i in range(min(4, xs.shape[0])):
            print(f"\nPuzzle {shown+i}:")
            print(xs[i].reshape(4,4))
            print("Pred:")
            print((preds_tok[i] + 1).reshape(4,4))   # tokens -> digits
            print("True:")
            print((ys_tok[i] + 1).reshape(4,4))
        shown += 1
        if shown >= n_batches:
            break

solve_and_show(model, val_loader, n_batches=1)



Puzzle 0:
[[3 0 1 0]
 [0 0 0 2]
 [0 1 2 3]
 [2 0 0 1]]
Pred:
[[3 2 1 4]
 [1 4 3 2]
 [4 1 2 3]
 [2 3 4 1]]
True:
[[3 2 1 4]
 [1 4 3 2]
 [4 1 2 3]
 [2 3 4 1]]

Puzzle 1:
[[0 0 0 4]
 [4 0 3 1]
 [0 4 1 3]
 [0 1 4 2]]
Pred:
[[1 3 2 4]
 [4 2 3 1]
 [2 4 1 3]
 [3 1 4 2]]
True:
[[1 3 2 4]
 [4 2 3 1]
 [2 4 1 3]
 [3 1 4 2]]

Puzzle 2:
[[0 4 0 1]
 [0 0 3 4]
 [0 2 4 3]
 [4 3 1 0]]
Pred:
[[3 4 2 1]
 [2 1 3 4]
 [1 2 4 3]
 [4 3 1 2]]
True:
[[3 4 2 1]
 [2 1 3 4]
 [1 2 4 3]
 [4 3 1 2]]

Puzzle 3:
[[0 0 4 2]
 [0 4 3 1]
 [4 0 1 3]
 [3 0 0 4]]
Pred:
[[1 3 4 2]
 [2 4 3 1]
 [4 2 1 3]
 [3 1 2 4]]
True:
[[1 3 4 2]
 [2 4 3 1]
 [4 2 1 3]
 [3 1 2 4]]
