In [1]:
import os
import sys
from pathlib import Path
PROJECT_DIR = Path.cwd().parent
sys.path.append(str(PROJECT_DIR))

In [2]:
import torch
from torch.utils.data import DataLoader

from dataset.synthetic_tasks import AddMul, Difficulty, make_collate_fn, Tokenizer
from models.transformer import GPT, TransformerBlockConfig

In [3]:
dataset = AddMul(num_samples=10, max_operands=4, max_digits=3, seed=42)

In [4]:
for data in dataset:
    print(data)
    break

(['-', '4', '+', '2', '+', '7', '0', '4', '+', '1', '3', '='], ['7', '1', '5'])


In [25]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)
torch.set_float32_matmul_precision("high")

tokenizer = Tokenizer()
max_seq_len = 128

ds = AddMul(num_samples=50_000_000, max_operands=2, max_digits=3, seed=0)
collate_fn = make_collate_fn(tokenizer, max_seq_len)
dl = DataLoader(ds, batch_size=256, shuffle=False, num_workers=0,
                collate_fn=collate_fn, drop_last=True)

In [26]:
cfg = TransformerBlockConfig(
    sequence_len=max_seq_len,
    vocab_size=len(tokenizer),
    n_head=4,
    n_kv_head=4,
    n_embd=128,
    n_layer=4,
    use_adaptive_computation=True,
    n_layers_per_block=1,
    max_pondering_steps=3,
    act_threshold=0.99,
    halting_penalty=0.001,
)
model = GPT(cfg).to(device)
model.init_weights()


In [27]:
model

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(16, 128)
    (h): ModuleList(
      (0-3): 4 x AdaptiveBlock(
        (layers): ModuleList(
          (0): Block(
            (attn): CausalSelfAttention(
              (c_q): Linear(in_features=128, out_features=128, bias=False)
              (c_k): Linear(in_features=128, out_features=128, bias=False)
              (c_v): Linear(in_features=128, out_features=128, bias=False)
              (c_proj): Linear(in_features=128, out_features=128, bias=False)
            )
            (mlp): MLP(
              (c_fc): Linear(in_features=128, out_features=512, bias=False)
              (c_proj): Linear(in_features=512, out_features=128, bias=False)
            )
          )
        )
        (halting_unit): HaltingUnit(
          (halting_linear): Linear(in_features=128, out_features=1, bias=True)
        )
      )
    )
  )
  (lm_head): Linear(in_features=128, out_features=16, bias=False)
)

In [28]:
opt = model.setup_optimizers(unembedding_lr=0.04, embedding_lr=0.04, matrix_lr=0.04, halting_lr=0.08)

Scaling the LR for the AdamW parameters ∝1/√(128/768) = 2.449490


In [29]:
import matplotlib.pyplot as plt
import numpy as np
def plot_params_grads_updates():
    param_to_lr = {}
    for group in opt.param_groups:
        for p in group['params']:
            param_to_lr[id(p)] = group['lr']

    plot_data = []
    for name, param in model.named_parameters():
        if param.grad is not None:
            lr = param_to_lr.get(id(param), 0)
            
            p_flat = param.data.detach().float().cpu().flatten().numpy()
            g_flat = param.grad.detach().float().cpu().flatten().numpy()
            update = lr * g_flat
            
            plot_data.append({
                'name': name.replace('transformer.h.', '').replace('.weight', '.w').replace('.bias', '.b'),
                'param': p_flat,
                'grad': g_flat,
                'update': update
            })

    n_params = len(plot_data)
    fig, axes = plt.subplots(n_params, 3, figsize=(15, 2*n_params))

    if n_params == 1:
        axes = [axes]

    for i, data in enumerate(plot_data):
        axes[i][0].hist(data['param'], bins=50, alpha=0.7, color='blue', edgecolor='black')
        axes[i][0].set_title(f"{data['name']}\nParam: μ={data['param'].mean():.4f}, σ={data['param'].std():.4f}")
        axes[i][0].set_ylabel('Count')
        
        axes[i][1].hist(data['grad'], bins=50, alpha=0.7, color='green', edgecolor='black')
        axes[i][1].set_title(f"Grad: μ={data['grad'].mean():.4e}, σ={data['grad'].std():.4e}")
        
        axes[i][2].hist(data['update'], bins=50, alpha=0.7, color='red', edgecolor='black')
        axes[i][2].set_title(f"LR×Grad: μ={data['update'].mean():.4e}, σ={data['update'].std():.4e}")

    plt.tight_layout()
    plt.show()

In [10]:
opt.zero_grad(set_to_none=True)
for batch in dl:
    batch = {k: v.to(device) for k, v in batch.items()}
    #with torch.autocast(device_type="cuda", dtype=torch.float16):
    loss = model(**batch)
    print("Loss:", loss)
    #print("Last Expected Steps:", model.last_expected_steps)
    loss.backward()
    #plot_params_grads_updates()
    opt.step()
    break


Loss: (tensor(2.7726, device='cuda:0', grad_fn=<NllLossBackward0>), tensor(0.0108, device='cuda:0', grad_fn=<DivBackward0>))


AttributeError: 'tuple' object has no attribute 'backward'

In [33]:
model.last_expected_steps

tensor(2.3844, device='cuda:0')

