
# HRM Transformer Sidecar (Colab-ready)

This notebook implements a **Transformer Sidecar** that adds **latent recurrent refinement** with **ACT-style halting** to a base encoder, inspired by the *Hierarchical Reasoning Model (HRM)* line of work. It focuses on the *L-module* refinement loop and the key idea of **using the halting policy during evaluation**, which the critique shows is crucial for performance.

**Highlights**  
- One-step gradient (no BPTT through time): we detach latent state at each refinement step.  
- ACT-like halting: train a Q-head and **use it at inference** (stop when `sigmoid(Q_halt)>0.5` or when `Mmax` reached).  
- Pluggable sidecar: can wrap a toy transformer encoder here; you can adapt it to your own model’s hidden states.

> Demo task: a tiny **sequence editing** problem (sorting digits) that benefits from iterative refinement.


In [None]:

#@title Setup
import math, random, time
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)


Device: cpu


In [None]:

#@title Tiny tokenizer for toy tasks
# We'll use tokens 0..9 for digits plus special tokens.
VOCAB = list("0123456789")
stoi = {ch:i for i,ch in enumerate(VOCAB)}
itos = {i:ch for ch,i in stoi.items()}
V = len(VOCAB)

def encode(seq):
    return torch.tensor([stoi[c] for c in seq], dtype=torch.long)

def decode(ids):
    return "".join(itos[int(i)] for i in ids)


In [None]:

#@title Toy dataset: "sort the digits" sequences
# Input: a random permutation of k digits from 0..9 (without replacement)
# Target: the digits sorted ascending as a string

class SortDigitsDataset(torch.utils.data.Dataset):
    def __init__(self, n_samples=20000, length=8, seed=0):
        random.seed(seed)
        self.samples = []
        for _ in range(n_samples):
            digits = random.sample(list("0123456789"), length)
            x = "".join(digits)
            y = "".join(sorted(digits))
            self.samples.append((x,y))
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        x,y = self.samples[idx]
        return encode(x), encode(y)

def collate_batch(batch):
    xs = [x for x,_ in batch]
    ys = [y for _,y in batch]
    X = torch.stack(xs, dim=0)
    Y = torch.stack(ys, dim=0)
    return X, Y

# Small train/val/test splits
train_ds = SortDigitsDataset(n_samples=8000, length=8, seed=1)
val_ds   = SortDigitsDataset(n_samples=1000, length=8, seed=2)
test_ds  = SortDigitsDataset(n_samples=1000, length=8, seed=3)

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True, collate_fn=collate_batch)
val_dl   = torch.utils.data.DataLoader(val_ds, batch_size=128, shuffle=False, collate_fn=collate_batch)
test_dl  = torch.utils.data.DataLoader(test_ds, batch_size=128, shuffle=False, collate_fn=collate_batch)

len(train_ds), len(val_ds), len(test_ds)


(8000, 1000, 1000)

In [None]:

#@title Base tiny Transformer encoder (token -> hidden)
class TinyTransformer(nn.Module):
    def __init__(self, d_model=128, nhead=4, num_layers=2, dim_feedforward=512, dropout=0.1, vocab_size=V):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                                   dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.pos = nn.Parameter(torch.randn(1, 32, d_model) * 0.01)  # supports length<=32
        self.d_model = d_model

    def forward(self, x):
        # x: (B, T) tokens
        B, T = x.shape
        h = self.emb(x) + self.pos[:, :T, :]
        h = self.encoder(h)  # (B, T, d)
        return h

# Prediction head
class TokenHead(nn.Module):
    def __init__(self, d_model=128, vocab_size=V):
        super().__init__()
        self.ln = nn.LayerNorm(d_model)
        self.proj = nn.Linear(d_model, vocab_size)
    def forward(self, h):
        return self.proj(self.ln(h))  # (B,T,V)



## Sidecar HRM (L-module only)

