# CircuitKV Novel Add-ons PoC: ICML 2026

**Goal:** Test CHEAP novel add-ons to RC (Random Circuit) that improve bridge detection.

## Novel Ideas (Zero Extra Cost - Just Modify Walker):

### 1. **Hitting Time Weighting** (RC+HT)
- Weight visits by WHEN in the walk they occur
- Early visits (near query) = high recency importance
- Late visits (near sink) = high foundational importance
- Cost: Zero - just change the atomic add weight

### 2. **Escape Probability** (RC+EP)
- Track FIRST visit vs TOTAL visits per token
- Bridge = you pass through once and escape to sink
- Hub = you keep coming back (high revisit ratio)
- Cost: +1 bit per walker per token (negligible)

### 3. **Absorption Speed** (RC+AS)
- Track how fast walkers get absorbed after visiting each token
- Fast absorption = direct path to sink = true bridge
- Slow absorption = stuck in distractor cluster
- Cost: Zero - just track step count

### 4. **Temperature Mixing** (RC+TM)
- Half walkers use sharp temperature (follow strongest edges)
- Half walkers use soft temperature (explore more)
- Combines local hub detection + global bridge detection
- Cost: Zero - just change softmax scale

**Metric:** F1 score on actual generation (same as LongBench)

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Optional, Tuple
from tqdm.auto import tqdm
import re
import string
from collections import Counter
import gc
import math

print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Config
MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
MAX_SEQ_LEN = 2048
MAX_NEW_TOKENS = 64
NUM_SAMPLES = 5
BUDGET_RATIO = 0.25

In [None]:
# F1 Score (LongBench metric)
def normalize_answer(s: str) -> str:
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)
    def white_space_fix(text):
        return ' '.join(text.split())
    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)
    return white_space_fix(remove_articles(remove_punc(s.lower())))