In [23]:
model

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(16, 128)
    (h): ModuleList(
      (0-3): 4 x AdaptiveBlock(
        (layers): ModuleList(
          (0): Block(
            (attn): CausalSelfAttention(
              (c_q): Linear(in_features=128, out_features=128, bias=False)
              (c_k): Linear(in_features=128, out_features=128, bias=False)
              (c_v): Linear(in_features=128, out_features=128, bias=False)
              (c_proj): Linear(in_features=128, out_features=128, bias=False)
            )
            (mlp): MLP(
              (c_fc): Linear(in_features=128, out_features=512, bias=False)
              (c_proj): Linear(in_features=512, out_features=128, bias=False)
            )
          )
        )
        (halting_unit): HaltingUnit(
          (halting_linear): Linear(in_features=128, out_features=1, bias=True)
        )
      )
    )
  )
  (lm_head): Linear(in_features=128, out_features=16, bias=False)
)

In [24]:
curriculum = [
    Difficulty(max_operands=2, max_digits=2, operations=['+']),
    Difficulty(max_operands=2, max_digits=4, operations=['+']),
    Difficulty(max_operands=3, max_digits=5, operations=['+']),
    Difficulty(max_operands=4, max_digits=6, operations=['+']),
    Difficulty(max_operands=2, max_digits=1, operations=['+', '*']),
    Difficulty(max_operands=2, max_digits=2, operations=['+', '*']),
    Difficulty(max_operands=2, max_digits=3, operations=['+', '*']),
    Difficulty(max_operands=3, max_digits=4, operations=['+', '*']),
    Difficulty(max_operands=3, max_digits=5, operations=['+', '*']),
    Difficulty(max_operands=4, max_digits=5, operations=['+', '*']),
]

In [41]:
steps = 800000
log_every = 200
model.train()
threshold_loss = 0.09
warmup_steps = 2000
penalty_scale = 1.0

#current_curric_ind = -1
#difficulty = curriculum[current_curric_ind]
#ds.set_difficulty(difficulty)
#loss = 100
for step, batch in enumerate(dl, start=1):
    if step > steps:
        break

    #if loss < threshold_loss:
    #    current_curric_ind += 1
    #    if current_curric_ind < len(curriculum):
    #        difficulty = curriculum[current_curric_ind]
    #        ds.set_difficulty(difficulty)
    #        print(f"Step {step}: Advancing to difficulty level:")
    #        print(f"Operands: {difficulty.max_operands}")
    #        print(f"Digits: {difficulty.max_digits}")
    #        print(f"Operations: {difficulty.operations}")

    idx = batch["idx"].to(device, non_blocking=True)
    targets = batch["targets"].to(device, non_blocking=True)
    ponder_mask = batch["ponder_mask"].to(device, non_blocking=True)

    opt.zero_grad(set_to_none=True)

    task_loss, act_penalty = model(idx, targets=targets, kv_cache=None,
                loss_reduction="mean", ponder_mask=ponder_mask)
    if step <= warmup_steps:
        penalty_scale = min(1.0, step / warmup_steps)

    actual_loss = task_loss + penalty_scale * act_penalty
    actual_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    opt.step()

    if step % log_every == 0:
        #pc = getattr(model, "last_ponder_cost", float("nan"))
        es = getattr(model, "last_expected_steps", float("nan"))
        print(f"step {step:5d}  loss={actual_loss.item():.6f} exp_steps={es:.6f} ")

step   200  loss=0.780491 exp_steps=1.229530 
step   400  loss=0.849597 exp_steps=1.233716 
step   600  loss=0.947119 exp_steps=1.436782 
step   800  loss=0.798600 exp_steps=1.458546 
step  1000  loss=0.651810 exp_steps=1.450000 
step  1200  loss=0.700909 exp_steps=1.448931 
step  1400  loss=0.624416 exp_steps=1.472511 
step  1600  loss=0.583983 exp_steps=1.452230 
step  1800  loss=0.658984 exp_steps=1.454299 
step  2000  loss=0.588130 exp_steps=1.464978 
step  2200  loss=0.535189 exp_steps=1.463749 
step  2400  loss=0.589774 exp_steps=1.460211 
step  2600  loss=0.545579 exp_steps=1.472758 
step  2800  loss=0.546046 exp_steps=1.467302 
step  3000  loss=0.556914 exp_steps=1.456343 
step  3200  loss=0.568714 exp_steps=1.447917 
step  3400  loss=0.544180 exp_steps=1.448182 
step  3600  loss=0.529519 exp_steps=1.452602 
step  3800  loss=0.523727 exp_steps=1.467005 
step  4000  loss=0.511341 exp_steps=1.463811 
step  4200  loss=0.500484 exp_steps=1.477832 
step  4400  loss=0.524417 exp_step

KeyboardInterrupt: 

In [74]:
ds.set_difficulty(Difficulty(max_operands=2, max_digits=4, operations=['+']))

In [78]:
model.eval()
with torch.inference_mode():
    input_seq, output_seq = ds[0]
    print("Problem:", "".join(input_seq), "Answer:", "".join(output_seq))
    prompt_ids = tokenizer.encode(input_seq)  # includes '='
    ponder_mask = [0] * len(prompt_ids)
    ponder_mask[-1] = 1
    outs = []
    for tok in model.generate(tokens = prompt_ids, ponder_mask = ponder_mask, max_tokens=40, temperature=0.0):
        if tok == tokenizer.eos_id:
            break
        outs.append(tok)

    print("Model output:", "".join(tokenizer.decode(outs)))

Problem: 3071+-4350= Answer: -1279
Model output: 1988864


In [80]:
None == None

True

In [72]:
model.last_expected_steps

tensor(1.4658, device='cuda:0')

In [63]:
model.last_expected_steps

tensor(1.4658, device='cuda:0')

In [64]:
prompt_ids

[13, 7, 6, 9, 14, 13, 6, 11, 2, 15]