# Transformer in JAX (functions only)

In [None]:
import math
import time
from functools import partial

import jax
import jax.numpy as jnp
from jax import random, jit, value_and_grad, tree_util


In [None]:
VOCAB = 512
SEQ   = 64
DMODEL = 256
NHEAD  = 8
DFF    = 1024
NLAYERS = 4
BATCH  = 32
STEPS  = 200
LR     = 3e-4
BETA1  = 0.9
BETA2  = 0.999
EPS    = 1e-8
WEIGHT_DECAY = 0.0


In [None]:
def glorot(key, shape):
    fan_in, fan_out = shape[0], shape[1]
    limit = math.sqrt(6.0/(fan_in+fan_out))
    return random.uniform(key, shape, minval=-limit, maxval=limit)

def zeros(shape): return jnp.zeros(shape, jnp.float32)
def ones(shape):  return jnp.ones(shape,  jnp.float32)

In [None]:
def gelu(x):
    return 0.5 * x * (1.0 + jnp.tanh(jnp.sqrt(2.0 / jnp.pi) * (x + 0.044715 * x**3)))

def layernorm(x, gamma, beta, eps=1e-5):
    mu = x.mean(axis=-1, keepdims=True)
    var = jnp.mean((x - mu) ** 2, axis=-1, keepdims=True)
    xhat = (x - mu) / jnp.sqrt(var + eps)
    return gamma * xhat + beta

def sinusoidal_positional_embedding(T, D):
    pos = jnp.arange(T, dtype=jnp.float32)[:, None]
    i = jnp.arange(0, D, 2, dtype=jnp.float32)[None, :]
    div = jnp.exp(-jnp.log(10000.0) * i / D)
    pe = jnp.zeros((T, D), dtype=jnp.float32)
    pe = pe.at[:, 0::2].set(jnp.sin(pos * div))
    pe = pe.at[:, 1::2].set(jnp.cos(pos * div))
    return pe  # [T, D]

def init_block_params(key, d_model, n_heads, d_ff):
    k1, k2, k3, k4 = random.split(key, 4)
    # Attention: Wqkv: [D, 3D], Wo: [D, D]
    Wqkv = glorot(k1, (d_model, 3 * d_model))
    Wo   = glorot(k2, (d_model, d_model))
    # MLP: [D, Dff], [Dff], [Dff, D], [D]
    W1, b1 = glorot(k3, (d_model, d_ff)), zeros((d_ff,))
    W2, b2 = glorot(k4, (d_ff, d_model)), zeros((d_model,))
    # LayerNorm scales/biases
    ln1_g, ln1_b = ones((d_model,)), zeros((d_model,))
    ln2_g, ln2_b = ones((d_model,)), zeros((d_model,))
    return {
        "attn": {"Wqkv": Wqkv, "Wo": Wo},
        "mlp":  {"W1": W1, "b1": b1, "W2": W2, "b2": b2},
        "ln1":  {"g": ln1_g, "b": ln1_b},
        "ln2":  {"g": ln2_g, "b": ln2_b},
    }

def init_params(key, vocab, d_model, n_heads, d_ff, n_layers):
    k_tok, k_head, k_ln, *rest = random.split(key, 3 + n_layers)
    tok_emb = 0.02 * random.normal(k_tok, (vocab, d_model))
    head_W  = glorot(k_head, (d_model, vocab))
    head_b  = zeros((vocab,))
    ln_g, ln_b = ones((d_model,)), zeros((d_model,))
    blocks = [init_block_params(rest[i], d_model, n_heads, d_ff) for i in range(n_layers)]
    return {"tok_emb": tok_emb, "blocks": blocks, "ln_f": {"g": ln_g, "b": ln_b}, "head": {"W": head_W, "b": head_b}}

def dense(x, W, b=None):
    y = x @ W
    return y if b is None else y + b

def split_heads(x, n_heads):
    # x: [B, T, D] -> [B, H, T, Dh]
    B, T, D = x.shape
    Dh = D // n_heads
    x = x.reshape(B, T, n_heads, Dh)
    return jnp.transpose(x, (0, 2, 1, 3))

def merge_heads(x):
    # x: [B, H, T, Dh] -> [B, T, D]
    B, H, T, Dh = x.shape
    return jnp.transpose(x, (0, 2, 1, 3)).reshape(B, T, H * Dh)

def mha_causal(x, Wqkv, Wo, n_heads):
    # x: [B, T, D]
    B, T, D = x.shape
    Dh = D // n_heads
    qkv = dense(x, Wqkv)  # [B,T,3D]
    q, k, v = jnp.split(qkv, 3, axis=-1)
    q, k, v = split_heads(q, n_heads), split_heads(k, n_heads), split_heads(v, n_heads)  # [B,H,T,Dh]
    scale = 1.0 / math.sqrt(Dh)
    scores = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) * scale  # [B,H,T,T]

    # causal mask
    mask = jnp.tril(jnp.ones((T, T), dtype=jnp.bool_))  # [T,T]
    scores = jnp.where(mask, scores, jnp.full_like(scores, -1e30))
    attn = jax.nn.softmax(scores, axis=-1)              # [B,H,T,T]
    y = jnp.matmul(attn, v)                             # [B,H,T,Dh]
    y = merge_heads(y)                                  # [B,T,D]
    return dense(y, Wo)                                 # [B,T,D]

