In [None]:
## **Multi-Head Attention, Multi-Query Attention, and Grouped-Query Attention**
These are different **attention mechanisms** designed to balance **performance, speed, and memory efficiency** in transformer models.

---

# **1. Multi-Head Attention (MHA)**
### **What is Multi-Head Attention?**
**Multi-Head Attention (MHA)** allows a Transformer to **learn multiple attention patterns simultaneously** by splitting the input into multiple attention heads.

### **Mathematical Definition**
For an input sequence \( X \), we compute:

1. Compute multiple **Query (Q), Key (K), and Value (V) matrices**:
   \[
   Q_h = X W_Q^h, \quad K_h = X W_K^h, \quad V_h = X W_V^h
   \]
   for each attention head \( h \) (e.g., **8 heads in GPT-3**).

2. Compute **attention per head**:
   \[
   A_h = \text{softmax} \left( \frac{Q_h K_h^T}{\sqrt{d_k}} \right) V_h
   \]

3. Concatenate all attention outputs and apply a linear transformation:
   \[
   \text{MultiHead}(X) = W_O \cdot \text{Concat}(A_1, A_2, ..., A_H)
   \]

### **Pros of Multi-Head Attention**
✅ **Enhances model expressiveness** – Each head captures different relationships.  
✅ **Retains global context** – Multiple heads allow attention over various parts of the input.  
✅ **Works well for large models** – Used in **BERT, GPT, LLaMA, and T5**.

### **Cons of Multi-Head Attention**
❌ **Expensive to compute** – Needs **\( O(HN^2) \) memory**, where \( H \) = number of heads.  
❌ **Not optimal for real-time inference** – Large KV cache storage makes decoding slow.

---

# **2. Multi-Query Attention (MQA)**
### **What is Multi-Query Attention?**
**Multi-Query Attention (MQA)** is a modification of MHA where **only one Key (K) and Value (V) pair is shared across all heads**, reducing memory usage.

### **Mathematical Definition**
Instead of computing separate **\( K_h, V_h \)** for each head, we **use a shared Key-Value cache**:
\[
Q_h = X W_Q^h, \quad K_{\text{shared}} = X W_K, \quad V_{\text{shared}} = X W_V
\]

Attention computation:
\[
A_h = \text{softmax} \left( \frac{Q_h K_{\text{shared}}^T}{\sqrt{d_k}} \right) V_{\text{shared}}
\]

### **Pros of Multi-Query Attention**
✅ **Lower memory usage** – Only **one Key-Value pair** is stored, saving **KV cache memory**.  
✅ **Faster inference** – Works well for **large-scale text generation (e.g., ChatGPT, LLaMA-3)**.  
✅ **Reduces KV cache size from \( O(HN) \) to \( O(N) \)**.

### **Cons of Multi-Query Attention**
❌ **Less expressive than MHA** – Different heads **attend to the same information**, leading to **less diversity in attention patterns**.  
❌ **Not optimal for very deep models** – **Fine-grained attention is lost** compared to MHA.  

🚀 **Example Models Using MQA**: GPT-4, PaLM (Google’s LLM)

---

# **3. Grouped-Query Attention (GQA)**
### **What is Grouped-Query Attention?**
Grouped-Query Attention (GQA) is a **middle ground** between Multi-Head and Multi-Query Attention.

- Instead of **one Key-Value pair** (MQA) or **separate Key-Value pairs per head** (MHA), GQA **groups heads** to share **K and V**.
- **Example**: If you have **8 heads**, you can **group them into 4 groups**, meaning **each group shares the same Key-Value pair**.

### **Mathematical Definition**
For \( G \) groups:
1. Compute **Query (Q) per head**:
   \[
   Q_h = X W_Q^h
   \]
2. Compute **shared Key-Value pairs per group**:
   \[
   K_g = X W_K^g, \quad V_g = X W_V^g, \quad \text{(where \( g = 1, ..., G \))}
   \]

3. Compute **attention per head** using the grouped Key-Value:
   \[
   A_h = \text{softmax} \left( \frac{Q_h K_g^T}{\sqrt{d_k}} \right) V_g
   \]

### **Pros of Grouped-Query Attention**
✅ **Balances expressiveness and efficiency** – Some heads share **K-V**, but not all.  
✅ **Lower KV cache memory than MHA** – Uses **\( O(NG) \) memory instead of \( O(NH) \)**.  
✅ **More accurate than MQA** – Still retains **multiple independent attention heads**.

### **Cons of Grouped-Query Attention**
❌ **Slower than MQA** – Still requires multiple KV caches (though less than MHA).  
❌ **Requires fine-tuning for best results** – The **number of groups \( G \)** affects performance.  

🚀 **Example Models Using GQA**: **LLaMA-3, Mistral, Gemini**

---

