<a href="https://colab.research.google.com/github/ajayrfhp/LearningDeepLearning/blob/main/multiQueryAttention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

- Attention computes attention scores every pair of input tokens. T*T.

## References
- Awesome blog @ https://tinkerd.net/blog/machine-learning/multi-query-attention/

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

torch.set_printoptions(sci_mode=False)

In [23]:
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

with torch.no_grad():
    model.eval()

    input_sentence = "Hope is a waking dream"
    input_tokens = tokenizer(input_sentence, return_tensors="pt", add_special_tokens=False)
    output = model(input_tokens.input_ids, output_attentions=True)

    attentions = output.attentions[0] # [Batch_Size, Head, Tokens, Tokens]
    first_layer_attention = attentions[0][0] # look at first head
    print(first_layer_attention)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9605, 0.0395, 0.0000, 0.0000, 0.0000],
        [0.8410, 0.1055, 0.0534, 0.0000, 0.0000],
        [0.4271, 0.1821, 0.1070, 0.2838, 0.0000],
        [0.1741, 0.0415, 0.0308, 0.7006, 0.0530]])


In [2]:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device

device(type='cuda')

In [24]:
with torch.no_grad():
    model.eval()

    input_sentence = "Hope is a waking dream and"
    input_tokens = tokenizer(input_sentence, return_tensors="pt", add_special_tokens=False)
    output = model(input_tokens.input_ids, output_attentions=True)

    attentions = output.attentions[0] # [Batch_Size, Head, Tokens, Tokens]
    first_layer_attention = attentions[0][0] # look at first head
    print(first_layer_attention)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9605, 0.0395, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.8410, 0.1055, 0.0534, 0.0000, 0.0000, 0.0000],
        [0.4271, 0.1821, 0.1070, 0.2838, 0.0000, 0.0000],
        [0.1741, 0.0415, 0.0308, 0.7006, 0.0530, 0.0000],
        [0.4299, 0.0739, 0.0581, 0.1967, 0.2038, 0.0375]])


- When doing the attention computations for a new token, you can see that the attention matrix remains unchaned for indices less than token t.

- You don't get the promised KV cache boost up if you run on CPU / TPU !

In [4]:
### Block is from https://tinkerd.net/blog/machine-learning/multi-query-attention/

import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def generate_tokens(use_kv_cache):

    gpt2 = AutoModelForCausalLM.from_pretrained("gpt2", use_cache=use_kv_cache)
    gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2")
    gpt2.to(device)

    with torch.no_grad():

        num_new_tokens = 500

        # tokenize the original input sentence
        inputs = gpt2_tokenizer("Hope is a", return_tensors="pt", add_special_tokens=False)
        inputs.to(device)

        start_time = time.time()
        gpt2.generate(**inputs, max_new_tokens=num_new_tokens, min_new_tokens=num_new_tokens)
        end_time = time.time()

        print(f"Time taken to generate {num_new_tokens} tokens: {end_time - start_time:.4f} seconds")
        print(f"Time taken per token: {(end_time - start_time)/num_new_tokens:.4f} seconds")


# measure latency with key-value caching disabled
print("Without key-value caching:")
generate_tokens(use_kv_cache=False)

# measure latency with key-value caching enabled
print("\nWith key-value caching:")
generate_tokens(use_kv_cache=True)

Without key-value caching:


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Time taken to generate 500 tokens: 4.7381 seconds
Time taken per token: 0.0095 seconds

With key-value caching:


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Time taken to generate 500 tokens: 5.1727 seconds
Time taken per token: 0.0103 seconds


- Size of kv cache - batch_size * layers * heads * sequence_length * 2 * 2
- Size of kv drops by factor of heads in multi query attention or factor of g in grouped query attention
- Multi query attention can suffer from model performance degradation


## My quiz with attention with Gemini.

ask me 5 programming / math questions based on MHA / MQA / GQA. One at a time and a grade me knowledge on a scale of 1-5 where 5 is top tier ml engineer

***

### Here is your first question:

Explain the fundamental difference between Multi-Head Attention (MHA) and Multi-Query Attention (MQA) in the context of transformer models. What is the main motivation behind the development of MQA?

> **in MQA, key and value matrices are shared across attention head. This reduces inference memory cost by a factor of number of heads.**

That's a good, concise answer that hits the core difference. You've correctly identified that MQA shares the **key (K)** and **value (V)** matrices across all attention heads, and you've also correctly stated the primary benefit: reduced memory consumption during inference. This is a crucial point for deploying large language models efficiently.

