In [None]:
# ! uv -q pip install transformers datasets accelerate bitsandbytes einops
# ! pip install wandb

### Phase-1 : - Setup & config

In [1]:
import os, re, math, random
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

import wandb



In [2]:
class Config:
    # HF bits
    baseline_model_name = "Qwen/Qwen2.5-0.5B-Instruct"  # or "meta-llama/Llama-3.2-1B-Instruct"
    dataset_name = "openai/gsm8k"
    dataset_subset = "main"   # gsm8k has "main", "socratic"

    # Data sizes
    max_train_examples = 5000    # for TRM specialist training
    max_eval_examples  = 300     # for both baseline & TRM

    # TRM model hyperparams (tiny)
    vocab_size = 2000            # if using a BPE tokenizer; or small char vocab if you prefer
    d_model = 256
    n_layers = 2                 # for reasoning block
    n_heads = 4
    max_seq_len = 256
    H_cycles = 4                 # outer cycles (reasoning loops)
    L_cycles = 2                 # inner loops
    halt_max_steps = 4           # max ACT steps per example

    # Training
    batch_size = 32
    lr = 3e-4
    num_epochs = 5
    device = "cuda" if torch.cuda.is_available() else "cpu"

cfg = Config()
device = torch.device(cfg.device)

wandb.init(
    project="trm-math-poc",
    config=vars(cfg),
)


