### Looped Transformer Theoretical Model

This notebook estimates the number of FLOPs, parameters, peak memory footprint, checkpoint size, etc. for our looped (depth-recurrent) transformer architecture: **prelude → recur (×n_loop) → coda**.

Key insight: recur block weights are *shared* across loop iterations, so the model has fewer unique parameters than its effective depth suggests.

In [1]:
from collections import OrderedDict

In [None]:
# Model config (matches GPTConfig defaults in model.py)
block_size = 1024
vocab_size = 50304  # GPT-2 vocab_size of 50257, padded to nearest multiple of 64
n_prelude = 2       # unique layers run once before looping
n_block = 2         # unique layers in the recurrent block (shared across loops)
n_coda = 2          # unique layers run once after looping
n_loop = 4          # number of recurrence iterations
n_head = 12
n_embd = 768
bias = False
input_injection = "inject"  # "inject", "inject_random", or "passthrough"
bptt_k = None               # truncate backprop to last k recurrences (None = full)
assert not bias, "this notebook assumes bias=False just for simplicity"

In [None]:
def params():
    """estimates the number of unique parameters in the model"""
    out = OrderedDict()

    # token and position embeddings
    out['embedding/position'] = n_embd * block_size
    out['embedding/token'] = n_embd * vocab_size
    out['embedding'] = out['embedding/position'] + out['embedding/token']

    # per-block parameters (same structure for prelude, recur, coda)
    out['attention/ln'] = n_embd  # RMSNorm weight (no bias)
    out['attention/kqv'] = n_embd * 3 * n_embd
    out['attention/proj'] = n_embd ** 2
    out['attention'] = out['attention/ln'] + out['attention/kqv'] + out['attention/proj']

    ffw_size = 4 * n_embd
    out['mlp/ln'] = n_embd  # RMSNorm weight (no bias)
    out['mlp/ffw'] = n_embd * ffw_size
    out['mlp/proj'] = ffw_size * n_embd
    out['mlp'] = out['mlp/ln'] + out['mlp/ffw'] + out['mlp/proj']

    out['block'] = out['attention'] + out['mlp']

    # architecture sections — recur weights are shared across loop iterations
    out['prelude'] = n_prelude * out['block']
    out['recur'] = n_block * out['block']
    out['coda'] = n_coda * out['block']
    out['all_blocks'] = out['prelude'] + out['recur'] + out['coda']

    # looped transformer extras
    out['norm_recur'] = n_embd  # RMSNorm after each recurrence
    if input_injection in ("inject", "inject_random"):
        out['inject'] = 2 * n_embd * n_embd  # Linear(2*n_embd, n_embd, bias=False)
    else:
        out['inject'] = 0
    out['ln_f'] = n_embd  # final RMSNorm
    out['lm_head'] = 0  # weight-tied with embedding/token

    # total
    out['total'] = out['embedding'] + out['all_blocks'] + out['norm_recur'] + out['inject'] + out['ln_f'] + out['lm_head']

    return out

p = params()
params_total = p['total']
n_unique_layers = n_prelude + n_block + n_coda
n_effective_layers = n_prelude + n_block * n_loop + n_coda
print(f"unique layers: {n_unique_layers}, effective depth: {n_effective_layers} (looping {n_block} blocks × {n_loop})")
print(f"total parameters: {params_total:,}\n")
print(f"{'name':24s} {'params':>12s} {'ratio (%)':>10s}")
for k, v in p.items():
    print(f"{k:24s} {v:12,d} {v/params_total*100:10.4f}")

In [None]:
# checkpoint size estimate
# params stored in fp32; AdamW keeps 2 additional fp32 buffers per param (momentum + variance)
params_bytes = params_total * 4
params_and_buffers_bytes = params_bytes + 2 * params_bytes
print(f"est checkpoint size (full optimizer state, fp32): {params_and_buffers_bytes/1e9:.2f} GB")
print(f"est checkpoint size (weights only, bf16):         {params_total*2/1e9:.2f} GB")

We can also estimate the ratio of our GPU memory that will be taken up just by the weights and the buffers inside the AdamW optimizer

In [None]:
gpu_memory = 80e9  # A100-SXM4-80GB
print(f"memory ratio for parameters + optimizer state: {params_and_buffers_bytes / gpu_memory * 100:.2f}%")

i.e. not that much of the memory for this tiny model, most of the memory is activations (forward and backward). This of course changes dramatically for larger and larger models.

Let's estimate FLOPs for a single forward+backward pass. In the looped architecture, the recur blocks and inject layer execute `n_loop` times in the forward pass, but backpropagation may only flow through the last `bptt_k` iterations (truncated BPTT).

