
# NumPy Transformer Attention: Causal, KV-Cache, RoPE, GQA, and FlashAttention-Style

This notebook builds a minimal-yet-complete transformer attention stack **purely in NumPy**, suitable for demos and experiments:

1. **Core attention (softmax, masking)**
2. **Causal self-attention & KV cache**
3. **LayerNorm, MLP, Transformer block**
4. **Rotary Positional Embeddings (RoPE)**
5. **Grouped Query Attention (GQA)**
6. **FlashAttention-style streaming softmax**
7. **Integration into MHA + Decoder + simple LLM.generate**

Each section includes sanity tests. You can toggle RoPE/GQA/Flash in the config section near the bottom.


In [1]:

import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple
from tqdm import tqdm


ModuleNotFoundError: No module named 'tqdm'

## Core utilities: softmax and causal mask

In [None]:

def softmax(x, axis=-1):
    x = x - np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

def causal_mask(q_len, kv_len, dtype=bool):
    """
    Returns a (q_len, kv_len) boolean mask where True = masked (disallowed).
    Works for KV caching (kv_len >= q_len) and plain self-attn.
    """
    shift = kv_len - q_len
    q_idx = np.arange(q_len)[:, None]                  # (q,1)
    k_idx = np.arange(kv_len)[None, :] - shift         # (1,kv) align last q with last k
    return (k_idx > q_idx).astype(dtype)               # True => future (mask)


## Rotary Positional Embeddings (RoPE)

In [None]:

def _rope_freqs(head_dim, seq_len, base=10000.0, dtype=np.float32):
    assert head_dim % 2 == 0, "RoPE requires even head_dim"
    half = head_dim // 2
    inv_freq = 1.0 / (base ** (np.arange(0, half, dtype=dtype) / half))
    t = np.arange(seq_len, dtype=dtype)[:, None] * inv_freq[None, :]  # (seq, half)
    cos = np.cos(t); sin = np.sin(t)

    cos_full = np.empty((seq_len, head_dim), dtype=dtype)
    sin_full = np.empty((seq_len, head_dim), dtype=dtype)
    cos_full[:, 0::2] = cos; cos_full[:, 1::2] = cos
    sin_full[:, 0::2] = sin; sin_full[:, 1::2] = sin
    return cos_full, sin_full

def _rope_rotate_half(x):
    x1 = x[..., 0::2]
    x2 = x[..., 1::2]
    out = np.empty_like(x)
    out[..., 0::2] = -x2
    out[..., 1::2] =  x1
    return out

def apply_rope(q, k, pos_q, pos_k=None, base=10000.0):
    """
    Apply RoPE to q,k per sequence position.
    q: (n_head, q_len, head_dim)
    k: (n_head, kv_len, head_dim)
    pos_q: (q_len,) integer absolute positions
    pos_k: (kv_len,) (defaults to range(kv_len))
    Returns rotated q,k with same shapes.
    """
    n, qlen, d = q.shape
    _, kvlen, _ = k.shape
    if pos_k is None:
        pos_k = np.arange(kvlen, dtype=np.int64)

    cos_q, sin_q = _rope_freqs(d, qlen, base=base, dtype=q.dtype)
    cos_k, sin_k = _rope_freqs(d, kvlen, base=base, dtype=k.dtype)

    cos_q = cos_q[pos_q][None, :, :]  # (1,q,d)
    sin_q = sin_q[pos_q][None, :, :]
    cos_k = cos_k[pos_k][None, :, :]  # (1,kv,d)
    sin_k = sin_k[pos_k][None, :, :]

    q_rot = q * cos_q + _rope_rotate_half(q) * sin_q
    k_rot = k * cos_k + _rope_rotate_half(k) * sin_k
    return q_rot, k_rot


## Attention primitives: dense, FlashAttention-style, and GQA

In [None]:

def single_headed_attention(q, k, v):
    """
    Single-headed causal attention with KV support.
    q: (q, d) k: (kv, d) v: (kv, d)
    returns: (q, d)
    """
    d = q.shape[-1]
    scores = (q @ k.T) / np.sqrt(d)                 # (q, kv)
    mask = causal_mask(q.shape[0], k.shape[0])      # (q, kv)
    scores = np.where(mask, -1e9, scores)
    p = softmax(scores, axis=-1)                    # (q, kv)
    return p @ v                                    # (q, d)

def multi_headed_attention(q, k, v):
    """
    q: (n, q, d)  k: (n, kv, d)  v: (n, kv, d)
    returns: (n, q, d)
    """
    n, qlen, d = q.shape
    kvlen = k.shape[1]
    scores = np.einsum('nqd,nkd->nqk', q, k) / np.sqrt(d)   # (n,q,kv)
    mask = causal_mask(qlen, kvlen)[None, :, :]             # (1,q,kv)
    scores = np.where(mask, -1e9, scores)
    p = softmax(scores, axis=-1)
    return np.einsum('nqk,nkd->nqd', p, v)

def attention_flash(q, k, v, block_size=128):
    """
    Single-head causal attention with streaming softmax (FlashAttention style).
    q: (q, d) k: (kv, d) v: (kv, d)
    returns: (q, d)
    """
    qlen, d = q.shape
    kvlen = k.shape[0]
    scale = np.sqrt(d)

    m = np.full((qlen,), -np.inf, dtype=q.dtype)
    l = np.zeros((qlen,), dtype=q.dtype)
    out = np.zeros((qlen, d), dtype=q.dtype)

    shift = kvlen - qlen
    q_idx = np.arange(qlen)[:, None]

    for start in range(0, kvlen, block_size):
        end = min(start + block_size, kvlen)
        k_blk = k[start:end]                       # (b, d)
        v_blk = v[start:end]                       # (b, d)

        scores = (q @ k_blk.T) / scale             # (q, b)
        k_idx_blk = (np.arange(start, end)[None, :] - shift)
        mask = (k_idx_blk > q_idx)                 # (q, b)
        scores = np.where(mask, -1e9, scores)

        m_blk = np.max(scores, axis=-1)            # (q,)
        m_new = np.maximum(m, m_blk)

        alpha = np.exp(m - m_new)
        beta  = np.exp(m_blk - m_new)

        p_blk_unnorm = np.exp(scores - m_blk[:, None])
        l = alpha * l + beta * np.sum(p_blk_unnorm, axis=-1)
        out = (alpha[:, None] * out) + (p_blk_unnorm @ v_blk)

        m = m_new

    out = out / l[:, None]
    return out

def multi_headed_attention_flash(q, k, v, block_size=128):
    n, qlen, d = q.shape
    out = np.empty_like(q)
    for h in range(n):
        out[h] = attention_flash(q[h], k[h], v[h], block_size=block_size)
    return out

def multi_headed_attention_gqa(q, k, v):
    """
    GQA causal attention.
    q: (n_q, q, d) ; k: (n_kv, kv, d) ; v: (n_kv, kv, d)
    Returns: (n_q, q, d)
    """
    n_q, qlen, d = q.shape
    n_kv, kvlen, _ = k.shape
    assert n_q % n_kv == 0, "n_q must be divisible by n_kv"
    group = n_q // n_kv
    out = np.empty_like(q)
    mask = causal_mask(qlen, kvlen)[None, :, :]
    for kvh in range(n_kv):
        qh_start, qh_end = kvh*group, (kvh+1)*group
        q_slice = q[qh_start:qh_end]                 # (group,q,d)
        k_h = k[kvh:kvh+1]                           # (1,kv,d)
        v_h = v[kvh:kvh+1]                           # (1,kv,d)
        scores = np.einsum('gqd,nkd->gqk', q_slice, k_h) / np.sqrt(d)
        scores = np.where(mask, -1e9, scores)
        p = softmax(scores, axis=-1)
        out[qh_start:qh_end] = np.einsum('gqk,nkd->gqd', p, v_h)
    return out

