# Implement Attention from Scratch
### 🧠 Problem Statement
Standard Multi-Head Attention (MHA) assigns a separate query, key, and value projection to each attention head. But that’s not always the most efficient approach. 

Enter **Grouped Query Attention (GQA)** — a clever mechanism where you use more query heads than key-value heads. This reduces compute/memory costs while still allowing for fine-grained query specialization.

Your task is to **implement GQA from scratch** and validate it against PyTorch’s `MultiheadAttention` under the special case where GQA behaves identically to MHA (i.e., when `num_query_heads == num_query_groups`).

---

### ✅ Requirements

1. **Define the GQA Mechanism**
   - Create a function `grouped_query_attention(q, k, v, num_query_groups, d_model, mask=None)`.
   - Project `q`, `k`, and `v` using linear layers:
     - Q projection → all query heads.
     - K/V projection → shared across grouped key/value heads.
   - Use `repeat_interleave()` to expand grouped K/V heads to match the number of Q heads.

2. **Compute Attention**
   - Apply scaled dot-product attention using `Q @ Kᵀ / sqrt(d_head)`.
   - Support optional masking.
   - Return output by concatenating heads and applying the output projection.

3. **Validate Against MHA**
   - Test your implementation using synthetic tensors.
   - Compare your output to `torch.nn.MultiheadAttention` where GQA degenerates to MHA (`num_query_heads == num_query_groups`).
   - Assert that both outputs match numerically.

---

### 📏 Constraints

- ✅ Use only PyTorch (no external libraries like xformers or HuggingFace).
- ✅ Output shape must be `(batch_size, seq_len, d_model)`.
- ✅ Support optional attention masking.
- ✅ Validate output against `torch.nn.MultiheadAttention` for correctness.

---

<details>
  <summary>💡 Hint</summary>

  - Use `nn.Linear(d_model, d_model)` for projecting `q`, `k`, and `v`.
  - When `num_query_heads > num_query_groups`, use `.repeat_interleave()` to duplicate each group’s `K`/`V` to match query head count.
  - Final output: reshape the multi-head outputs to `(batch_size, seq_len, d_model)` and apply the output projection layer.
  - Test with `num_query_heads == num_query_groups` to confirm it behaves like MHA.

</details>

---

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
# Synthetic data
torch.manual_seed(42)
batch_size = 3
seq_len = 4
d_model = 8
num_heads = 2

q = torch.rand(batch_size, seq_len, d_model)
k = torch.rand(batch_size, seq_len, d_model)
v = torch.rand(batch_size, seq_len, d_model)
print(q.shape)

device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def grouped_query_attention(q, k, v, num_query_groups, d_model, mask=None):
    """
    Implements Grouped Query Attention (GQA).

    Args:
        q (Tensor): Query tensor of shape (batch_size, seq_len, d_model)
        k (Tensor): Key tensor of shape (batch_size, seq_len, d_model)
        v (Tensor): Value tensor of shape (batch_size, seq_len, d_model)
        num_query_groups (int): Number of key/value groups (fewer than query groups)
        d_model (int): Total embedding dimension
        mask (Tensor, optional): Masking tensor for attention

    Returns:
        Tensor: GQA output of shape (batch_size, seq_len, d_model)
    """
    ...
    

In [None]:
# Temporarily set
num_query_heads = 8
num_query_groups = 8  # => GQA behaves like MHA

# Use same d_model and input
multihead_attn = torch.nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, bias=False, batch_first=True)
output_ref, _ = multihead_attn(q, k, v)
output_custom = grouped_query_attention(q, k, v, num_query_groups=num_query_groups, d_model=d_model)

# Compare
assert torch.allclose(output_custom, output_ref, atol=1e-8, rtol=1e-5)

