In [None]:
import math, numpy as np

np.random.seed(42)

# -----------------------
# Hyperparameters
# -----------------------
T = 64          # context length
d = 64          # model dim
d_ff = 128      # MLP hidden
V = None        # to be set after building vocab
lr = 1e-2
eps = 1e-5

# -----------------------
# Data: tiny character corpus
# -----------------------
text = (
    "tiny gpt built with pure numpy. "
    "this is a tiny demo to learn the math of transformers. "
    "it can overfit a short text. "
)

# Build char vocab
chars = sorted(list(set(text)))
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for ch,i in stoi.items()}
V = len(chars)

def encode(s): return np.array([stoi[c] for c in s], dtype=np.int64)
def decode(ids): return "".join([itos[int(i)] for i in ids])

data = encode(text)
n = len(data)

# batches: random contiguous chunks
def get_batch(bs=16):
    ix = np.random.randint(0, n - T - 1, size=bs)
    x = np.stack([data[i:i+T] for i in ix])             # (B,T)
    y = np.stack([data[i+1:i+T+1] for i in ix])         # next chars
    return x, y

# -----------------------
# Parameters (single block GPT)
# -----------------------
def randn(shape, scale=0.02):
    return np.random.randn(*shape).astype(np.float32) * scale

# Embeddings
E = randn((V, d))
P = randn((T, d))  # learnable positions

# Attention projections
W_Q = randn((d, d))
W_K = randn((d, d))
W_V = randn((d, d))
W_O = randn((d, d))

# RMSNorm (two of them, pre-attn and pre-mlp often used, we’ll do pre-mlp only for simplicity)
g_attn = np.ones((d,), dtype=np.float32)
g_mlp  = np.ones((d,), dtype=np.float32)

# MLP
W1 = randn((d, d_ff))
W2 = randn((d_ff, d))

# LM head
WLM = randn((d, V))

# -----------------------
# Utilities
# -----------------------
def gelu(x):
    # tanh approximation
    return 0.5 * x * (1.0 + np.tanh(np.sqrt(2/np.pi) * (x + 0.044715 * x**3)))

def dgelu(x):
    # derivative of tanh approximation
    a = np.sqrt(2/np.pi)
    t = np.tanh(a*(x + 0.044715*x**3))
    dt_dx = (a*(1 + 3*0.044715*x**2)) * (1 - t**2)
    return 0.5 * (1 + t) + 0.5 * x * dt_dx

def rmsnorm_fwd(h, g):
    # h: (..., d), g: (d,)
    rms = np.sqrt(np.mean(h*h, axis=-1, keepdims=True) + eps)  # (...,1)
    h_hat = h / rms
    out = h_hat * g
    cache = (h, g, rms, h_hat)
    return out, cache

def rmsnorm_bwd(dout, cache):
    h, g, rms, h_hat = cache
    # out = h_hat * g
    dh_hat = dout * g
    # h_hat = h / rms ; rms = sqrt(mean(h^2)+eps)
    # d(h/rms) = (dh*rms - h*drms) / rms^2
    # drms = (1/(2*rms)) * d(mean(h^2)) = (1/(2*rms)) * (2*mean(h*dh)) = mean(h*dh)/rms
    # but we need vectorized over last dim:
    B = h.shape[:-1]
    d_ = h.shape[-1]
    # mean over last dim:
    mean_h_dh = np.mean(h * dh_hat, axis=-1, keepdims=True)
    drms = mean_h_dh / rms
    dh = (dh_hat / rms) - (h * drms) / (rms * rms)
    dg = np.sum(dout * h_hat, axis=tuple(range(len(dout.shape)-1)), keepdims=False)
    return dh, dg

def softmax(x):
    x = x - np.max(x, axis=-1, keepdims=True)
    e = np.exp(x, dtype=np.float32)
    return e / np.sum(e, axis=-1, keepdims=True)

def cross_entropy(logits, targets):
    # logits: (B,T,V), targets: (B,T)
    B,T,V = logits.shape
    probs = softmax(logits)
    loss = -np.log(probs[np.arange(B)[:,None], np.arange(T)[None,:], targets]).mean()
    return loss, probs

def causal_mask(T):
    m = np.triu(np.ones((T,T), dtype=bool), k=1)
    return m

