# Implement Attention from Scratch
### Problem Statement
Multi-Head Attention (MHA) is the bread-and-butter of the Transformer architecture. It enables the model to **jointly attend** to information from different representation subspaces at different positions.

Your goal is to implement MHA from scratch using PyTorch, simulating exactly what `torch.nn.MultiheadAttention` does — projecting Q, K, V for each head, computing attention weights, applying them to V, and concatenating the outputs across all heads.

---

### Requirements

1. **Linear Projections for Q, K, V**
   - Project input `q`, `k`, `v` into a total of `d_model` dimensions.
   - Split them into `num_heads` of `d_head = d_model // num_heads` each.

2. **Scaled Dot-Product Attention per Head**
   - Compute attention scores:  
     `scores = Q @ Kᵀ / sqrt(d_head)`
   - Apply an optional `mask` before softmax.
   - Use the scores to weight `V`.

3. **Combine the Heads**
   - Concatenate the outputs of all heads.
   - Apply a final linear projection to restore the shape: `(batch_size, seq_len, d_model)`.

4. **Validate Against PyTorch’s Reference**
   - Test your output against `torch.nn.MultiheadAttention` using the same input tensors.
   - Check for numerical closeness using `torch.allclose()`.

---

### Constraints

- ✅ Use only PyTorch operations.
- ✅ Make sure all tensors are reshaped properly when splitting and combining heads.
- ✅ Support optional masking.
- ✅ Must match `torch.nn.MultiheadAttention` output when heads and shape are aligned.

---

<details>
  <summary>💡 Hint</summary>

  - Use `.view()` and `.transpose()` to shape Q, K, V to `(batch_size, num_heads, seq_len, d_head)`.
  - Softmax should be applied over the **last dimension** (attention scores across sequence).
  - Use `.contiguous().view()` to flatten the multi-head outputs back into `(batch_size, seq_len, d_model)`.
  - Match PyTorch’s behavior using the same projections and batch-first format.

</details>

---

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
# Synthetic data
torch.manual_seed(42)
batch_size = 3
seq_len = 4
d_model = 8
num_heads = 2

q = torch.rand(batch_size, seq_len, d_model)
k = torch.rand(batch_size, seq_len, d_model)
v = torch.rand(batch_size, seq_len, d_model)
print(q.shape)

device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F




def multi_head_attention(q, k, v, num_heads, d_model, mask=None):
    """
    Implements multi-head attention.
    
    Args:
        q (Tensor): Query tensor of shape (batch_size, seq_len, d_model)
        k (Tensor): Key tensor of shape (batch_size, seq_len, d_model)
        v (Tensor): Value tensor of shape (batch_size, seq_len, d_model)
        num_heads (int): Number of attention heads
        d_model (int): Total embedding dimension
        mask (Tensor, optional): Masking tensor for attention
        
    Returns:
        Tensor: Multi-head attention output of shape (batch_size, seq_len, d_model)
    """
    ...

In [None]:
# Testing on data & compare
output_custom = multi_head_attention(q, k, v, num_heads, d_model)
print(output_custom)

multihead_attn = torch.nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, bias=False, batch_first=True)
output, _ = multihead_attn(q, k, v)
print(output)

assert torch.allclose(output_custom, output, atol=1e-08, rtol=1e-05) # Check if they are close enough.
