In [1]:
# import torch
# from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
# from typing import Optional, Tuple, Dict, List

# class SimpleKVCache:
#     """Simple KV Cache store with prefix matching for LLM inference"""
    
#     def __init__(self, model, tokenizer):
#         self.model = model
#         self.tokenizer = tokenizer
#         self.device = next(model.parameters()).device
        
#         # KV Store: maps token_ids tuple -> DynamicCache
#         self.kv_store: Dict[tuple, DynamicCache] = {}
    
#     def clone_cache(self, cache: DynamicCache) -> DynamicCache:
#         """Deep copy a cache"""
#         new_cache = DynamicCache()
#         for i in range(len(cache)):
#             k, v = cache[i]
#             new_cache.update(k.clone(), v.clone(), i)
#         return new_cache
    
#     def slice_cache(self, cache: DynamicCache, length: int) -> DynamicCache:
#         """Slice cache to specific length"""
#         new_cache = DynamicCache()
#         for i in range(len(cache)):
#             k, v = cache[i]
#             new_cache.update(k[:,:,:length,:].clone(), v[:,:,:length,:].clone(), i)
#         return new_cache
    
#     def find_longest_prefix(self, token_ids: List[int]) -> Tuple[Optional[DynamicCache], int]:
#         """Find longest matching prefix in KV store"""
#         best_cache = None
#         best_len = 0
        
#         for stored_ids, stored_cache in self.kv_store.items():
#             # Find matching length
#             match_len = 0
#             for i in range(min(len(token_ids), len(stored_ids))):
#                 if token_ids[i] == stored_ids[i]:
#                     match_len += 1
#                 else:
#                     break
            
#             # Update best match (require at least 2 tokens to avoid just <bos>)
#             if match_len > best_len and match_len >= 2:
#                 best_len = match_len
#                 best_cache = stored_cache
        
#         if best_cache:
#             return self.slice_cache(best_cache, best_len), best_len
#         return None, 0
    
#     def prefill(self, token_ids: List[int], past_kv: Optional[DynamicCache] = None) -> DynamicCache:
#         """Run prefill and return KV cache"""
#         input_ids = torch.tensor([token_ids], dtype=torch.long).to(self.device)
#         with torch.no_grad():
#             out = self.model(input_ids=input_ids, past_key_values=past_kv, use_cache=True)
#         return out.past_key_values
    
#     def decode(self, token_id: int, past_kv: DynamicCache) -> Tuple[int, DynamicCache]:
#         """Single decode step - returns next token and updated cache"""
#         input_ids = torch.tensor([[token_id]], dtype=torch.long).to(self.device)
#         with torch.no_grad():
#             out = self.model(input_ids=input_ids, past_key_values=past_kv, use_cache=True)
#         next_token = torch.argmax(out.logits[:, -1, :], dim=-1).item()
#         return next_token, out.past_key_values
    
#     def generate(self, prompt: str, max_new_tokens: int = 50) -> str:
#         """
#         Main generation function:
#         1. Tokenize prompt
#         2. Check KV store for prefix match
#         3. Prefill (full or partial)
#         4. Store KV in store
#         5. Decode to generate tokens
#         """
#         # Step 1: Tokenize
#         token_ids = self.tokenizer.encode(prompt, add_special_tokens=True)
#         print(f"\n{'='*60}")
#         print(f"Prompt: '{prompt}'")
#         print(f"Tokens: {token_ids} ({len(token_ids)} tokens)")
        
#         # Step 2: Check KV store for prefix match
#         cached_kv, prefix_len = self.find_longest_prefix(token_ids)
        
#         # Step 3: Prefill
#         if cached_kv and prefix_len > 0:
#             # Partial prefill - only compute remaining tokens
#             remaining = token_ids[prefix_len:]
#             print(f"✓ Cache HIT: {prefix_len} tokens cached, computing {len(remaining)} new tokens")
#             kv_cache = self.prefill(remaining, past_kv=cached_kv)
#         else:
#             # Full prefill - compute all tokens
#             print(f"✗ Cache MISS: computing all {len(token_ids)} tokens")
#             kv_cache = self.prefill(token_ids)
        