mask = causal_mask(T)

# -----------------------
# Forward & Backward pass (one block)
# -----------------------
def forward_backward(x, y):
    # x,y: (B,T)
    B = x.shape[0]

    # Embedding lookup
    X = E[x]                       # (B,T,d)
    H0 = X + P[None, :, :]         # add positions

    # --- Attention (pre-norm optional; for simplicity we use post-norm on attn, pre-norm on MLP)
    # Compute Q,K,V
    Q = H0 @ W_Q                   # (B,T,d)
    K = H0 @ W_K                   # (B,T,d)
    Vv = H0 @ W_V                  # (B,T,d)

    # Scores
    S = (Q @ K.transpose(0,2,1)) / np.sqrt(d)  # (B,T,T)
    S = np.where(mask[None, :, :], -1e9, S)    # causal

    A = softmax(S)                 # (B,T,T)
    O = A @ Vv                     # (B,T,d)
    H_attn = O @ W_O               # (B,T,d)

    H1 = H0 + H_attn               # residual

    # --- MLP block with pre-norm (RMSNorm)
    H1n, cache_rms = rmsnorm_fwd(H1, g_mlp)  # (B,T,d)
    M1 = H1n @ W1                  # (B,T,d_ff)
    M2 = gelu(M1)                  # (B,T,d_ff)
    M3 = M2 @ W2                   # (B,T,d)
    H2 = H1 + M3                   # residual

    # LM head
    logits = H2 @ WLM              # (B,T,V)

    # Loss
    loss, probs = cross_entropy(logits, y)

    # ----------------- Backward -----------------
    dlogits = probs
    dlogits[np.arange(B)[:,None], np.arange(T)[None,:], y] -= 1
    dlogits /= (B*T)

    dWLM = (H2.reshape(-1, d).T @ dlogits.reshape(-1, V))
    dH2 = dlogits @ WLM.T
    # residual to H1 and M3
    dH1 = dH2.copy()
    dM3 = dH2

    # MLP back
    dW2 = (M2.reshape(-1, d_ff).T @ dM3.reshape(-1, d))
    dM2 = dM3 @ W2.T
    dM1 = dM2 * dgelu(M1)
    dW1 = (H1n.reshape(-1, d).T @ dM1.reshape(-1, d_ff))
    dH1n = dM1 @ W1.T

    # RMSNorm back
    dH1_rms, dg_mlp_local = rmsnorm_bwd(dH1n, cache_rms)
    d_g_mlp = dg_mlp_local

    dH1 += dH1_rms  # residual path added above MLP

    # Attention back: H_attn = (A @ Vv) @ W_O
    dH_attn = dH1
    dW_O = (O.reshape(-1, d).T @ dH_attn.reshape(-1, d))
    dO = dH_attn @ W_O.T

    dA = dO @ Vv.transpose(0,2,1)        # (B,T,T)
    dVv = A.transpose(0,2,1) @ dO        # (B,T,d)

    # softmax backward: for each row, J = diag(p) - p p^T
    # Efficient: dS = dA - sum(dA* A, axis=-1, keepdims=True) * A
    sum_dA_A = np.sum(dA * A, axis=-1, keepdims=True)
    dS = dA - sum_dA_A * A

    # causal mask blocks grad
    dS = np.where(mask[None,:,:], 0.0, dS)

    # S = (Q K^T)/sqrt(d)
    dQ = (dS @ K) / np.sqrt(d)          # (B,T,d)
    dK = (dS.transpose(0,2,1) @ Q) / np.sqrt(d)

    # Q = H0 W_Q, etc.
    dW_Q = (H0.reshape(-1, d).T @ dQ.reshape(-1, d))
    dW_K = (H0.reshape(-1, d).T @ dK.reshape(-1, d))
    dW_V = (H0.reshape(-1, d).T @ dVv.reshape(-1, d))

    dH0 = dQ @ W_Q.T
    dH0 += dK @ W_K.T
    dH0 += dVv @ W_V.T

    # Residual from H1 = H0 + H_attn
    dH0 += dH1

    # Back to embeddings/positions
    dP = np.sum(dH0, axis=0)  # (T,d)
    dX = dH0                  # (B,T,d)

    # E lookup scatter‑add
    dE = np.zeros_like(E)
    # add dX to rows selected by x
    for b in range(B):
        np.add.at(dE, x[b], dX[b])

    grads = {
        "E": dE, "P": dP,
        "W_Q": dW_Q, "W_K": dW_K, "W_V": dW_V, "W_O": dW_O,
        "g_mlp": d_g_mlp,
        "W1": dW1, "W2": dW2,
        "WLM": dWLM
    }

    cache_vals = {
        "loss": loss
    }
    return loss, grads