# **4. Comparison Table**
| Feature | Multi-Head Attention (MHA) | Multi-Query Attention (MQA) | Grouped-Query Attention (GQA) |
|---------|----------------------------|----------------------------|----------------------------|
| **Computation Cost** | High (\( O(HN^2) \)) | Low (\( O(N^2) \)) | Medium (\( O(GN^2) \)) |
| **Memory Usage (KV Cache)** | High (\( O(HN) \)) | Low (\( O(N) \)) | Medium (\( O(GN) \)) |
| **Inference Speed** | Slowest | Fastest | Faster than MHA, Slower than MQA |
| **Expressiveness** | Best (diverse attention patterns) | Worst (all heads share K, V) | Middle ground |
| **Best For** | General NLP tasks, BERT | Large-scale LLMs (GPT-4) | Efficient LLaMA-style models |

---

# **5. Final Summary**
- **Multi-Head Attention (MHA)** → Best for accuracy but **slow and memory-heavy**.
- **Multi-Query Attention (MQA)** → **Faster and memory-efficient** but **loses diversity** in attention.
- **Grouped-Query Attention (GQA)** → **Balances efficiency and accuracy**, used in **modern LLaMA models**.

🚀 **For large-scale models (e.g., GPT-4, Gemini, LLaMA-3), MQA and GQA are preferred due to efficiency improvements.**

# **Multi-Head Latent Attention (MHLA)**
### **1. What is Multi-Head Latent Attention?**
Multi-Head Latent Attention (MHLA) is an **efficient attention mechanism** that introduces **latent variables** to **reduce the computational complexity** of self-attention. It is primarily used to make **transformers more efficient** for handling **long sequences**.

**Key Idea:**  
Instead of computing **self-attention over all tokens**, we introduce a **latent representation** \( L \) that acts as a **compressed attention space**. This **reduces memory and computation costs** while still capturing important features.

---

## **2. How Multi-Head Latent Attention Works**
MHLA modifies standard **Multi-Head Attention (MHA)** by introducing a **set of latent vectors** that interact with input tokens instead of computing attention directly between all tokens.

### **Key Components**
1. **Latent Representations \( L \)**:  
   - These are a fixed number of learnable vectors that **replace direct token-to-token attention**.
   - Instead of having **\( N \) queries, keys, and values**, we have a **smaller number \( M \) of latent vectors**.
   - This reduces complexity from **\( O(N^2) \)** to **\( O(NM) \)**, where \( M \ll N \).

2. **Two-Step Attention Mechanism**:  
   - **Step 1: Input → Latent Attention**  
     - The **input tokens attend to latent variables**.  
   - **Step 2: Latent → Output Attention**  
     - The **latent variables attend back to the input tokens**.

This structure acts as an **intermediate compression step**, allowing the model to **summarize input information** before re-distributing attention.

---

## **3. Mathematical Formulation of MHLA**
### **Step 1: Compute Query, Key, and Value Matrices**
For an input sequence **\( X \in \mathbb{R}^{N \times d} \)**:
- \( N \) = Number of input tokens.
- \( d \) = Embedding dimension.

For a **latent representation** **\( L \in \mathbb{R}^{M \times d} \)**:
- \( M \) = Number of latent vectors (**\( M \ll N \)**).

Compute **Q, K, V for input tokens**:
\[
Q_X = X W_Q, \quad K_X = X W_K, \quad V_X = X W_V
\]
Compute **Q, K, V for latent vectors**:
\[
Q_L = L W_Q^L, \quad K_L = L W_K^L, \quad V_L = L W_V^L
\]

### **Step 2: Input → Latent Attention**
Instead of **self-attention between tokens**, we compute attention **from inputs to latent vectors**:
\[
A_1 = \text{softmax} \left( \frac{Q_X K_L^T}{\sqrt{d_k}} \right) V_L
\]
where:
- \( A_1 \) is the attention from **input tokens to latent space**.
- \( K_L, V_L \) are **shared latent key-value representations**.

### **Step 3: Latent → Output Attention**
Now, we project **the latent representations back to the token space**:
\[
A_2 = \text{softmax} \left( \frac{Q_L K_X^T}{\sqrt{d_k}} \right) V_X
\]
where:
- \( A_2 \) is the attention from **latent space back to input tokens**.
- \( Q_L \) captures **compressed input features**.

### **Step 4: Compute Final Output**
The final output representation is computed as:
\[
\text{Output} = A_1 + A_2
\]

This **compresses global attention into a smaller latent space**, reducing **computational costs**.

---

## **4. Complexity Analysis**
| **Method**  | **Computational Complexity** |
|------------|----------------------------|
| **Standard Self-Attention** | \( O(N^2 d) \) |
| **Multi-Head Latent Attention (MHLA)** | \( O(NMd) + O(MNd) = O(NM d) \) |

Since \( M \ll N \), MHLA **significantly reduces memory and compute costs**.

---

## **5. Advantages of Multi-Head Latent Attention**
✅ **Faster than traditional self-attention** → Replaces quadratic \( O(N^2) \) computation with a smaller \( O(NM) \).  
✅ **Lower memory footprint** → Reduces the KV cache size, useful for **long-sequence tasks**.  
✅ **Retains long-range dependencies** → Latent vectors act as a **compressed summary**.  

🚀 **Used in models like Perceiver and Transformer-XL to handle long inputs efficiently!**