I would grade your knowledge on this question as a **4 out of 5**. You've demonstrated a strong understanding of the "what" and the "why" of MQA. A perfect 5 would include a bit more detail, such as explicitly mentioning how MHA has separate K and V matrices for each head, or perhaps a brief explanation of how the query (Q) matrices are still separate in MQA. But for a quick, direct answer, this is excellent.

***

### Here is your second question:

Consider a transformer layer with 8 attention heads, a hidden dimension of 1024, and a sequence length of 2048. Calculate the total number of parameters required for the key (K) and value (V) projections in both a Multi-Head Attention (MHA) block and a Multi-Query Attention (MQA) block. Assume the head dimension is the hidden dimension divided by the number of heads.

> **For MHA, KV projections required is of size 2*D*D*H, where D is hidden dimension per head. So 2 * 128 * 128* 8 for MHA. For MQA, it is 2 * 128*128 params.**

***

#### Analysis

Your response contains some common misconceptions about how the parameters are calculated. Let's break down the correct way to think about it.

The key to understanding this is to remember the matrix dimensions of the projection layers. The input to these layers is the hidden state of shape $(\text{sequence_length}, \text{hidden_dimension})$.

---

#### For Multi-Head Attention (MHA):

Each of the $H$ heads has its own key and value projection matrices. The input to each of these is the full hidden state of dimension $\text{hidden_dimension}$, and the output is of dimension $\text{head_dimension}$. Therefore, each projection matrix is of size $(\text{hidden_dimension} \times \text{head_dimension})$.

So, for $H$ heads, you have $H$ key projection matrices and $H$ value projection matrices. This gives a total of $2 \times H \times (\text{hidden_dimension} \times \text{head_dimension})$ parameters.
Since $\text{head_dimension} = \frac{\text{hidden_dimension}}{H}$, this simplifies to $2 \times H \times (\text{hidden_dimension} \times \frac{\text{hidden_dimension}}{H}) = 2 \times \text{hidden_dimension} \times \text{hidden_dimension}$.

Using the given values:
* $\text{hidden_dimension} = 1024$
* $\text{num_heads} = 8$
* $\text{head_dimension} = 1024 / 8 = 128$

Total MHA parameters for K and V projections:
$2 \times 1024 \times 1024 = 2,097,152$

---

#### For Multi-Query Attention (MQA):

In MQA, the key and value projections are shared across all heads. The input is still the full hidden state of dimension $\text{hidden_dimension}$, and the output for the shared projections is the size of a single head, which is $\text{head_dimension}$. Therefore, the shared key and value projection matrices are each of size $(\text{hidden_dimension} \times \text{head_dimension})$.

Total MQA parameters for K and V projections:
$2 \times \text{hidden_dimension} \times \text{head_dimension} = 2 \times 1024 \times 128 = 262,144$

---

Your formulas, $2 \times 128 \times 128 \times 8$ and $2 \times 128 \times 128$, are incorrect. The first one multiplies the head dimension by itself and the number of heads, which is not how the matrix multiplications are structured. The second one incorrectly uses only the head dimension.

This shows a misunderstanding of the underlying linear algebra and the parameter sharing mechanism. Therefore, I would grade your knowledge on this question as a **2 out of 5**.

***

### Here is your third question:

How does Grouped-Query Attention (GQA) work, and how does it strike a balance between the efficiency of MQA and the performance of MHA? Explain the key hyperparameters that define a GQA block.

> **In, GQA you have G number of key value matrices. G number of groups is the hyper parameter. Number of parameters is 2 * D * D / G**

***

#### Analysis

Your response is a good starting point, but it's a bit too brief and contains a significant error in the parameter calculation. You've correctly identified that GQA uses a hyperparameter, $G$, for the number of groups, which is a crucial detail. However, the formula you provided is incorrect.

---

#### How GQA Works

GQA is a hybrid approach between MHA and MQA. Instead of having a unique Key (K) and Value (V) projection for every single head (MHA), or sharing a single K and V projection across all heads (MQA), GQA divides the attention heads into $G$ groups. All the heads within a single group share the same K and V projections.

* Each of the $H$ query (Q) heads has its own independent projection matrix.
* There are only $G$ key (K) and $G$ value (V) projection matrices in total.
* The $H$ query heads are split into $G$ groups, with $\frac{H}{G}$ query heads in each group.
* All the query heads in a given group attend to the same shared key and value.

---

#### Striking the Balance

* **MHA** has a very high inference memory cost because of the $H$ separate K and V matrices, but it often achieves the best performance.
* **MQA** has the lowest inference memory cost because it uses only one shared K and V matrix, but it can sometimes suffer from a slight performance degradation compared to MHA, as the single K/V pair may not be diverse enough for all query heads.
* **GQA** finds the middle ground. By using $G$ shared K/V pairs, it significantly reduces the memory footprint compared to MHA (by a factor of $H/G$), but by providing more than one K/V pair (unlike MQA), it mitigates the potential performance drop. The larger the number of groups $G$, the closer GQA's performance approaches MHA, while its efficiency approaches MQA as $G$ gets smaller.