- **Refinement block (L-module):** a small Transformer layer that updates the latent state.
- **Q-head:** predicts a *halt* logit each step. We train with a simple ACT-style loss.
- **One-step gradient:** we detach the state at each step so gradients don't flow across time (no BPTT).
- **Evaluation:** we **use halting** (`sigmoid(Q_halt)>0.5`) and stop early per sequence.


In [None]:
#@title Sidecar HRM module
class LModule(nn.Module):
    def __init__(self, d_model=128, nhead=4, dim_feedforward=512, dropout=0.1):
        super().__init__()
        layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                           dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True)
        self.block = nn.TransformerEncoder(layer, num_layers=1)
        self.ln = nn.LayerNorm(d_model)

    def forward(self, h):
        return self.ln(self.block(h))

class SidecarHRM(nn.Module):
    def __init__(self, base_encoder, token_head, d_model=128, Mmax=6, Mmin=2):
        super().__init__()
        self.base = base_encoder
        self.head = token_head
        self.l_module = LModule(d_model=d_model)
        # Q-head predicts a single logit "halt"
        self.q_head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, 1))
        self.Mmax = Mmax
        self.Mmin = Mmin

    def forward_once(self, x, h=None):
        # returns new h, logits, q_halt_logit per token position (we pool for decision)
        if h is None:
            h = self.base(x)  # (B,T,d)
        h2 = h + self.l_module(h)
        logits = self.head(h2)  # (B,T,V)
        # Aggregate Q over sequence by mean-pooling token features
        q_logit = self.q_head(h2).mean(dim=1)  # (B,1)
        return h2, logits, q_logit.squeeze(1)

    @torch.no_grad()
    def infer(self, x, threshold=0.5):
        self.eval()
        B, T = x.shape
        done = torch.zeros(B, dtype=torch.bool, device=x.device)
        h = None
        steps_taken = torch.zeros(B, dtype=torch.long, device=x.device)
        y_best = None

        for m in range(1, self.Mmax+1):
            h, logits, q = self.forward_once(x, h=h)
            y_hat = logits.argmax(dim=-1)
            # decide halting
            will_halt = (torch.sigmoid(q) > threshold) & (m >= self.Mmin)
            now_done = (~done) & will_halt
            steps_taken[now_done] = m
            done = done | will_halt
            if y_best is None:
                y_best = y_hat.clone()
            # keep last y_hat per sequence
            y_best[~done] = y_hat[~done]
            if done.all():
                break

        # for any not halted, set steps taken to Mmax
        steps_taken[~done] = self.Mmax
        return y_best, steps_taken

    def training_step(self, x, y, ce_weight=1.0, act_weight=0.1):
        # "one-step" gradient: don't backprop through time
        B = x.size(0)
        h = None
        total_ce = 0.0
        total_act = 0.0

        for m in range(1, self.Mmax+1):
            if h is not None:
                h = h.detach()  # no BPTT
            h, logits, q = self.forward_once(x, h=h)
            y_hat = logits.argmax(dim=-1)
            ce = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))

            # simple halt targets: halt if fully correct, else continue
            correct = (y_hat == y).all(dim=1).float()  # (B,)
            if m >= self.Mmin:
                target_halt = correct
            else:
                target_halt = torch.zeros_like(correct)

            act = F.binary_cross_entropy_with_logits(q, target_halt)
            total_ce += ce
            total_act += act

        loss = ce_weight * (total_ce / self.Mmax) + act_weight * (total_act / self.Mmax)
        return loss


In [None]:

#@title Train/eval utilities
@torch.no_grad()
def exact_match_acc(model, dl):
    model.eval()
    n_ok = 0
    n_all = 0
    total_steps = 0
    for X, Y in dl:
        X, Y = X.to(DEVICE), Y.to(DEVICE)
        Y_hat, steps = model.infer(X)
        n_ok += (Y_hat == Y).all(dim=1).sum().item()
        n_all += X.size(0)
        total_steps += steps.sum().item()
    return n_ok / n_all, total_steps / n_all

