# Coding Assignment 3: Build a Tiny Transformer

**Learning goals**
- Add **positional information** with **sinusoidal positional encodings**.
- Understand and implement **self-attention** 
- Implement **multi-head self-attention** by splitting/combining heads.
- Train or sample from a **tiny character-level LM** on CPU.


In [4]:
# Minimal setup — CPU only, deterministic runs
import math
import random
import textwrap
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

SEED = 476
random.seed(SEED); torch.manual_seed(SEED)
DEVICE = torch.device("cpu")
print("PyTorch:", torch.__version__, "| Device:", DEVICE)

  device: torch.device = torch.device(torch._C._get_default_device()),  # torch.device('cpu'),


PyTorch: 2.2.2 | Device: cpu


## Dataset & Character Vocabulary

In this section, we load a tiny text corpus (`shakespeare.txt`) and build a **character-level** vocabulary so our model can learn to predict the next **character**.

In [5]:
with open("shakespeare.txt", "r") as f:
    NANO_SHAKESPEARE = f.read()

# --- Build a tiny char-level vocab ---
chars = sorted(list(set(NANO_SHAKESPEARE)))
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for ch,i in stoi.items()}
vocab_size = len(stoi)
print("Unique chars:", vocab_size)

def encode(s: str):
    return torch.tensor([stoi[c] for c in s], dtype=torch.long)

def decode(ids: torch.Tensor):
    return ''.join(itos[int(i)] for i in ids)

# train/val split (character-level)
data = encode(NANO_SHAKESPEARE)
split = int(0.9 * len(data))
train_ids, val_ids = data[:split], data[split:]
len(train_ids), len(val_ids)


Unique chars: 65


(1003853, 111540)

## Background (5-minute primer)

**Self-attention (one layer).** For each token, we compute a weighted sum of other token vectors. The weights come from a **similarity** between a **query** (Q) and **keys** (K):  
`attention(Q,K,V) = softmax(QKᵀ / √dₖ) V`  
Scaling by √dₖ keeps the dot products in a numerically friendly range.  
We use **values** (V) as the information that actually flows forward.

**Multi-head self-attention (MHA).** Instead of one attention, we project the inputs into multiple Q/K/V sets (“heads”), attend in parallel, then **concatenate** and project back. Each head can focus on different patterns.

**Why positional encodings?** Attention alone doesn’t know token order. We **add** a position vector to token embeddings so the model can represent order. In this assignment, you’ll build the **sinusoidal** version.

**Transformer LM and Transformer block.** We build a transformer LM that consists of many transformer blocks. One transformer block has two main components with residual connections. 
1) LayerNorm → MHA → residual add  
2) LayerNorm → Feed-forward (GELU) → residual add

### LayerNorm 

LayerNorm is a normalization technique that stabilizes training. We have implemented it for you. Here's an introduction if you are interested:
- **What it does:** For each token’s feature vector (the last dimension), LayerNorm makes the values have roughly **mean 0** and **variance 1**, then applies a learned **scale (γ)** and **shift (β)**.  
- **Why it helps:** Keeps activations in a stable range so **training is smoother**, gradients flow better, and deep Transformer stacks don’t blow up or stall.  
- **How it differs from BatchNorm:** **BatchNorm** normalizes *across the batch* (depends on batch size/order). **LayerNorm** normalizes *within each token* (independent of batch), so it works well with **small/variable batches** and **variable sequence lengths**.  
- **Where it goes (pre-norm):** Common today is **LayerNorm → sublayer (MHA/FFN) → residual add**. Putting LN first makes optimization more robust.  
- **Mental model:** Think of it as “standardizing” each token’s features so every sublayer sees inputs with comparable scale, making the network easier to train.

### Causal Masks

