# 0) Prologue: What you'll build

Goal: implement and *understand* attention — first in **NumPy**, then mirror it in **PyTorch**, and finally use it in a tiny training task.

Plan:
1) NumPy: stable softmax → causal mask → Scaled Dot-Product Attention (SDPA) → Multi-Head Attention (MHA).
2) PyTorch: SDPA → compare with `torch.nn.functional.scaled_dot_product_attention` → your own `MyMHA` vs `nn.MultiheadAttention`.
3) Mini task: tiny copy task or bigram LM to verify it learns.

Each code cell contains function stubs that raise `NotImplementedError`. Your job: implement them and run the checks.


In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

np.random.seed(42)
torch.manual_seed(42)

# Tiny demo sizes (feel free to adjust)
B = 2          # batch size
T = 6          # sequence length
d_model = 32   # model dimension
n_heads = 4    # number of heads
d_head = d_model // n_heads

# Dummy inputs
x_np = np.random.randn(B, T, d_model).astype(np.float32)
x_t  = torch.tensor(x_np, dtype=torch.float32)

print("Shapes:", x_np.shape, x_t.shape)


# 1) NumPy: numerically stable softmax

**Task:** Implement a numerically stable softmax along a given axis:
- subtract the maximum along that axis
- exponentiate
- normalize by the sum

**Checks:** Sum over the axis should be ~1; values in (0, 1).

In [None]:
def softmax_stable(x: np.ndarray, axis: int = -1) -> np.ndarray:
    """
    Implement numerically stable softmax along a given axis:
    - subtract max along axis
    - exponentiate
    - normalize by sum
    """
    raise NotImplementedError("Implement stable softmax in NumPy")


# --- Minimal checks (run after implementation) ---
# y = softmax_stable(np.random.randn(2,3,4).astype(np.float32), axis=-1)
# assert y.shape == (2,3,4)
# assert np.allclose(y.sum(axis=-1), 1.0, atol=1e-5)
# assert (y > 0).all() and (y < 1).all()


# 2) NumPy: causal mask (no peeking ahead)

**Task:** Create a causal mask of shape `[T, T]` that forbids attending to the future.
- If `dtype="bool"`: `True` means "masked out".
- If `dtype="float"`: 0.0 for keep, `-inf` for masked positions.

**Hint:** `np.triu` helps.

In [None]:
def causal_mask_np(T: int, dtype: str = "bool") -> np.ndarray:
    """
    Create an upper-triangular causal mask of shape [T, T] that masks j > i.
    dtype:
      - "bool"  -> boolean mask (True means "mask out")
      - "float" -> float mask with 0 for keep and -inf for masked positions
    """
    raise NotImplementedError("Implement causal mask for NumPy")


# --- Minimal checks ---
# m_bool = causal_mask_np(T, dtype="bool")
# assert m_bool.shape == (T, T) and m_bool.dtype == np.bool_
# assert m_bool[0,1] == True and m_bool[1,0] == False
# m_flt = causal_mask_np(T, dtype="float")
# assert m_flt.shape == (T, T)

# 3) NumPy: Scaled Dot-Product Attention (SDPA)

**Task:** Implement `scaled_dot_product_attention_np(Q, K, V, mask=None)`.

Steps:
1) `scores = Q @ K^T / sqrt(d_k)` → shape `[B, T, T]`
2) Apply mask:
   - if float mask: add directly (0.0 or `-inf`)
   - if bool mask: add a large negative number (e.g., `-1e9`) where `True`
3) `probs = softmax_stable(scores, axis=-1)`
4) `out = probs @ V`

Return `(out, probs)`. Shapes: `out [B,T,d]`, `probs [B,T,T]`.

In [None]:
def scaled_dot_product_attention_np(Q: np.ndarray,
                                    K: np.ndarray,
                                    V: np.ndarray,
                                    mask: np.ndarray | None = None):
    """
    Compute scaled dot-product attention in NumPy.
    Shapes:
      Q,K,V: [B, T, d_k]   => scores: [B, T, T]   => out: [B, T, d_v] (here d_v=d_k)
    Return: out, attn_probs
    """
    raise NotImplementedError("Implement NumPy SDPA")


# --- Minimal shape check after implementation ---
# Q = np.random.randn(B, T, d_head).astype(np.float32)
# K = np.random.randn(B, T, d_head).astype(np.float32)
# V = np.random.randn(B, T, d_head).astype(np.float32)
# cm = causal_mask_np(T, dtype="float")
# out, p = scaled_dot_product_attention_np(Q, K, V, mask=cm)
# assert out.shape == (B, T, d_head)
# assert p.shape == (B, T, T)