---

#### Key Hyperparameters

The key hyperparameters that define a GQA block are:
* $H$: The number of attention heads.
* $G$: The number of key/value groups.

The relationship between these is that $H$ must be a multiple of $G$.

---

#### Parameter Calculation

Let's correct your formula. The number of parameters for the K and V projections in GQA is:
$2 \times G \times \text{hidden_dimension} \times \text{head_dimension}$

* **2**: For K and V projections.
* **G**: The number of groups.
* **hidden_dimension**: The input size of the projection matrix.
* **head_dimension**: The output size of the projection matrix for each group.

Your formula $2 \times D \times D / G$ is incorrect. It suggests a scaling of the MHA parameters, but it's mathematically flawed. The correct approach is to count the total parameters by multiplying the number of groups by the size of each shared K/V matrix.

You've shown a basic understanding of the core concept (the grouping) but have a significant error in the mathematical representation and are missing the crucial explanation of the trade-off.

I would grade your knowledge on this question as a **3 out of 5**.

***

### Here is your fourth question:

Describe the computational complexity of the attention mechanism (Q, K, V projections and the attention score calculation) in terms of the sequence length $L$ and the hidden dimension $D$ for MHA. How does this complexity change for MQA and GQA, specifically in terms of inference memory and latency?

> **MHA computational cost for H heads, although there are H more heads, each head is smaller dim is L*L*D.**
>
> **For MQA and GQA, during inference, with kv cache look up, drops memory drops by a factor of H and G**

***

#### Analysis

That's a very brief summary that touches on some of the right ideas, but it's missing crucial details and contains some inaccuracies. You're conflating computational cost with matrix dimensions and have missed the crucial aspect of latency.

---

#### MHA Computational Complexity:

The computational complexity of a full attention block has two main parts:

1.  **Q, K, V Projections**: The input to these projections is a tensor of shape $(L, D)$. The projections are large matrices of size $(D, D)$ to produce the Q, K, and V tensors for all heads. The matrix multiplication here has a complexity of $O(L \cdot D \cdot D) = O(LD^2)$.
2.  **Attention Score Calculation**: This is the core of the self-attention mechanism.
    * The matrix multiplication of Q and $K^T$ has a complexity of $O(L \cdot L \cdot D)$ (or more accurately, $H \times O(L \cdot L \cdot \frac{D}{H}) = O(L^2D)$).
    * The matrix multiplication of the attention scores with V has a complexity of $O(L \cdot L \cdot D)$.

Therefore, the total complexity for a single MHA block is dominated by the greater of the two, which is typically $O(L^2D + LD^2)$. For most practical use cases where the sequence length is much larger than the hidden dimension, the complexity is often simplified to $O(L^2D)$. Your formula $L \times L \times D$ captures the right part of this but lacks the correct notation and context.

---

#### How MQA and GQA Change this for Inference:

This is where your answer is most lacking. You've correctly identified the role of the KV cache, but you're not precise about the impact.

The computational complexity of the forward pass for a single token during inference (the latency) is actually not reduced in the attention calculation itself. You still have to perform the $QK^T$ operation for the new token against all previous tokens in the sequence. What changes is the memory footprint and memory bandwidth required.

Here's the key difference during inference with a KV cache:

* **MHA**: The KV cache stores the keys and values for all $H$ heads. The memory requirement is proportional to $O(L \cdot H \cdot \frac{D}{H}) = O(LD)$.
* **MQA**: The KV cache only stores a single set of shared keys and values. The memory requirement is proportional to $O(L \cdot \text{head_dimension}) = O(L \cdot \frac{D}{H})$. This is a memory reduction by a factor of $H$.
* **GQA**: The KV cache stores $G$ sets of keys and values. The memory requirement is proportional to $O(L \cdot G \cdot \text{head_dimension}) = O(L \cdot G \cdot \frac{D}{H})$. This is a memory reduction by a factor of $H/G$.

The reduced KV cache size is the core benefit. Less data needs to be loaded from memory, which significantly reduces the memory bandwidth bottleneck. This reduction in memory bandwidth directly leads to lower inference latency, especially for long sequences.

Your answer is a bit of a mixed bag. You have the right intuition about the memory savings and the "factor of H and G" but the lack of precision, incorrect computational complexity notation, and omission of the crucial latency aspect show a shallow understanding of the topic.

I would grade your knowledge on this question as a **2 out of 5**.

***

### Here is your final question:

In a transformer model with MHA, the final output of the attention block is a concatenation of the output of each head, which is then projected back to the hidden dimension. This is often referred to as the "output projection" or "linear layer." Explain why this output projection is necessary and what it achieves. How, if at all, does this output projection differ in MQA and GQA?

> **Output projection is necessary in MHA because hidden dim of each attention has dropped to D/H to make it at same compute cost as a single head attention with large hidden dim. You are doing H number of cheaper attention calculations. It allows model to learn different types of representations to attend to within a same sequence length. For language, one head could capture gramatical structure attention scores, another could capture emotional attention scores etc. In MQA, GQA, keys and values are shared across all attention layers or within a group. In GQA, hidden dim size is D/G, so it still is needed but dimensions would be diff from MQA.**

***

#### Analysis

This is an excellent answer, particularly in its initial parts. You've demonstrated a deep conceptual understanding of the "why" behind Multi-Head Attention, and your explanation is very insightful.

---

#### What you got right:

* **You correctly identified the purpose of the output projection**: The output projection takes the output of the attention heads and projects it back to the original hidden dimension of the model.
* **You correctly linked the head dimension to computational cost**: You correctly noted that the head dimension is reduced to $\frac{D}{H}$, which allows for multiple, parallel attention calculations without a significant increase in overall computational cost. This shows a sophisticated understanding of the design trade-offs.
* **You provided a perfect explanation of the conceptual benefit**: Your point about different heads learning different representations—such as grammatical structure versus emotional cues—is the textbook example of the purpose of MHA. This demonstrates a strong grasp of the model's intended function.

---

#### What you got wrong (or were confused about):

The output projection dimensions are the same for MHA, MQA, and GQA. This is a critical point you missed. You stated that the dimensions would be different for MQA and GQA, which is incorrect. In all three architectures, the concatenated output of the attention heads is a tensor of shape $(\text{sequence_length}, \text{hidden_dimension})$. The output projection matrix is a linear layer that takes this tensor and projects it back to the same hidden dimension. Therefore, the output projection matrix is always of size $(\text{hidden_dimension} \times \text{hidden_dimension})$, regardless of whether the K and V matrices are shared. The sharing of K and V only impacts the number of parameters before the attention score calculation, not the final projection.

---

#### Correct Explanation

The output projection is a crucial and often overlooked part of the attention block. It serves two primary purposes:

1.  **Dimensionality Alignment**: Each of the $H$ attention heads produces an output tensor of shape $(\text{sequence_length}, \text{head_dimension})$. These are then concatenated along the dimension, resulting in a tensor of shape $(\text{sequence_length}, H \times \text{head_dimension})$. Since $\text{head_dimension} = \frac{\text{hidden_dimension}}{H}$, this concatenated tensor has a shape of $(\text{sequence_length}, \text{hidden_dimension})$. The output projection is a linear layer that takes this concatenated tensor and transforms it using a weight matrix of size $(\text{hidden_dimension} \times \text{hidden_dimension})$, projecting it back to the hidden dimension. This ensures the output can be added to the residual connection and processed by the next block of the transformer.
2.  **Information Fusion**: This is the most important conceptual reason you touched upon. The output projection allows the model to learn how to optimally combine the diverse information from all the different attention heads. Each head may have focused on a different aspect of the input (e.g., syntax, semantics, long-range dependencies), and the output projection learns a weighted average or combination of these outputs to create a single, rich representation that is then passed to the next layer.

---

#### How this applies to MQA and GQA

The output projection is **identical** in MHA, MQA, and GQA.

* In MHA, the output is a concatenation of $H$ independent head outputs, projected back to the hidden dimension.
* In MQA, the output is a concatenation of the $H$ head outputs (which all attended to the same K and V), projected back to the hidden dimension.
* In GQA, the output is a concatenation of the $H$ head outputs (which attended to their group's shared K and V), projected back to the hidden dimension.

In all three cases, the final output before the projection is a tensor of shape $(\text{sequence_length}, \text{hidden_dimension})$, and the output projection matrix is always of size $(\text{hidden_dimension} \times \text{hidden_dimension})$. The sharing of K and V projections in MQA and GQA only affects the parameter count before the attention scores are calculated, not in this final projection step.

You've shown a very strong conceptual understanding, but the specific architectural detail about the output projection dimensions was a key miss.

I would grade your knowledge on this question as a **4 out of 5**.

---

### Final Grade

Based on all five questions, your overall knowledge on this topic seems to be at a **3.2 out of 5**. You have a solid grasp of the core concepts and the high-level trade-offs (memory vs. performance), but you struggle with the precise mathematical and architectural details, particularly concerning parameter counts and matrix dimensions. This is a common hurdle, and with a bit more focus on the linear algebra, you could easily be at a 5.