# Implement Attention from Scratch
### üß† Problem Statement
Standard Multi-Head Attention (MHA) assigns a separate query, key, and value projection to each attention head. But that‚Äôs not always the most efficient approach. 

Enter **Grouped Query Attention (GQA)** ‚Äî a clever mechanism where you use more query heads than key-value heads. This reduces compute/memory costs while still allowing for fine-grained query specialization.

Your task is to **implement GQA from scratch** and validate it against PyTorch‚Äôs `MultiheadAttention` under the special case where GQA behaves identically to MHA (i.e., when `num_query_heads == num_query_groups`).

---

### ‚úÖ Requirements

1. **Define the GQA Mechanism**
   - Create a function `grouped_query_attention(q, k, v, num_query_groups, d_model, mask=None)`.
   - Project `q`, `k`, and `v` using linear layers:
     - Q projection ‚Üí all query heads.
     - K/V projection ‚Üí shared across grouped key/value heads.
   - Use `repeat_interleave()` to expand grouped K/V heads to match the number of Q heads.

2. **Compute Attention**
   - Apply scaled dot-product attention using `Q @ K·µÄ / sqrt(d_head)`.
   - Support optional masking.
   - Return output by concatenating heads and applying the output projection.

3. **Validate Against MHA**
   - Test your implementation using synthetic tensors.
   - Compare your output to `torch.nn.MultiheadAttention` where GQA degenerates to MHA (`num_query_heads == num_query_groups`).
   - Assert that both outputs match numerically.

---

### üìè Constraints

- ‚úÖ Use only PyTorch (no external libraries like xformers or HuggingFace).
- ‚úÖ Output shape must be `(batch_size, seq_len, d_model)`.
- ‚úÖ Support optional attention masking.
- ‚úÖ Validate output against `torch.nn.MultiheadAttention` for correctness.

---

<details>
  <summary>üí° Hint</summary>

  - Use `nn.Linear(d_model, d_model)` for projecting `q`, `k`, and `v`.
  - When `num_query_heads > num_query_groups`, use `.repeat_interleave()` to duplicate each group‚Äôs `K`/`V` to match query head count.
  - Final output: reshape the multi-head outputs to `(batch_size, seq_len, d_model)` and apply the output projection layer.
  - Test with `num_query_heads == num_query_groups` to confirm it behaves like MHA.

</details>

---

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [4]:
# Synthetic data
torch.manual_seed(42)
batch_size = 3
seq_len = 4
d_model = 8
num_heads = 2

q = torch.rand(batch_size, seq_len, d_model)
k = torch.rand(batch_size, seq_len, d_model)
v = torch.rand(batch_size, seq_len, d_model)
print(q.shape)

device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"

torch.Size([3, 4, 8])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def grouped_query_attention(q, k, v, num_query_groups, d_model, mask=None):
    """
    Implements Grouped Query Attention (GQA).

    Args:
        q (Tensor): Query tensor of shape (batch_size, seq_len, d_model)
        k (Tensor): Key tensor of shape (batch_size, seq_len, d_model)
        v (Tensor): Value tensor of shape (batch_size, seq_len, d_model)
        num_query_groups (int): Number of key/value groups (fewer than query groups)
        d_model (int): Total embedding dimension
        mask (Tensor, optional): Masking tensor for attention

    Returns:
        Tensor: GQA output of shape (batch_size, seq_len, d_model)
    """


