# Implement Attention from Scratch
### Problem Statement
Multi-Head Attention (MHA) is the bread-and-butter of the Transformer architecture. It enables the model to **jointly attend** to information from different representation subspaces at different positions.

Your goal is to implement MHA from scratch using PyTorch, simulating exactly what `torch.nn.MultiheadAttention` does ‚Äî projecting Q, K, V for each head, computing attention weights, applying them to V, and concatenating the outputs across all heads.

---

### Requirements

1. **Linear Projections for Q, K, V**
   - Project input `q`, `k`, `v` into a total of `d_model` dimensions.
   - Split them into `num_heads` of `d_head = d_model // num_heads` each.

2. **Scaled Dot-Product Attention per Head**
   - Compute attention scores:  
     `scores = Q @ K·µÄ / sqrt(d_head)`
   - Apply an optional `mask` before softmax.
   - Use the scores to weight `V`.

3. **Combine the Heads**
   - Concatenate the outputs of all heads.
   - Apply a final linear projection to restore the shape: `(batch_size, seq_len, d_model)`.

4. **Validate Against PyTorch‚Äôs Reference**
   - Test your output against `torch.nn.MultiheadAttention` using the same input tensors.
   - Check for numerical closeness using `torch.allclose()`.

---

### Constraints

- ‚úÖ Use only PyTorch operations.
- ‚úÖ Make sure all tensors are reshaped properly when splitting and combining heads.
- ‚úÖ Support optional masking.
- ‚úÖ Must match `torch.nn.MultiheadAttention` output when heads and shape are aligned.

---

<details>
  <summary>üí° Hint</summary>

  - Use `.view()` and `.transpose()` to shape Q, K, V to `(batch_size, num_heads, seq_len, d_head)`.
  - Softmax should be applied over the **last dimension** (attention scores across sequence).
  - Use `.contiguous().view()` to flatten the multi-head outputs back into `(batch_size, seq_len, d_model)`.
  - Match PyTorch‚Äôs behavior using the same projections and batch-first format.

</details>

---

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
# Synthetic data
torch.manual_seed(42)
batch_size = 3
seq_len = 4
d_model = 8
num_heads = 2

q = torch.rand(batch_size, seq_len, d_model)
k = torch.rand(batch_size, seq_len, d_model)
v = torch.rand(batch_size, seq_len, d_model)
print(q.shape)

device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"

torch.Size([3, 4, 8])


In [None]:

# def multi_head_attention(q, k, v, num_heads, d_model, mask=None):
#     """
#     Implements multi-head attention.
    
#     Args:
#         q (Tensor): Query tensor of shape (batch_size, seq_len, d_model)
#         k (Tensor): Key tensor of shape (batch_size, seq_len, d_model)
#         v (Tensor): Value tensor of shape (batch_size, seq_len, d_model)
#         num_heads (int): Number of attention heads
#         d_model (int): Total embedding dimension
#         mask (Tensor, optional): Masking tensor for attention
        
#     Returns:
#         Tensor: Multi-head attention output of shape (batch_size, seq_len, d_model)
#     """
#     batch_size, seq_len, d_q = q.shape

#     # original paper doesn't mention presence of bias term in the projection layer
#     Wq = torch.nn.Linear(in_features=q.shape[-1], out_features=d_model, bias=False)
#     Wk = torch.nn.Linear(in_features=k.shape[-1], out_features=d_model, bias=False)
#     Wv = torch.nn.Linear(in_features=v.shape[-1], out_features=d_model, bias=False)
#     Wc = torch.nn.Linear(in_features=d_model, out_features=d_model, bias=False)

#     # # projections for Q, K, V metrics
#     # q = q @ Wq # (B, L_k, d_model)
#     # k = k @ Wk # (B, L_k, d_model)
#     # v = q @ Wv # (B, L_v, d_model)

#     # projections for Q, K, V metrics
#     q = Wq(q) # (B, L_k, d_model)
#     k = Wk(k) # (B, L_k, d_model)
#     v = Wv(v) # (B, L_v, d_model)

#     d_mha = d_model // num_heads
#     print(f"d_model={d_model} | num_heads={num_heads} | d_mha={d_mha}")

#     assert d_mha*num_heads == d_model, "incompatible num_head and d_model conbination"

#     # split Q, K, V into multiple heads
#     q = q.view(batch_size, seq_len, num_heads, d_mha) # (B, L_k, H, d_mha)
#     k = k.view(batch_size, seq_len, num_heads, d_mha) # (B, L_k, H, d_mha)
#     v = v.view(batch_size, seq_len, num_heads, d_mha) # (B, L_k, H, d_mha)

#     # reshape to move number of heads to 2nd axis
#     q = q.transpose(1,2) # (B, H, L_k, d_mha)
#     k = k.transpose(1,2) # (B, H, L_k, d_mha)
#     v = v.transpose(1,2) # (B, H, L_k, d_mha)

#     # apply attention
#     attention_per_head = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # (B, H, L_k, d_mha)
#     attention_per_head = attention_per_head.permute(0,2,1,3) # (B, L_k, H, d_mha)
#     print(f"attention_per_head.shape = {attention_per_head.shape}")
#     print(f"(batch_size, seq_len, d_model) = {(batch_size, seq_len, d_model)}")
#     attention_concatenated = attention_per_head.reshape(batch_size, seq_len, d_model) # (B, L_k, d_model)
    
