# Chapter 2: Solving the Inference Bottleneck with the Key-Value Cache

## Introduction

The inference bottleneck is one of the most critical challenges in deploying large language models like DeepSeek. This chapter explores how the key-value cache, a fundamental optimization technique addresses this bottleneck and serves as the foundation for more advanced attention mechanisms.

In autoregressive generation, each new token requires attention computations across all previous tokens in the sequence. Without optimizations, this would lead to:

1. Quadratically increasing computation time as sequence length grows
2. Redundant recomputation of key and value tensors for tokens that have already been processed
3. Prohibitive memory and computational costs for practical applications

The key-value cache solves these issues by storing previously computed key-value pairs, dramatically reducing the computational burden during token generation. This foundational technique enables DeepSeek's impressive performance with long contexts of up to 128K tokens.

## 2.1 Understanding Attention Mechanisms

Before diving into the key-value cache, we need to understand the attention mechanism itself. Attention allows a model to focus on relevant parts of the input when producing each element of the output.

### The Attention Formula

The attention mechanism can be expressed mathematically as:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where:
- $Q$ (queries): What we're looking for
- $K$ (keys): What we match against
- $V$ (values): What we extract when we find a match
- $d_k$: The dimension of the keys (used for scaling)

In transformer models like DeepSeek, each attention layer processes:
- **Queries**: Representations of the token we're currently generating
- **Keys/Values**: Representations of all tokens in the context

### Multi-Head Attention

DeepSeek uses multi-head attention, which allows the model to attend to information from different representation subspaces simultaneously:

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O$$

Where each head is computed as:

$$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$

This creates multiple "attention heads" that can focus on different aspects of the input sequence.

## 2.2 Autoregressive Generation: The Root of the Inference Bottleneck

DeepSeek models, like other transformer-based LLMs, generate text autoregressively—one token at a time, where each new token depends on all previous tokens. This creates a computational challenge during inference:

1. For the first token, we compute attention using just the prompt
2. For the second token, we compute attention using the prompt plus the first generated token
3. For the third token, we compute attention using all previous tokens
4. And so on...

As the sequence grows, each new token requires more computation than the last. Without optimization, this would create:

- **O(n²) complexity** in sequence length for each new token
- **Redundant calculations** as the same keys and values are recomputed for existing tokens
- **Slow inference speed** for practical applications

The code below demonstrates this autoregressive generation process using a simple GPT-2 model. Pay attention to how we generate one token at a time, repeatedly passing the entire sequence through the model:

In [1]:
#
# SETUP: Imports and Model Loading for the entire notebook
#
import torch
import torch.nn as nn
import time
from transformers import GPT2LMHeadModel, GPT2Tokenizer

print("Setting up models...")
# Load pre-trained model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
print("Setup complete.")

#
# LISTING 2.1: Visualizing autoregressive generation with GPT-2
#
prompt = "The next day is bright"
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs.input_ids

print(f"Prompt: '{prompt}'", end="")

# Generate 20 tokens
for _ in range(20):
    # Pass the entire sequence to the model
    outputs = model(input_ids)
    logits = outputs.logits

    # Get the logits for the very last token
    next_token_logits = logits[:, -1, :]

    # Get the ID of the most likely next token (greedy decoding)
    next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)

    # Append the new token ID to the input sequence
    input_ids = torch.cat([input_ids, next_token_id], dim=-1)

    # Decode and print the new token
    new_token = tokenizer.decode(next_token_id[0])
    print(new_token, end="", flush=True)

print("\n")

Setting up models...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Setup complete.
Prompt: 'The next day is bright' and sunny, and the sun is shining. The sun is shining, and the moon is shining.



### Listing 2.1: Visualizing Autoregressive Generation

The code below demonstrates the token-by-token generation process in an autoregressive model. We start with a prompt and then generate 20 additional tokens sequentially. Notice how:

1. The entire sequence is passed through the model at each step
2. We extract only the logits for the last token position
3. We select the most likely next token (greedy decoding)
4. We append this token to our sequence and repeat