For **autoregressive** language modeling (next-token prediction), training must mimic generation: the model should never use future tokens to predict the current one. We enforce this by applying a causal mask to attention scores so that, at position i, forbidden positions (positions > i) are replaced with a **large negative** number (≈ −1e9). After softmax, those entries become (near) **0**. We provide a helper method `make_causal_mask` that return a boolean mask to indicate forbidden positions. Later in this assignment you will use it to replace the forbidden positions with a large negative number (more details in Task B).

In [6]:
@dataclass
class ModelConfig:
    d_model: int = 128
    n_heads: int = 4
    n_layers: int = 2
    ff_mult: int = 4
    max_seq_len: int = 128
    vocab_size: int = vocab_size
    pe_type: str = "sinusoidal"  # "none" | "sinusoidal" | "learned"

def make_causal_mask(seq_len: int) -> torch.Tensor:
    """
    Returns a boolean mask of shape [1, 1, seq_len, seq_len]
    where True means "allowed" and False means "masked out".
    """
    mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))
    return mask.unsqueeze(0).unsqueeze(0)  # [1,1,S,S]

def batched_random_chunks(ids: torch.Tensor, block_size: int, batch_size: int):
    """
    Simple CPU-friendly batcher for quick tests/training.
    """
    ix = torch.randint(0, len(ids) - block_size - 1, (batch_size,))
    x = torch.stack([ids[i:i+block_size] for i in ix])
    y = torch.stack([ids[i+1:i+block_size+1] for i in ix])
    return x.to(DEVICE), y.to(DEVICE)


## Task A — Sinusoidal Positional Encoding

### What you’re building
You will return a tensor `pe` of shape `[max_len, d_model]`. For each position `pos = 0..max_len-1` and each pair of feature columns `(2i, 2i+1)`:

- `pe[pos, 2*i]   = sin( pos * ω_i )`
- `pe[pos, 2*i+1] = cos( pos * ω_i )`
- where `ω_i = 10000 ** ( -2*i / d_model )`

**Intuition:** Even columns use sine, odd columns use cosine, with frequencies that shrink as `i` grows. This gives a smooth, multi-frequency code for positions without learning extra parameters.

**Example:** Let `d_model = 8` (so 4 sine/cosine pairs) and `pos = 5`. At this position, we have a position embedding vector. 
Frequencies use ω_i = 10000^(-2i / d_model):

- ω_0 = 10000^{-0/8} = 1
- ω_1 = 10000^{-2/8} = 10000^{-0.25} = 0.1
- ω_2 = 10000^{-4/8} = 10000^{-0.5} = 0.01
- ω_3 = 10000^{-6/8} = 10000^{-0.75} = 0.001

Therefore, the postiion vector is 

[
    sin(5 * 1), cos(5 * 1), sin(5 * 0.1), cos(5 * 0.1), sin(5 * 0.01), cos(5 * 0.01),  sin(5 * 0.001), cos(5 * 0.001)
]

, which is

[
-0.958924, 0.283662,
0.479426, 0.877583,
0.049979, 0.998750,
0.005000, 0.999988
]


### PyTorch tips:

Slicing notation you’ll use
- `pe[:, 0::2]` → all rows, **even** columns (0, 2, 4, …)  
- `pe[:, 1::2]` → all rows, **odd** columns (1, 3, 5, …)

This lets you fill all even columns with `sin(...)` and all odd columns with `cos(...)` in one shot.

`torch.arange(start, end, step)` gives you a sequence of position indices.

In [11]:
# ===== TODO (Student) =====
# Implement sinusoidal_positional_encoding(max_len, d_model)
# Return a tensor of shape [max_len, d_model] (float32).
# pe[pos, 0::2] = sin(pos / 10000^(2i/d_model))
# pe[pos, 1::2] = cos(pos / 10000^(2i/d_model))