def multi_headed_attention_flash_gqa(q, k, v, block_size=128):
    n_q, qlen, d = q.shape
    n_kv, kvlen, _ = k.shape
    assert n_q % n_kv == 0
    group = n_q // n_kv
    out = np.empty_like(q)
    for kvh in range(n_kv):
        qh_start, qh_end = kvh*group, (kvh+1)*group
        for h in range(qh_start, qh_end):
            out[h] = attention_flash(q[h], k[kvh], v[kvh], block_size=block_size)
    return out


## Split/Merge heads and basic layers

In [None]:

def split_heads(x, n_head):
    # x: (seq, hidden) -> (n_head, seq, head_dim)
    seq, hidden = x.shape
    d = hidden // n_head
    return np.transpose(x.reshape(seq, n_head, d), (1,0,2))

def merge_heads(x):
    # x: (n_head, seq, head_dim) -> (seq, hidden)
    n, seq, d = x.shape
    return np.transpose(x, (1,0,2)).reshape(seq, n*d)

def gelu(x):
    return 0.5 * x * (1 + np.tanh(np.sqrt(2/np.pi) * (x + 0.044715 * x**3)))

def layer_norm(x, g, b, eps: float = 1e-5):
    mu = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return g * (x - mu) / np.sqrt(var + eps) + b

def linear(x, w, b):
    return x @ w + b

def ffn(x, c_fc, c_proj):
    return linear(gelu(linear(x, **c_fc)), **c_proj)


## Multi-Head Attention (MHA) with flags: RoPE, GQA, Flash

In [None]:

@dataclass
class AttnConfig:
    n_head: int
    n_kv_head: int = None         # if None => standard MHA; else => GQA
    use_rope: bool = False
    rope_base: float = 10000.0
    use_flash: bool = False
    flash_block: int = 128

def mha_dispatch(q, k_cache, v_cache, c_attn, c_proj, attn_cfg: AttnConfig, pos_offset: int = 0):
    """
    x-projection, KV-cache update, split heads, optional RoPE, and attention variant.
    q: (seq, hidden)  (this is the input x to the attn block, already pre-LN)
    k_cache/v_cache: (kv, hidden) caches for this layer
    returns: (seq, hidden), k_cache, v_cache
    """
    x = linear(q, **c_attn)                     # (seq, 3H)
    q_, k_, v_ = np.split(x, 3, axis=-1)        # (seq,H) each

    # update caches
    k_cache = np.concatenate([k_cache, k_], axis=0)
    v_cache = np.concatenate([v_cache, v_], axis=0)

    if attn_cfg.n_kv_head is None:
        # standard MHA
        qh = split_heads(q_, attn_cfg.n_head)               # (n, q, d)
        kh = split_heads(k_cache, attn_cfg.n_head)          # (n, kv, d)
        vh = split_heads(v_cache, attn_cfg.n_head)          # (n, kv, d)
        if attn_cfg.use_rope:
            q_pos = np.arange(pos_offset, pos_offset + qh.shape[1], dtype=np.int64)
            k_pos = np.arange(0, kh.shape[1], dtype=np.int64)
            qh, kh = apply_rope(qh, kh, q_pos, k_pos, base=attn_cfg.rope_base)
        if attn_cfg.use_flash:
            ah = multi_headed_attention_flash(qh, kh, vh, block_size=attn_cfg.flash_block)
        else:
            ah = multi_headed_attention(qh, kh, vh)
        out = merge_heads(ah)
    else:
        # GQA
        n_kv = attn_cfg.n_kv_head
        n_q = attn_cfg.n_head
        qh = split_heads(q_, n_q)                            # (n_q,q,d)
        kh = split_heads(k_cache, n_kv)                      # (n_kv,kv,d)
        vh = split_heads(v_cache, n_kv)                      # (n_kv,kv,d)
        if attn_cfg.use_rope:
            q_pos = np.arange(pos_offset, pos_offset + qh.shape[1], dtype=np.int64)
            k_pos = np.arange(0, kh.shape[1], dtype=np.int64)
            # Rotate q and k independently
            qh, _ = apply_rope(qh, qh, q_pos, q_pos, base=attn_cfg.rope_base)
            _, kh = apply_rope(kh, kh, k_pos, k_pos, base=attn_cfg.rope_base)
        if attn_cfg.use_flash:
            ah = multi_headed_attention_flash_gqa(qh, kh, vh, block_size=attn_cfg.flash_block)
        else:
            ah = multi_headed_attention_gqa(qh, kh, vh)
        out = merge_heads(ah)

    out = linear(out, **c_proj)
    return out, k_cache, v_cache