#         # Step 4: Store in KV store
#         self.kv_store[tuple(token_ids)] = self.clone_cache(kv_cache)
#         print(f"  Stored in KV store (total entries: {len(self.kv_store)})")
        
#         # Step 5: Decode - generate new tokens
#         print(f"  Generating up to {max_new_tokens} tokens...")
#         generated = []
#         current_token = token_ids[-1]
        
#         for _ in range(max_new_tokens):
#             next_token, kv_cache = self.decode(current_token, kv_cache)
#             if next_token == self.tokenizer.eos_token_id:
#                 break
#             generated.append(next_token)
#             current_token = next_token
        
#         # Decode output
#         output_text = self.tokenizer.decode(generated, skip_special_tokens=True)
#         full_response = prompt + output_text
        
#         print(f"{'='*60}")
#         return full_response


# # ============================================================
# # MAIN: Simple demonstration
# # ============================================================
# if __name__ == "__main__":
#     # Load model
#     print("Loading model...")
#     MODEL = "google/gemma-3-1b-it"
    
#     tokenizer = AutoTokenizer.from_pretrained(MODEL)
#     model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.float16, device_map="cpu")
    
#     # Create KV cache manager
#     kv_cache = SimpleKVCache(model, tokenizer)
    
#     # --------------------------------------------------------
#     # PROMPT 1: First query (no cache)
#     # --------------------------------------------------------
#     prompt1 = "What is machine learning"
#     response1 = kv_cache.generate(prompt1, max_new_tokens=70)
#     print(f"\nResponse 1:\n{response1}\n")
    
#     # --------------------------------------------------------
#     # PROMPT 2: Query with prefix match (reuses "What is machine learning")
#     # --------------------------------------------------------
#     prompt2 = "What is machine learning formula are maths ?"
#     response2 = kv_cache.generate(prompt2, max_new_tokens=70)
#     print(f"\nResponse 2:\n{response2}\n")
    
#     # --------------------------------------------------------
#     # PROMPT 3: Another prefix match
#     # --------------------------------------------------------
#     prompt3 = "What is machine learning used for"
#     response3 = kv_cache.generate(prompt3, max_new_tokens=70)
#     print(f"\nResponse 3:\n{response3}\n")
    
#     # --------------------------------------------------------
#     # Show KV Store contents
#     # --------------------------------------------------------
#     print("\n" + "="*60)
#     print("KV STORE CONTENTS:")
#     print("="*60)
#     for i, (ids, cache) in enumerate(kv_cache.kv_store.items(), 1):
#         text = tokenizer.decode(list(ids), skip_special_tokens=True)
#         seq_len = cache.get_seq_length()
#         print(f"{i}. '{text}' -> {seq_len} cached positions")

In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache
from typing import Dict, List, Tuple, Optional