In [None]:
def flops():
    """estimate FLOPs for one forward+backward pass of a single sequence.

    Accounts for looped architecture: recur blocks + inject run n_loop times
    forward, but only bptt_k times backward (truncated BPTT)."""
    out = OrderedDict()
    head_size = n_embd // n_head
    ffw_size = 4 * n_embd
    _bptt_k = min(bptt_k, n_loop) if bptt_k is not None else n_loop

    # per-block FLOPs (for one execution of one block on one sequence)
    block_attn = (
        2 * block_size * (n_embd * 3 * n_embd)                     # kqv projection
        + 2 * block_size * block_size * n_embd                      # Q @ K^T scores
        + 2 * n_head * (block_size * block_size * head_size)        # attn @ V reduce
        + 2 * block_size * (n_embd * n_embd)                        # output projection
    )
    block_mlp = (
        2 * block_size * (n_embd * ffw_size)                        # ffw up
        + 2 * block_size * (ffw_size * n_embd)                      # ffw down
    )
    block_flops = block_attn + block_mlp
    inject_flops = 2 * block_size * (2 * n_embd * n_embd)           # inject linear

    # forward pass
    out['fwd/prelude'] = n_prelude * block_flops
    out['fwd/recur'] = n_block * n_loop * block_flops
    out['fwd/coda'] = n_coda * block_flops
    out['fwd/inject'] = n_loop * inject_flops if input_injection in ("inject", "inject_random") else 0
    out['fwd/dense'] = 2 * block_size * (n_embd * vocab_size)
    out['forward_total'] = sum(v for k, v in out.items() if k.startswith('fwd/'))

    # backward pass (2× forward FLOPs, but recur/inject only through bptt_k iterations)
    out['bwd/prelude'] = 2 * n_prelude * block_flops
    out['bwd/recur'] = 2 * n_block * _bptt_k * block_flops
    out['bwd/coda'] = 2 * n_coda * block_flops
    out['bwd/inject'] = 2 * _bptt_k * inject_flops if input_injection in ("inject", "inject_random") else 0
    out['bwd/dense'] = 2 * out['fwd/dense']
    out['backward_total'] = sum(v for k, v in out.items() if k.startswith('bwd/'))

    out['total'] = out['forward_total'] + out['backward_total']
    return out

f = flops()
print(f"{'name':20s} {'flops':>18s} {'ratio (% of fwd)':>18s}")
for k, v in f.items():
    print(f"{k:20s} {v:18,d} {v/f['forward_total']*100:18.4f}")

In [None]:
def model_flops_per_fwdbwd():
    """estimate total FLOPs per forward+backward pass using the PaLM-style formula,
    adapted for looped architecture. This matches model.estimate_mfu() logic."""
    p = params()
    _bptt_k = min(bptt_k, n_loop) if bptt_k is not None else n_loop
    H, Q, T = n_head, n_embd // n_head, block_size

    # split params into "run once" vs "reused per loop iteration"
    # N = non-position-embedding params (matches model.get_num_params())
    N = p['total'] - p['embedding/position']
    once_params = N - p['recur'] - p['inject']
    reused_params = p['recur'] + p['inject']

    # matmul FLOPs: 2× per param per token fwd, 4× per param per token bwd
    fwd_matmul = 2 * (once_params + reused_params * n_loop)
    bwd_matmul = 4 * (once_params + reused_params * _bptt_k)
    matmul_flops = fwd_matmul + bwd_matmul

    # attention FLOPs (Q@K^T and attn@V, not captured in param count)
    attn_fwd = 2 * (n_prelude + n_block * n_loop + n_coda) * (2 * H * Q * T)
    attn_bwd = 4 * (n_prelude + n_block * _bptt_k + n_coda) * (2 * H * Q * T)
    attn_flops = attn_fwd + attn_bwd

    return (matmul_flops + attn_flops) * T

mf = model_flops_per_fwdbwd()
detailed = flops()['total']
print(f"PaLM-style estimate: {mf:,d}")
print(f"detailed estimate:   {detailed:,d}")
print(f"ratio: {mf/detailed:.4f}")

The two estimates are close, giving confidence in our FLOPs math. The small discrepancy comes from the PaLM formula treating all parameters uniformly (including layer norms), while the detailed count only tracks weight matrix FLOPs.

Now let's estimate model FLOPs utilization (MFU). A100-SXM4-80GB is cited at 312 TFLOPS bfloat16 on tensor cores.

In [None]:
# plug in your measured values here
batch_size = 100  # total batch size (micro_batch × grad_accum)
measured_time = 1.0  # seconds per iteration — update with your measurement!

measured_throughput = batch_size / measured_time
flops_achieved = f['total'] * measured_throughput

a100_flops_promised = 312e12  # A100 bfloat16 peak TFLOPS
print(f"MFU: {flops_achieved / a100_flops_promised * 100:.2f}%")
print("(update batch_size and measured_time with your actual numbers)")

For reference, we'd prefer to be somewhere around 50%+, and not just for a single GPU but for an entire DDP run. So we still have some work to do, but at least we're within a factor of ~2X of what is achievable with this GPU.

In [None]:
# training cost estimate
tokens_num = 300e9  # dataset size in tokens
num_gpus = 2  # 2×A100-SXM4-80GB
a100_flops = 312e12
assumed_mfu = 0.3
flops_throughput = a100_flops * num_gpus * assumed_mfu

# for looped models, naive 6ND underestimates because each token goes through
# the recur block n_loop times. Use our detailed FLOPs calculation instead.
flops_per_token = flops()['total'] / block_size
total_flops = flops_per_token * tokens_num
time_needed_s = total_flops / flops_throughput
print(f"training time estimate: {time_needed_s/3600/24:.2f} days")
print(f"  ({tokens_num/1e9:.0f}B tokens, {num_gpus}×A100-80GB, {assumed_mfu*100:.0f}% MFU)")

# compare with naive 6ND (ignores weight reuse from looping)
N = params()['total'] - params()['embedding/position']
naive_6nd_flops = 6 * N * tokens_num
naive_time = naive_6nd_flops / flops_throughput
print(f"\nnaive 6ND estimate: {naive_time/3600/24:.2f} days")
print(f"  (underestimates looped models by {total_flops/naive_6nd_flops:.1f}×)")

The 6ND formula assumes each token does 6 FLOPs per unique parameter, but in a looped model each token passes through the recur block `n_loop` times, so the actual compute per token is higher than the unique parameter count suggests. This is the whole point — we get more "effective depth" (compute) from fewer parameters.

Now, FLOPs are just one constraint, the other that we have to keep a close track of is the memory bandwidth. TODO estimate LOAD/STORE costs of our model later.