In [None]:
import numpy as np
import pandas as pd
import tiktoken as tk
import torch
import torch.nn as nn
import torch.nn.functional as F

## Input/Output, for MHANew-

- Project Q, K, V all at once via self.W_q, self.W_k, self.W_v.
- Reshape to (B, num_heads, L, d_head) to separate heads.
- Compute scaled dot-product attention in a batched fashion: scores = QK^T / sqrt(d_head).-
- Optional mask applied to prevent attending to certain positions (useful for autoregressive tasks).
- Multiply attention weights by V, then merge heads back: (B, L, d_model).
- Final linear W_o mixes information across heads.

## Input/Output for MHAOld-

- Implements multi-head attention **naively by creating independent SingleHeadAttention objects** for each head.  
- `num_heads` separate SingleHeadAttention instances, each mapping `d_model -> d_head`.  
- Forward pass:
  - Loops through each head, runs forward separately: **slow, non-batched**.  
  - Concatenates outputs along feature dimension → shape `(B, L, d_model)`.  
  - Final linear `W_o` mixes all heads back together.  


In [None]:
class MultiHeadAttentionOld(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        #spawn a bunch of independent heads
        self.heads = nn.ModuleList([
            SingleHeadAttention(d_model, self.d_head)
            for _ in range(num_heads)
        ])

        self.W_o = nn.Linear(d_model, d_model)

    
    
    def forward(self, q, k, v, mask=None):
        ## run each head separately (slow as all hell so please dont do this one)
        out_per_head = [head(q, k, v, mask) for head in self.heads]

        concat = torch.cat(out_per_head, dim=-1)

        return self.W_o(concat)

In [None]:
class MultiHeadAttentionNew(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        # Single fat-ass projections
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        # Final output mixer
        self.W_o = nn.Linear(d_model, d_model)

    
    
    def forward(self, q, k = None, v = None, mask=None):
        B, L, O = q.shape

        if k is None:
            k = q
        if v is None:
            v = q

        
        Q = self.W_q(q)  # (B, L, d_model)
        K = self.W_k(k)
        V = self.W_v(v)
        

        Q = Q.view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        K = K.view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        V = V.view(B, L, self.num_heads, self.d_head).transpose(1, 2)

        
        # scaled dot-product attention (batched!!!!!)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_head ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn = torch.softmax(scores, dim=-1)

        out = torch.matmul(attn, V)  # (B, num_heads, L, d_head)

        # back to (B, L, d_model)
        out = out.transpose(1, 2).contiguous().view(B, L, -1)
        return self.W_o(out)