In [6]:
import os
import sys
from pathlib import Path
PROJECT_DIR = Path.cwd().parent
sys.path.append(str(PROJECT_DIR))

In [7]:
import torch
from models.transformer import GPT

In [3]:
from types import SimpleNamespace

cfg = SimpleNamespace(
    # === Architecture ===
    vocab_size = 256,           # small toy vocab
    n_layer = 4,                # total transformer layers
    n_head = 2,                 # small # of heads
    n_kv_head = 1,              # GQA -> 1 KV head
    n_embd = 32,                # embedding dimension
    sequence_len = 32,          # small context length

    # === Adaptive Computation (ACT) ===
    use_adaptive_computation = True,
    n_layers_per_block = 2,     # → 2 AdaptiveBlocks (each with 2 layers)
    max_pondering_steps = 3,    # can repeat up to 3 times per token
    act_threshold = 0.9,        # typical threshold
    halting_penalty = 0.01,     # small τ for now

    # === Training ===
    dropout = 0.0,              # easier to debug deterministically
    bias = False,               # for linear layers
    vocab_pad_id = 0,           # optional, for padding tests
    dtype = "float32",          # keep stable on CPU
    device = "cuda" if torch.cuda.is_available() else "cpu"
)


In [12]:
model = GPT(cfg).to(cfg.device)
print(f"{sum(p.numel() for p in model.parameters())/1e3:.1f}k params")

# dummy batch
B, T = 2, 5
x = torch.randint(0, cfg.vocab_size, (B, T), device=cfg.device)
y = torch.randint(0, cfg.vocab_size, (B, T), device=cfg.device)

loss, aux = model(x, targets=y)
# print(f"Loss: {loss.item():.4f}")
# print(f"Expected Steps: {aux['expected_steps']:.3f}")
# print(f"ACT Penalty: {aux['ponder_cost']:.3f}")
print(loss)
print(aux)

61.5k params
tensor(5.8405, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.0556, device='cuda:0', grad_fn=<MulBackward0>)


In [13]:
model

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(256, 32)
    (h): ModuleList(
      (0-1): 2 x AdaptiveBlock(
        (layers): ModuleList(
          (0-1): 2 x Block(
            (attn): CausalSelfAttention(
              (c_q): Linear(in_features=32, out_features=32, bias=False)
              (c_k): Linear(in_features=32, out_features=16, bias=False)
              (c_v): Linear(in_features=32, out_features=16, bias=False)
              (c_proj): Linear(in_features=32, out_features=32, bias=False)
            )
            (mlp): MLP(
              (c_fc): Linear(in_features=32, out_features=128, bias=False)
              (c_proj): Linear(in_features=128, out_features=32, bias=False)
            )
          )
        )
        (halting_unit): HaltingUnit(
          (halting_linear): Linear(in_features=32, out_features=1, bias=True)
        )
      )
    )
  )
  (lm_head): Linear(in_features=32, out_features=256, bias=False)
)

In [16]:
torch.manual_seed(0)
m = GPT(cfg).train()
opt = m.setup_optimizers(unembedding_lr=0.04, embedding_lr=0.04, matrix_lr=0.04, halting_lr=0.08)
xb = torch.randint(0,cfg.vocab_size,(4,16))
yb = torch.randint(0,cfg.vocab_size,(4,16))
losses = []
for it in range(200):
    opt.zero_grad()
    loss, aux = m(xb, targets=yb)
    loss.backward()
    opt.step()
    losses.append(loss.item())
assert losses[-1] < losses[0]*0.5, f"did not learn: {losses[0]} -> {losses[-1]}"


Scaling the LR for the AdamW parameters ∝1/√(32/768) = 4.898979