def transformer_forward(params, tokens):
    # tokens: [B, T] int32
    B, T = tokens.shape
    D = params["tok_emb"].shape[1]
    x = params["tok_emb"][tokens]                       # [B,T,D]
    x = x + sinusoidal_positional_embedding(T, D)       # broadcast [T,D] -> [B,T,D]

    for blk in params["blocks"]:
        # Block 1: LN -> MHA -> Residual
        y = layernorm(x, blk["ln1"]["g"], blk["ln1"]["b"])
        y = mha_causal(y, blk["attn"]["Wqkv"], blk["attn"]["Wo"], NHEAD)
        x = x + y
        # Block 2: LN -> MLP -> Residual
        y = layernorm(x, blk["ln2"]["g"], blk["ln2"]["b"])
        y = dense(gelu(dense(y, blk["mlp"]["W1"], blk["mlp"]["b1"])), blk["mlp"]["W2"], blk["mlp"]["b2"])
        x = x + y

    x = layernorm(x, params["ln_f"]["g"], params["ln_f"]["b"])
    logits = dense(x, params["head"]["W"], params["head"]["b"])  # [B,T,V]
    return logits

def lm_loss(params, tokens):
    # Next-token prediction cross-entropy
    logits = transformer_forward(params, tokens)         # [B,T,V]
    logits = logits[:, :-1, :]                           # [B,T-1,V]
    targets = tokens[:, 1:]                              # [B,T-1]
    logp = jax.nn.log_softmax(logits, axis=-1)
    nll = -jnp.take_along_axis(logp, targets[..., None], axis=-1).squeeze(-1)  # [B,T-1]
    return nll.mean()

def tree_zeros_like(tree):
    return tree_util.tree_map(lambda x: jnp.zeros_like(x), tree)


In [None]:

def make_opt_state(params):
    return {"m": tree_zeros_like(params), "v": tree_zeros_like(params), "t": jnp.array(0, dtype=jnp.int32)}

def adamw_update(params, grads, opt_state, lr=LR, beta1=BETA1, beta2=BETA2, eps=EPS, weight_decay=WEIGHT_DECAY):
    t = opt_state["t"] + 1
    m = tree_util.tree_map(lambda m, g: beta1 * m + (1.0 - beta1) * g, opt_state["m"], grads)
    v = tree_util.tree_map(lambda v, g: beta2 * v + (1.0 - beta2) * (g * g), opt_state["v"], grads)
    b1c = 1.0 - beta1 ** t.astype(jnp.float32)
    b2c = 1.0 - beta2 ** t.astype(jnp.float32)

    mhat = tree_util.tree_map(lambda mm: mm / b1c, m)
    vhat = tree_util.tree_map(lambda vv: vv / b2c, v)

    def _update(p, mh, vh):
        upd = mh / (jnp.sqrt(vh) + eps)
        if weight_decay != 0.0:
            upd = upd + weight_decay * p
        return p - lr * upd

    new_params = tree_util.tree_map(_update, params, mhat, vhat)
    new_state = {"m": m, "v": v, "t": t}
    return new_params, new_state

In [None]:
@jit
def train_step(params, opt_state, rng):
    rng, sub = random.split(rng)
    tokens = random.randint(sub, (BATCH, SEQ), 0, VOCAB, dtype=jnp.int32)

    def loss_fn(p):
        return lm_loss(p, tokens)

    loss, grads = value_and_grad(loss_fn)(params)
    new_params, new_opt = adamw_update(params, grads, opt_state)
    return new_params, new_opt, rng, loss


In [None]:
def main():
    print(jax.devices())
    key = random.PRNGKey(42)
    params = init_params(key, VOCAB, DMODEL, NHEAD, DFF, NLAYERS)
    opt_state = make_opt_state(params)

    # Warmup / compile
    t0 = time.time()
    params, opt_state, key, loss = train_step(params, opt_state, key)
    jax.block_until_ready(loss)
    print(f"Compiled step, initial loss: {float(loss):.4f} (compile took {time.time()-t0:.2f}s)")

    # Train
    t0 = time.time()
    for s in range(1, STEPS + 1):
        params, opt_state, key, loss = train_step(params, opt_state, key)
        if s % 20 == 0 or s == STEPS:
            jax.block_until_ready(loss)
            print(f"step {s:04d} | loss {float(loss):.4f}")
    print(f"Done in {time.time()-t0:.2f}s")