## Transformer Block and Decoder

In [None]:

def transformer_block(x, mlp, attn, ln_1, ln_2, attn_cfg: AttnConfig, k_cache, v_cache, pos_offset: int):
    a, k_cache, v_cache = mha_dispatch(
        q=layer_norm(x, **ln_1),
        k_cache=k_cache, v_cache=v_cache,
        c_attn=attn['c_attn'], c_proj=attn['c_proj'],
        attn_cfg=attn_cfg, pos_offset=pos_offset
    )
    x = x + a
    x = x + ffn(layer_norm(x, **ln_2), **mlp)
    return x, k_cache, v_cache

def decoder(input_ids, seq_len, wte, wpe, blocks, ln_f, attn_cfg: AttnConfig, k_cache, v_cache):
    if isinstance(input_ids, list):  # prompt pass
        x = wte[input_ids] + wpe[np.arange(len(input_ids))]
    else:  # single token generation
        x = wte[[input_ids]] + wpe[[seq_len - 1]]

    # pos_offset is absolute position of first token in this x chunk
    # If input_ids is list, it's positioned at 0..len-1.
    # If scalar, it is at position seq_len-1.
    pos_offset = 0 if isinstance(input_ids, list) else (seq_len - 1)

    for i, block in enumerate(blocks):
        x, k_cache_i, v_cache_i = transformer_block(
            x, mlp=block['mlp'], attn=block['attn'],
            ln_1=block['ln_1'], ln_2=block['ln_2'],
            attn_cfg=attn_cfg,
            k_cache=k_cache[i], v_cache=v_cache[i],
            pos_offset=pos_offset
        )
        k_cache[i] = k_cache_i
        v_cache[i] = v_cache_i

    logits = layer_norm(x, **ln_f) @ wte.T
    probs = softmax(logits, axis=-1)
    return probs


## Tiny tokenizer and random parameters (for demo purposes)

In [None]:

class TinyTokenizer:
    def __init__(self):
        # simple charset (extend as needed)
        chars = ['<pad>','<unk>'] +                 list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ .,!?'-")
        self.stoi = {ch:i for i,ch in enumerate(chars)}
        self.itos = {i:ch for ch,i in self.stoi.items()}
    def encode(self, s: str):
        return [self.stoi.get(ch, 1) for ch in s]
    def decode(self, ids: List[int]):
        return ''.join(self.itos.get(i, '?') for i in ids)