This naive implementation clearly shows why autoregressive generation becomes increasingly slow as the sequence grows longer. Each token requires a full forward pass through all previous tokens.

In [2]:
#
# LISTING 2.2: Demonstrating the Speedup of KV Caching
#
prompt = "The next day is bright"
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask

# --- Timing without KV cache ---
print("Generating without KV Cache...")
start_time_without_cache = time.time()
output_without_cache = model.generate(
    input_ids,
    max_new_tokens=100,
    use_cache=False, # Explicitly disable the cache
    attention_mask=attention_mask
)
end_time_without_cache = time.time()
duration_without_cache = end_time_without_cache - start_time_without_cache
print(f"Time without KV Cache: {duration_without_cache:.4f} seconds\n")


# --- Timing with KV cache ---
print("Generating with KV Cache...")
start_time_with_cache = time.time()
output_with_cache = model.generate(
    input_ids,
    max_new_tokens=100,
    use_cache=True, # Explicitly enable the cache
    attention_mask=attention_mask
)
end_time_with_cache = time.time()
duration_with_cache = end_time_with_cache - start_time_with_cache
print(f"Time with KV Cache: {duration_with_cache:.4f} seconds\n")


# --- Calculate and print the speedup ---
speedup = duration_without_cache / duration_with_cache
print(f"KV Cache Speedup: {speedup:.2f}x")

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Generating without KV Cache...


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Time without KV Cache: 57.6215 seconds

Generating with KV Cache...
Time with KV Cache: 6.3934 seconds

KV Cache Speedup: 9.01x


## 2.3 Key-Value Caching: The First Generation Solution

The key-value (KV) cache is the foundational technique for addressing the inference bottleneck in transformer models like DeepSeek. The key insight is simple but powerful:

**Since previous tokens don't change during generation, their key and value projections can be computed once and reused.**

### How KV Caching Works:

1. **Initial Computation**: For the first token, compute Q, K, V as usual
2. **Storage**: Save the K, V tensors in a "cache" 
3. **Subsequent Tokens**: 
   - Only compute Q, K, V for the new token
   - Retrieve previous K, V from the cache
   - Concatenate the new K, V with the cached ones
   - Store the expanded K, V in the cache
4. **Result**: Each new token only needs to compute its own key-value pair rather than recomputing for all tokens

This reduces the computational complexity from O(n²) to O(n) per token, where n is the sequence length.

### Memory-Computation Tradeoff:

The KV cache trades increased memory usage for significantly reduced computation:

- **Memory Usage**: Increases linearly with sequence length
- **Computation**: Dramatically reduced for long sequences

This is particularly important for DeepSeek models with their 128K context window capability.

### Listing 2.2: Benchmarking KV Cache Performance

The code below demonstrates and measures the dramatic speedup achieved through KV caching:

## 2.4 From KV Cache to Advanced Attention Mechanisms

While the key-value cache dramatically improves inference speed, models like DeepSeek require further optimizations to handle their massive scale efficiently. The KV cache serves as the foundation for more advanced attention mechanisms:

### The Evolution Path:

1. **KV Cache (First Generation)**: The basic optimization we've just explored
2. **Multi-Query Attention (MQA)**: Shares K/V projections across attention heads
3. **Grouped-Query Attention (GQA)**: A middle ground between standard attention and MQA
4. **Multi-Head Latent Attention**: Advanced techniques used in the latest DeepSeek models

Let's explore these evolved attention mechanisms that build on the KV cache foundation.

### Multi-Query Attention (MQA)

Standard multi-head attention requires unique K, V projections for each attention head, which increases the KV cache size proportionally to the number of heads. Multi-Query Attention addresses this by:

- Using a **single shared K, V projection** across all attention heads
- Maintaining **separate Q projections** for each head

This significantly reduces memory requirements during inference while maintaining most of the model's capacity.

### Grouped-Query Attention (GQA)

While MQA significantly reduces memory usage, it can sometimes degrade model quality. DeepSeek models use Grouped-Query Attention (GQA) as a balanced approach:

- Group attention heads into clusters (e.g., 4-8 groups for 32 heads)
- Each group shares the same K,V projections
- Queries remain separate for each head

GQA offers a favorable trade-off:
- Better quality than MQA (more expressive)
- More efficient than standard multi-head attention
- Well-suited for DeepSeek's massive scale

This approach is particularly important for DeepSeek's 671B parameter models where balancing efficiency and quality is crucial.

### Comparison of Attention Variants

| Feature | Standard Multi-Head | Grouped-Query | Multi-Query |
|---------|---------------------|---------------|-------------|
| K,V Projections | One per head | One per group | One shared |
| Cache Size | Largest | Medium | Smallest |
| Quality | Highest | Good | Lower |
| Memory Efficiency | Low | Medium | High |

DeepSeek models use these optimizations strategically depending on the model size and intended use case.

In [3]:
#
# LISTING 2.3: Implementing an MQA layer from scratch
#
class MultiQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.0):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, self.d_head) # Single projection for K
        self.W_v = nn.Linear(d_model, self.d_head) # Single projection for V
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
        # Using a fixed size mask for demonstration. A dynamic one is better in practice.
        self.register_buffer('mask', torch.triu(torch.ones(1, 1, 1024, 1024), diagonal=1))

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        # Query: (B, num_heads, seq_len, d_head)
        q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)

        # Key & Value: (B, 1, seq_len, d_head)
        k = self.W_k(x).view(batch_size, seq_len, 1, self.d_head).transpose(1, 2)
        v = self.W_v(x).view(batch_size, seq_len, 1, self.d_head).transpose(1, 2)

        # Repeat K and V for each query head
        k = k.repeat(1, self.num_heads, 1, 1) # (B, num_heads, seq_len, d_head)
        v = v.repeat(1, self.num_heads, 1, 1) # (B, num_heads, seq_len, d_head)

        attn_scores = (q @ k.transpose(-2, -1)) / (self.d_head ** 0.5)

        # Apply causal mask
        attn_scores = attn_scores.masked_fill(self.mask[:,:,:seq_len,:seq_len] == 1, float('-inf'))

        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vector = (attn_weights @ v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

        output = self.W_o(context_vector)
        return output

# --- Usage Example ---
d_model = 512
num_heads = 8
batch_size = 4
seq_len = 64

mqa_layer = MultiQueryAttention(d_model, num_heads)
dummy_input = torch.randn(batch_size, seq_len, d_model)
output = mqa_layer(dummy_input)

print("MQA Layer successful!")
print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")

MQA Layer successful!
Input shape: torch.Size([4, 64, 512])
Output shape: torch.Size([4, 64, 512])


### Listing 2.3: Implementing Multi-Query Attention (MQA)

The code below implements a Multi-Query Attention layer from scratch. Notice the key differences from standard multi-head attention:

1. Query projection (`self.W_q`) maps to the full model dimension (`d_model`)
2. Key and value projections (`self.W_k` and `self.W_v`) map to just a single head dimension (`self.d_head`)
3. We use `repeat()` to duplicate the single key and value for all query heads

This implementation clearly shows how MQA reduces the parameter count and memory footprint during inference, especially for the KV cache which only needs to store a single key-value pair per token instead of one per attention head.

In [4]:
#
# LISTING 2.4: Implementing a GQA layer from scratch
#
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads, num_groups, dropout=0.0, max_seq_len: int = 1024):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        assert num_heads % num_groups == 0, "num_heads must be divisible by num_groups"

        self.d_model = d_model
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.d_head = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, self.num_groups * self.d_head) # Grouped projection for K
        self.W_v = nn.Linear(d_model, self.num_groups * self.d_head) # Grouped projection for V
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
        self._register_mask_buffer(max_seq_len)

    def _register_mask_buffer(self, max_seq_len):
        if max_seq_len > 0:
            mask = torch.triu(torch.ones(1, 1, max_seq_len, max_seq_len, dtype=torch.bool), diagonal=1)
            self.register_buffer("causal_mask", mask, persistent=False)
        else:
            self.causal_mask = None

    def _get_causal_mask(self, seq_len, device):
        if self.causal_mask is not None and self.causal_mask.size(-1) >= seq_len:
            return self.causal_mask[:, :, :seq_len, :seq_len]
        # Dynamically create mask if needed
        return torch.triu(torch.ones(1, 1, seq_len, seq_len, dtype=torch.bool, device=device), diagonal=1)

    def forward(self, x):
        B, T, _ = x.shape

        # Query: (B, num_heads, T, d_head)
        q = self.W_q(x).view(B, T, self.num_heads, self.d_head).transpose(1, 2)

        # Key & Value: (B, num_groups, T, d_head)
        k = self.W_k(x).view(B, T, self.num_groups, self.d_head).transpose(1, 2)
        v = self.W_v(x).view(B, T, self.num_groups, self.d_head).transpose(1, 2)

        heads_per_group = self.num_heads // self.num_groups

        # Repeat K and V to match query heads
        k = k.repeat_interleave(heads_per_group, dim=1) # (B, num_heads, T, d_head)
        v = v.repeat_interleave(heads_per_group, dim=1) # (B, num_heads, T, d_head)

        attn_scores = (q @ k.transpose(-2, -1)) * (self.d_head**-0.5)

        causal_mask = self._get_causal_mask(T, x.device)
        attn_scores = attn_scores.masked_fill(causal_mask, float("-inf"))

        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context = (attn_weights @ v).transpose(1, 2).contiguous().view(B, T, self.d_model)

        return self.W_o(context)