def sinusoidal_positional_encoding(max_len: int, d_model: int) -> torch.Tensor:
    """
    Build sinusoidal positional encodings (no batch).
    Args:
        max_len: maximum sequence length
        d_model: model hidden size (must be >=1)
    Returns:
        pe: [max_len, d_model] float tensor
    Example:
        >>> pe = sinusoidal_positional_encoding(3, 4)
        >>> pe.shape
        torch.Size([3, 4])
        >>> torch.allclose(pe[0, 1::2], torch.zeros(2))  # sin(0)=0
        True
    """
    pe = torch.zeros(max_len, d_model, dtype=torch.float32)
    # TODO: Implement sinusoidal positional encoding
    # - Use torch.arange to create position indices
    # - Compute the div_term as in the formula
    # - Apply the sin/cos functions to the position indices
    # - Fill the even and odd elements with the computed values
    # - Return the positional encoding tensor
    tensor1 = torch.arange(0,max_len-1,2)
    tensor2 = torch.arange(1,max_len-1,2)

    pe = torch.arange(0, max_len-1,2)

    print(tensor1)
    print(tensor2)
    for i in range(max_len):

        ##Compute the div_term as in the formula
        ω_i = 10000 ** ( -2 * i / d_model )
        ## what is positon
        pe[:, 0::2] = math.sin(pos / 10000^( 2* i/d_model))
        pe[:, 1::2] = math.cos(pos / 10000^(2 * i/d_model))


    return pe

In [12]:
# Quick, deterministic tests
_pe = sinusoidal_positional_encoding(16, 8)
assert _pe.shape == (16, 8), "Shape must be [L, D]"
# sin(0)=0 on even dims; cos(0)=1 on odd dims
assert torch.allclose(_pe[0, 0::2], torch.zeros(4)), "sin(0) should be 0 on even dims"
assert torch.allclose(_pe[0, 1::2], torch.ones(4)), "cos(0) should be 1 on odd dims"
# Smoothness: position 1 should be closer to 2 than to 8
d12 = torch.dist(_pe[1], _pe[2])
d18 = torch.dist(_pe[1], _pe[8])
assert d12 < d18, "Nearby positions should be more similar than far-away"
print("✅ Task A tests passed")


tensor([ 0,  2,  4,  6,  8, 10, 12, 14])
tensor([ 1,  3,  5,  7,  9, 11, 13])


AssertionError: cos(0) should be 1 on odd dims