def f1_score(prediction: str, ground_truth: str) -> float:
    pred_tokens = normalize_answer(prediction).split()
    gold_tokens = normalize_answer(ground_truth).split()
    if len(pred_tokens) == 0 or len(gold_tokens) == 0:
        return float(pred_tokens == gold_tokens)
    common = Counter(pred_tokens) & Counter(gold_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0.0
    precision = num_same / len(pred_tokens)
    recall = num_same / len(gold_tokens)
    return 2 * precision * recall / (precision + recall)

def best_f1(prediction: str, answers: List[str]) -> float:
    return max(f1_score(prediction, ans) for ans in answers) if answers else 0.0

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="eager",
)
model.eval()
print(f"Loaded. GPU: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

In [None]:
from datasets import load_dataset

ds = load_dataset("THUDM/LongBench", "narrativeqa", split="test", trust_remote_code=True)
samples = []
for i, item in enumerate(ds):
    if i >= NUM_SAMPLES:
        break
    prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Read the following text and answer the question.

Text: {item['context']}

Question: {item['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""
    samples.append({"prompt": prompt, "answers": item["answers"], "question": item["input"]})
print(f"Loaded {len(samples)} samples")

## Random Walk Simulation (PyTorch - Simulates CUDA Kernel)

In [None]:
def build_sparse_graph(attn_matrix: torch.Tensor, top_k: int = 32) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Build sparse adjacency graph from attention matrix.
    Returns: (adj_indices [n, top_k], adj_weights [n, top_k])
    """
    n = attn_matrix.shape[0]
    device = attn_matrix.device
    
    # For each row, get top-k neighbors (causal: only attend to past)
    adj_weights = torch.zeros(n, top_k, device=device)
    adj_indices = torch.full((n, top_k), -1, dtype=torch.long, device=device)
    
    for i in range(n):
        # Only consider tokens 0..i-1 (causal)
        if i == 0:
            continue
        
        row = attn_matrix[i, :i]  # Only past tokens
        k = min(top_k, i)
        
        if k > 0:
            topk_weights, topk_idx = row.topk(k)
            adj_weights[i, :k] = topk_weights
            adj_indices[i, :k] = topk_idx
    
    return adj_indices, adj_weights

In [None]:
def run_random_walks(
    adj_indices: torch.Tensor,
    adj_weights: torch.Tensor,
    source_nodes: List[int],
    num_walkers: int = 256,
    max_steps: int = 100,
    sink_size: int = 4,
    mode: str = "basic",  # basic, hitting_time, escape_prob, absorption_speed, temp_mix
    temperature: float = 1.0,
) -> torch.Tensor:
    """
    Run absorbing random walks from source nodes toward sink.
    
    Modes:
    - basic: Standard visit counting (current RC)
    - hitting_time: Weight by step number (early = high weight)
    - escape_prob: Track first-visit ratio
    - absorption_speed: Track absorption time per token
    - temp_mix: Mix sharp and soft temperature walkers
    
    Returns: scores [n]
    """
    n = adj_indices.shape[0]
    device = adj_indices.device
    
    # Initialize score buffers based on mode
    visits = torch.zeros(n, device=device)
    first_visits = torch.zeros(n, device=device) if mode == "escape_prob" else None
    absorption_times = torch.zeros(n, device=device) if mode == "absorption_speed" else None
    absorption_counts = torch.zeros(n, device=device) if mode == "absorption_speed" else None
    
    total_walkers = num_walkers * len(source_nodes)
    
    for walker_id in range(total_walkers):
        # Select source
        source_idx = walker_id % len(source_nodes)
        pos = source_nodes[source_idx]
        
        # Per-walker state
        visited_this_walk = set() if mode == "escape_prob" else None
        
        # Temperature for this walker (temp_mix mode)
        if mode == "temp_mix":
            # Half walkers use sharp temp, half use soft
            temp = 2.0 if walker_id % 2 == 0 else 0.5
        else:
            temp = temperature
        
        for step in range(max_steps):
            # Record visit based on mode
            if mode == "basic":
                visits[pos] += 1.0
            
            elif mode == "hitting_time":
                # Early visits get higher weight
                weight = (max_steps - step) / max_steps
                visits[pos] += weight
            
            elif mode == "escape_prob":
                visits[pos] += 1.0
                if pos not in visited_this_walk:
                    first_visits[pos] += 1.0
                    visited_this_walk.add(pos)
            
            elif mode == "absorption_speed":
                visits[pos] += 1.0
            
            elif mode == "temp_mix":
                visits[pos] += 1.0
            
            # Check absorption
            if pos < sink_size:
                # Record absorption stats
                if mode == "absorption_speed" and visited_this_walk is None:
                    # Track which tokens led to fast absorption
                    # (we'll handle this differently)
                    pass
                break
            
            # Get neighbors
            neighbors = adj_indices[pos]
            weights = adj_weights[pos]
            
            # Filter valid neighbors
            valid_mask = neighbors >= 0
            if not valid_mask.any():
                break
            
            valid_neighbors = neighbors[valid_mask]
            valid_weights = weights[valid_mask]
            
            # Apply temperature
            if temp != 1.0:
                valid_weights = valid_weights ** (1.0 / temp)
            
            # Normalize to probabilities
            probs = valid_weights / valid_weights.sum().clamp(min=1e-8)
            
            # Sample next position
            idx = torch.multinomial(probs, 1).item()
            pos = valid_neighbors[idx].item()
    
    # Compute final scores based on mode
    if mode == "escape_prob":
        # Escape ratio = first_visits / total_visits
        # High ratio = bridge (pass through once and escape)
        # Low ratio = hub (keep coming back)
        escape_ratio = first_visits / visits.clamp(min=1.0)
        # Combine with visit count (need both signal strength and escape)
        scores = escape_ratio * (visits / visits.max().clamp(min=1e-8))
    else:
        scores = visits
    
    # Normalize
    if scores.max() > 0:
        scores = scores / scores.max()
    
    return scores

## Eviction Strategies

In [None]:
class BaseStrategy:
    """Base class for KV eviction strategies."""
    def __init__(self, budget_ratio=0.25, sink_size=4, local_window=32, top_k=32):
        self.budget_ratio = budget_ratio
        self.sink_size = sink_size
        self.local_window = local_window
        self.top_k = top_k
    
    def compute_scores(self, attn_matrix: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError
    
    def get_keep_indices(self, scores: torch.Tensor, seq_len: int) -> torch.Tensor:
        budget = int(seq_len * self.budget_ratio)
        device = scores.device
        
        mask = torch.zeros(seq_len, dtype=torch.bool, device=device)
        mask[:self.sink_size] = True
        mask[-self.local_window:] = True
        
        already_kept = mask.sum().item()
        remaining = max(0, budget - already_kept)
        
        if remaining > 0:
            scores_masked = scores.clone()
            scores_masked[mask] = float('-inf')
            _, top_idx = scores_masked.topk(remaining)
            mask[top_idx] = True
        
        return mask.nonzero(as_tuple=True)[0]

In [None]:
class H2OStrategy(BaseStrategy):
    """Baseline: H2O (Heavy Hitter Oracle)."""
    
    def compute_scores(self, attn_matrix: torch.Tensor) -> torch.Tensor:
        W = min(32, attn_matrix.shape[0] - self.sink_size)
        h2o = attn_matrix[-W:, :].sum(dim=0)
        return h2o / h2o.max().clamp(min=1e-8)

In [None]:
class RCStrategy(BaseStrategy):
    """
    RC (Random Circuit): Base absorbing random walk strategy.
    This is the current CircuitKV implementation.
    """
    
    def __init__(self, num_walkers=256, observation_window=8, **kwargs):
        super().__init__(**kwargs)
        self.num_walkers = num_walkers
        self.observation_window = observation_window
    
    def compute_scores(self, attn_matrix: torch.Tensor) -> torch.Tensor:
        n = attn_matrix.shape[0]
        W = min(self.observation_window, n - self.sink_size)
        
        # Build sparse graph
        adj_indices, adj_weights = build_sparse_graph(attn_matrix, self.top_k)
        
        # Source nodes = observation window
        source_nodes = list(range(n - W, n))
        
        # Run walks
        circuit_scores = run_random_walks(
            adj_indices, adj_weights, source_nodes,
            num_walkers=self.num_walkers,
            sink_size=self.sink_size,
            mode="basic",
        )
        
        # Combine with H2O
        h2o = attn_matrix[-W:, :].sum(dim=0)
        h2o = h2o / h2o.max().clamp(min=1e-8)
        
        # Union: max of both
        return torch.maximum(h2o, circuit_scores)

In [None]:
class RC_HittingTime(RCStrategy):
    """
    RC + Hitting Time Weighting (RC+HT)
    
    NOVEL: Weight visits by WHEN in the walk they occur.
    - Early visits (near query) = high recency weight
    - Late visits (near sink) = lower weight
    
    Insight: Tokens visited early in the walk are closer to the query.
    Tokens visited late might be "detours" through distractors.
    """
    
    def compute_scores(self, attn_matrix: torch.Tensor) -> torch.Tensor:
        n = attn_matrix.shape[0]
        W = min(self.observation_window, n - self.sink_size)
        
        adj_indices, adj_weights = build_sparse_graph(attn_matrix, self.top_k)
        source_nodes = list(range(n - W, n))
        
        # Run walks with hitting time weighting
        circuit_scores = run_random_walks(
            adj_indices, adj_weights, source_nodes,
            num_walkers=self.num_walkers,
            sink_size=self.sink_size,
            mode="hitting_time",  # <-- Novel modification
        )
        
        h2o = attn_matrix[-W:, :].sum(dim=0)
        h2o = h2o / h2o.max().clamp(min=1e-8)
        
        return torch.maximum(h2o, circuit_scores)

In [None]:
class RC_EscapeProb(RCStrategy):
    """
    RC + Escape Probability (RC+EP)
    
    NOVEL: Track first-visit vs total-visit ratio per token.
    - High escape ratio = bridge (pass through once, escape to sink)
    - Low escape ratio = hub (walkers keep returning)
    
    Insight: True reasoning bridges are "one-way streets" toward sink.
    Distractors are "traps" that walkers revisit.
    """
    
    def compute_scores(self, attn_matrix: torch.Tensor) -> torch.Tensor:
        n = attn_matrix.shape[0]
        W = min(self.observation_window, n - self.sink_size)
        
        adj_indices, adj_weights = build_sparse_graph(attn_matrix, self.top_k)
        source_nodes = list(range(n - W, n))
        
        # Run walks with escape probability tracking
        circuit_scores = run_random_walks(
            adj_indices, adj_weights, source_nodes,
            num_walkers=self.num_walkers,
            sink_size=self.sink_size,
            mode="escape_prob",  # <-- Novel modification
        )
        
        h2o = attn_matrix[-W:, :].sum(dim=0)
        h2o = h2o / h2o.max().clamp(min=1e-8)
        
        return torch.maximum(h2o, circuit_scores)

In [None]:
class RC_TempMix(RCStrategy):
    """
    RC + Temperature Mixing (RC+TM)
    
    NOVEL: Run walkers with different "temperatures".
    - Sharp temp (2.0): Follows strongest edges -> finds local hubs
    - Soft temp (0.5): Explores broadly -> finds global bridges
    
    Insight: Different temperatures capture different importance patterns.
    Combining them catches both local and global structure.
    """
    
    def compute_scores(self, attn_matrix: torch.Tensor) -> torch.Tensor:
        n = attn_matrix.shape[0]
        W = min(self.observation_window, n - self.sink_size)
        
        adj_indices, adj_weights = build_sparse_graph(attn_matrix, self.top_k)
        source_nodes = list(range(n - W, n))
        
        # Run walks with temperature mixing
        circuit_scores = run_random_walks(
            adj_indices, adj_weights, source_nodes,
            num_walkers=self.num_walkers,
            sink_size=self.sink_size,
            mode="temp_mix",  # <-- Novel modification
        )
        
        h2o = attn_matrix[-W:, :].sum(dim=0)
        h2o = h2o / h2o.max().clamp(min=1e-8)
        
        return torch.maximum(h2o, circuit_scores)

In [None]:
class RC_Combined(RCStrategy):
    """
    RC + All Novel Modifications Combined
    
    Combines: Hitting Time + Escape Probability + Temperature Mixing
    Uses max() to keep tokens important on ANY signal.
    """
    
    def compute_scores(self, attn_matrix: torch.Tensor) -> torch.Tensor:
        n = attn_matrix.shape[0]
        W = min(self.observation_window, n - self.sink_size)
        
        adj_indices, adj_weights = build_sparse_graph(attn_matrix, self.top_k)
        source_nodes = list(range(n - W, n))
        
        # Run all modes
        ht_scores = run_random_walks(
            adj_indices, adj_weights, source_nodes,
            num_walkers=self.num_walkers // 3,
            sink_size=self.sink_size,
            mode="hitting_time",
        )
        
        ep_scores = run_random_walks(
            adj_indices, adj_weights, source_nodes,
            num_walkers=self.num_walkers // 3,
            sink_size=self.sink_size,
            mode="escape_prob",
        )
        
        tm_scores = run_random_walks(
            adj_indices, adj_weights, source_nodes,
            num_walkers=self.num_walkers // 3,
            sink_size=self.sink_size,
            mode="temp_mix",
        )
        
        # Combine with max (union)
        circuit_scores = torch.maximum(ht_scores, torch.maximum(ep_scores, tm_scores))
        
        h2o = attn_matrix[-W:, :].sum(dim=0)
        h2o = h2o / h2o.max().clamp(min=1e-8)
        
        return torch.maximum(h2o, circuit_scores)

## Generation with Eviction

In [None]:
@torch.no_grad()
def generate_with_strategy(
    model, tokenizer, prompt: str,
    strategy: Optional[BaseStrategy] = None,
    max_length: int = 2048,
    max_new_tokens: int = 64,
) -> str:
    
    inputs = tokenizer(
        prompt, return_tensors="pt", truncation=True,
        max_length=max_length - max_new_tokens,
    ).to(model.device)
    
    input_len = inputs.input_ids.shape[1]
    
    if strategy is None:
        # FullKV - no eviction
        outputs = model.generate(
            **inputs, max_new_tokens=max_new_tokens,
            do_sample=False, pad_token_id=tokenizer.pad_token_id,
        )
        return tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True).strip()
    
    # Run prefill to get attention
    outputs = model(**inputs, output_attentions=True, use_cache=True)
    past_kv = outputs.past_key_values
    
    # Get attention matrix from last layer, first head
    attn = outputs.attentions[-1][0, 0, :, :].float()  # [seq_len, seq_len]
    
    # Compute importance scores
    scores = strategy.compute_scores(attn)
    keep_indices = strategy.get_keep_indices(scores, input_len)
    
    # Evict KV cache
    new_past_kv = []
    for k, v in past_kv:
        new_past_kv.append((k[:, :, keep_indices, :], v[:, :, keep_indices, :]))
    new_past_kv = tuple(new_past_kv)
    
    # Generate
    generated_ids = inputs.input_ids[:, -1:]
    for _ in range(max_new_tokens):
        out = model(input_ids=generated_ids[:, -1:], past_key_values=new_past_kv, use_cache=True)
        next_token = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
        generated_ids = torch.cat([generated_ids, next_token], dim=1)
        new_past_kv = out.past_key_values
        if next_token.item() == tokenizer.eos_token_id:
            break
    
    result = tokenizer.decode(generated_ids[0, 1:], skip_special_tokens=True).strip()
    
    del outputs, past_kv, new_past_kv, attn
    torch.cuda.empty_cache()
    
    return result

## Run Evaluation

In [None]:
# Define strategies - all novel ones are cheap add-ons to RC
strategies = {
    "FullKV": None,
    "H2O": H2OStrategy(budget_ratio=BUDGET_RATIO),
    "RC": RCStrategy(budget_ratio=BUDGET_RATIO),
    "RC+HT": RC_HittingTime(budget_ratio=BUDGET_RATIO),     # Novel: Hitting Time
    "RC+EP": RC_EscapeProb(budget_ratio=BUDGET_RATIO),      # Novel: Escape Probability
    "RC+TM": RC_TempMix(budget_ratio=BUDGET_RATIO),         # Novel: Temperature Mixing
    "RC+ALL": RC_Combined(budget_ratio=BUDGET_RATIO),       # Novel: All combined
}

print(f"Testing {len(strategies)} strategies on {len(samples)} samples")
print("\nNovel modifications (zero extra CUDA kernel cost):")
print("- RC+HT: Hitting Time Weighting")
print("- RC+EP: Escape Probability")
print("- RC+TM: Temperature Mixing")
print("- RC+ALL: All combined")

In [None]:
results = {name: [] for name in strategies}

for i, sample in enumerate(tqdm(samples)):
    print(f"\n--- Sample {i+1} ---")
    print(f"Q: {sample['question'][:80]}...")
    
    for name, strategy in strategies.items():
        try:
            gc.collect()
            torch.cuda.empty_cache()
            
            pred = generate_with_strategy(
                model, tokenizer, sample["prompt"],
                strategy=strategy,
                max_length=MAX_SEQ_LEN,
                max_new_tokens=MAX_NEW_TOKENS,
            )
            
            f1 = best_f1(pred, sample["answers"])
            results[name].append(f1)
            print(f"  {name:8s}: F1={f1:.3f} | {pred[:50]}...")
            
        except Exception as e:
            print(f"  {name}: ERROR - {e}")
            results[name].append(0.0)

In [None]:
print("\n" + "="*70)
print("RESULTS: Novel RC Add-ons for ICML 2026")
print("="*70)

stats = {}
for name, scores in results.items():
    if scores:
        stats[name] = {"mean": np.mean(scores), "std": np.std(scores)}

sorted_names = sorted(stats.keys(), key=lambda x: stats[x]["mean"], reverse=True)
fullkv_mean = stats.get("FullKV", {}).get("mean", 0)
h2o_mean = stats.get("H2O", {}).get("mean", 0)
rc_mean = stats.get("RC", {}).get("mean", 0)

print(f"\n{'Strategy':<10} {'Mean F1':>10} {'vs RC':>10} {'vs H2O':>10} {'Novel?':>8}")
print("-" * 55)

novel_markers = {
    "FullKV": "",
    "H2O": "baseline",
    "RC": "current",
    "RC+HT": "NEW",
    "RC+EP": "NEW",
    "RC+TM": "NEW",
    "RC+ALL": "NEW",
}

for name in sorted_names:
    s = stats[name]
    diff_rc = s["mean"] - rc_mean if name not in ["FullKV", "H2O", "RC"] else 0
    diff_h2o = s["mean"] - h2o_mean if name not in ["FullKV", "H2O"] else 0
    
    rc_str = f"{diff_rc:+.3f}" if diff_rc != 0 else "---"
    h2o_str = f"{diff_h2o:+.3f}" if name not in ["FullKV", "H2O"] else "---"
    
    print(f"{name:<10} {s['mean']:>10.3f} {rc_str:>10} {h2o_str:>10} {novel_markers.get(name, ''):>8}")

# Find best novel strategy
novel_strategies = ["RC+HT", "RC+EP", "RC+TM", "RC+ALL"]
best_novel = None
best_novel_score = 0
for name in novel_strategies:
    if name in stats and stats[name]["mean"] > best_novel_score:
        best_novel_score = stats[name]["mean"]
        best_novel = name

print("\n" + "="*70)
print("ICML 2026 RECOMMENDATION")
print("="*70)

if best_novel:
    improvement_over_rc = best_novel_score - rc_mean
    improvement_over_h2o = best_novel_score - h2o_mean
    gap_to_full = fullkv_mean - best_novel_score
    
    print(f"\nBest NOVEL add-on: {best_novel}")
    print(f"F1 score: {best_novel_score:.3f}")
    print(f"Improvement over RC (current): {improvement_over_rc:+.3f}")
    print(f"Improvement over H2O: {improvement_over_h2o:+.3f}")
    print(f"Gap to FullKV: {gap_to_full:.3f}")
    
    if improvement_over_rc > 0.02:
        print(f"\n=> IMPLEMENT {best_novel} in CUDA kernel!")
        print("   This is a zero-cost improvement worth pursuing.")
    elif improvement_over_rc > 0:
        print(f"\n=> MARGINAL: {best_novel} shows slight improvement. Run more samples.")
    else:
        print("\n=> Current RC is already optimal. Focus on other innovations.")

In [None]:
# Cleanup
del model
gc.collect()
torch.cuda.empty_cache()
print("Done!")