# 4) NumPy: split/combine heads

**Task:** Implement:
- `split_heads_np([B,T,d_model]) -> [B,n_heads,T,d_head]`
- `combine_heads_np([B,n_heads,T,d_head]) -> [B,T,d_model]`

**Note:** Prefer reshaping/transposing without copies when possible.

In [None]:
def split_heads_np(x: np.ndarray, n_heads: int) -> np.ndarray:
    """Reshape [B,T,d_model] -> [B,n_heads,T,d_head] without copying if possible."""
    raise NotImplementedError("Implement split_heads for NumPy")


def combine_heads_np(x: np.ndarray) -> np.ndarray:
    """Reshape [B,n_heads,T,d_head] -> [B,T,d_model]."""
    raise NotImplementedError("Implement combine_heads for NumPy")


# --- Minimal checks ---
# a = np.random.randn(B, T, d_model).astype(np.float32)
# h = split_heads_np(a, n_heads)
# assert h.shape == (B, n_heads, T, d_head)
# a2 = combine_heads_np(h)
# assert a2.shape == (B, T, d_model)
# assert np.allclose(a, a2)

# 5) NumPy: full Multi-Head Attention

**Task:** Implement `multi_head_attention_np` using linear projections `Wq,Wk,Wv,Wo`.

Steps:
1) `Q = x @ Wq`, `K = x @ Wk`, `V = x @ Wv`
2) split into heads
3) SDPA per head (same mask for all heads)
4) combine heads, then `out = out @ Wo`

Return `(out, attn_probs)` (you may return per-head probabilities or an average).

In [None]:
# Fixed random weights for reproducibility
Wq = np.random.randn(d_model, d_model).astype(np.float32) / math.sqrt(d_model)
Wk = np.random.randn(d_model, d_model).astype(np.float32) / math.sqrt(d_model)
Wv = np.random.randn(d_model, d_model).astype(np.float32) / math.sqrt(d_model)
Wo = np.random.randn(d_model, d_model).astype(np.float32) / math.sqrt(d_model)

def multi_head_attention_np(x: np.ndarray,
                            Wq: np.ndarray, Wk: np.ndarray, Wv: np.ndarray, Wo: np.ndarray,
                            mask: np.ndarray | None = None):
    """
    Implement Multi-Head Attention with NumPy.
    Return: (out, attn_probs_merged_or_list)
    - You may return per-head probs or an averaged [B,T,T].
    """
    raise NotImplementedError("Implement NumPy MHA")


# --- Minimal checks ---
# cm = causal_mask_np(T, dtype="float")
# out_np, probs_np = multi_head_attention_np(x_np, Wq, Wk, Wv, Wo, mask=cm)
# assert out_np.shape == (B, T, d_model)

# 6) PyTorch: SDPA by hand and compare to reference

**Task:** Implement `scaled_dot_product_attention_torch(Q,K,V, attn_mask=None)` and compare it with
`torch.nn.functional.scaled_dot_product_attention` (no dropout).

**Hints:**
- Manual: `scores = (Q @ K.transpose(-2,-1)) / sqrt(d)` → masked fill → softmax → `@ V`.
- `attn_mask` can be `bool` (True=mask) or `float` (`-inf` where masked).

In [None]:
def scaled_dot_product_attention_torch(Q: torch.Tensor,
                                       K: torch.Tensor,
                                       V: torch.Tensor,
                                       attn_mask: torch.Tensor | None = None):
    """
    Implement SDPA in PyTorch. Return (out, probs).
    Shapes: Q,K,V [B,T,d], attn_mask [T,T] or [B,1,T,T]
    """
    raise NotImplementedError("Implement PyTorch SDPA")


# --- Comparison harness (run after implementation) ---
# Q = torch.randn(B, T, d_head)
# K = torch.randn(B, T, d_head)
# V = torch.randn(B, T, d_head)
# cm_bool = torch.from_numpy(causal_mask_np(T, dtype="bool"))
# out_ref = F.scaled_dot_product_attention(Q, K, V, attn_mask=cm_bool, dropout_p=0.0, is_causal=False)
# out_my, p_my = scaled_dot_product_attention_torch(Q, K, V, attn_mask=cm_bool)
# assert torch.allclose(out_ref, out_my, atol=1e-5), "Your SDPA doesn't match F.sdpa"

# 7) PyTorch: your own Multi-Head Attention vs `nn.MultiheadAttention`