In [None]:
class TokenAndPositionEmbedding(nn.Module):
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.tok = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.cfg = cfg
        if cfg.pe_type == "learned":
            self.pos = nn.Embedding(cfg.max_seq_len, cfg.d_model)
        elif cfg.pe_type == "sinusoidal":
            pe = sinusoidal_positional_encoding(cfg.max_seq_len, cfg.d_model)  # student fn
            self.register_buffer("pe_table", pe)  # [L,D]
        else:
            self.pos = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [B, S] token ids
        returns: [B, S, D]
        """
        B, S = x.shape
        tok = self.tok(x)
        if self.cfg.pe_type == "none":
            return tok
        elif self.cfg.pe_type == "learned":
            pos_ids = torch.arange(S, device=x.device).unsqueeze(0)  # [1,S]
            return tok + self.pos(pos_ids)
        else:  # sinusoidal
            return tok + self.pe_table[:S].unsqueeze(0)


## Task B — Implement `scaled_dot_product_attention(q, k, v, mask=None)`

**Notations**
- **B** = batch size  
- **H** = number of heads  
- **S** = sequence length (number of tokens per example)  
- **D_head** = head dimension (usually `D_model / H`)

Per head, we work with:
- **q, k, v**: `[B, H, S, D_head]` They are the Query, Key, Value matrices we are trying to learn
- **scores**: `[B, H, S, S]`: It is the per-head **similarity matrix** between every query position *i* and every key position *j*. To get the individual score:
    - Take the **query** vector for token `i`: `q_vec = q[b, h, i, :]`  (length = `D_head`)
    - Take the **key** vector for token `j`: `k_vec = k[b, h, j, :]`   (length = `D_head`)
    - Compute their **dot product**, then **scale** by `1 / sqrt(D_head)`: scores[b,h,i,j] = (q_vec . k_vec) / sqrt(D_head)
    - In tensor form: scores = ( q @ k.transpose(-2, -1) ) / sqrt(D_head)

- **attn_probs**: `[B, H, S, S]` It is the attention weight that represents how much token i should "look at" token j when building its new representation. From the attention scores we just computed:
    - (Optional) apply a mask: set forbidden entries to a large negative value (e.g., `-1e9`) so they become ~0 after softmax.
    - Apply **softmax over the last dim** (keys `j`) to turn scores into probabilities (so that the token i dimension forms a probability distribution over all possible tokens)

- **output**: `[B, H, S, D_head]` It is the new representation. For example out[b, h, i, :] is the vector representation of token i after attending to all tokens (including itself). 
    - `attn_probs[b,h,i,j]` is how much token `i` attends to token `j`.  
    - `v[b,h,j,:]` is the value vector at position `j`. 
    - The output is the weighted average of values across all `j`: out[b, h, i, :] = sum(attn_probs[b, h, i, j] * v[b, h, j, :] for j=0,1,...,S)

**What to write.**
1) Compute `scores = (q @ kᵀ) / √dₖ` → shape `[B, H, S, S]`.  
2) If a boolean `mask` is provided (`True = keep`, `False = block`), set masked positions to a large negative number (e.g., `-1e9`) **before** softmax.  
3) Softmax over the last dimension → `[B, H, S, S]`.  
4) Multiply by `v` to get outputs → `[B, H, S, D]`. Return both `(out, attn_weights)`.

### PyTorch Tips:
`tensor.masked_fill(mask, value)` takes a boolean tensor (same size as the tensor) fills elements of the tensor with value where mask is True. 

To assign a large negative number for the block elements (when mask is False), you can do `scores = scores.masked_fill(~mask, -1e9)`

**After you code:** run the tiny deterministic test. Look for **“✅ Tiny SDPA test passed”**.


In [None]:
# ===== TODO (Student) =====
# Implement scaled_dot_product_attention with an optional boolean mask.
# Tensor shapes (per head):
#   q, k, v: [B, H, S, D_head]  -- we compute attention over the last two dims
#   mask:   [B, 1, S, S]    -- True means "keep", False means "mask out"
# Return:
#   out:  [B, H, S, D_head] 
#   attn: [B, H, S, S]  (softmax weights)

def scaled_dot_product_attention(
    q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    >>> B,H,S,D = 1,1,2,2
    >>> q = torch.tensor([[[[10., 0.],[0., 10.]]]])  # sharp self-similarity
    >>> k = q.clone()
    >>> v = torch.tensor([[[[1., 2.],[3., 4.]]]])
    >>> out, attn = scaled_dot_product_attention(q, k, v, None)
    >>> attn.shape, out.shape
    (torch.Size([1, 1, 2, 2]), torch.Size([1, 1, 2, 2]))
    """
     # q,k,v: [B,H,S,D]
    d_head = q.size(-1)
    scores = ...  #TODO: compute attention scores from q, k, size should be [B,H,S,S]
    if mask is not None:
        # mask: True = keep, False = mask out
        # TODO: replace the forbidden positions with a large negative number (-1e9)

        pass
    
    scores = scores - scores.amax(dim=-1, keepdim=True) # to make scores numerically stable
    attn = torch.softmax(scores, dim=-1)  # obtain the attention weights, shape [B,H,S,S]
    out = torch.matmul(attn, v)           # obtain the new representation, shape [B,H,S,D]
    return out, attn

