### Multi-head Latent Attention (MLA)

**Sources:**

* DeepSeek v2 [https://arxiv.org/abs/2405.04434](https://arxiv.org/abs/2405.04434) (05/2024)
* DeepSeek v3 [https://arxiv.org/abs/2412.19437](https://arxiv.org/abs/2412.19437) (12/2024)

**Visual Comparison of Attention Mechanisms**

The image displays a diagram comparing how Queries, Keys, and Values are handled:

* **MHA (Multi-Head Attention):** Unique Keys and Values for every Query head.
* **GQA (Grouped-Query Attention):** Keys and Values are shared across groups of Query heads.
* **MQA (Multi-Query Attention):** All Query heads share a single set of Keys and Values.
* **MLA (Multi-Head Latent Attention):** Uses a **Compressed Latent KV** vector which is projected into Keys and Values, significantly reducing the cached data.

**Key Performance & Features**

* **Implementation:** Used in DeepSeek v2 and v3.
* **Caching Strategy:** K and V matrices are **not cached**; instead, a **low-rank representation** learned during training is cached.
* **Efficiency:** * Much lower KV cache usage (**90%+ savings**).
* **5-6x inference speedup**.


* **Quality:** Accuracy is **higher than MHA**.
* **Reference:** Detailed in Appendix D.1 of the DeepSeek v2 paper.

---

**Table 1: KV Cache Comparison**

Comparison of the KV cache per token among different attention mechanisms (where  denotes number of layers).

\begin{array}{l|l|l}
\text{Attention Mechanism} &
\text{KV Cache per Token (\# Element)} &
\text{Capability} \\
\hline
\text{Multi-Head Attention (MHA)} &
2 n_h d_h l &
\text{Strong} \\
\text{Grouped-Query Attention (GQA)} &
2 n_g d_h l &
\text{Moderate} \\
\text{Multi-Query Attention (MQA)} &
2 d_h l &
\text{Weak} \\
\textbf{MLA (Ours)} &
(d_c + d_h^{R})\, l \approx \frac{9}{2} d_h l &
\textbf{Stronger}
\end{array}

**Note from text:** $n_h$ is the number of heads, $d_h$ is dimension per attention head,$l$ denotes the number of layers,$n_g$  is number of groups in GQA,and $d_c$ and $d_h^R$ denote the KV compression dimension and the per-head dimension of the decoupled queries and key in MLA, respectively. For DeepSeek-V2,$d_c = 4d_h$ and $d_h^R = \frac{d_h}{2}$.



$$\begin{array}{|l|c|c|l|}
\hline
\textbf{Tensor} & \textbf{Dimensions} & \textbf{Example} & \textbf{Purpose} \\ \hline
X: \text{input embeddings} & N \times d_{\text{hidden}} & 11 \times 512 & \text{The input sequence, in embedded form} \\ \hline
W_{Qi}: \text{query weight matrix} & & & d_{\text{mha}} = d_{\text{hidden}} / \text{number of attention heads (here, 8)} \\
W_{Ki}: \text{key weight matrix} & d_{\text{hidden}} \times d_{\text{mha}} & 512 \times 64 & \\
W_{Vi}: \text{value weight matrix} & & & \text{Each head has its own weight matrices, which learn different features.} \\
W_0: \text{feed-forward weight matrix} & & & \\ \hline
W_{i \text{ down}}: \text{down-projection matrix} & d_{\text{mha}} \times d_{\text{mha\_latent}} & 64 \times 4 & d_{\text{latent}} \text{ should be much smaller than } d_{\text{hidden}} \text{ (here, 32)} \\
W_{\text{up}}: \text{up-projection matrix} & d_{\text{latent}} \times d_{\text{hidden}} & 32 \times 512 & d_{\text{mha\_latent}} = d_{\text{latent}} / \text{number of attention heads (here, 8) = 4} \\ \hline
Q_i = X W_{Qi}: \text{query matrix} & N \times d_{\text{mha}} & 11 \times 64 & \text{All heads run this in parallel.} \\
K_i = X W_{Ki} W_{i \text{ down}}: \text{key matrix} & N \times d_{\text{mha\_latent}} & 11 \times 4 & K_i \text{ and } V_i \text{ are stored in the KV cache and are } d_{\text{hidden}}/d_{\text{latent}} \text{ times smaller (here, 16).} \\
V_i = X W_{Vi} W_{i \text{ down}}: \text{value matrix} & N \times d_{\text{mha\_latent}} & 11 \times 4 & \text{The tradeoff is an extra matmut to compute } K_i \text{ and } V_i. \text{ DeepSeek uses different down-projection matrices for K and V.} \\ \hline
Q_i W_{i \text{ down}} K_i^T: \text{attention scores} & N \times N & 11 \times 11 & \text{All heads run this in parallel. } Q_i \text{ is down-projected for this calculation only (possibly with its own matrix).} \\ \hline
\text{softmax}\left( \frac{\text{scores}}{\sqrt{d_k}} \right) V & N \times d_{\text{mha\_latent}} & 11 \times 4 & \text{All heads run this in parallel} \\ \hline
\text{Concatenate head outputs} & N \times d_{\text{latent}} & 11 \times 32 & \\ \hline
\text{Outputs} \times W_{\text{up}} & N \times d_{\text{hidden}} & 11 \times 512 & \text{Bring back the output to the initial dimension} \\ \hline
\text{Attention weights} \times W_0 & N \times d_{\text{hidden}} & 11 \times 512 & \text{Capture additional interactions across the sequence} \\ \hline
\end{array}$$





#### From Attention Outputs to Text Generation Table

$$\begin{array}{|l|c|c|l|}
\hline
\textbf{Tensor} & \textbf{Dimensions} & \textbf{Example} & \textbf{Purpose} \\ \hline
\text{Attention output for the input sequence} & N \times d_{\text{hidden}} & 11 \times 512 & \text{Updated token embeddings after considering the context} \\
\text{(aka pre-fill), stored in the KV cache} & & & \text{of all other tokens.} \\ \hline
\text{Retrieve the attention output for the last} & 1 \times d_{\text{hidden}} & 1 \times 512 & \text{} \\
\text{token in the input sequence} & & & \\ \hline
W_{\text{output}}: \text{linear layer} & V \times d_{\text{hidden}} & 100,000 \times 512 & V: \text{vocabulary size (here, 100,000)} \\
\text{(aka projection layer)} & & & \\ \hline
\text{Logits} = \text{attention output} \times W_{\text{output}}^T & 1 \times V & 1 \times 100,000 & \text{Raw scores for all tokens in the vocabulary} \\ \hline
\text{softmax(Logits)} & 1 \times V & 1 \times 100,000 & \text{Turn token scores into token probabilities} \\ \hline
\text{Decode the token} & 1 \text{ token} & 1 & \text{Greedy decoding: pick the token with the highest probability} \\
& & & \text{Top-k sampling: pick a token from the } k \text{ most likely tokens} \\
& & & \text{Top-p decoding: pick a token from the smallest subset of tokens} \\
& & & \text{such that their cumulative probability exceeds the } p \text{ threshold} \\ \hline
\text{Use the new token as the next input} & & & \text{} \\ \hline
\text{Repeat until a stopping condition is met} & & & \text{End-of-sentence token, or maximum number of output tokens} \\ \hline
\end{array}$$

## Flash Attention

- The original Flash Attention introduced tiling and SRAM utilization to avoid HBM bottlenecks.

- Avoids repeated reading and writing of the attention matrix to HBM by loading Q and K once and keeping intermediate results in SRAM.

- Computes attention scores (P) incrementally in SRAM using tiling before writing back to HBM.

- Parallelizes computations over both the batch size and the number of heads.

- Reduces memory complexity to linear, resulting in 2-4x speed improvements and 10-20x memory savings.

- Optimized for both forward and backward passes to accelerate model training.

- Integrated into Hugging Face Text Generation Inference (TGI)

**Flash Attention 2**

- Flash Attention 2 focuses on maximizing GPU throughput and increasing parallelism compared to the original version.

- Minimizes non-matmul operations to maximize GPU throughput.

- Optimizes operations specifically for Multi-Query Attention (MQA) and Grouped-Query Attention (GQA).

- Increases parallelism across the sequence length.

- Optimizes both prompt processing (prefill) and text generation.

- Performs 2x faster than Flash Attention and up to 9x faster than standard attention.


## Paged Attention

Paged Attention, a memory management technique inspired by operating systems to improve GPU efficiency for LLM inference.

**Challenges Addressed**
- The KV cache memory size changes dynamically for every inference request.
- Traditional allocation leads to GPU memory fragmentation, which wastes space and limits batch size scalability.

**Core Mechanism**
- Paged Attention partitions the KV cache into fixed-size, memory-aligned blocks called "pages," similar to virtual memory in operating systems.
- This page-based allocation significantly reduces both internal and external memory fragmentation.

**Availability**
- It is the core technology implemented in the [**vLLM project**](https://github.com/vllm-project/vllm).
- Paged Attention is also available in **Hugging Face Text Generation Inference (TGI)**.