# -----------------------
# Update (SGD)
# -----------------------
def sgd(param, grad, lr):
    param -= lr * grad

# -----------------------
# Training loop
# -----------------------
def train(steps=500, bs=16, print_every=50):
    global E,P,W_Q,W_K,W_V,W_O,g_mlp,W1,W2,WLM
    for step in range(1, steps+1):
        x, y = get_batch(bs)
        loss, g = forward_backward(x, y)

        # update
        sgd(E,   g["E"], lr)
        sgd(P,   g["P"], lr)
        sgd(W_Q, g["W_Q"], lr)
        sgd(W_K, g["W_K"], lr)
        sgd(W_V, g["W_V"], lr)
        sgd(W_O, g["W_O"], lr)
        sgd(g_mlp, g["g_mlp"], lr)
        sgd(W1,  g["W1"], lr)
        sgd(W2,  g["W2"], lr)
        sgd(WLM, g["WLM"], lr)

        if step % print_every == 0:
            print(f"step {step}: loss {loss:.4f}")

def sample(prefix="tiny ", max_new=200, temperature=1.0, top_k=None):
    ctx = encode(prefix)
    ctx = ctx[-T:]  # clip
    for _ in range(max_new):
        # single forward (no grad) on last T tokens
        x = np.zeros((1,T), dtype=np.int64)
        # left-pad with last tokens
        if len(ctx) < T:
            x[0, :len(ctx)] = ctx
        else:
            x[0] = ctx[-T:]

        # Forward until logits (reusing forward code without storing grads)
        X = E[x] + P[None,:,:]
        Q = (X @ W_Q); K=(X @ W_K); Vv=(X @ W_V)
        S = (Q @ K.transpose(0,2,1))/np.sqrt(d)
        S = np.where(mask[None,:,:], -1e9, S)
        A = softmax(S)
        O = A @ Vv
        H1 = X + (O @ W_O)
        # pre-norm MLP
        H1n, _ = rmsnorm_fwd(H1, g_mlp)
        H2 = H1 + gelu(H1n @ W1) @ W2
        logits = H2 @ WLM
        last = logits[0, len(ctx)-1 if len(ctx)<=T else T-1]  # last step

        # temperature
        logits_t = last / max(1e-8, temperature)

        # top-k
        if top_k is not None:
            idx = np.argpartition(logits_t, -top_k)[-top_k:]
            mask_k = np.ones_like(logits_t, dtype=bool)
            mask_k[idx] = False
            logits_t[mask_k] = -1e9

        probs = softmax(logits_t[None,:])[0]
        next_id = np.random.choice(len(probs), p=probs)
        ctx = np.concatenate([ctx, [next_id]])
        if len(ctx) > 4*T:  # keep it bounded
            ctx = ctx[-T:]
    return decode(ctx)

if __name__ == "__main__":
    print("Vocab size:", V, "Unique chars:", "".join(chars))
    train(steps=400, bs=16, print_every=50)
    print(sample(prefix="tiny ", max_new=200, temperature=0.9, top_k=20))

Vocab size: 24 Unique chars:  .abcdefghilmnoprstuvwxy
step 50: loss 3.1769
step 100: loss 3.1761
step 150: loss 3.1752
step 200: loss 3.1747
step 250: loss 3.1740
step 300: loss 3.1721
step 350: loss 3.1717
step 400: loss 3.1708
tiny vmhg  avvmcaexivmnp.huigngnay ld.sytfl mar. ux.drnxh  ousvctvvhnnnesremgtplpdcc ryexpc tpcfucvbimhedrgpvvbpodgeovwrnyplwbbtw.uhpaywhwwixyc.en.aytwygrepcb.oxamplxb godbvbhyrnhlrvpt. tewmhrtrsmxwftwonpb