In [None]:
def tiny_test_sdpa():
    # Shapes: B=1, H=1, S=2, D=1  → super small and inspectable
    B,H,S,D = 1,1,2,1

    # Choose q,k = zeros so QK^T/√d_k = [[0,0],[0,0]]
    # -> softmax rows become [0.5, 0.5] without a mask (by symmetry).
    q = torch.zeros(B, H, S, D, dtype=torch.float32)
    k = torch.zeros(B, H, S, D, dtype=torch.float32)

    # Pick easy values for v so the weighted sums are trivial:
    # token0 value = 2, token1 value = 0
    v = torch.tensor([[[[2.0], [0.0]]]], dtype=torch.float32)  # [1,1,2,1]

    # --- Case 1: No mask ---
    out, attn = scaled_dot_product_attention(q, k, v, mask=None)

    # Expected: every row = [0.5, 0.5] (since scores are all zeros)
    exp_attn = torch.tensor([[[[0.5, 0.5],
                               [0.5, 0.5]]]], dtype=torch.float32)

    # Output = attn @ v = 0.5*2 + 0.5*0 = 1 for each row
    exp_out  = torch.tensor([[[[1.0],
                               [1.0]]]], dtype=torch.float32)

    assert attn.shape == (B,H,S,S) and out.shape == (B,H,S,D)
    assert torch.allclose(attn, exp_attn, atol=1e-7), f"Unmasked attn mismatch:\n{attn}\n!=\n{exp_attn}"
    assert torch.allclose(out,  exp_out,  atol=1e-7), f"Unmasked out mismatch:\n{out}\n!=\n{exp_out}"

    # --- Case 2: With causal mask (True=keep, False=mask) ---
    # Row 0 may only attend to token 0; Row 1 may attend to {0,1}.
    mask = torch.tensor([[[[ True, False],
                           [ True,  True]]]], dtype=torch.bool)  # [1,1,2,2]

    out_m, attn_m = scaled_dot_product_attention(q, k, v, mask=mask)

    # Expected attention:
    #   row0: [1, 0]  (only token 0 allowed)
    #   row1: [0.5, 0.5]  (both allowed, equal scores)
    exp_attn_m = torch.tensor([[[[1.0, 0.0],
                                 [0.5, 0.5]]]], dtype=torch.float32)
    # Expected outputs:
    #   row0: 1*2 + 0*0 = 2
    #   row1: 0.5*2 + 0.5*0 = 1
    exp_out_m  = torch.tensor([[[[2.0],
                                 [1.0]]]], dtype=torch.float32)

    assert attn_m.shape == (B,H,S,S) and out_m.shape == (B,H,S,D)
    assert torch.allclose(attn_m, exp_attn_m, atol=1e-7), f"Masked attn mismatch:\n{attn_m}\n!=\n{exp_attn_m}"
    assert torch.allclose(out_m,  exp_out_m,  atol=1e-7), f"Masked out mismatch:\n{out_m}\n!=\n{exp_out_m}"

    print("✅ Tiny SDPA test passed")

# Run it:
tiny_test_sdpa()


## Task C — Split/Combine Heads 

### Why do we split into **multiple heads**?
**Multi-Head Self-Attention (MHA)** computes **several attention maps in parallel**. Each “head”:
- looks at the sequence through a **different learned projection** (its own k,q,v matrices),
- produces its **own attention distribution** (its own softmax over positions),
- returns a **head-specific context vector**.

This lets the model **attend to different things at the same time** (e.g., one head tracks next-token agreement, another tracks long-range rhymes). With a **single head**, every token has only **one** attention distribution, so it must compromise between competing patterns. With MHA, each head computes its **own** scaled dot-product attention and softmax. Finally we **concatenate** all head outputs and mix them with a single output projection.


To do that, we need two shape transforms:
- `split_heads`: **[B, S, D_model] → [B, H, S, D_head]** (prepare per-head tensors)
- `combine_heads`: **[B, H, S, D_head] → [B, S, D_model]** (stitch heads back together)

This is *only* a rearrangement of the same numbers—no math, just shape moves—so attention can run independently per head.