[34m[1mwandb[0m: Currently logged in as: [33mvedaangchopra[0m ([33mvedaangchopra_gatech[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


### Phase 2 â€“ Load Dataset & Extract Numeric Answers

In [3]:
raw_ds = load_dataset(cfg.dataset_name, cfg.dataset_subset)
raw_train = raw_ds["train"]
raw_test  = raw_ds["test"]

print(raw_train[0])

{'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?', 'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72'}


In [4]:
ANSWER_PATTERN = re.compile(r"(-?\d+\.?\d*)")

def extract_numeric_answer(ans_str: str):
    # GSM8K uses "#### 42" at the end â€” we grab the last number
    matches = ANSWER_PATTERN.findall(ans_str)
    return matches[-1] if matches else None

In [5]:
def filter_and_prepare_split(split):
    data = []
    for ex in split:
        a = extract_numeric_answer(ex["answer"])
        if a is None:
            continue
        data.append({
            "question": ex["question"],
            "answer": a,
        })
    return data

train_data = filter_and_prepare_split(raw_train)
test_data  = filter_and_prepare_split(raw_test)

len(train_data), len(test_data)


(7473, 1319)

In [6]:
random.seed(42)
random.shuffle(train_data)
random.shuffle(test_data)

train_data = train_data[:cfg.max_train_examples]
eval_data  = test_data[:cfg.max_eval_examples]


### Phase-3: - Baseline: Qwen / Llama evaluation

In [7]:
tokenizer = AutoTokenizer.from_pretrained(cfg.baseline_model_name)
model = AutoModelForCausalLM.from_pretrained(
    cfg.baseline_model_name,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    device_map="auto",
)
model.eval()


`torch_dtype` is deprecated! Use `dtype` instead!


Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((896,), eps=1e-06)
    (rotary_emb): Qwen2

In [8]:
def make_prompt(q: str) -> str:
    return (
        "You are a math solver.\n"
        "Solve the following problem step by step in your head, "
        "but only output the final numeric answer.\n\n"
        f"Question: {q}\n"
        "Answer:"
    )

@torch.no_grad()
def generate_answer_llm(question: str, max_new_tokens: int = 64) -> str:
    prompt = make_prompt(question)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    out = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        temperature=0.0,
        pad_token_id=tokenizer.eos_token_id,
    )
    text = tokenizer.decode(out[0], skip_special_tokens=True)
    # Take only text after "Answer:"
    if "Answer:" in text:
        text = text.split("Answer:", 1)[1]
    return text.strip()


In [9]:
def eval_llm_on_gsm8k(data, n_eval=200):
    correct = 0
    records = []
    for i, ex in enumerate(data[:n_eval]):
        q, gold = ex["question"], ex["answer"]
        pred_str = generate_answer_llm(q)
        pred_num = extract_numeric_answer(pred_str)
        is_correct = (pred_num == gold)
        correct += int(is_correct)

        records.append({
            "question": q,
            "gold": gold,
            "raw_pred": pred_str,
            "pred_num": pred_num,
            "correct": is_correct,
        })

        if (i + 1) % 10 == 0:
            acc = correct / (i + 1)
            print(f"[LLM] {i+1}/{n_eval} acc={acc:.3f}")
            wandb.log({"baseline/step": i + 1, "baseline/acc": acc})

    acc = correct / min(len(data), n_eval)
    wandb.log({"baseline/final_acc": acc})
    return acc, records

baseline_acc, baseline_records = eval_llm_on_gsm8k(eval_data)
print("Baseline LLM acc:", baseline_acc)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


[LLM] 10/200 acc=0.000
[LLM] 20/200 acc=0.000
[LLM] 30/200 acc=0.000
[LLM] 40/200 acc=0.000
[LLM] 50/200 acc=0.040
[LLM] 60/200 acc=0.033
[LLM] 70/200 acc=0.029
[LLM] 80/200 acc=0.025
[LLM] 90/200 acc=0.022
[LLM] 100/200 acc=0.020
[LLM] 110/200 acc=0.018
[LLM] 120/200 acc=0.025
[LLM] 130/200 acc=0.031
[LLM] 140/200 acc=0.036
[LLM] 150/200 acc=0.033
[LLM] 160/200 acc=0.031
[LLM] 170/200 acc=0.029
[LLM] 180/200 acc=0.033
[LLM] 190/200 acc=0.032
[LLM] 200/200 acc=0.030
Baseline LLM acc: 0.03


### Phase-4: - TRM-style tiny model (specialist)

In [10]:
# ---- CHAR VOCAB SETUP (REPLACE YOUR CURRENT CHARSET/VOCAB + ENCODE/DECODE) ----

# Build charset from train_data
charset = set()
for ex in train_data:
    charset.update(ex["question"])
    charset.update(ex["answer"])

special_tokens = ["<pad>", "<bos>", "<eos>"]
itos = special_tokens + sorted(list(charset))
stoi = {ch: i for i, ch in enumerate(itos)}

PAD_ID = stoi["<pad>"]
BOS_ID = stoi["<bos>"]
EOS_ID = stoi["<eos>"]

cfg.vocab_size = len(itos)

# Use a fixed max_seq_len based on your earlier analysis or keep what you had
# Just make sure cfg.max_seq_len is set BEFORE creating the model.
# Example:
# max_q_len = max(len(ex["question"]) for ex in train_data)
# max_a_len = max(len(ex["answer"]) for ex in train_data)
# cfg.max_seq_len = max_q_len + max_a_len + 10  # a bit of slack


def encode_text(s: str, max_len: int) -> list[int]:
    """
    Encode string into fixed-length [BOS, ..., EOS/PAD] with safe clamping.
    """
    # Map chars to ids, unknown -> PAD
    ids = [BOS_ID] + [stoi.get(ch, PAD_ID) for ch in s]

    # Clamp to valid vocab range just in case
    ids = [min(max(i, 0), cfg.vocab_size - 1) for i in ids]

    # Reserve last slot for EOS
    ids = ids[: max_len - 1] + [EOS_ID]

    # Pad if needed
    if len(ids) < max_len:
        ids += [PAD_ID] * (max_len - len(ids))

    return ids


def decode_ids(ids: list[int]) -> str:
    """
    Decode ids back into string, skipping special tokens.
    """
    chars = []
    for idx in ids:
        if idx < 0 or idx >= len(itos):
            continue
        ch = itos[idx]
        if ch in ("<pad>", "<bos>", "<eos>"):
            continue
        chars.append(ch)
    return "".join(chars)


In [11]:
max_q_len = max(len(ex["question"]) for ex in train_data)
max_a_len = max(len(ex["answer"]) for ex in train_data)
cfg.max_seq_len = max_q_len + max_a_len + 4


### Dataset class for TRM

In [12]:
class GSM8KCharDataset(Dataset):
    def __init__(self, data, max_len, answer_only_loss=True):
        self.data = data
        self.max_len = max_len
        self.answer_only_loss = answer_only_loss

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        ex = self.data[idx]
        q = ex["question"]
        a = ex["answer"]
        # You can tweak formatting; simple version:
        prompt = q + " Answer:"
        full = prompt + " " + a

        ids = encode_text(full, self.max_len)
        input_ids = torch.tensor(ids, dtype=torch.long)

        # LM target = next char
        labels = input_ids.clone()
        # Optionally, we can ignore loss on non-answer parts
        if self.answer_only_loss:
            # ignore everything before start of answer
            ans_start = len(encode_text(prompt + " ", self.max_len)) - 1
            labels[:ans_start] = -100  # ignore index
        return {"input_ids": input_ids, "labels": labels}

train_ds = GSM8KCharDataset(train_data, cfg.max_seq_len, answer_only_loss=True)
eval_ds  = GSM8KCharDataset(eval_data,  cfg.max_seq_len, answer_only_loss=True)

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True)
eval_loader  = DataLoader(eval_ds,  batch_size=cfg.batch_size, shuffle=False)


#### TRM-style block (tiny transformer)

In [13]:
class TRMBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, expansion: float = 4.0):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(d_model, int(d_model * expansion)),
            nn.SiLU(),
            nn.Linear(int(d_model * expansion), d_model),
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, L, D]
        attn_out, _ = self.attn(x, x, x, need_weights=False)
        x = self.norm1(x + attn_out)
        ff_out = self.ff(x)
        x = self.norm2(x + ff_out)
        return x

In [14]:
class TinyMathTRM(nn.Module):
    """
    Tiny TRM-style model with:
      - single latent z
      - H_cycles outer loops, L_cycles inner loops
      - deep recursion with no_grad() on earlier cycles
    """

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg

        self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)

        self.blocks = nn.ModuleList(
            [TRMBlock(cfg.d_model, cfg.n_heads) for _ in range(cfg.n_layers)]
        )

        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size)
        self.q_head = nn.Linear(cfg.d_model, 1)  # halting logit from first position

        # Learned initial latent state (broadcast later)
        self.z_init = nn.Parameter(
            torch.randn(1, 1, cfg.d_model) / math.sqrt(cfg.d_model)
        )

    def forward_once(
        self,
        input_ids: torch.Tensor,
        z: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        One 'latent recursion' pass with gradients.
        input_ids: [B, L]
        z:         [B, L, D]
        """
        B, L = input_ids.shape
        tok_emb = self.embed(input_ids)  # [B, L, D]

        # Positional embeddings
        positions = torch.arange(L, device=input_ids.device)
        pos_emb = self.pos_emb(positions)[None, :, :].expand(B, L, -1)  # [B, L, D]

        # Combine: token + position + latent
        x = tok_emb + pos_emb + z  # [B, L, D]

        # Pass through a small stack of transformer blocks
        for blk in self.blocks:
            x = blk(x)

        new_z = x
        logits = self.lm_head(x)                     # [B, L, V]
        q_logit = self.q_head(x[:, 0, :]).squeeze(-1)  # [B]
        return new_z, logits, q_logit

    def forward(
        self,
        input_ids: torch.Tensor,
        labels: torch.Tensor | None = None,
        carry_z: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]:
        """
        TRM-style recursion:
          - If carry_z is None, start from z_init.
          - Run H_cycles - 1 passes under no_grad.
          - Run final H-cycle with grads and compute loss if labels provided.
          - Return new_z (detached), logits, q_logit, loss.
        """
        B, L = input_ids.shape

        # Initialise or fix latent state size
        if carry_z is None or carry_z.shape[0] != B or carry_z.shape[1] != L:
            z = self.z_init.expand(B, L, -1)
        else:
            z = carry_z

        # H_cycles - 1 cycles with no_grad (deep recursion, no backprop)
        with torch.no_grad():
            for _ in range(self.cfg.H_cycles - 1):
                for _ in range(self.cfg.L_cycles):
                    z, _, _ = self.forward_once(input_ids, z)

        # Final cycle with gradients
        for _ in range(self.cfg.L_cycles):
            z, logits, q_logit = self.forward_once(input_ids, z)

        new_z = z.detach()

        loss = None
        if labels is not None:
            # Standard LM loss over characters (answer_only mask handled in labels)
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(
                logits.view(-1, self.cfg.vocab_size),
                labels.view(-1),
            )

        return new_z, logits, q_logit, loss

### Training loop with ACT-style supervision + W&B

In [15]:
# ---- TRAINING LOOP (REPLACE YOUR CURRENT TRAINING LOOP) ----

trm_model = TinyMathTRM(cfg).to(device)
optimizer = torch.optim.AdamW(trm_model.parameters(), lr=cfg.lr)

global_step = 0

def evaluate_trm(model, eval_loader, max_batches: int | None = None):
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    with torch.no_grad():
        for b_idx, batch in enumerate(eval_loader):
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)

            _, logits, q_logit, loss = model(input_ids, labels=labels, carry_z=None)
            if loss is None:
                continue

            tokens = (labels != -100).sum().item()
            total_loss += loss.item() * tokens
            total_tokens += tokens

            if max_batches is not None and (b_idx + 1) >= max_batches:
                break

    avg_loss = total_loss / max(1, total_tokens)
    ppl = math.exp(avg_loss)
    return {"eval_loss": avg_loss, "eval_ppl": ppl}


for epoch in range(cfg.num_epochs):
    trm_model.train()
    carry_z = None

    for batch in train_loader:
        global_step += 1
        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)

        # Reset carry_z if shape mismatches
        if carry_z is None or carry_z.shape[0] != input_ids.shape[0] or carry_z.shape[1] != input_ids.shape[1]:
            carry_z = None

        optimizer.zero_grad()
        carry_z, logits, q_logit, loss = trm_model(input_ids, labels=labels, carry_z=carry_z)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(trm_model.parameters(), 1.0)
        optimizer.step()

        if global_step % 50 == 0:
            wandb.log({"trm/train_loss": loss.item(), "step": global_step})
            print(f"Epoch {epoch} step {global_step}: loss={loss.item():.4f}")

        if global_step % 500 == 0:
            metrics = evaluate_trm(trm_model, eval_loader, max_batches=50)
            wandb.log({f"trm/{k}": v for k, v in metrics.items()} | {"step": global_step})
            print("Eval:", metrics)


Epoch 0 step 50: loss=0.0011
Epoch 0 step 100: loss=0.0005
Epoch 0 step 150: loss=0.0003
Epoch 1 step 200: loss=0.0002
Epoch 1 step 250: loss=0.0002
Epoch 1 step 300: loss=0.0001
Epoch 2 step 350: loss=0.0001
Epoch 2 step 400: loss=0.0001
Epoch 2 step 450: loss=0.0001
Epoch 3 step 500: loss=0.0001
Eval: {'eval_loss': 5.334195452936304e-05, 'eval_ppl': 1.0000533433772367}
Epoch 3 step 550: loss=0.0000
Epoch 3 step 600: loss=0.0000
Epoch 4 step 650: loss=0.0000
Epoch 4 step 700: loss=0.0000
Epoch 4 step 750: loss=0.0000


#### Adding halting loss (more TRM-like, optional)

In [16]:
### Need to implement halting loss

### Measuring TRM accuracy on GSM8K subset

In [17]:
# ---- GENERATION + EVAL HELPERS (REPLACE YOUR CURRENT VERSIONS) ----

@torch.no_grad()
def generate_answer_trm(model: TinyMathTRM, question: str, max_answer_len: int = 16) -> str:
    """
    Fixed-length generation: we don't change sequence length,
    we fill answer chars into PAD slots in-place.
    """
    model.eval()
    prompt = question + " Answer:"

    # 1) Encode prompt into fixed-length ids
    prompt_ids = encode_text(prompt, cfg.max_seq_len)  # [L]

    # 2) Find first PAD position: where the answer starts
    try:
        pad_idx = prompt_ids.index(PAD_ID)
    except ValueError:
        # Prompt filled everything; reserve tail for answer
        pad_idx = cfg.max_seq_len - max_answer_len

    ans_start = pad_idx

    # 3) Tensor [1, L]
    input_ids = torch.tensor([prompt_ids], device=device)
    carry_z = None

    # 4) Iteratively fill answer positions
    for t in range(max_answer_len):
        carry_z, logits, q_logit, _ = model(input_ids, labels=None, carry_z=carry_z)

        pos = ans_start + t
        if pos >= cfg.max_seq_len:
            break

        prev_pos = max(pos - 1, 0)
        next_id = logits[0, prev_pos, :].argmax().item()

        # ðŸ”’ clamp to valid vocab range to avoid CUDA device-side asserts
        if next_id < 0 or next_id >= cfg.vocab_size:
            next_id = EOS_ID

        input_ids[0, pos] = next_id

        if next_id == EOS_ID:
            break

    # 5) Decode and extract answer substring
    text = decode_ids(input_ids[0].tolist())
    if "Answer:" in text:
        text = text.split("Answer:", 1)[1]
    return text.strip()

In [18]:
def eval_trm_on_gsm8k(model: TinyMathTRM, data, n_eval: int = 200):
    correct = 0
    for i, ex in enumerate(data[:n_eval]):
        q, gold = ex["question"], ex["answer"]
        pred_str = generate_answer_trm(model, q)
        pred_num = extract_numeric_answer(pred_str)
        is_correct = (pred_num == gold)
        correct += int(is_correct)

        if (i + 1) % 20 == 0:
            acc = correct / (i + 1)
            print(f"[TRM] {i+1}/{n_eval} acc={acc:.3f}")
            wandb.log({"trm/acc_eval": acc, "trm/step_eval": i + 1})

    acc = correct / min(len(data), n_eval)
    wandb.log({"trm/final_acc": acc})
    return acc

In [19]:
# Call after training:
trm_acc = eval_trm_on_gsm8k(trm_model, eval_data, n_eval=cfg.max_eval_examples)
print("TRM acc:", trm_acc)

[TRM] 20/300 acc=0.000
[TRM] 40/300 acc=0.000
[TRM] 60/300 acc=0.000
[TRM] 80/300 acc=0.000
[TRM] 100/300 acc=0.000
[TRM] 120/300 acc=0.000
[TRM] 140/300 acc=0.000
[TRM] 160/300 acc=0.000
[TRM] 180/300 acc=0.000
[TRM] 200/300 acc=0.000
[TRM] 220/300 acc=0.000
[TRM] 240/300 acc=0.000
[TRM] 260/300 acc=0.000
[TRM] 280/300 acc=0.000
[TRM] 300/300 acc=0.000
TRM acc: 0.0