class PackedKVCache:
    """
    Combines:
    1. KV Cache Store with prefix matching
    2. Prompt packing for batch processing
    """
    
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.device = next(model.parameters()).device
        self.num_heads = model.config.num_attention_heads
        self.num_layers = model.config.num_hidden_layers
        
        # KV Store: token_ids (tuple) -> list of (key, value) per layer
        self.kv_store: Dict[tuple, List[Tuple[torch.Tensor, torch.Tensor]]] = {}
    
    # ============================================================
    # KV Store Operations
    # ============================================================
    def store_kv(self, token_ids: List[int], kv_list: List[Tuple[torch.Tensor, torch.Tensor]]):
        """Store KV cache for a token sequence"""
        # Clone tensors to avoid mutation
        cloned = [(k.clone(), v.clone()) for k, v in kv_list]
        self.kv_store[tuple(token_ids)] = cloned
    
    def find_prefix(self, token_ids: List[int], min_match: int = 2) -> Tuple[Optional[List], int]:
        """Find longest prefix match in store"""
        best_kv = None
        best_len = 0
        
        for stored_ids, stored_kv in self.kv_store.items():
            match_len = 0
            for i in range(min(len(token_ids), len(stored_ids))):
                if token_ids[i] == stored_ids[i]:
                    match_len += 1
                else:
                    break
            
            if match_len > best_len and match_len >= min_match:
                best_len = match_len
                best_kv = stored_kv
        
        if best_kv:
            # Slice to prefix length
            sliced = [(k[:,:,:best_len,:].clone(), v[:,:,:best_len,:].clone()) for k, v in best_kv]
            return sliced, best_len
        return None, 0
    
    # ============================================================
    # Prompt Packing with Prefix Matching
    # ============================================================
    def process_batch(self, prompts: List[str]) -> List[Tuple[List[int], List]]:
        """
        Process batch of prompts with prefix matching + packing
        Returns: List of (token_ids, kv_cache) for each prompt
        """
        print("=" * 70)
        print("BATCH PROCESSING WITH PREFIX MATCHING + PACKING")
        print("=" * 70)
        
        # Step 1: Tokenize and check prefix matches
        print("\nStep 1: Tokenize & check KV store for prefixes")
        
        prompt_data = []  # (tokens, cached_kv, prefix_len, remaining_tokens)
        
        for i, prompt in enumerate(prompts):
            tokens = self.tokenizer.encode(prompt, add_special_tokens=True)
            cached_kv, prefix_len = self.find_prefix(tokens)
            remaining = tokens[prefix_len:] if prefix_len > 0 else tokens
            
            prompt_data.append({
                'idx': i,
                'prompt': prompt,
                'tokens': tokens,
                'cached_kv': cached_kv,
                'prefix_len': prefix_len,
                'remaining': remaining
            })
            
            status = f"✓ HIT ({prefix_len} cached)" if prefix_len > 0 else "✗ MISS"
            print(f"  [{i}] '{prompt[:40]}...' -> {status}, {len(remaining)} to compute")
        
        # Step 2: Pack remaining tokens for batch processing
        print("\nStep 2: Pack remaining tokens")
        
        packed_tokens = []
        prompt_ranges = []  # (start, end) in packed sequence
        current_pos = 0
        sep_token = self.tokenizer.eos_token_id
        
        for i, data in enumerate(prompt_data):
            remaining = data['remaining']
            if len(remaining) == 0:
                prompt_ranges.append(None)  # Fully cached
                continue
            
            start = current_pos
            end = current_pos + len(remaining)
            prompt_ranges.append((start, end))
            
            packed_tokens.extend(remaining)
            
            # Add separator between prompts
            if i < len(prompt_data) - 1:
                packed_tokens.append(sep_token)
                current_pos = end + 1
            else:
                current_pos = end
        
        print(f"  Packed sequence length: {len(packed_tokens)} tokens")
        print(f"  Ranges: {prompt_ranges}")
        
        # Step 3: Build block-diagonal attention mask
        print("\nStep 3: Build attention mask")
        
        seq_len = len(packed_tokens)
        if seq_len > 0:
            attn_mask = torch.zeros((seq_len, seq_len), dtype=torch.float, device=self.device)
            
            for rng in prompt_ranges:
                if rng is not None:
                    start, end = rng
                    attn_mask[start:end, start:end] = 1.0
            
            # Expand for heads: [1, heads, seq, seq]
            attn_mask = attn_mask.unsqueeze(0).unsqueeze(0).repeat(1, self.num_heads, 1, 1)
            print(f"  Mask shape: {attn_mask.shape}")
        
        # Step 4: Forward pass on packed sequence
        print("\nStep 4: Forward pass (packed)")
        
        if seq_len > 0:
            packed_input = torch.tensor([packed_tokens], device=self.device)
            
            with torch.no_grad():
                outputs = self.model(
                    input_ids=packed_input,
                    attention_mask=attn_mask,
                    use_cache=True
                )
            
            packed_kv = outputs.past_key_values
            print(f"  ✓ Computed KV for {seq_len} packed tokens")
        else:
            packed_kv = None
            print("  All prompts fully cached!")
        
        # Step 5: Extract and combine KV caches
        print("\nStep 5: Extract & combine KV caches")
        
        results = []
        
        for i, data in enumerate(prompt_data):
            tokens = data['tokens']
            cached_kv = data['cached_kv']
            prefix_len = data['prefix_len']
            rng = prompt_ranges[i] if i < len(prompt_ranges) else None
            
            if rng is None and cached_kv is not None:
                # Fully cached - use cached KV directly
                final_kv = cached_kv
                print(f"  [{i}] Using full cache ({prefix_len} tokens)")
            
            elif cached_kv is not None and rng is not None:
                # Partial cache - concatenate cached + new
                start, end = rng
                final_kv = []
                
                for layer_idx in range(self.num_layers):
                    cached_k, cached_v = cached_kv[layer_idx]
                    new_k = packed_kv[layer_idx][0][:, :, start:end, :]
                    new_v = packed_kv[layer_idx][1][:, :, start:end, :]
                    
                    combined_k = torch.cat([cached_k, new_k], dim=2)
                    combined_v = torch.cat([cached_v, new_v], dim=2)
                    final_kv.append((combined_k, combined_v))
                
                print(f"  [{i}] Combined: {prefix_len} cached + {end-start} new = {len(tokens)} total")
            
            else:
                # No cache - extract from packed
                start, end = rng
                final_kv = []
                
                for layer_idx in range(self.num_layers):
                    k = packed_kv[layer_idx][0][:, :, start:end, :]
                    v = packed_kv[layer_idx][1][:, :, start:end, :]
                    final_kv.append((k.clone(), v.clone()))
                
                print(f"  [{i}] Extracted: {end-start} tokens")
            
            # Store in KV store for future use
            self.store_kv(tokens, final_kv)
            results.append((tokens, final_kv))
        
        print(f"\n  KV Store now has {len(self.kv_store)} entries")
        return results
    
    # ============================================================
    # Generation
    # ============================================================
    def generate(self, prompt: str, kv_list: List, max_new: int = 30) -> str:
        """Generate from a KV cache"""
        tokens = self.tokenizer.encode(prompt, add_special_tokens=True)
        
        # Convert to DynamicCache
        cache = DynamicCache()
        for layer_idx, (k, v) in enumerate(kv_list):
            cache.update(k, v, layer_idx)
        
        cache_len = cache.get_seq_length()
        
        # Start from last token
        last_token = torch.tensor([[tokens[-1]]], device=self.device)
        position_ids = torch.tensor([[cache_len - 1]], device=self.device)
        
        with torch.no_grad():
            out = self.model(
                input_ids=last_token,
                past_key_values=cache,
                position_ids=position_ids,
                use_cache=True
            )
        
        # Generate tokens
        generated = []
        cache = out.past_key_values
        next_token = torch.argmax(out.logits[:, -1, :], dim=-1)
        generated.append(next_token.item())
        
        current_pos = cache_len
        
        for _ in range(max_new - 1):
            if next_token.item() == self.tokenizer.eos_token_id:
                break
            
            position_ids = torch.tensor([[current_pos]], device=self.device)
            
            with torch.no_grad():
                out = self.model(
                    input_ids=next_token.unsqueeze(0),
                    past_key_values=cache,
                    position_ids=position_ids,
                    use_cache=True
                )
            
            cache = out.past_key_values
            next_token = torch.argmax(out.logits[:, -1, :], dim=-1)
            generated.append(next_token.item())
            current_pos += 1
        
        return self.tokenizer.decode(generated, skip_special_tokens=True)


