## Solving the Inference Bottleneck with Key-Value cache


## Introduction
* The inference bottleneck is one of the most critical challenges in deploying LLMs like deepseek and ChatGPT. This notebook is intended to explore how the key-value cache, a fundemental technique addresses this bottleneck and serves as the foundation for more advanced attention mechanisms.

* In autoregressive generation, each token requires attention computations across previous tokens in the seqence. Without optimizations, this would lead to:
  1. Quadratically increasing computation time as sequence lengths grows.
  2. Redundant recomputation of key and value tensors for tokens that have already been processed.
  3. Prohibitive memory and computational costs for practical applications.

* The key-value cache solves these issues by storing previously computed key-value pairs dramatically reducing the computational burden during token generaton. This foundational technique enables Deepseek's impressive pefromance with long contexts of up to 128k tokens.


## Multi-Head Attention
* Deepsek uses multi-head attention,which allows the model attend to information from different representation subspaces simultaneously:

`MultiHead(Q,K,V) = Concat(head_1....,head_h)W_O`
* Where each head is computed as:
`head_i = Attention(QWi_Q,KWi_K,VWi_V)`
* This creates multiple "attention heads" that can focus on different aspects of the input sequence.

## Autoregressive Generation: The Root of the Inference Bottleneck
* Deepseek models, like other transformer-based LLMs, generate text autoregressively-one token at a time, where each new token depends on all previous tokens.
* This creates a computational challenge during inference:
  1. For the first token, we compute attention using just the prompt.
  2. For the second token, we compute attention using the prompt plus the first generated token.
  3. For the third token, we compute attention using all previous tokens.
  4. And so on..

* As the sequence grows, each new token requires more computation than the last. Without optimzation, this would create:
   1. O(n^2) complexity in the sequence length for each new token.
   2. Redundant calculations as the same keys and values are recomputed for existing tokens.
   3. Slow inference speed for practical applications.

* The following code demonstrates the autoregressive generation proess using a simple GPT-2.

In [3]:
#importing required libraries
import torch
import torch.nn as nn
import time
from transformers import GPT2LMHeadModel, GPT2Tokenizer

print("Models are being set up....")
#loading the pretrained model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
print("Setup completed successfully")

#visualizing autoregressive gpt-2 generation
prompt = "The Mona Lisa was painted"
inputs = tokenizer(prompt,return_tensors="pt")
input_ids = inputs.input_ids

print(f"Prompt: '{prompt}'",end="")

#generate 20 tokens
for _ in range(20):
  #pass the entire sequence to the model
  outputs = model(input_ids)
  logits = outputs.logits

  #get the logits for the very last token
  next_token_logits = logits[:,-1,:]

  #get the ID of the most likely next token (greedy decoding)
  next_token_id = torch.argmax(next_token_logits,dim=-1).unsqueeze(-1)

  #append the new token ID to the input sequence
  input_ids = torch.cat([input_ids,next_token_id],dim=-1)

  #decode and print the new token
  new_token = tokenizer.decode(next_token_id[0])
  print(new_token,end="",flush=True)
print("\n")

Models are being set up....
Setup completed successfully
Prompt: 'The Mona Lisa was painted' in the same way as the original, but with a different color scheme.

The Mona



In [4]:
prompt = "The Mona Lisa was painted by"
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask

# --- Timing without KV cache ---
print("Generating without KV Cache...")
start_time_without_cache = time.time()
output_without_cache = model.generate(
    input_ids,
    max_new_tokens=100,
    use_cache=False, # Explicitly disable the cache
    attention_mask=attention_mask
)
end_time_without_cache = time.time()
duration_without_cache = end_time_without_cache - start_time_without_cache
print(f"Time without KV Cache: {duration_without_cache:.4f} seconds\n")


# --- Timing with KV cache ---
print("Generating with KV Cache...")
start_time_with_cache = time.time()
output_with_cache = model.generate(
    input_ids,
    max_new_tokens=100,
    use_cache=True, # Explicitly enable the cache
    attention_mask=attention_mask
)
end_time_with_cache = time.time()
duration_with_cache = end_time_with_cache - start_time_with_cache
print(f"Time with KV Cache: {duration_with_cache:.4f} seconds\n")


# --- Calculate and print the speedup ---
speedup = duration_without_cache / duration_with_cache
print(f"KV Cache Speedup: {speedup:.2f}x")


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Generating without KV Cache...


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Time without KV Cache: 21.4571 seconds

Generating with KV Cache...
Time with KV Cache: 4.7570 seconds

KV Cache Speedup: 4.51x


This code cell demonstrates the autoregressive generation process. It initializes a GPT-2 model and tokenizer, then iteratively generates 20 tokens based on a given prompt. In each step, it feeds the entire sequence generated so far back into the model to predict the next token. This highlights the increasing computation required as the sequence grows.

This code cell compares the inference time of generating text with and without the Key-Value (KV) cache. It uses the `model.generate` method with `use_cache=False` to disable the cache and `use_cache=True` to enable it. The execution time for each scenario is measured and the speedup achieved by using the KV cache is calculated and printed. This demonstrates the practical benefit of the KV cache in reducing inference time.