**Task:** Implement `MyMHA` (Q/K/V/O linear layers; split/combine; SDPA).
Then compare outputs with `nn.MultiheadAttention` **after copying weights**.

**Tip:** `nn.MultiheadAttention` packs QKV in `in_proj_weight`/`in_proj_bias`. You can copy those slices into your module.

In [None]:
class MyMHA(nn.Module):
    def __init__(self, d_model: int, n_heads: int, bias: bool = True):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        # Define projections: in->Q,K,V, and out->O
        raise NotImplementedError("Define linear layers for Q,K,V,O")

    def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None):
        """
        x: [B,T,d_model]
        Return: out [B,T,d_model], attn_probs [B, n_heads, T, T] (optional)
        """
        raise NotImplementedError("Implement forward: project -> split -> SDPA -> combine -> out")


# --- Comparison (run after implementation & weight copy) ---
# mha_ref = nn.MultiheadAttention(d_model, n_heads, batch_first=True, bias=True)
# my = MyMHA(d_model, n_heads, bias=True)
# # TODO: copy weights from mha_ref into my
# cm_bool = torch.from_numpy(causal_mask_np(T, dtype="bool"))
# out_ref, _ = mha_ref(x_t, x_t, x_t, attn_mask=cm_bool)          # batch_first=True
# out_my, _ = my(x_t, attn_mask=cm_bool)
# assert torch.allclose(out_ref, out_my, atol=1e-4), "Mismatch between MyMHA and nn.MultiheadAttention"

# 8) Mini task: copy-task or tiny bigram LM

Choose one:

**Copy task:** input is a sequence over {0,1,2,3}; predict the same sequence shifted by 1.
**Bigram LM:** given a tiny text corpus, predict next token.

**Tasks:**
1) Implement `TinyEmbedding`, `TinyBlock` (MHA + FFN), `TinyLMHead`.
2) Write a simple training loop (cross-entropy, Adam).
3) Verify that the loss decreases and predictions improve.

In [None]:
class TinyEmbedding(nn.Module):
    def __init__(self, vocab_size: int, d_model: int):
        super().__init__()
        raise NotImplementedError("Create token embedding and optional positional embedding")

    def forward(self, idx: torch.Tensor) -> torch.Tensor:
        """
        idx: [B,T] -> return [B,T,d_model]
        """
        raise NotImplementedError("Implement forward for embeddings")


class TinyBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, mlp_ratio: float = 4.0, dropout: float = 0.0):
        super().__init__()
        raise NotImplementedError("Create MyMHA (or nn.MultiheadAttention), LayerNorms, and MLP")

    def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor:
        raise NotImplementedError("Implement Transformer block forward")


class TinyLMHead(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        raise NotImplementedError("Create output projection to vocab")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [B,T,d_model] -> logits: [B,T,vocab]
        """
        raise NotImplementedError("Implement forward")


class TinyModel(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, n_heads: int, n_layers: int):
        super().__init__()
        raise NotImplementedError("Assemble Embedding -> n*Blocks -> Head")

    def forward(self, idx: torch.Tensor, attn_mask: torch.Tensor | None = None):
        raise NotImplementedError("Implement forward to produce logits")


# --- Simple dataset for copy task (example skeleton) ---
# def make_copy_batch(B:int, T:int, vocab:int=4):
#     data = torch.randint(low=0, high=vocab, size=(B,T))
#     x = data.clone()
#     y = data.clone()  # predict same with shift by 1 later in loss
#     return x, y
#
# --- Training loop sketch (fill in) ---
# model = TinyModel(vocab_size=4, d_model=32, n_heads=4, n_layers=2)
# opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
# for step in range(1000):
#     x,y = make_copy_batch(B=32, T=16, vocab=4)
#     logits = model(x)  # [B,T,vocab]
#     loss = F.cross_entropy(logits[:,:-1].reshape(-1, 4), y[:,1:].reshape(-1))
#     opt.zero_grad(); loss.backward(); opt.step()
#     if step % 100 == 0:
#         print(step, float(loss))

# 9) Debugging checklist

- **Masks:** Verify that position *t* never attends to *t+1…* under a causal mask.
- **Parity tests:** Match your PyTorch SDPA against `F.scaled_dot_product_attention` on random tensors and multiple sizes.
- **MyMHA vs nn.MultiheadAttention:** After copying weights, outputs should match within a small tolerance.
- **Training dynamics:** If loss doesn't drop, re-check normalization, mask broadcasting, and dimension reshapes/transposes.