# Implement Attention from Scratch
### Problem Statement
Multi-Head Attention (MHA) is the bread-and-butter of the Transformer architecture. It enables the model to **jointly attend** to information from different representation subspaces at different positions.

Your goal is to implement MHA from scratch using PyTorch, simulating exactly what `torch.nn.MultiheadAttention` does — projecting Q, K, V for each head, computing attention weights, applying them to V, and concatenating the outputs across all heads.

---

### Requirements

1. **Linear Projections for Q, K, V**
   - Project input `q`, `k`, `v` into a total of `d_model` dimensions.
   - Split them into `num_heads` of `d_head = d_model // num_heads` each.

2. **Scaled Dot-Product Attention per Head**
   - Compute attention scores:  
     `scores = Q @ Kᵀ / sqrt(d_head)`
   - Apply an optional `mask` before softmax.
   - Use the scores to weight `V`.

3. **Combine the Heads**
   - Concatenate the outputs of all heads.
   - Apply a final linear projection to restore the shape: `(batch_size, seq_len, d_model)`.

4. **Validate Against PyTorch’s Reference**
   - Test your output against `torch.nn.MultiheadAttention` using the same input tensors.
   - Check for numerical closeness using `torch.allclose()`.

---

### Constraints

- ✅ Use only PyTorch operations.
- ✅ Make sure all tensors are reshaped properly when splitting and combining heads.
- ✅ Support optional masking.
- ✅ Must match `torch.nn.MultiheadAttention` output when heads and shape are aligned.

---

<details>
  <summary>💡 Hint</summary>

  - Use `.view()` and `.transpose()` to shape Q, K, V to `(batch_size, num_heads, seq_len, d_head)`.
  - Softmax should be applied over the **last dimension** (attention scores across sequence).
  - Use `.contiguous().view()` to flatten the multi-head outputs back into `(batch_size, seq_len, d_model)`.
  - Match PyTorch’s behavior using the same projections and batch-first format.

</details>

---

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [3]:
# Synthetic data
torch.manual_seed(42)
batch_size = 3
seq_len = 4
d_model = 8
num_heads = 2

q = torch.rand(batch_size, seq_len, d_model)
k = torch.rand(batch_size, seq_len, d_model)
v = torch.rand(batch_size, seq_len, d_model)
print(q.shape)

device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"

torch.Size([3, 4, 8])


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


def multi_head_attention(q, k, v, num_heads, d_model, mask=None):
    """
    Implements multi-head attention.

    Args:
        q (Tensor): Query tensor of shape (batch_size, seq_len, d_model)
        k (Tensor): Key tensor of shape (batch_size, seq_len, d_model)
        v (Tensor): Value tensor of shape (batch_size, seq_len, d_model)
        num_heads (int): Number of attention heads
        d_model (int): Total embedding dimension
        mask (Tensor, optional): Masking tensor for attention

    Returns:
        Tensor: Multi-head attention output of shape (batch_size, seq_len, d_model)
    """

    b, t, d = q.shape
    _, s, _ = k.shape
    h = num_heads
    q = q.reshape([b, t, h, d//h])
    k = k.reshape([b, s, h, d//h])
    v = v.reshape([b, s, h, d//h])

    scores = torch.einsum('bthd,bshd->bhts', q, k) / np.sqrt(d//h)
    if mask:
        scores = torch.where(mask==0, -torch.inf, scores)
    weights = scores.softmax(dim=-1)
    output = torch.einsum('bhts,bshd->bthd', weights, v)
    output = output.reshape([b, t, d])
    return output

In [6]:
# Testing on data & compare
output_custom = multi_head_attention(q, k, v, num_heads, d_model)
print(output_custom)

multihead_attn = torch.nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, bias=False, batch_first=True)
output, _ = multihead_attn(q, k, v)
print(output)

assert torch.allclose(output_custom, output, atol=1e-08, rtol=1e-05) # Check if they are close enough.


tensor([[[0.4663, 0.5878, 0.7101, 0.5205, 0.3190, 0.4160, 0.3755, 0.5013],
         [0.4711, 0.5907, 0.7079, 0.5186, 0.3179, 0.4094, 0.3754, 0.5044],
         [0.4645, 0.5890, 0.7107, 0.5228, 0.3196, 0.4143, 0.3717, 0.4947],
         [0.4525, 0.5736, 0.7196, 0.5246, 0.3156, 0.4096, 0.3670, 0.4936]],

        [[0.4097, 0.6708, 0.6076, 0.4224, 0.3674, 0.1395, 0.1817, 0.3833],
         [0.4157, 0.6726, 0.6049, 0.4202, 0.3819, 0.1433, 0.1792, 0.3662],
         [0.4188, 0.6925, 0.6050, 0.4212, 0.3584, 0.1340, 0.1780, 0.4003],
         [0.4071, 0.6927, 0.6012, 0.4161, 0.3792, 0.1406, 0.1765, 0.3718]],

        [[0.6554, 0.5052, 0.6553, 0.4615, 0.4144, 0.2952, 0.6128, 0.2900],
         [0.6561, 0.5060, 0.6351, 0.4754, 0.4194, 0.2955, 0.6165, 0.2900],
         [0.6594, 0.5299, 0.6568, 0.4584, 0.3994, 0.2991, 0.6167, 0.2753],
         [0.6555, 0.5199, 0.6562, 0.4559, 0.4063, 0.2945, 0.6198, 0.2749]]])
tensor([[[ 7.1700e-02,  5.3411e-02,  3.3034e-01,  1.5892e-01,  6.5924e-02,
           2.3432e-

AssertionError: 

In [7]:
output.shape

torch.Size([3, 4, 8])

In [8]:
output_custom.shape

torch.Size([3, 4, 8])