def train(model, train_dl, val_dl, epochs=8, lr=3e-4):
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    best_val = 0.0
    for ep in range(1, epochs+1):
        model.train()
        running = 0.0
        for X, Y in train_dl:
            X, Y = X.to(DEVICE), Y.to(DEVICE)
            loss = model.training_step(X, Y, ce_weight=1.0, act_weight=0.1)
            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            running += loss.item()
        val_acc, val_steps = exact_match_acc(model, val_dl)
        print(f"[ep {ep:02d}] train_loss={running/len(train_dl):.4f}  val_acc={val_acc:.3f}  avg_steps={val_steps:.2f}")
    return model


In [None]:

#@title Initialize and train
base = TinyTransformer(d_model=128, nhead=4, num_layers=2, dim_feedforward=256).to(DEVICE)
head = TokenHead(d_model=128).to(DEVICE)
model = SidecarHRM(base, head, d_model=128, Mmax=6, Mmin=2).to(DEVICE)

model = train(model, train_dl, val_dl, epochs=8, lr=3e-4)
test_acc, test_steps = exact_match_acc(model, test_dl)
print(f"TEST  acc={test_acc:.3f}  avg_steps={test_steps:.2f}")


[ep 01] train_loss=2.0921  val_acc=0.029  avg_steps=6.00
[ep 02] train_loss=0.5695  val_acc=0.991  avg_steps=3.24
[ep 03] train_loss=0.1439  val_acc=1.000  avg_steps=2.04
[ep 04] train_loss=0.0701  val_acc=1.000  avg_steps=2.00
[ep 05] train_loss=0.0427  val_acc=1.000  avg_steps=2.00
[ep 06] train_loss=0.0272  val_acc=1.000  avg_steps=2.00
[ep 07] train_loss=0.0190  val_acc=1.000  avg_steps=2.00
[ep 08] train_loss=0.0115  val_acc=1.000  avg_steps=2.00
TEST  acc=1.000  avg_steps=2.00


In [None]:

#@title Quick qualitative check
@torch.no_grad()
def show_examples(model, k=5):
    model.eval()
    for i, (X,Y) in enumerate(test_dl):
        X,Y = X.to(DEVICE), Y.to(DEVICE)
        Y_hat, steps = model.infer(X)
        for j in range(min(k, X.size(0))):
            x = decode(X[j].cpu())
            y = decode(Y[j].cpu())
            yh = decode(Y_hat[j].cpu())
            st = int(steps[j].cpu())
            print(f"in : {x} -> out: {yh}  (gold: {y})  steps={st}")
        break

show_examples(model, k=8)


in : 38274906 -> out: 02346789  (gold: 02346789)  steps=2
in : 07481532 -> out: 01234578  (gold: 01234578)  steps=2
in : 87651942 -> out: 12456789  (gold: 12456789)  steps=2
in : 60174825 -> out: 01245678  (gold: 01245678)  steps=2
in : 47653890 -> out: 03456789  (gold: 03456789)  steps=2
in : 51083629 -> out: 01235689  (gold: 01235689)  steps=2
in : 64983725 -> out: 23456789  (gold: 23456789)  steps=2
in : 96325087 -> out: 02356789  (gold: 02356789)  steps=2



## Plugging into your own model (sketch)

If your base model exposes a hidden state `h` (shape `(B,T,d)`), you can:
1. Replace `TinyTransformer` with your encoder.
2. Keep `TokenHead(d_model=...)` or adapt to your head.
3. Feed your tokenized inputs `x` into `SidecarHRM`. The sidecar will refine `h` several times, predict, and **halt early** at inference if confident.

Key switches to try:
- `Mmax` (max refinement steps)
- `Mmin` (minimum steps before halting allowed)
- ACT weight in `training_step(..., act_weight=0.1)`
- Use a stronger L-module (stack 2–3 layers) or a deeper base encoder.



### Why use halting at inference?

The critique observes that **ignoring** the halting policy at evaluation (always running full steps) can **hurt accuracy**; continuing to edit after a correct solution can introduce errors. Enabling ACT halting at inference often **improves** both accuracy and efficiency by stopping when the solution is ready.