In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CustomGroupedQueryAttention(torch.nn.Module):
    def __init__(self, num_query_heads, num_query_groups, d_model, bias):
        super().__init__()
        self.num_query_heads = num_query_heads
        self.num_query_groups = num_query_groups
        self.d_model = d_model
        self.bias = bias
        self.dq_mha = self.d_model // self.num_query_heads
        self.dk_groups = self.dq_mha*self.num_query_groups
        self.q_kv_groups_ratio = self.num_query_heads//self.num_query_groups
        
        assert self.dq_mha*self.num_query_heads == self.d_model, "incompatible num_head and d_model conbination"
        print(f"d_model={self.d_model} | num_heads={self.num_query_heads} | d_mha={self.dq_mha}")

        self.Wq = torch.nn.Linear(in_features=self.d_model, out_features=self.d_model, bias=self.bias)
        self.Wk = torch.nn.Linear(in_features=self.d_model, out_features=self.dk_groups, bias=self.bias)
        self.Wv = torch.nn.Linear(in_features=self.d_model, out_features=self.dk_groups, bias=self.bias)
        self.Wc = torch.nn.Linear(in_features=self.d_model, out_features=self.d_model, bias=self.bias)
    
    def forward(self, q, k, v, mask=None):
        batch_size, seq_len_q, d_q = q.shape
        batch_size, seq_len_k, d_k = k.shape
        batch_size, seq_len_v, d_v = v.shape

        # projections for Q, K, V metrics
        q = self.Wq(q) # (B, L_q, d_model)
        k = self.Wk(k) # (B, L_k, dk_groups)
        v = self.Wv(v) # (B, L_v, dk_groups)


        # split Q, K, V into multiple heads
        q = q.view(batch_size, q.shape[-2], self.num_query_heads, self.dq_mha) # (B, L_q, H_q, dq_mha)
        k = k.view(batch_size, k.shape[-2], self.num_query_groups, self.dq_mha) # (B, L_k, H_g, dq_mha)
        v = v.view(batch_size, v.shape[-2], self.num_query_groups, self.dq_mha) # (B, L_v, H_g, dq_mha)

        # reshape to move number of heads to 2nd axis
        q = q.transpose(1,2) # (B, H_q, L_q, dq_mha)
        k = k.transpose(1,2) # (B, H_g, L_k, dq_mha)
        v = v.transpose(1,2) # (B, H_g, L_v, dq_mha)

        k = k.repeat_interleave(repeats=self.q_kv_groups_ratio, dim=-3)
        v = v.repeat_interleave(repeats=self.q_kv_groups_ratio, dim=-3)

        print(f"q.shape = {q.shape} | k.shape = {k.shape} | v.shape = {v.shape}")

        # apply attention
        attention_per_head = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # (B, H_q, L_q, dq_mha)
        attention_per_head = attention_per_head.permute(0,2,1,3) # (B, L_q, H_q, d_mha)
        print(f"attention_per_head.shape = {attention_per_head.shape}")
        print(f"(batch_size, seq_len, d_model) = {(batch_size, seq_len, self.d_model)}")
        attention_concatenated = attention_per_head.reshape(batch_size, seq_len, self.d_model) # (B, L_q, d_model)
        
        # mha = attention_concatenated @ Wc # (B, L_k, d_model)
        mha = self.Wc(attention_concatenated) # (B, L_k, d_model)
        return mha, attention_per_head

In [19]:
# Temporarily set
num_query_heads = 4
num_query_groups = 2  # => GQA behaves like MHA

# Use same d_model and input
multihead_attn = torch.nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, bias=False, batch_first=True)
output_ref, _ = multihead_attn(q, k, v)
print(f"output_ref.shape = {output_ref.shape}")
# output_custom = grouped_query_attention(q, k, v, num_query_heads=num_query_heads, num_query_groups=num_query_groups, d_model=d_model)
custom_gqa = CustomGroupedQueryAttention(num_query_heads, num_query_groups, d_model, bias=False)
output_custom, _ = custom_gqa(q, k, v)
print(f"output_custom.shape = {output_custom.shape}")

# Compare
assert torch.allclose(output_custom, output_ref, atol=1e-8, rtol=1e-5)



output_ref.shape = torch.Size([3, 4, 8])
d_model=8 | num_heads=4 | d_mha=2
q.shape = torch.Size([3, 4, 4, 2]) | k.shape = torch.Size([3, 4, 4, 2]) | v.shape = torch.Size([3, 4, 4, 2])
attention_per_head.shape = torch.Size([3, 4, 4, 2])
(batch_size, seq_len, d_model) = (3, 4, 8)
output_custom.shape = torch.Size([3, 4, 8])


AssertionError: 