# ============================================================
# MAIN DEMO
# ============================================================
if __name__ == "__main__":
    print("Loading model...")
    MODEL = "google/gemma-3-1b-it"
    
    tokenizer = AutoTokenizer.from_pretrained(MODEL)
    model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.float16, device_map="cpu")
    
    cache_manager = PackedKVCache(model, tokenizer)
    
    # --------------------------------------------------------
    # BATCH 1: First batch (no cache hits)
    # --------------------------------------------------------
    print("\n" + "=" * 70)
    print("BATCH 1: Initial prompts (no cache)")
    print("=" * 70)
    
    batch1 = [
        "What is machine learning",
        "Explain neural networks",
        "Tell me about Python"
    ]
    
    results1 = cache_manager.process_batch(batch1)
    
    print("\nGenerating responses:")
    for prompt, kv in results1:
        text = tokenizer.decode(prompt, skip_special_tokens=True)
        response = cache_manager.generate(text, kv, max_new=120)
        print(f"  '{text[:30]}...' -> {response}...")
    
    # --------------------------------------------------------
    # BATCH 2: Second batch (should have prefix hits!)
    # --------------------------------------------------------
    print("\n" + "=" * 70)
    print("BATCH 2: Prompts with prefix matches")
    print("=" * 70)
    
    batch2 = [
        "What is machine learning and deep learning",  # Prefix: "What is machine learning"
        "Explain neural networks in simple terms",      # Prefix: "Explain neural networks"
        "What is machine learning used for",            # Prefix: "What is machine learning"
        "Can you write me something about India?"              # Prefix: "Tell me about Python"
    ]
    
    results2 = cache_manager.process_batch(batch2)
    
    print("\nGenerating responses:")
    for prompt, kv in results2:
        text = tokenizer.decode(prompt, skip_special_tokens=True)
        response = cache_manager.generate(text, kv, max_new=120)
        print(f"  '{text[:30]}...' -> {response}...")
    
    # --------------------------------------------------------
    # Show KV Store
    # --------------------------------------------------------
    print("\n" + "=" * 70)
    print("FINAL KV STORE:")
    print("=" * 70)
    for i, (ids, kv) in enumerate(cache_manager.kv_store.items(), 1):
        text = tokenizer.decode(list(ids), skip_special_tokens=True)
        seq_len = kv[0][0].shape[2]  # Get seq length from first layer key
        print(f"  {i}. '{text[:50]}' ({seq_len} tokens)")