#     # mha = attention_concatenated @ Wc # (B, L_k, d_model)
#     mha = Wc(attention_concatenated) # (B, L_k, d_model)
#     return mha

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CustomMultiHeadAttention(torch.nn.Module):
    def __init__(self, num_heads, d_model, bias):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.bias = bias
        self.d_mha = self.d_model // self.num_heads
        assert self.d_mha*self.num_heads == self.d_model, "incompatible num_head and d_model conbination"
        print(f"d_model={self.d_model} | num_heads={self.num_heads} | d_mha={self.d_mha}")

        self.Wq = torch.nn.Linear(in_features=self.d_model, out_features=self.d_model, bias=self.bias)
        self.Wk = torch.nn.Linear(in_features=self.d_model, out_features=self.d_model, bias=self.bias)
        self.Wv = torch.nn.Linear(in_features=self.d_model, out_features=self.d_model, bias=self.bias)
        self.Wc = torch.nn.Linear(in_features=self.d_model, out_features=self.d_model, bias=self.bias)
    
    def forward(self, q, k, v, mask=None):
        batch_size, seq_len, d_q = q.shape

        # projections for Q, K, V metrics
        q = self.Wq(q) # (B, L_k, d_model)
        k = self.Wk(k) # (B, L_k, d_model)
        v = self.Wv(v) # (B, L_v, d_model)


        # split Q, K, V into multiple heads
        q = q.view(batch_size, q.shape[-2], self.num_heads, self.d_mha) # (B, L_k, H, d_mha)
        k = k.view(batch_size, k.shape[-2], self.num_heads, self.d_mha) # (B, L_k, H, d_mha)
        v = v.view(batch_size, v.shape[-2], self.num_heads, self.d_mha) # (B, L_k, H, d_mha)

        # reshape to move number of heads to 2nd axis
        q = q.transpose(1,2) # (B, H, L_k, d_mha)
        k = k.transpose(1,2) # (B, H, L_k, d_mha)
        v = v.transpose(1,2) # (B, H, L_k, d_mha)

        # apply attention
        attention_per_head = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # (B, H, L_k, d_mha)
        attention_per_head = attention_per_head.permute(0,2,1,3) # (B, L_k, H, d_mha)
        print(f"attention_per_head.shape = {attention_per_head.shape}")
        print(f"(batch_size, seq_len, d_model) = {(batch_size, seq_len, self.d_model)}")
        attention_concatenated = attention_per_head.reshape(batch_size, seq_len, self.d_model) # (B, L_k, d_model)
        
        # mha = attention_concatenated @ Wc # (B, L_k, d_model)
        mha = self.Wc(attention_concatenated) # (B, L_k, d_model)
        return mha

In [7]:
# Testing on data & compare
output_custom = multi_head_attention(q, k, v, num_heads, d_model)
# print(output_custom)

torch.manual_seed(2025)
multihead_attn = torch.nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, bias=False, batch_first=True)
output, _ = multihead_attn(q, k, v)
print(output)

assert torch.allclose(output_custom, output, atol=1e-08, rtol=1e-05) # Check if they are close enough.

NameError: name 'multi_head_attention' is not defined

In [8]:
# # # --- 2. The Setup ---
# # d_model = 512
# # num_heads = 8
# # seq_len = 10
# # batch_size = 2

# # Synthetic data
# torch.manual_seed(42)
# batch_size = 3
# seq_len = 4
# d_model = 8
# num_heads = 2

# # Input Data (Random)
# x = torch.randn(batch_size, seq_len, d_model)

# --- 3. Instantiate Both Models ---
# PyTorch Native Implementation
# We use batch_first=True to match your shape convention (Batch, Seq, Feature)
pytorch_mha = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, batch_first=True, bias=True)

# Your Custom Implementation
custom_mha = CustomMultiHeadAttention(d_model=d_model, num_heads=num_heads, bias=True)

# --- 4. THE MAGIC STEP: Copy Weights ---
# PyTorch stores Q, K, V weights in one giant matrix called `in_proj_weight`
# shape: [3 * d_model, d_model]. We chunk it into 3 parts.
with torch.no_grad():
    # 1. Slice the weights
    W_q_check, W_k_check, W_v_check = pytorch_mha.in_proj_weight.chunk(3, dim=0)
    b_q_check, b_k_check, b_v_check = pytorch_mha.in_proj_bias.chunk(3, dim=0)

    # 2. Copy to your custom layers
    custom_mha.Wq.weight.copy_(W_q_check)
    custom_mha.Wq.bias.copy_(b_q_check)
    
    custom_mha.Wk.weight.copy_(W_k_check)
    custom_mha.Wk.bias.copy_(b_k_check)
    
    custom_mha.Wv.weight.copy_(W_v_check)
    custom_mha.Wv.bias.copy_(b_v_check)
    
    # 3. Copy the Output Projection weights (PyTorch calls it `out_proj`)
    custom_mha.Wc.weight.copy_(pytorch_mha.out_proj.weight)
    custom_mha.Wc.bias.copy_(pytorch_mha.out_proj.bias)

# --- 5. Run Comparison ---
pytorch_mha.eval() # Disable dropout
custom_mha.eval()  # Disable dropout

# PyTorch Forward
# output is a tuple (attn_output, attn_output_weights), we only need the first one
pytorch_out, _ = pytorch_mha(q,k,v, need_weights=False)

# Custom Forward
custom_out = custom_mha(q,k,v)

# --- 6. Verify ---
# We check if the difference is negligible (e.g., less than 1e-5)
diff = (pytorch_out - custom_out).abs().max()
print(f"Maximum difference: {diff.item()}")

if torch.allclose(pytorch_out, custom_out, atol=1e-5):
    print("‚úÖ SUCCESS: Your implementation matches PyTorch exactly!")
else:
    print("‚ùå FAILURE: The outputs diverge.")

d_model=8 | num_heads=2 | d_mha=4
attention_per_head.shape = torch.Size([3, 4, 2, 4])
(batch_size, seq_len, d_model) = (3, 4, 8)
Maximum difference: 0.0
‚úÖ SUCCESS: Your implementation matches PyTorch exactly!
