# H2O

## Headers Imports

In [None]:
import torch
import time
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM

## Model load

In [None]:
# --- Configuration ---
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "gpt2-large"
generation_length = 512  # The number of tokens to generate for the test

print(f"Using device: {device}")
print(f"Loading model: {model_name} (with attentions)...")

# --- Load Model and Tokenizer ---
# CRITICAL: We need 'output_attentions=True' to see which tokens are important.
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True).to(device)

print("Setup complete.")

## Analysis

In [None]:
def analyze_inference_h2o(strategy, window_size=None, heavy_hitters_to_keep=4):
    """
    Generates tokens and measures latency for baseline and a simplified H2O strategy.
    """
    print(f"\n--- Running test for strategy: '{strategy}' ---")
    
    input_ids = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(device)
    past_key_values = None
    timings = []
    
    # We'll use this to track the "importance" of each token in the cache
    attention_scores = torch.zeros(generation_length + 1).to(device)

    with torch.no_grad():
        for i in range(generation_length):
            start_time = time.perf_counter()

            # Model forward pass
            outputs = model(input_ids=input_ids[:, -1:], past_key_values=past_key_values, use_cache=True)
            
            # --- Logic for H2O: Track Attention Scores ---
            if strategy == 'h2o' and outputs.attentions is not None:
                # Go through each attention layer
                for layer_attention in outputs.attentions:
                    # Sum the attention scores from the current token to all past tokens
                    # This gives us a rough measure of how much the model "looked back" at each past token
                    current_attention_to_past = layer_attention[0, :, -1, :-1].sum(dim=0)
                    if current_attention_to_past.numel() > 0:
                        attention_scores[:i] += current_attention_to_past

            next_token_id = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(-1)
            input_ids = torch.cat([input_ids, next_token_id], dim=-1)
            past_key_values = outputs.past_key_values

            # --- KV Cache Eviction Logic ---
            if past_key_values is None:
                end_time = time.perf_counter()
                timings.append((end_time - start_time) * 1000)
                continue

            current_cache_size = past_key_values[0][0].shape[2]

            if strategy == 'h2o' and current_cache_size > window_size:
                # 1. Identify the "heavy hitters" (tokens with the highest attention scores)
                _, heavy_hitter_indices = torch.topk(attention_scores[:current_cache_size], heavy_hitters_to_keep)
                
                # 2. Identify the most recent tokens
                num_recent_to_keep = window_size - heavy_hitters_to_keep
                recent_indices = torch.arange(current_cache_size - num_recent_to_keep, current_cache_size).to(device)
                
                # 3. Combine them and remove duplicates to get the final set of tokens to keep
                indices_to_keep = torch.unique(torch.cat([heavy_hitter_indices, recent_indices]))
                
                # 4. Trim the cache to only keep these important indices
                past_key_values = tuple(
                    (
                        layer_past[0].index_select(2, indices_to_keep),
                        layer_past[1].index_select(2, indices_to_keep)
                    ) for layer_past in past_key_values
                )
            
            end_time = time.perf_counter()
            timings.append((end_time - start_time) * 1000)

            if (i + 1) % 100 == 0:
                print(f"Generated {i + 1}/{generation_length} tokens...")
    
    print(f"'{strategy}' test complete.")
    return timings

## Analysis 

In [None]:
# --- Run the Two Experiments ---
window = 128
heavy_hitters = 32 # The number of "important" past tokens to always protect

baseline_timings = analyze_inference_h2o(strategy='baseline')
h2o_timings = analyze_inference_h2o(strategy='h2o', window_size=window, heavy_hitters_to_keep=heavy_hitters)

## Plotting

In [None]:
plt.figure(figsize=(12, 7))
plt.plot(baseline_timings, label='Baseline (Full KV Cache)')
plt.plot(h2o_timings, label=f'Simplified H2O (Window={window}, Heavy Hitters={heavy_hitters})')

plt.xlabel('Generated Token Number (Sequence Length)')
plt.ylabel('Time per Token (ms)')
plt.title(f'KV Cache Performance: Baseline vs. Simplified H2O on {model_name}')
plt.legend()
plt.grid(True)
plt.show()