# --- Usage Example ---
d_model = 512
num_heads = 32
num_groups = 4 # e.g., Llama 2 7B uses 4 groups for 32 heads
batch_size = 4
seq_len = 64

gqa_layer = GroupedQueryAttention(d_model, num_heads, num_groups)
dummy_input = torch.randn(batch_size, seq_len, d_model)
output = gqa_layer(dummy_input)

print("GQA Layer successful!")
print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")

GQA Layer successful!
Input shape: torch.Size([4, 64, 512])
Output shape: torch.Size([4, 64, 512])


### Listing 2.4: Implementing Grouped-Query Attention (GQA)

The code below implements a Grouped-Query Attention layer from scratch. Note the key differences:

1. Query projection (`self.W_q`) still maps to the full model dimension
2. Key and value projections (`self.W_k` and `self.W_v`) map to `self.num_groups * self.d_head`
3. We use `repeat_interleave()` to match each key and value group with its corresponding query heads

This implementation demonstrates how GQA balances between standard multi-head attention and MQA:
- Fewer K,V projections than standard attention (reduced by factor of `num_heads/num_groups`)
- More K,V diversity than MQA (one set per group rather than just one shared)

DeepSeek models use this approach to maintain quality while reducing memory requirements, especially for the KV cache.

## 2.5 Conclusion: The Key-Value Cache as Foundation

The key-value cache represents the first major breakthrough in addressing the inference bottleneck for transformer models, and it serves as the foundation for all subsequent attention optimizations in DeepSeek models:

1. **Fundamental Optimization**: By storing and reusing key-value pairs, the KV cache dramatically reduces the computational cost of autoregressive generation.

2. **Memory-Computation Tradeoff**: The KV cache exemplifies an essential engineering principle—trading increased memory usage for reduced computation, which is often beneficial for practical applications.

3. **Foundation for Advanced Techniques**: The MQA and GQA techniques build directly on the KV cache foundation, further optimizing memory usage while maintaining model quality.

4. **Enabling Long Context**: DeepSeek's impressive 128K token context window capability would be impossible without these attention optimizations.

Understanding the key-value cache and its evolved forms is essential for grasping how DeepSeek models achieve their remarkable balance of quality and efficiency at scale. As we'll see in the next chapter, these attention optimizations work in concert with DeepSeek's Mixture of Experts (MoE) architecture to enable its massive 671B parameter scale while keeping activated parameters at just 37B.