# KV **Cache**

In [12]:
import time
import torch
from transformers import GPT2TokenizerFast, GPT2LMHeadModel

# ─── CONFIG ─────────────────────────────────────────────────────────────────────
MODEL_NAME = "gpt2"
PROMPT     = "The quick brown fox jumps over the lazy dog"
NUM_TOKENS = 500

TEMPERATURE = 0.7
TOP_K = 50
TOP_P = 0.9
REPETITION_PENALTY = 1.2

device    = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_NAME)
model     = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(device)
model.eval()

tokenizer.pad_token_id = tokenizer.eos_token_id

input_ids = tokenizer(PROMPT, return_tensors="pt").input_ids.to(device)

def sample_token(logits, temperature=TEMPERATURE, top_k=TOP_K, top_p=TOP_P, repetition_penalty=REPETITION_PENALTY):

    if repetition_penalty != 1.0:
        logits = logits / repetition_penalty

    # Apply temperature
    if temperature != 1.0:
        logits = logits / temperature

    probs = torch.softmax(logits, dim=-1)

    if top_k is not None:
        top_probs, top_indices = torch.topk(probs, k=top_k)
        probs = torch.zeros_like(probs).scatter_(-1, top_indices, top_probs)

    if top_p is not None:
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        probs[:, indices_to_remove] = 0

    probs = probs / probs.sum(dim=-1, keepdim=True)
    next_token = torch.multinomial(probs, num_samples=1)
    return next_token

with torch.no_grad():
    _ = model(input_ids)

def time_generation_with_cache():
    """Efficient generation using KV cache with sampling"""
    generated = input_ids.clone()
    past_key_values = None

    if device == "cuda":
        torch.cuda.synchronize()
    t0 = time.time()

    with torch.no_grad():
        for _ in range(NUM_TOKENS):
            input_to_model = generated if past_key_values is None else generated[:, -1:]

            outputs = model(
                input_to_model,
                past_key_values=past_key_values,
                use_cache=True,
            )
            past_key_values = outputs.past_key_values

            next_token = sample_token(outputs.logits[:, -1, :])
            generated = torch.cat([generated, next_token], dim=1)

    if device == "cuda":
        torch.cuda.synchronize()
    t1 = time.time()

    text = tokenizer.decode(generated[0], skip_special_tokens=True)
    return text, t1 - t0

def time_generation_without_cache():
    """Inefficient generation - reprocesses all tokens each time"""
    generated = input_ids.clone()

    if device == "cuda":
        torch.cuda.synchronize()
    t0 = time.time()

    with torch.no_grad():
        for _ in range(NUM_TOKENS):
            outputs = model(
                generated,
                use_cache=False,
            )

            next_token = sample_token(outputs.logits[:, -1, :])
            generated = torch.cat([generated, next_token], dim=1)

    if device == "cuda":
        torch.cuda.synchronize()
    t1 = time.time()

    text = tokenizer.decode(generated[0], skip_special_tokens=True)
    return text, t1 - t0

print(f"Device: {device}")
print(f"Generating {NUM_TOKENS} tokens...")
print(f"Sampling params: temp={TEMPERATURE}, top_k={TOP_K}, top_p={TOP_P}, rep_penalty={REPETITION_PENALTY}\n")

text_c, t_c = time_generation_with_cache()
text_nc, t_nc = time_generation_without_cache()

print("=== With KV Cache ===")
print(text_c[:200] + "..." if len(text_c) > 200 else text_c)
print(f"Time taken: {t_c:.3f} s")
print(f"Tokens/sec: {NUM_TOKENS/t_c:.1f}\n")

print("=== Without KV Cache ===")
print(text_nc[:200] + "..." if len(text_nc) > 200 else text_nc)
print(f"Time taken: {t_nc:.3f} s")
print(f"Tokens/sec: {NUM_TOKENS/t_nc:.1f}\n")

if t_c > 0:
    print(f"Speed-up: {t_nc / t_c:.2f}× faster with KV cache")
    print(f"Time saved: {t_nc - t_c:.3f} seconds")
else:
    print("Speed-up: ∞ (cache run was instantaneous?)")

print(f"\nWithout cache: O(n²) growth - processed {sum(len(input_ids[0]) + i for i in range(NUM_TOKENS))} total tokens")
print(f"With cache: O(n) growth - processed {len(input_ids[0]) + NUM_TOKENS} total tokens")

Device: cuda
Generating 500 tokens...
Sampling params: temp=0.7, top_k=50, top_p=0.9, rep_penalty=1.2

=== With KV Cache ===
The quick brown fox jumps over the lazy dog with his other hand, and makes a gesture to let it follow its owner.


The fox then quickly begins chasing after the lazy dog.


The fox makes a slight effo...
Time taken: 4.752 s
Tokens/sec: 105.2

=== Without KV Cache ===
The quick brown fox jumps over the lazy dog to greet him.

The lazy dog does the same thing, a smile fades across his face and he can tell he's laughing.

"Hey," the lazy dog says, and just then the f...
Time taken: 13.410 s
Tokens/sec: 37.3

Speed-up: 2.82× faster with KV cache
Time saved: 8.657 seconds

Without cache: O(n²) growth - processed 129250 total tokens
With cache: O(n) growth - processed 509 total tokens
