# Hands-on: LLM Inference Optimization with Batching and Caching

**The Challenge:** Deploying Large Language Models (LLMs) in production is often bottlenecked by high latency and prohibitive GPU costs. Every second spent waiting for a response costs money and degrades user experience.

**The Solution:** This notebook tears down the barriers to efficient LLM serving by implementing and benchmarking two crucial, performance-boosting techniques: Inference Batching and intelligent Query Caching.

Let's dive in and unlock the hidden performance of your LLM infrastructure!

## Setup

This initial step securely imports your Hugging Face token (assumed to be stored in a service like Google Colab's user data secrets) and logs you in, ensuring the Qwen model can be downloaded and used for the subsequent optimization benchmarks.

In [None]:
from google.colab import userdata
from huggingface_hub import login
import torch
import time
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
hf_token = userdata.get('HF_TOKEN')
login(hf_token, add_to_git_credential=True)

## Core Setup: Initializing the Qwen LLM
We are using the Qwen/Qwen2.5-1.5B-Instruct model, a compact yet powerful decoder-only LLM, loaded in torch.float16 precision for efficient GPU usage. The generate_single function establishes our baseline by simulating a standard user request and processing one prompt at a time, providing a reference for performance. Subsequent optimizations, such as batched inference and semantic caching, will be compared against this sequential baseline to evaluate improvements in latency and throughput.

In [None]:
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.padding_side = "left"
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto"
)
model.eval()

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1536)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((1536,), eps=1e-06)
    (rotar

In [None]:
# Single prompt generation
def generate_single(prompt):
    messages = [{"role": "user", "content": prompt}]
    text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=512
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response

## Inference Batching
Inference batching is fundamental to modern LLM serving. Instead of running prompts one after another (leaving the GPU idle between requests), we bundle multiple prompts into a single payload. This maximizes the utilization of the high-throughput parallel processing power of the GPU, leading to massive gains in requests per second (throughput) and a dramatic drop in effective latency per user. We will compare sequential inference (batch size 1) against true batch processing.

In [None]:
def generate_batch(prompts, max_new_tokens=512):
    # Step 1: Apply chat template to each prompt
    texts = [
        tokenizer.apply_chat_template(
            [{"role": "user", "content": p}],
            tokenize=False,
            add_generation_prompt=True
        )
        for p in prompts
    ]

    # Step 2: Tokenize all prompts as a batch
    model_inputs = tokenizer(
        texts,
        padding=True,
        return_tensors="pt"
    ).to(model.device)

    # Step 3: Generate outputs
    with torch.no_grad():
        outputs = model.generate(**model_inputs, max_new_tokens=max_new_tokens)

    # Step 4: Extract generated tokens per prompt
    results = []
    for i in range(len(prompts)):
        # Correctly skip prompt tokens using attention_mask
        seq_len = model_inputs['attention_mask'][i].sum()
        generated_tokens = outputs[i][seq_len:]
        decoded = tokenizer.decode(generated_tokens, skip_special_tokens=True)
        results.append(decoded)

    return results

### **Comparison of Sequential vs. Batched Latency**
The test below demonstrates how combining multiple LLM calls into a single batch can drastically cut down overall processing time by maximizing GPU utilization.

In [None]:

prompts = [
    "Explain quantum computing in simple terms.",
    "What are the benefits of meditation?",
    "Give me a summary of climate change effects.",
    "Define reinforcement learning."
]

start = time.time()
single_results = [generate_single(p) for p in prompts]
end = time.time()

print(f"\n=== Single Prompt Latency ===")
print(f"Total time for {len(prompts)} prompts: {end - start:.3f}s")
print(f"Average per prompt: {(end - start)/len(prompts):.3f}s\n")

start = time.time()
batch_results = generate_batch(prompts)
end = time.time()

print(f"=== Batched Latency ===")
print(f"Total batch time: {end - start:.3f}s")




=== Single Prompt Latency ===
Total time for 4 prompts: 48.334s
Average per prompt: 12.083s

=== Batched Latency ===
Total batch time: 15.777s


## Query Caching

### Exact Caching

For highly repetitive queries, such as FAQ systems or internal developer tools, running the LLM multiple times is pure waste. Exact-match caching uses a simple dictionary lookup to store the result of a function call based on its arguments. This provides near-zero latency for repeated, identical inputs, completely bypassing the massive computational cost of the LLM forward pass.

In [None]:
from functools import lru_cache
import time

# Simple exact-match cache
@lru_cache(maxsize=128)
def generate_cached(prompt):
    return generate_single(prompt)

prompt = "Define chaching."

# ---- Cold cache ----
start = time.time()
cold_output = generate_cached(prompt)
cold_time = time.time() - start

print("Cold output:", cold_output)
print(f"Cold cache latency: {cold_time:.3f}s\n")

# ---- Warm cache ----
start = time.time()
warm_output = generate_cached(prompt)
warm_time = time.time() - start

print("Warm output:", warm_output)
print(f"Warm cache latency: {warm_time:.6f}s ")


Cold output: Cache is a type of data storage that stores frequently accessed data to speed up access times and reduce the load on main memory or disk drives. It is used in computer systems, networks, and other computing environments where performance can be improved by reducing the time required to retrieve information from slower storage devices.

In more technical terms, caching involves storing copies of recently accessed data in a faster, more efficient form (such as in RAM instead of hard disk) so that subsequent requests for that same data do not need to be retrieved again from the slower source. This reduces the overall processing time and increases system responsiveness.

Caches can be implemented at various levels within a computer system, including:

1. Hardware-level caches: These are built into the processor's hardware and store frequently used instructions and data.
2. Operating System level caches: These are managed by the operating system and store data that is frequentl

### Sematic Caching

This method goes beyond exact matching. It understands the meaning of the query, making it the most powerful caching strategy for user-facing applications. If a user rephrases a question (e.g., "What is RL?" instead of "Define reinforcement learning."), the cache still detects the semantic similarity and returns the instant, stored response. This requires an embedding model to convert text into comparable vectors.

**How it works:**
1. Embeddings: Queries are converted into numerical vectors (embeddings).
2. Cosine Similarity: The similarity between the incoming query's vector and all cached vectors is calculated.
3. Cache Hit: If the cosine similarity score $\ge 0.8$, it's considered a semantic match, and the cached response is returned.

**Setup & Embedding Model Initialization**

In [None]:
!pip install -q sentence-transformers numpy

import time
from sentence_transformers import SentenceTransformer
import numpy as np
import json

# Global list to store queries, embeddings, and responses
semantic_cache = []

# Initialize the embedding model (small and fast model, suitable for caching)
# We choose a fast model to ensure the embedding process doesn't negate the latency savings.
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")

print(" Libraries installed and embedding model initialized.")
print(f"Embedding dimensions: {embedding_model.get_sentence_embedding_dimension()}")

 Libraries installed and embedding model initialized.
Embedding dimensions: 384


**1. Define Core Cache Logic (`query` and `add` functions)**

These functions manage the cache:
* `semantic_cache_query`: Calculates the similarity score between the new query and every cached item. The similarity calculation uses the **dot product** of normalized vectors, which is equivalent to **cosine similarity**.
* `add_to_semantic_cache`: Converts the new query into an embedding and stores it with the response.

In [None]:
def semantic_cache_query(query, threshold=0.8):
    # Embed the incoming query
    query_emb = embedding_model.encode([query])[0]

    for item in semantic_cache:
        # Cosine Similarity Calculation: (A . B) / (||A|| * ||B||)
        sim = np.dot(query_emb, item['embedding']) / (np.linalg.norm(query_emb) * np.linalg.norm(item['embedding']))

        if sim >= threshold:
            print(f"\n[CACHE HIT!] Similarity: {sim:.4f}")
            return item['response']

    return None

def add_to_semantic_cache(query, response):
    query_emb = embedding_model.encode([query])[0]
    semantic_cache.append({'query': query, 'embedding': query_emb, 'response': response})
    print("[Cache added]")

**2.Define Generation with semantic cache**

In [None]:
def generate_with_semantic_cache(query):
    cached = semantic_cache_query(query)
    if cached:
        return cached

    # Call your LLM
    response = generate_single(query)
    add_to_semantic_cache(query, response)
    return response


**3. Test and Results: Cold vs. Warm Cache**

We define a set of queries designed to trigger a semantic hit on the second run:
* Query 1 and Query 2 ("Explain RL" vs. "What is RL?") are semantically similar and should hit the cache.
* Query 3 is new and acts as a baseline.

We will run the queries twice:
1.  **Cold Cache:** All queries miss, resulting in high latency.
2.  **Warm Cache:** Subsequent semantically similar queries hit, resulting in near-zero latency.

In [None]:
queries = [
    "Explain reinforcement learning.",        # Q1 (Miss, Cache: 1)
    "Describe benefits of meditation.",        # Q2 (Miss, Cache: 2)
    "What is reinforcement learning?",         # Q3 (Hit Q1)
    "Tell me about the advantages of meditating." # Q4 (Hit Q2)
]

print("--- COLD CACHE RUN---")
start_cold = time.time()
for q in queries:
    out = generate_with_semantic_cache(q)
    print(f"Query: {q}\nResponse: {out[:70]}...\n")
end_cold = time.time()
print(f"Cold cache total latency: {end_cold - start_cold:.3f}s")


print("\n--- WARM CACHE RUN---")
start_warm = time.time()
for q in queries:
    out = generate_with_semantic_cache(q)
    # Print the cache status from the function
    print(f"Query: {q}\nResponse: {out[:70]}...\n")
end_warm = time.time()
print(f"Warm cache total latency: {end_warm - start_warm:.6f}s")

--- COLD CACHE RUN---
[Cache added]
Query: Explain reinforcement learning.
Response: Reinforcement Learning (RL) is an area of machine learning that focuse...

[Cache added]
Query: Describe benefits of meditation.
Response: Meditation has numerous benefits for both physical and mental health. ...


[CACHE HIT!] Similarity: 0.8803
Query: What is reinforcement learning?
Response: Reinforcement Learning (RL) is an area of machine learning that focuse...


[CACHE HIT!] Similarity: 0.8332
Query: Tell me about the advantages of meditating.
Response: Meditation has numerous benefits for both physical and mental health. ...

Cold cache total latency: 56.302s

--- WARM CACHE RUN---

[CACHE HIT!] Similarity: 1.0000
Query: Explain reinforcement learning.
Response: Reinforcement Learning (RL) is an area of machine learning that focuse...


[CACHE HIT!] Similarity: 1.0000
Query: Describe benefits of meditation.
Response: Meditation has numerous benefits for both physical and mental health. ...


[C