### 2) PyTorch ops you’ll use (and why)
- **`view`**: change tensor dimensions **without** changing data.  
  - `view` is fast but **requires contiguous memory**; it may error if the tensor is not contiguous.
- **`transpose(dim0, dim1)`**: swap **two** dimensions (no data copy; just a view).  
  For MHA we want to go from `[B, S, H, D_head]` to `[B, H, S, D_head]` (swap axes 1 and 2) we can use `transpose(1, 2)`
- **`contiguous()`**: In PyTorch, A tensor is contiguous if its elements are stored in one uninterrupted block of memory. Some operations require the input tensor to be contiguous (e.g., view()). tensor.contiguous() returns a contiguous copy so that a subsequent `view(...)` is valid.  
  After `transpose`, the tensor is often **not** contiguous; call `.contiguous()` before `view`.


### 3) Shape recipes (step-by-step)

- `split_heads(x: [B, S, D_model]) -> [B, H, S, D_head]`
1. Check that `D_model % H == 0`, and let `D_head = D_model // H`.
2. First, reshape the last dim into `(H, D_head)`:  
   `x = x.view(B, S, H, D_head)` 
3. Bring `H` in front of `S` so heads sit right after batch:  
   `x = x.transpose(1, 2)` → shape `[B, H, S, D_head]`
4. Ensure contiguity: `x = x.contiguous()` (safe to add here).

- `combine_heads(x: [B, H, S, D_head]) -> [B, S, D_model]`
1. Swap back the `H` and `S` dims:  
   `x = x.transpose(1, 2)` → `[B, S, H, D_head]`
2. Make it contiguous (required if you’ll `view`):  
   `x = x.contiguous()`
3. Flatten the last two dims:  
   `x = x.view(B, S, H * D_head)` (or `reshape`) → `[B, S, D_model]`

**After you code:** run the test — you should see **“✅ Task C tests passed”**.


In [None]:
# ===== TODO (Student) =====
# Implement split_heads and combine_heads round-trip.

def split_heads(x: torch.Tensor, num_heads: int) -> torch.Tensor:
    """
    x: [B, S, D_model] -> [B, H, S, D_head]
    """
    # TODO: Implement split_heads
    # - Check that D_model % H == 0
    # - Let D_head = D_model // H
    # - Reshape the last dim into (H, D_head)
    # - Bring H in front of S
    # - Ensure contiguity
    # - Return the transposed tensor
    pass

def combine_heads(x: torch.Tensor) -> torch.Tensor:
    """
    x: [B, H, S, D_head] -> [B, S, D_model]
    """
    # TODO: Implement combine_heads
    # - Swap back the H and S dims
    # - Make it contiguous
    # - Flatten the last two dims
    # - Return the reshaped tensor
    pass


In [None]:
B,S,D_model,H = 2,4,8,2
x = torch.arange(B*S*D_model, dtype=torch.float32).reshape(B,S,D_model)
xh = split_heads(x, H)
xc = combine_heads(xh)
assert xh.shape == (B,H,S,D_model//H), "split_heads wrong shape"
assert xc.shape == (B,S,D_model), "combine_heads wrong shape"
assert torch.allclose(x, xc), "combine(split(x)) must equal x"
print("✅ Task C tests passed")


## Task D — Implement `MultiHeadSelfAttention.forward(x, mask)`

**Steps.**
1) Linear projections: `Q = Wq x`, `K = Wk x`, `V = Wv x`.  
2) `split_heads` for each.  
3) Call your `scaled_dot_product_attention(Q, K, V, mask)`.  
4) `combine_heads` and project with `Wo`. Return `(y, attn_map)`.

**Sanity check.** With `n_heads=1` and all projection matrices set to identity, MHA must match the single-head SDPA on the same input (we test this for you). Expect **“✅ Task D tests passed”**.