def init_params(vocab_size, hidden_dim, max_ctx, n_layer, n_head, n_kv_head=None, seed=42):
    rng = np.random.default_rng(seed)
    params = {}
    params['wte'] = rng.normal(0, 0.02, size=(vocab_size, hidden_dim)).astype(np.float32)
    params['wpe'] = rng.normal(0, 0.02, size=(max_ctx, hidden_dim)).astype(np.float32)

    blocks = []
    for _ in range(n_layer):
        c_attn = {
            'w': rng.normal(0, 0.02, size=(hidden_dim, 3*hidden_dim)).astype(np.float32),
            'b': np.zeros((3*hidden_dim,), dtype=np.float32),
        }
        c_proj = {
            'w': rng.normal(0, 0.02, size=(hidden_dim, hidden_dim)).astype(np.float32),
            'b': np.zeros((hidden_dim,), dtype=np.float32),
        }
        c_fc = {
            'w': rng.normal(0, 0.02, size=(hidden_dim, 4*hidden_dim)).astype(np.float32),
            'b': np.zeros((4*hidden_dim,), dtype=np.float32),
        }
        c_mlp_proj = {
            'w': rng.normal(0, 0.02, size=(4*hidden_dim, hidden_dim)).astype(np.float32),
            'b': np.zeros((hidden_dim,), dtype=np.float32),
        }
        ln_1 = {'g': np.ones((hidden_dim,), dtype=np.float32), 'b': np.zeros((hidden_dim,), dtype=np.float32)}
        ln_2 = {'g': np.ones((hidden_dim,), dtype=np.float32), 'b': np.zeros((hidden_dim,), dtype=np.float32)}
        blocks.append({
            'attn': {'c_attn': c_attn, 'c_proj': c_proj},
            'mlp': {'c_fc': c_fc, 'c_proj': c_mlp_proj},
            'ln_1': ln_1,
            'ln_2': ln_2,
        })
    ln_f = {'g': np.ones((hidden_dim,), dtype=np.float32), 'b': np.zeros((hidden_dim,), dtype=np.float32)}

    return {'wte': params['wte'], 'wpe': params['wpe'], 'blocks': blocks, 'ln_f': ln_f}


## LLM wrapper with KV cache and generation

In [None]:

class LLM:
    def __init__(self, tokenizer, params, model_config, attn_cfg: AttnConfig):
        self.tokenizer = tokenizer
        self.params = params
        self.model_config = model_config
        self.attn_cfg = attn_cfg

    def _init_cache(self):
        num_layers = self.model_config['n_layer']
        hidden_dim = self.model_config['n_embd']
        k_cache = [np.zeros((0, hidden_dim), dtype=np.float32) for _ in range(num_layers)]
        v_cache = [np.zeros((0, hidden_dim), dtype=np.float32) for _ in range(num_layers)]
        return k_cache, v_cache

    def forward_prompt(self, input_ids):
        k_cache, v_cache = self._init_cache()
        probs = decoder(
            input_ids, len(input_ids),
            wte=self.params['wte'], wpe=self.params['wpe'],
            blocks=self.params['blocks'], ln_f=self.params['ln_f'],
            attn_cfg=self.attn_cfg, k_cache=k_cache, v_cache=v_cache
        )
        return probs, k_cache, v_cache

    def generate(self, prompt, max_new_tokens=20, greedy=True):
        input_ids = self.tokenizer.encode(prompt)
        # prime caches with prompt
        _, k_cache, v_cache = self.forward_prompt(input_ids)

        output_ids = []
        seq_len = len(input_ids)
        for _ in tqdm(range(max_new_tokens), desc='Generating', dynamic_ncols=True):
            last_id = input_ids[-1] if len(output_ids)==0 else output_ids[-1]
            probs = decoder(
                last_id, seq_len + 1,
                wte=self.params['wte'], wpe=self.params['wpe'],
                blocks=self.params['blocks'], ln_f=self.params['ln_f'],
                attn_cfg=self.attn_cfg, k_cache=k_cache, v_cache=v_cache
            )
            next_id = int(np.argmax(probs[-1])) if greedy else int(np.random.choice(probs.shape[-1], p=probs[-1]))
            output_ids.append(next_id)
            seq_len += 1
        return self.tokenizer.decode(output_ids)


## Sanity tests

In [None]:

def _attn_dense(q, k, v):
    scores = (q @ k.T) / np.sqrt(q.shape[-1])
    scores = np.where(causal_mask(q.shape[0], k.shape[0]), -1e9, scores)
    p = softmax(scores, axis=-1)
    return p @ v

