In [1]:
import os
import sys
from pathlib import Path
PROJECT_DIR = Path.cwd().parent
sys.path.append(str(PROJECT_DIR))

In [2]:
import torch
from torch.utils.data import DataLoader

from dataset.synthetic_tasks import AddMul, Difficulty, make_collate_fn, Tokenizer
from models.transformer import GPT, TransformerBlockConfig

In [3]:
dataset = AddMul(num_samples=10, max_operands=4, max_digits=3, seed=42)

In [4]:
for data in dataset:
    print(data)
    break


(['-', '4', '+', '2', '+', '7', '0', '4', '+', '1', '3', '='], ['7', '1', '5'])


In [None]:
difficulty  = Difficulty(max_operands=2, max_digits=1, operations=['*'])
dataset.set_difficulty(difficulty)

In [6]:
for data in dataset:
    print(data)
    break


(['9', '*', '-', '4', '='], ['-', '3', '6'])


In [35]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)
torch.set_float32_matmul_precision("high")

tokenizer = Tokenizer()
max_seq_len = 128

ds = AddMul(num_samples=50_000_000, max_operands=6, max_digits=6, seed=0)
collate_fn = make_collate_fn(tokenizer, max_seq_len)
dl = DataLoader(ds, batch_size=64, shuffle=False, num_workers=0,
                collate_fn=collate_fn, drop_last=True)

In [49]:
cfg = TransformerBlockConfig(
    sequence_len=max_seq_len,
    vocab_size=len(tokenizer),
    n_head=4,
    n_kv_head=4,
    n_embd=128,
    n_layer=2,
    use_adaptive_computation=True,
    n_layers_per_block=2,
    max_pondering_steps=5,
    act_threshold=0.99,
    halting_penalty=0.01,
)
model = GPT(cfg).to(device)
model.init_weights()


In [50]:
model

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(16, 128)
    (h): ModuleList(
      (0): AdaptiveBlock(
        (layers): ModuleList(
          (0-1): 2 x Block(
            (attn): CausalSelfAttention(
              (c_q): Linear(in_features=128, out_features=128, bias=False)
              (c_k): Linear(in_features=128, out_features=128, bias=False)
              (c_v): Linear(in_features=128, out_features=128, bias=False)
              (c_proj): Linear(in_features=128, out_features=128, bias=False)
            )
            (mlp): MLP(
              (c_fc): Linear(in_features=128, out_features=512, bias=False)
              (c_proj): Linear(in_features=512, out_features=128, bias=False)
            )
          )
        )
        (halting_unit): HaltingUnit(
          (halting_linear): Linear(in_features=128, out_features=1, bias=True)
        )
      )
    )
  )
  (lm_head): Linear(in_features=128, out_features=16, bias=False)
)

In [51]:
opt = model.setup_optimizers()

Scaling the LR for the AdamW parameters ∝1/√(128/768) = 2.449490


In [86]:
for batch in dl:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        loss = model(**batch)
    print(loss)
    break

tensor(2.5276, device='cuda:0', grad_fn=<AddBackward0>)


In [87]:
model.last_expected_steps

tensor(1., device='cuda:0')

In [88]:
loss.backward()

In [89]:
model.transformer.h[0].halting_unit.halting_linear.weight.grad.mean()

tensor(-0.0007, device='cuda:0')

In [90]:
opt.step()

In [None]:
loss.backward()
#torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)


In [46]:
batch['idx'].shape

torch.Size([64, 58])

In [47]:
batch['targets'].shape

torch.Size([64, 58])

In [48]:
batch['ponder_mask'].shape

torch.Size([64, 58])

In [38]:
model.transformer.h[0].layers[0].attn.c_q.weight.data.dtype

torch.float32

In [37]:
model

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(16, 128)
    (h): ModuleList(
      (0): AdaptiveBlock(
        (layers): ModuleList(
          (0-1): 2 x Block(
            (attn): CausalSelfAttention(
              (c_q): Linear(in_features=128, out_features=128, bias=False)
              (c_k): Linear(in_features=128, out_features=128, bias=False)
              (c_v): Linear(in_features=128, out_features=128, bias=False)
              (c_proj): Linear(in_features=128, out_features=128, bias=False)
            )
            (mlp): MLP(
              (c_fc): Linear(in_features=128, out_features=512, bias=False)
              (c_proj): Linear(in_features=512, out_features=128, bias=False)
            )
          )
        )
        (halting_unit): HaltingUnit(
          (halting_linear): Linear(in_features=128, out_features=1, bias=True)
        )
      )
    )
  )
  (lm_head): Linear(in_features=128, out_features=16, bias=False)
)

In [None]:
steps = 20000
log_every = 50
model.train()
for step, batch in enumerate(dl, start=1):
    if step > steps:
        break
    idx = batch["idx"].to(device, non_blocking=True)
    targets = batch["targets"].to(device, non_blocking=True)
    ponder_mask = batch["ponder_mask"].to(device, non_blocking=True)

    opt.zero_grad(set_to_none=True)
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        loss = model(idx, targets=targets, kv_cache=None,
                    loss_reduction="mean", ponder_mask=ponder_mask)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

        if step % log_every == 0:
            pc = getattr(model, "last_ponder_cost", float("nan"))
            es = getattr(model, "last_expected_steps", float("nan"))
            ap = getattr(model, "last_act_penalty", float("nan"))
            print(f"step {step:5d}  loss={loss.item():.4f}  ponder_cost={pc:.3f}  "
                f"exp_steps={es:.3f}  act_penalty={ap:.4f}")

step    50  loss=2.7754  ponder_cost=0.282  exp_steps=2.460  act_penalty=0.0028
step   100  loss=2.7755  ponder_cost=0.290  exp_steps=2.445  act_penalty=0.0029


KeyboardInterrupt: 

In [48]:
model.transformer.h[0].halting_unit.halting_linear.weight.grad.mean()

tensor(-1.6659, device='cuda:0')

In [30]:
0.01 * 0.315

0.00315

In [38]:
model.eval()
with torch.inference_mode():
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        input_seq, output_seq = ds[0]
        print("Problem:", "".join(input_seq), "Answer:", "".join(output_seq))
        prompt_ids = tokenizer.encode(input_seq)  # includes '='
        outs = []
        for tok in model.generate(tokens=prompt_ids, max_tokens=40, temperature=0.0):
            outs.append(tok)
            if tok == tokenizer.eos_id:
                break
        print("Model output:", "".join(tokenizer.decode(outs)))

Problem: -965393*93*69*-28= Answer: 173457952668
Model output: <pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