Loading model...

BATCH 1: Initial prompts (no cache)
BATCH PROCESSING WITH PREFIX MATCHING + PACKING

Step 1: Tokenize & check KV store for prefixes
  [0] 'What is machine learning...' -> ✗ MISS, 5 to compute
  [1] 'Explain neural networks...' -> ✗ MISS, 4 to compute
  [2] 'Tell me about Python...' -> ✗ MISS, 5 to compute

Step 2: Pack remaining tokens
  Packed sequence length: 16 tokens
  Ranges: [(0, 5), (6, 10), (11, 16)]

Step 3: Build attention mask
  Mask shape: torch.Size([1, 4, 16, 16])

Step 4: Forward pass (packed)
  ✓ Computed KV for 16 packed tokens

Step 5: Extract & combine KV caches
  [0] Extracted: 5 tokens
  [1] Extracted: 4 tokens
  [2] Extracted: 5 tokens

  KV Store now has 3 entries

Generating responses:
  'What is machine learning...' -> ?

Machine learning is a branch of artificial intelligence (AI) that focuses on enabling computers to learn from data without being explicitly programmed. Instead of providing step-by-step instructions, machine learning algorith

In [5]:
batch2 = [
        "What is machine learning and deep learning",  # Prefix: "What is machine learning"
        "Explain neural networks in simple terms",      # Prefix: "Explain neural networks"
        "What is machine learning used for",            # Prefix: "What is machine learning"
        "Can you write me something about India?"              # Prefix: "Tell me about Python"
    ]

MODEL = "google/gemma-3-1b-it"
tokenizer = AutoTokenizer.from_pretrained(MODEL)

output = tokenizer(batch2, return_tensors="pt" , padding=True)
print(output["input_ids"].shape)

torch.Size([4, 9])
