# Implement Attention from Scratch
### 🧠 Problem Statement
Standard Multi-Head Attention (MHA) assigns a separate query, key, and value projection to each attention head. But that’s not always the most efficient approach. 

Enter **Grouped Query Attention (GQA)** — a clever mechanism where you use more query heads than key-value heads. This reduces compute/memory costs while still allowing for fine-grained query specialization.

Your task is to **implement GQA from scratch** and validate it against PyTorch’s `MultiheadAttention` under the special case where GQA behaves identically to MHA (i.e., when `num_query_heads == num_query_groups`).

---

### ✅ Requirements

1. **Define the GQA Mechanism**
   - Create a function `grouped_query_attention(q, k, v, num_query_groups, d_model, mask=None)`.
   - Project `q`, `k`, and `v` using linear layers:
     - Q projection → all query heads.
     - K/V projection → shared across grouped key/value heads.
   - Use `repeat_interleave()` to expand grouped K/V heads to match the number of Q heads.

2. **Compute Attention**
   - Apply scaled dot-product attention using `Q @ Kᵀ / sqrt(d_head)`.
   - Support optional masking.
   - Return output by concatenating heads and applying the output projection.

3. **Validate Against MHA**
   - Test your implementation using synthetic tensors.
   - Compare your output to `torch.nn.MultiheadAttention` where GQA degenerates to MHA (`num_query_heads == num_query_groups`).
   - Assert that both outputs match numerically.

---

### 📏 Constraints

- ✅ Use only PyTorch (no external libraries like xformers or HuggingFace).
- ✅ Output shape must be `(batch_size, seq_len, d_model)`.
- ✅ Support optional attention masking.
- ✅ Validate output against `torch.nn.MultiheadAttention` for correctness.

---

<details>
  <summary>💡 Hint</summary>

  - Use `nn.Linear(d_model, d_model)` for projecting `q`, `k`, and `v`.
  - When `num_query_heads > num_query_groups`, use `.repeat_interleave()` to duplicate each group’s `K`/`V` to match query head count.
  - Final output: reshape the multi-head outputs to `(batch_size, seq_len, d_model)` and apply the output projection layer.
  - Test with `num_query_heads == num_query_groups` to confirm it behaves like MHA.

</details>

---

In [25]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [26]:
# Synthetic data
torch.manual_seed(42)
batch_size = 3
seq_len = 4
d_model = 8
num_heads = 2

q = torch.rand(batch_size, seq_len, d_model)
k = torch.rand(batch_size, seq_len, d_model)
v = torch.rand(batch_size, seq_len, d_model)
print(q.shape)

device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"

torch.Size([3, 4, 8])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def grouped_query_attention(q, k, v, num_query_groups, d_model, mask=None):
    """
    Implements Grouped Query Attention (GQA).

    Args:
        q (Tensor): Query tensor of shape (batch_size, seq_len, d_model)
        k (Tensor): Key tensor of shape (batch_size, seq_len, d_model)
        v (Tensor): Value tensor of shape (batch_size, seq_len, d_model)
        num_query_groups (int): Number of key/value groups (fewer than query groups)
        d_model (int): Total embedding dimension
        mask (Tensor, optional): Masking tensor for attention

    Returns:
        Tensor: GQA output of shape (batch_size, seq_len, d_model)
    """
    num_query_heads = d_model // 64  # Usually set like in MHA (e.g., 8 heads if d_model=512)
    assert d_model % num_query_heads == 0
    assert num_query_heads % num_query_groups == 0

    d_head = d_model // num_query_heads
    batch_size, seq_len, _ = q.shape

    Q_w = nn.Linear(d_model, d_model, bias=False).to(q.device)
    K_w = nn.Linear(d_model, d_model * num_query_groups // num_query_heads, bias=False).to(q.device)
    V_w = nn.Linear(d_model, d_model * num_query_groups // num_query_heads, bias=False).to(q.device)
    W_out = nn.Linear(d_model, d_model, bias=False).to(q.device)

    Q = Q_w(q)  # (batch_size, seq_len, d_model)
    K = K_w(k)  # (batch_size, seq_len, d_model * r), r = num_query_groups / num_query_heads
    V = V_w(v)

    Q = Q.view(batch_size, seq_len, num_query_heads, d_head).transpose(1, 2)  # (batch, heads, seq, d_head)

    # Expand keys/values per group
    d_kv = d_model // num_query_groups
    K = K.view(batch_size, seq_len, num_query_groups, d_kv).transpose(1, 2)  # (batch, kv_groups, seq, d_kv)
    V = V.view(batch_size, seq_len, num_query_groups, d_kv).transpose(1, 2)

    # Repeat K/V for matching Q heads
    repeat_factor = num_query_heads // num_query_groups
    K = K.repeat_interleave(repeat_factor, dim=1)  # (batch, heads, seq, d_kv)
    V = V.repeat_interleave(repeat_factor, dim=1)

    scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_head ** 0.5)  # (batch, heads, seq, seq)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V)  # (batch, heads, seq, d_head)

    output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
    return W_out(output)


In [None]:
# Temporarily set
num_query_heads = 8
num_query_groups = 8  # => GQA behaves like MHA

# Use same d_model and input
multihead_attn = torch.nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, bias=False, batch_first=True)
output_ref, _ = multihead_attn(q, k, v)
output_custom = grouped_query_attention(q, k, v, num_query_groups=num_query_groups, d_model=d_model)

# Compare
assert torch.allclose(output_custom, output_ref, atol=1e-8, rtol=1e-5)



tensor([[[-9.5287e-03, -9.4515e-02,  3.6103e-01,  8.3070e-02, -1.4784e-01,
           1.1758e-01, -2.0393e-02,  7.4377e-02],
         [-1.2889e-02, -9.5831e-02,  3.6233e-01,  8.0714e-02, -1.4981e-01,
           1.1651e-01, -2.2352e-02,  7.3968e-02],
         [-1.0349e-02, -9.5222e-02,  3.6118e-01,  8.3003e-02, -1.4805e-01,
           1.1768e-01, -2.0513e-02,  7.4629e-02],
         [-1.3522e-02, -9.6064e-02,  3.6350e-01,  7.9288e-02, -1.5093e-01,
           1.1577e-01, -2.3686e-02,  7.3739e-02]],

        [[ 5.4852e-03, -9.3221e-02,  2.8457e-01,  1.8735e-02, -1.4442e-01,
           6.7877e-02, -4.0639e-02,  4.1987e-02],
         [ 5.2663e-03, -9.3650e-02,  2.8520e-01,  1.8377e-02, -1.4474e-01,
           6.8188e-02, -4.1217e-02,  4.1996e-02],
         [ 5.1576e-03, -9.2715e-02,  2.8375e-01,  1.9321e-02, -1.4380e-01,
           6.7911e-02, -4.0011e-02,  4.1879e-02],
         [ 4.8699e-03, -9.3693e-02,  2.8548e-01,  1.8387e-02, -1.4475e-01,
           6.8602e-02, -4.1422e-02,  4.1853e-02]

AssertionError: 