# CircuitKV Initialization Strategy Comparison
**Goal:** Find optimal prefill initialization to close gap with H2O on summarization.

In [None]:
!pip install -q transformers accelerate "datasets<3.0"

In [None]:
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from tqdm.auto import tqdm

DEVICE = "cuda"
TOP_K_PERCENT = 0.20
N_SAMPLES = 3

print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Load model (full precision fp16)
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",
    attn_implementation="eager",
    output_attentions=True,
)
model.eval()
print("Model loaded (fp16).")

In [None]:
# Load real LongBench NarrativeQA
print("Loading LongBench NarrativeQA...")

dataset = load_dataset(
    "THUDM/LongBench",
    "narrativeqa",
    split="test",
    trust_remote_code=True,
)

samples = [dataset[i] for i in range(N_SAMPLES)]
print(f"Loaded {len(samples)} samples")
print(f"Sample keys: {samples[0].keys()}")
print(f"Context length (chars): {len(samples[0]['context'])}")

In [None]:
def normalize(x):
    x = x - x.min()
    return x / (x.max() + 1e-10)

def jaccard_topk(scores, oracle, k_percent=0.20):
    k = int(len(scores) * k_percent)
    top_scores = set(torch.argsort(scores, descending=True)[:k].tolist())
    top_oracle = set(torch.argsort(oracle, descending=True)[:k].tolist())
    return len(top_scores & top_oracle) / len(top_scores | top_oracle)

def compute_strategies(h2o, query_attn):
    h2o_norm = normalize(h2o)
    query_norm = normalize(query_attn)
    
    return {
        "H2O (Blind)": h2o_norm,
        "Additive": 0.5 * h2o_norm + 0.5 * query_norm,
        "Max": torch.maximum(h2o_norm, query_norm),
        "Multiplicative": torch.sqrt(h2o_norm * query_norm + 1e-10),
    }

In [None]:
@torch.no_grad()
def evaluate_sample(model, tokenizer, sample):
    context = sample["context"][:12000]
    question = sample["input"]
    
    prompt = f"Context: {context}\n\nQuestion: {question}\nAnswer:"
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
    input_ids = inputs["input_ids"].to(DEVICE)
    seq_len = input_ids.shape[1]
    
    outputs = model(input_ids, output_attentions=True)
    attn = outputs.attentions[-1][0].mean(dim=0).float().cpu()
    
    h2o = attn.sum(dim=0)
    query_idx = seq_len - 1
    query_attn = attn[query_idx, :]
    oracle = query_attn.clone()
    
    strategies = compute_strategies(h2o, query_attn)
    
    results = {}
    for name, scores in strategies.items():
        results[name] = jaccard_topk(scores, oracle, TOP_K_PERCENT)
    
    return results, seq_len

In [None]:
# Run evaluation
all_results = {name: [] for name in ["H2O (Blind)", "Additive", "Max", "Multiplicative"]}

print("Evaluating samples...\n")
for i, sample in enumerate(tqdm(samples)):
    results, seq_len = evaluate_sample(model, tokenizer, sample)
    for name, score in results.items():
        all_results[name].append(score)
    print(f"Sample {i+1} (len={seq_len}): {results}")

print("\nDone.")

In [None]:
# Results
print("\n" + "="*50)
print("        INITIALIZATION STRATEGY COMPARISON")
print("="*50)
print(f"{'Strategy':<20} {'Jaccard@20%':>15}")
print("-"*50)

final_scores = {}
for name in ["H2O (Blind)", "Additive", "Max", "Multiplicative"]:
    mean_score = np.mean(all_results[name])
    final_scores[name] = mean_score
    marker = " â˜…" if mean_score == max(np.mean(all_results[n]) for n in all_results) else ""
    print(f"{name:<20} {mean_score:>15.4f}{marker}")

print("="*50)

best = max(final_scores, key=final_scores.get)
print(f"\nWinner: {best}")

if final_scores["Multiplicative"] > final_scores["H2O (Blind)"]:
    gain = (final_scores["Multiplicative"] - final_scores["H2O (Blind)"]) / final_scores["H2O (Blind)"] * 100
    print(f"Multiplicative vs H2O: +{gain:.1f}% relative improvement")

In [None]:
# Per-sample breakdown
print("\n" + "="*70)
print("                    PER-SAMPLE BREAKDOWN")
print("="*70)
print(f"{'Sample':<10} {'H2O':<12} {'Additive':<12} {'Max':<12} {'Multiplicative':<12}")
print("-"*70)

for i in range(len(samples)):
    h = all_results["H2O (Blind)"][i]
    a = all_results["Additive"][i]
    m = all_results["Max"][i]
    g = all_results["Multiplicative"][i]
    print(f"{i+1:<10} {h:<12.4f} {a:<12.4f} {m:<12.4f} {g:<12.4f}")

print("="*70)

In [None]:
# Visualization
import matplotlib.pyplot as plt

strategies = list(final_scores.keys())
scores = list(final_scores.values())
colors = ['#3498db', '#9b59b6', '#e67e22', '#2ecc71']

plt.figure(figsize=(10, 6))
bars = plt.bar(strategies, scores, color=colors, edgecolor='black', linewidth=1.5)

max_idx = scores.index(max(scores))
bars[max_idx].set_edgecolor('red')
bars[max_idx].set_linewidth(3)

plt.ylabel('Jaccard@20% Overlap with Oracle', fontsize=12)
plt.title('Initialization Strategy Comparison (LongBench NarrativeQA)', fontsize=14, fontweight='bold')
plt.ylim(0, max(scores) * 1.2)

for bar, score in zip(bars, scores):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
             f'{score:.4f}', ha='center', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.savefig('strategy_comparison.png', dpi=150)
plt.show()