In [None]:
# ===== TODO (Student) =====
# Implement the forward pass of MHA:
# 1) Project X -> Q,K,V
# 2) split to heads
# 3) scaled_dot_product_attention (with mask)
# 4) combine heads -> output projection

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model, self.n_heads = d_model, n_heads
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.o_proj = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        x: [B,S,D], mask: [B,1,S,S] or None
        returns: y: [B,S,D], attn: [B,H,S,S]
        """
        B,S,D = x.shape
        # TODO: Implement the forward pass of MHA:
        # 1) Project X -> Q,K,V
        # 2) split to heads
        # 3) scaled_dot_product_attention (with mask)
        # 4) combine heads and output projection
        # 5) return output and attention weights
        pass
        


In [None]:
# Sanity test with 1 head and identity projections → attention reduces to single-head SDPA
B,S,D = 1,3,4
x = torch.randn(B,S,D)
mha = MultiHeadSelfAttention(d_model=D, n_heads=1)

# Set projections to identity
with torch.no_grad():
    mha.q_proj.weight.copy_(torch.eye(D))
    mha.k_proj.weight.copy_(torch.eye(D))
    mha.v_proj.weight.copy_(torch.eye(D))
    mha.o_proj.weight.copy_(torch.eye(D))

mask = make_causal_mask(S)
y, attn = mha(x, mask=mask)
assert y.shape == (B,S,D) and attn.shape == (B,1,S,S)
# Compare with manual SDPA on the same x
q = x.unsqueeze(1)  # [B,1,S,D]
k = x.unsqueeze(1)
v = x.unsqueeze(1)
manual, manual_attn = scaled_dot_product_attention(q, k, v, mask=mask)
assert torch.allclose(y, manual.squeeze(1), atol=1e-5), "MHA(1 head) must match SDPA"
print("✅ Task D tests passed")


## Quick Run / Inspect

We have provided the building blocks for our transformer LM. You can just run the following code and see the logits, loss, and attention maps.

We keep defaults very small so it runs in minutes on CPU:
- Model: `d_model=128`, `n_heads=4`, `n_layers=2`, GELU FFN, no dropout.
- Batch/bptt sizes are tiny in the toy trainer.

You can train for a few hundred steps. Then try the `generate(...)` helper:
- **Temperature** < 1.0 → safer/peakier choices; > 1.0 → more diverse.
- **Top-k** keeps only the k most likely tokens.
- **Top-p** (nucleus) keeps the smallest set whose cumulative probability ≥ p.

If you skip training, you can still run `generate(...)` on the untrained model to see random-ish outputs.


In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, ff_mult: int = 4, p_drop: float = 0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.mha = MultiHeadSelfAttention(d_model, n_heads)
        self.ff = nn.Sequential(
            nn.Linear(d_model, ff_mult * d_model),
            nn.GELU(),
            nn.Linear(ff_mult * d_model, d_model),
        )
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(p_drop)

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor]):
        attn_out, attn = self.mha(self.ln1(x), mask=mask)
        x = x + self.dropout(attn_out)
        ff_out = self.ff(self.ln2(x))
        x = x + self.dropout(ff_out)
        return x, attn

class TinyTransformerLM(nn.Module):
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.cfg = cfg
        self.embed = TokenAndPositionEmbedding(cfg)
        self.blocks = nn.ModuleList([
            TransformerBlock(cfg.d_model, cfg.n_heads, cfg.ff_mult) for _ in range(cfg.n_layers)
        ])
        self.ln_f = nn.LayerNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)

    def forward(self, x: torch.Tensor, targets: Optional[torch.Tensor] = None):
        # x: [B,S]
        B,S = x.shape
        h = self.embed(x)  # [B,S,D]
        mask = make_causal_mask(S).to(x.device)
        attn_maps = []
        for blk in self.blocks:
            h, attn = blk(h, mask)
            attn_maps.append(attn)  # list of [B,H,S,S]
        h = self.ln_f(h)
        logits = self.lm_head(h)  # [B,S,V]
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss, attn_maps


In [None]:
cfg = ModelConfig(d_model=128, n_heads=4, n_layers=2, max_seq_len=128, pe_type="sinusoidal")
model = TinyTransformerLM(cfg).to(DEVICE)

# One tiny batch
x, y = batched_random_chunks(train_ids, block_size=64, batch_size=2)
logits, loss, attn_maps = model(x, y)
print("logits:", logits.shape, "| loss:", float(loss))
print("attn maps:", [a.shape for a in attn_maps])


## Task E — Train & Sample from the Tiny Transformer LM

**What you do:** We provide the training loop. Just run the training cell, then use the `generate(...)` helper to sample text. The training time should be within 5 minutes.

Expected outcomes:
- Val loss curve: decreasing quickly at first, then flattening. If val loss rises while train falls, stop earlier. The training and validation losses should be below ~1.9.
- Sampled texts: the samples should have "Shakespeare-ish structure" (uppercase speaker tags, dialogues, line breaks) and plausible (fake) words. It's normal that the sentences are not semantically consistent. 

In [None]:
def train_tiny(model: nn.Module, steps=300, batch_size=64, block_size=128, lr=5e-3):
    model.train()
    opt = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.95))
    for t in range(steps):
        x, y = batched_random_chunks(train_ids, block_size, batch_size)
        logits, loss, _ = model(x, y)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()
        if (t+1) % 50 == 0:
            with torch.no_grad():
                vx, vy = batched_random_chunks(val_ids, block_size, batch_size)
                _, vloss, _ = model(vx, vy)
            print(f"step {t+1:4d} | train {loss.item():.3f} | val {vloss.item():.3f}")
    return model

# run it (comment out to skip in CPU-only tight time)
model = train_tiny(model, steps=500)

In [None]:
@torch.no_grad()
def generate(model, prompt: str, max_new_tokens=200, temperature=1.0, top_k=None, top_p=None):
    model.eval()
    ids = encode(prompt).unsqueeze(0).to(DEVICE)  # [1,S]
    for _ in range(max_new_tokens):
        ids_cond = ids[:, -model.cfg.max_seq_len:]
        logits, _, _ = model(ids_cond, None)
        logits = logits[:, -1, :] / max(1e-6, temperature)
        probs = F.softmax(logits, dim=-1)

        if top_k is not None:
            v, _ = torch.topk(probs, k=top_k, dim=-1)
            thresh = v[:, -1].unsqueeze(-1)
            probs = torch.where(probs >= thresh, probs, torch.zeros_like(probs))
            probs = probs / probs.sum(dim=-1, keepdim=True)

        if top_p is not None:
            sorted_probs, sorted_idx = torch.sort(probs, descending=True)
            cum = torch.cumsum(sorted_probs, dim=-1)
            keep = cum <= top_p
            keep[..., 0] = True
            filtered = torch.where(keep, sorted_probs, torch.zeros_like(sorted_probs))
            filtered = filtered / filtered.sum(dim=-1, keepdim=True)
            # sample from filtered
            next_id = torch.multinomial(filtered, num_samples=1)
            next_token = sorted_idx.gather(-1, next_id)
        else:
            next_token = torch.multinomial(probs, num_samples=1)

        ids = torch.cat([ids, next_token], dim=1)
    return decode(ids[0].cpu())

# Try different knobs after (optionally) training:
print(generate(model, "ROMEO:\n", max_new_tokens=200, temperature=0.8, top_k=50))
print("--------------------------------")
print(generate(model, "JULIET:\n", max_new_tokens=200, temperature=1.0, top_p=0.9))


## Mini Reflection

- Switch `pe_type` among `"none"` and `"sinusoidal"`. After a short run, which produced the best validation loss and most coherent samples?
- Try a couple of `temperature` and `top_k`/`top_p` settings in `generate`. What happens to diversity vs repetition?

Write 2–4 bullet points on what you observe (1-2 sentences each)


In [None]:
# Write down your observations here.