def run_tests():
    rng = np.random.default_rng(0)

    # softmax
    x = np.array([[1., 2., 3.], [1000., 1000., 1000.]], dtype=np.float32)
    s = softmax(x, axis=-1)
    assert np.allclose(s.sum(-1), 1.0)

    # causal mask
    assert causal_mask(3,3).shape == (3,3)
    assert causal_mask(1,5).shape == (1,5)
    assert causal_mask(5,1).shape == (5,1)

    # single-head + flash equivalence
    q = rng.normal(size=(7, 16)).astype(np.float32)
    k = rng.normal(size=(23, 16)).astype(np.float32)
    v = rng.normal(size=(23, 16)).astype(np.float32)
    d_dense  = _attn_dense(q, k, v)
    d_flash  = attention_flash(q, k, v, block_size=6)
    assert np.allclose(d_dense, d_flash, atol=1e-5)

    # multi-head
    qh = rng.normal(size=(4, 5, 16)).astype(np.float32)
    kh = rng.normal(size=(4, 9, 16)).astype(np.float32)
    vh = rng.normal(size=(4, 9, 16)).astype(np.float32)
    o  = multi_headed_attention(qh, kh, vh)
    of = multi_headed_attention_flash(qh, kh, vh, block_size=4)
    assert o.shape == of.shape == (4,5,16)

    # GQA
    qg = rng.normal(size=(8, 5, 16)).astype(np.float32)
    kg = rng.normal(size=(2, 9, 16)).astype(np.float32)
    vg = rng.normal(size=(2, 9, 16)).astype(np.float32)
    og  = multi_headed_attention_gqa(qg, kg, vg)
    ogf = multi_headed_attention_flash_gqa(qg, kg, vg, block_size=4)
    assert og.shape == ogf.shape == (8,5,16)

    # RoPE shapes
    qh2 = rng.normal(size=(4, 6, 16)).astype(np.float32)
    kh2 = rng.normal(size=(4, 10, 16)).astype(np.float32)
    qrot, krot = apply_rope(qh2, kh2, np.arange(6), np.arange(10))
    assert qrot.shape == qh2.shape and krot.shape == kh2.shape

    # split/merge
    X = rng.normal(size=(7, 32)).astype(np.float32)
    assert np.allclose(X, merge_heads(split_heads(X, 4)))

    print("All tests passed ✅")

run_tests()


## Configure a tiny model and try generation

In [None]:

# Model config
model_config = {
    'vocab_size': 80,   # from TinyTokenizer size (will compute below)
    'n_embd': 32,
    'max_ctx': 128,
    'n_layer': 2,
    'n_head': 4,
    # 'n_kv_head': 2,    # uncomment to enable GQA (n_head must be divisible by n_kv_head)
}

# Build tokenizer to get vocab size
tok = TinyTokenizer()
model_config['vocab_size'] = len(tok.stoi)

# Init random params
params = init_params(
    vocab_size=model_config['vocab_size'],
    hidden_dim=model_config['n_embd'],
    max_ctx=model_config['max_ctx'],
    n_layer=model_config['n_layer'],
    n_head=model_config['n_head'],
)

# Attention config toggles
attn_cfg = AttnConfig(
    n_head=model_config['n_head'],
    n_kv_head=None,          # set to 2 to enable GQA with 4 query heads -> 2 kv heads
    use_rope=True,           # toggle RoPE
    rope_base=10000.0,
    use_flash=True,          # toggle FlashAttention-style
    flash_block=32
)

# Instantiate model
llm = LLM(tok, params, model_config, attn_cfg)

# Try generation (random weights => babble)
prompt = "Hello"
print("Prompt:", prompt)
print("---- Generating (greedy) ----")
print(llm.generate(prompt, max_new_tokens=40, greedy=True))
