# On the Limits of Learned Importance Scoring for KV Cache Compression

## Evaluation Experiments Notebook

This notebook contains all experiments for our negative results paper investigating learned KV cache compression.

**Key Finding:** Despite architectural sophistication (multi-horizon lookahead, cross-attention, confidence weighting), our learned scorer (SIP, 1.7M parameters) does **not** outperform simple heuristics—and random selection often performs comparably.

### Main Results Summary
- **Position-Heuristic** wins at aggressive compression (10%, 25%)
- **Prefill-Attn** wins at moderate compression (50%, 75%)  
- **SIP ≈ Random** (no statistically significant difference)

### Experiments Included
1. Multi-seed perplexity evaluation (5 seeds, 95% CI, paired t-tests)
2. Task-based evaluation (QA accuracy, Needle-in-Haystack)
3. Component ablation study
4. Lookahead prediction accuracy
5. Confidence calibration analysis

**Hardware:** A100 GPU recommended  
**Runtime:** ~2-4 hours for full evaluation

In [None]:
!pip install -q torch transformers datasets accelerate
!pip install -q scipy scikit-learn matplotlib seaborn pandas tqdm

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, asdict
from collections import deque
from tqdm.auto import tqdm
from scipy import stats
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import math
import json
import os
import warnings
warnings.filterwarnings('ignore')

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## 1. Configuration

In [None]:
@dataclass
class SIPConfig:
    hidden_dim: int = 2048
    head_dim: int = 64
    num_heads: int = 32
    num_kv_heads: int = 4
    lookahead_steps: int = 8
    speculation_horizon: int = 16
    encoder_hidden_dim: int = 256
    encoder_layers: int = 2
    predictor_hidden_dim: int = 128
    num_predictor_heads: int = 4
    dropout: float = 0.1
    use_confidence_weighting: bool = True
    confidence_temperature: float = 1.0
    discount_factor: float = 0.95


TRAINING_CONFIG = {
    "num_train_samples": 5000,
    "num_val_samples": 500,
    "batch_size": 4,
    "epochs": 10,
    "lr": 1e-4,
    "weight_decay": 0.01,
    "warmup_steps": 500,
    "gradient_clip": 1.0,
    "accumulation_steps": 4,
    "rollout_length": 32,
    "initial_lookahead": 1,
    "max_lookahead": 8,
    "curriculum_warmup": 5,
    "kl_weight": 1.0,
    "ranking_weight": 0.5,
    "confidence_weight": 0.3,
    "temporal_weight": 0.1,
}

## 2. Improved Loss Function

In [None]:
class ImprovedSIPLoss(nn.Module):
    def __init__(
        self,
        kl_weight: float = 1.0,
        ranking_weight: float = 0.5,
        confidence_weight: float = 0.3,
        temporal_weight: float = 0.1,
        focal_gamma: float = 2.0,
        temperature: float = 10.0,
    ):
        super().__init__()
        self.kl_weight = kl_weight
        self.ranking_weight = ranking_weight
        self.confidence_weight = confidence_weight
        self.temporal_weight = temporal_weight
        self.focal_gamma = focal_gamma
        self.temperature = temperature

    def forward(
        self,
        predicted_importance: torch.Tensor,
        confidence: torch.Tensor,
        future_attention: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> Dict[str, torch.Tensor]:
        losses = {}
        batch, heads, seq_len, lookahead = predicted_importance.shape
        device = predicted_importance.device

        rollout_steps = future_attention.shape[2]
        actual_lookahead = min(lookahead, rollout_steps)

        future_attention_truncated = future_attention[:, :, :actual_lookahead, :]
        target_soft_ranks = F.softmax(future_attention_truncated * self.temperature, dim=-1)
        target_soft_ranks = target_soft_ranks.permute(0, 1, 3, 2)

        predicted_truncated = predicted_importance[:, :, :, :actual_lookahead]
        confidence_truncated = confidence[:, :, :, :actual_lookahead]

        pred_log_probs = F.log_softmax(predicted_truncated * self.temperature, dim=2)
        kl_loss = F.kl_div(pred_log_probs, target_soft_ranks, reduction='batchmean')
        losses['kl'] = self.kl_weight * kl_loss

        pred_flat = predicted_truncated[:, :, :, 0].reshape(-1, seq_len)
        target_flat = target_soft_ranks[:, :, :, 0].reshape(-1, seq_len)

        k = min(32, seq_len // 4)
        if k > 1:
            _, top_idx = target_flat.topk(k, dim=-1)
            _, bot_idx = target_flat.topk(k, dim=-1, largest=False)
            pred_top = pred_flat.gather(1, top_idx).mean(dim=1)
            pred_bot = pred_flat.gather(1, bot_idx).mean(dim=1)
            margin = 0.2
            ranking_loss = F.relu(margin - (pred_top - pred_bot)).mean()
            losses['ranking'] = self.ranking_weight * ranking_loss
        else:
            losses['ranking'] = torch.tensor(0.0, device=device)

        pred_probs = F.softmax(predicted_truncated * self.temperature, dim=2)
        pred_error = (pred_probs - target_soft_ranks).abs().mean(dim=-1)
        target_confidence = (1.0 - pred_error).clamp(0, 1)

        conf_mean = confidence_truncated.mean(dim=-1)
        pt = conf_mean * target_confidence + (1 - conf_mean) * (1 - target_confidence)
        focal_weight = (1 - pt) ** self.focal_gamma
        conf_loss = (focal_weight * F.mse_loss(conf_mean, target_confidence.detach(), reduction='none')).mean()
        losses['confidence'] = self.confidence_weight * conf_loss

        if actual_lookahead > 1:
            temporal_diff = (predicted_truncated[:, :, :, 1:] - predicted_truncated[:, :, :, :-1]).abs()
            losses['temporal'] = self.temporal_weight * temporal_diff.mean()
        else:
            losses['temporal'] = torch.tensor(0.0, device=device)

        losses['total'] = sum(v for k, v in losses.items() if k != 'total')
        return losses

In [None]:
class ListMLELoss(nn.Module):
    def __init__(self, temperature: float = 1.0, eps: float = 1e-10):
        super().__init__()
        self.temperature = temperature
        self.eps = eps
    
    def forward(
        self,
        predicted_scores: torch.Tensor,
        target_scores: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if predicted_scores.dim() == 3:
            predicted_scores = predicted_scores.mean(dim=1)
            target_scores = target_scores.mean(dim=1)
        
        batch_size, seq_len = predicted_scores.shape
        device = predicted_scores.device
        
        _, sorted_indices = target_scores.sort(dim=-1, descending=True)
        sorted_preds = predicted_scores.gather(1, sorted_indices) / self.temperature
        
        if mask is not None:
            if mask.dim() == 1:
                mask = mask.unsqueeze(0).expand(batch_size, -1)
            sorted_mask = mask.gather(1, sorted_indices)
            sorted_preds = sorted_preds.masked_fill(~sorted_mask.bool(), float('-inf'))
        
        max_pred = sorted_preds.max(dim=-1, keepdim=True)[0]
        shifted = sorted_preds - max_pred
        
        cumsumexp = torch.zeros(batch_size, seq_len + 1, device=device)
        for i in range(seq_len - 1, -1, -1):
            cumsumexp[:, i] = torch.log(torch.exp(shifted[:, i]) + torch.exp(cumsumexp[:, i + 1]) + self.eps)
        
        log_probs = shifted - cumsumexp[:, :-1]
        
        if mask is not None:
            loss = -(log_probs * sorted_mask.float()).sum(dim=-1) / (sorted_mask.float().sum(dim=-1) + self.eps)
        else:
            loss = -log_probs.mean(dim=-1)
        
        return loss.mean()


class PairwiseRankingLoss(nn.Module):
    def __init__(self, margin: float = 0.1, num_pairs: int = 256):
        super().__init__()
        self.margin = margin
        self.num_pairs = num_pairs
    
    def forward(
        self,
        predicted_scores: torch.Tensor,
        target_scores: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if predicted_scores.dim() == 3:
            predicted_scores = predicted_scores.mean(dim=1)
            target_scores = target_scores.mean(dim=1)
        
        batch_size, seq_len = predicted_scores.shape
        device = predicted_scores.device
        num_pairs = min(self.num_pairs, seq_len * (seq_len - 1) // 2)
        
        total_loss = 0.0
        valid_pairs = 0
        
        for b in range(batch_size):
            idx1 = torch.randint(0, seq_len, (num_pairs,), device=device)
            idx2 = torch.randint(0, seq_len, (num_pairs,), device=device)
            
            pred1 = predicted_scores[b, idx1]
            pred2 = predicted_scores[b, idx2]
            target1 = target_scores[b, idx1]
            target2 = target_scores[b, idx2]
            
            target_diff = target1 - target2
            sign = torch.sign(target_diff)
            pred_diff = pred1 - pred2
            pair_loss = F.relu(self.margin - sign * pred_diff)
            
            meaningful = target_diff.abs() > 0.01
            if meaningful.any():
                total_loss += pair_loss[meaningful].mean()
                valid_pairs += 1
        
        if valid_pairs == 0:
            return torch.tensor(0.0, device=device)
        
        return total_loss / valid_pairs


class EnhancedSIPLoss(nn.Module):
    def __init__(
        self,
        listmle_weight: float = 1.0,
        pairwise_weight: float = 0.5,
        topk_weight: float = 0.3,
        confidence_weight: float = 0.2,
        focal_gamma: float = 2.0,
        top_k_ratio: float = 0.25,
    ):
        super().__init__()
        self.listmle_weight = listmle_weight
        self.pairwise_weight = pairwise_weight
        self.topk_weight = topk_weight
        self.confidence_weight = confidence_weight
        self.focal_gamma = focal_gamma
        self.top_k_ratio = top_k_ratio
        
        self.listmle = ListMLELoss()
        self.pairwise = PairwiseRankingLoss()
    
    def forward(
        self,
        predicted_importance: torch.Tensor,
        confidence: torch.Tensor,
        future_attention: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> Dict[str, torch.Tensor]:
        losses = {}
        batch, heads, seq_len, lookahead = predicted_importance.shape
        device = predicted_importance.device
        
        rollout_steps = future_attention.shape[2]
        actual_lookahead = min(lookahead, rollout_steps)
        
        pred = predicted_importance[:, :, :, 0]
        target = future_attention[:, :, 0, :]
        target = target / (target.sum(dim=-1, keepdim=True) + 1e-8)
        
        pred_flat = pred.view(batch, -1)
        target_flat = target.view(batch, -1)
        losses['listmle'] = self.listmle_weight * self.listmle(pred_flat, target_flat)
        losses['pairwise'] = self.pairwise_weight * self.pairwise(pred_flat, target_flat)
        
        k = max(1, int(seq_len * self.top_k_ratio))
        _, pred_topk = pred.view(batch, -1).topk(k * heads, dim=-1)
        _, target_topk = target.view(batch, -1).topk(k * heads, dim=-1)
        
        pred_mask = torch.zeros(batch, heads * seq_len, device=device)
        target_mask = torch.zeros(batch, heads * seq_len, device=device)
        pred_mask.scatter_(1, pred_topk, 1.0)
        target_mask.scatter_(1, target_topk, 1.0)
        
        intersection = (pred_mask * target_mask).sum(dim=-1)
        recall = intersection / (target_mask.sum(dim=-1) + 1e-8)
        losses['topk_recall'] = self.topk_weight * (1.0 - recall.mean())
        
        conf = confidence[:, :, :, 0]
        pred_normalized = pred / (pred.max(dim=-1, keepdim=True)[0] + 1e-8)
        error = (pred_normalized - target).abs().mean(dim=1)
        accuracy = (1.0 - error).clamp(0, 1)
        
        conf_mean = conf.mean(dim=1)
        pt = conf_mean * accuracy + (1 - conf_mean) * (1 - accuracy)
        focal_weight = (1 - pt) ** self.focal_gamma
        conf_loss = (focal_weight * F.mse_loss(conf_mean, accuracy.detach(), reduction='none')).mean()
        losses['confidence'] = self.confidence_weight * conf_loss
        
        losses['total'] = sum(v for k, v in losses.items() if k != 'total')
        return losses

## 3. SIP Architecture

In [None]:
class TemporalPositionEncoding(nn.Module):
    def __init__(self, dim: int, max_len: int = 8192, max_lookahead: int = 64):
        super().__init__()
        self.dim = dim
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, dim, 2) * (-math.log(10000.0) / dim))
        pe = torch.zeros(max_len, dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

        relative = torch.arange(-max_len, max_len).unsqueeze(1)
        rel_pe = torch.zeros(2 * max_len, dim)
        rel_pe[:, 0::2] = torch.sin(relative * div_term)
        rel_pe[:, 1::2] = torch.cos(relative * div_term)
        self.register_buffer('rel_pe', rel_pe)

        lookahead = torch.arange(max_lookahead).unsqueeze(1)
        look_pe = torch.zeros(max_lookahead, dim // 2)
        look_div = torch.exp(torch.arange(0, dim // 2, 2) * (-math.log(10000.0) / (dim // 2)))
        look_pe[:, 0::2] = torch.sin(lookahead * look_div)
        look_pe[:, 1::2] = torch.cos(lookahead * look_div)
        self.register_buffer('look_pe', look_pe)

    def forward(self, positions: torch.Tensor, current_step: int, lookahead_step: int = 0) -> torch.Tensor:
        batch, seq_len = positions.shape
        abs_enc = self.pe[positions.clamp(0, self.pe.shape[0] - 1)]
        rel_positions = current_step - positions + self.rel_pe.shape[0] // 2
        rel_positions = rel_positions.clamp(0, self.rel_pe.shape[0] - 1)
        rel_enc = self.rel_pe[rel_positions]
        look_enc = self.look_pe[min(lookahead_step, self.look_pe.shape[0] - 1)]
        look_enc = look_enc.unsqueeze(0).unsqueeze(0).expand(batch, seq_len, -1)
        combined = abs_enc + rel_enc
        combined[:, :, :look_enc.shape[-1]] = combined[:, :, :look_enc.shape[-1]] + look_enc
        return combined


class KeyValueEncoder(nn.Module):
    def __init__(self, config: SIPConfig):
        super().__init__()
        self.config = config
        self.key_encoder = nn.Sequential(
            nn.Linear(config.head_dim, config.encoder_hidden_dim),
            nn.LayerNorm(config.encoder_hidden_dim),
            nn.GELU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.encoder_hidden_dim, config.encoder_hidden_dim),
        )
        self.value_encoder = nn.Sequential(
            nn.Linear(config.head_dim, config.encoder_hidden_dim),
            nn.LayerNorm(config.encoder_hidden_dim),
            nn.GELU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.encoder_hidden_dim, config.encoder_hidden_dim),
        )
        self.kv_attention = nn.MultiheadAttention(
            embed_dim=config.encoder_hidden_dim,
            num_heads=4,
            dropout=config.dropout,
            batch_first=True,
        )
        self.output_proj = nn.Linear(config.encoder_hidden_dim * 2, config.encoder_hidden_dim)

    def forward(self, keys: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
        batch, heads, seq_len, head_dim = keys.shape
        keys_flat = keys.view(batch * heads, seq_len, head_dim)
        values_flat = values.view(batch * heads, seq_len, head_dim)
        key_enc = self.key_encoder(keys_flat)
        value_enc = self.value_encoder(values_flat)
        kv_combined, _ = self.kv_attention(key_enc, value_enc, value_enc)
        combined = torch.cat([key_enc, kv_combined], dim=-1)
        output = self.output_proj(combined)
        return output.view(batch, heads, seq_len, -1)


class SpeculativePredictor(nn.Module):
    def __init__(self, config: SIPConfig):
        super().__init__()
        self.config = config
        self.temporal_encoding = TemporalPositionEncoding(
            config.encoder_hidden_dim, max_lookahead=config.speculation_horizon
        )
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.encoder_hidden_dim,
            nhead=config.num_predictor_heads,
            dim_feedforward=config.predictor_hidden_dim * 4,
            dropout=config.dropout,
            activation='gelu',
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config.encoder_layers)
        self.lookahead_queries = nn.Parameter(
            torch.randn(config.lookahead_steps, config.encoder_hidden_dim) * 0.02
        )
        self.importance_head = nn.Sequential(
            nn.LayerNorm(config.encoder_hidden_dim),
            nn.Linear(config.encoder_hidden_dim, config.predictor_hidden_dim),
            nn.GELU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.predictor_hidden_dim, 1),
        )
        self.confidence_head = nn.Sequential(
            nn.LayerNorm(config.encoder_hidden_dim),
            nn.Linear(config.encoder_hidden_dim, config.predictor_hidden_dim),
            nn.GELU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.predictor_hidden_dim, 1),
        )
        self.head_bias = nn.Parameter(torch.zeros(config.num_kv_heads))
        self.temperature = nn.Parameter(torch.ones(1))

    def forward(
        self, kv_features: torch.Tensor, positions: torch.Tensor,
        current_step: int, lookahead_steps: Optional[int] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if lookahead_steps is None:
            lookahead_steps = self.config.lookahead_steps
        batch, heads, seq_len, hidden = kv_features.shape
        dtype = kv_features.dtype
        all_importance, all_confidence = [], []

        for step in range(lookahead_steps):
            temp_enc = self.temporal_encoding(positions, current_step, step)
            temp_enc = temp_enc.unsqueeze(1).expand(-1, heads, -1, -1)
            features = kv_features + temp_enc.to(dtype)
            query = self.lookahead_queries[min(step, len(self.lookahead_queries)-1)]
            query = query.view(1, 1, 1, -1).expand(batch, heads, seq_len, -1)
            features = features + query.to(dtype)
            features_flat = features.view(batch * heads, seq_len, hidden)
            transformed = self.transformer(features_flat.float()).to(dtype)
            importance = self.importance_head(transformed.float()).squeeze(-1)
            confidence = self.confidence_head(transformed.float()).squeeze(-1)
            importance = importance.view(batch, heads, seq_len)
            confidence = confidence.view(batch, heads, seq_len)
            importance = importance + self.head_bias.view(1, -1, 1).to(dtype)
            importance = torch.sigmoid(importance / self.temperature)
            confidence = torch.sigmoid(confidence)
            all_importance.append(importance)
            all_confidence.append(confidence)

        return torch.stack(all_importance, dim=-1), torch.stack(all_confidence, dim=-1)


class SpeculativeImportanceScorer(nn.Module):
    def __init__(self, config: SIPConfig):
        super().__init__()
        self.config = config
        self.kv_encoder = KeyValueEncoder(config)
        self.predictor = SpeculativePredictor(config)

    def forward(
        self, keys: torch.Tensor, values: torch.Tensor, positions: torch.Tensor,
        current_step: Optional[int] = None, return_all_lookahead: bool = False,
    ) -> torch.Tensor:
        batch, heads, seq_len, head_dim = keys.shape
        if current_step is None:
            current_step = seq_len
        kv_features = self.kv_encoder(keys.float(), values.float())
        importance_pred, confidence = self.predictor(kv_features, positions, current_step)
        if return_all_lookahead:
            return importance_pred
        weights = F.softmax(confidence / self.config.confidence_temperature, dim=-1)
        importance = (importance_pred * weights).sum(dim=-1)
        return importance

## 4. KV Cache Compression

In [None]:
from transformers.cache_utils import DynamicCache

def get_cache_keys_values(cache, layer_idx: int = 0):
    if isinstance(cache, DynamicCache):
        try:
            return cache[layer_idx]
        except (IndexError, TypeError):
            pass
        if hasattr(cache, 'key_cache') and cache.key_cache:
            return cache.key_cache[layer_idx], cache.value_cache[layer_idx]
    elif isinstance(cache, (list, tuple)) and len(cache) > 0:
        return cache[layer_idx]
    raise ValueError(f"Unknown cache format: {type(cache)}")


def get_cache_length(cache) -> int:
    if isinstance(cache, DynamicCache):
        return len(cache)
    elif isinstance(cache, (list, tuple)):
        return len(cache)
    return 0


@dataclass
class CompressionConfig:
    attention_sink_size: int = 4
    recent_window_size: int = 64
    min_cache_size: int = 128


class HardEvictionCache:
    def __init__(self, config: CompressionConfig = None):
        self.config = config or CompressionConfig()

    def _get_keep_indices(self, importance, seq_len, retention_ratio, device, batch_size=1):
        budget = max(int(seq_len * retention_ratio), self.config.min_cache_size)
        budget = min(budget, seq_len)

        base_mask = torch.zeros(seq_len, dtype=torch.bool, device=device)
        sink_end = min(self.config.attention_sink_size, seq_len)
        base_mask[:sink_end] = True
        recent_start = max(0, seq_len - self.config.recent_window_size)
        base_mask[recent_start:] = True

        reserved = base_mask.sum().item()
        remaining_budget = max(0, budget - reserved)

        if importance.dim() == 1:
            keep_mask = base_mask.clone()
            if remaining_budget > 0:
                importance_masked = importance.clone()
                importance_masked[base_mask] = -float('inf')
                available = (~base_mask).sum().item()
                k = min(remaining_budget, available)
                if k > 0:
                    _, top_indices = importance_masked.topk(k)
                    keep_mask[top_indices] = True
            keep_indices = keep_mask.nonzero(as_tuple=True)[0]
        else:
            batch = importance.shape[0]
            keep_masks = base_mask.unsqueeze(0).expand(batch, -1).clone()
            
            middle_len = seq_len - self.config.attention_sink_size - self.config.recent_window_size
            k = min(remaining_budget, max(0, middle_len))
            
            if k > 0:
                for b in range(batch):
                    scores = importance[b].clone()
                    scores[base_mask] = -float('inf')
                    _, top_indices = scores.topk(k, dim=-1)
                    keep_masks[b, top_indices] = True
            
            union_mask = keep_masks.any(dim=0)
            keep_indices = union_mask.nonzero(as_tuple=True)[0]

        return keep_indices.sort()[0]

    def compress_cache(self, past_key_values, importance_scores, retention_ratio):
        num_layers = get_cache_length(past_key_values)
        if num_layers == 0:
            return past_key_values

        keys, values = get_cache_keys_values(past_key_values, 0)
        seq_len = keys.shape[2]
        batch_size = keys.shape[0]
        device = keys.device

        if importance_scores.dim() == 3:
            avg_importance = importance_scores.mean(dim=1)
        elif importance_scores.dim() == 2:
            avg_importance = importance_scores
        else:
            avg_importance = importance_scores.unsqueeze(0)
        avg_importance = avg_importance.to(device)

        keep_indices = self._get_keep_indices(avg_importance, seq_len, retention_ratio, device, batch_size)

        new_cache = DynamicCache()
        for layer_idx in range(num_layers):
            keys, values = get_cache_keys_values(past_key_values, layer_idx)
            new_keys = keys.index_select(2, keep_indices)
            new_values = values.index_select(2, keep_indices)
            new_cache.update(new_keys, new_values, layer_idx)

        return new_cache


class H2OImportanceScorer:
    def __init__(self, decay: float = 0.9):
        self.decay = decay
        self.accumulated_attention = None

    def reset(self):
        self.accumulated_attention = None

    def __call__(self, keys, values, attentions):
        if attentions.dim() == 4:
            current = attentions[:, :, -1, :]
        else:
            current = attentions
        if self.accumulated_attention is None:
            self.accumulated_attention = current
        else:
            if self.accumulated_attention.shape[-1] < current.shape[-1]:
                pad_len = current.shape[-1] - self.accumulated_attention.shape[-1]
                self.accumulated_attention = F.pad(self.accumulated_attention, (0, pad_len))
            self.accumulated_attention = self.decay * self.accumulated_attention + current
        return self.accumulated_attention


class SnapKVImportanceScorer:
    def __init__(self, window_size=32, kernel_size=5):
        self.window_size = window_size
        self.kernel_size = kernel_size

    def __call__(self, keys, values, attentions):
        if attentions.dim() == 4:
            batch, heads, query_len, kv_len = attentions.shape
            obs_start = max(0, query_len - self.window_size)
            obs_attn = attentions[:, :, obs_start:, :]
            summed_attn = obs_attn.sum(dim=2)
        else:
            summed_attn = attentions
        return summed_attn


class RecentImportanceScorer:
    def __call__(self, keys, values, attentions):
        seq_len = keys.shape[2]
        device = keys.device
        importance = torch.arange(seq_len, device=device, dtype=torch.float32) / seq_len
        batch, heads = keys.shape[0], keys.shape[1]
        return importance.unsqueeze(0).unsqueeze(0).expand(batch, heads, -1)


class RandomImportanceScorer:
    def __call__(self, keys, values, attentions):
        batch, heads, seq_len, _ = keys.shape
        return torch.rand(batch, heads, seq_len, device=keys.device)


class StreamingLLMScorer:
    def __init__(self, sink_size=4, window_size=252):
        self.sink_size = sink_size
        self.window_size = window_size

    def reset(self):
        pass

    def __call__(self, keys, values, attentions):
        batch, heads, seq_len, _ = keys.shape
        device = keys.device
        importance = torch.zeros(batch, heads, seq_len, device=device)
        importance[:, :, :self.sink_size] = 1.0
        if seq_len > self.sink_size:
            window_start = max(self.sink_size, seq_len - self.window_size)
            importance[:, :, window_start:] = 1.0
        return importance


class ExpectedAttentionScorer:
    def __init__(self, query_history_size=32, temperature=1.0):
        self.query_history_size = query_history_size
        self.temperature = temperature
        self.query_mean = None
        self.query_cov_diag = None

    def reset(self):
        self.query_mean = None
        self.query_cov_diag = None

    def __call__(self, keys, values, attentions):
        batch, heads, seq_len, head_dim = keys.shape
        device = keys.device
        dtype = keys.dtype

        recent_start = max(0, seq_len - self.query_history_size)
        recent_keys = keys[:, :, recent_start:, :].float()
        
        self.query_mean = recent_keys.mean(dim=2)
        self.query_cov_diag = recent_keys.var(dim=2) + 1e-8

        keys_float = keys.float()
        mu_term = torch.einsum('bhd,bhsd->bhs', self.query_mean, keys_float) / (head_dim ** 0.5)
        quad_term = ((keys_float ** 2) * self.query_cov_diag.unsqueeze(2)).sum(dim=-1) / (2 * head_dim)
        
        log_importance = (mu_term + quad_term) / self.temperature
        importance = torch.softmax(log_importance, dim=-1)
        
        return importance.to(dtype)


class TRIMKVScorer:
    def __init__(self, decay_base=0.95):
        self.decay_base = decay_base

    def reset(self):
        pass

    def __call__(self, keys, values, attentions):
        batch, heads, seq_len, head_dim = keys.shape
        device = keys.device
        dtype = keys.dtype

        key_norms = keys.float().norm(dim=-1)
        base_importance = key_norms / (key_norms.max(dim=-1, keepdim=True)[0] + 1e-8)

        positions = torch.arange(seq_len, device=device).float()
        relative_age = (seq_len - 1 - positions) / max(seq_len, 1)
        decay = self.decay_base ** relative_age
        decay = decay.unsqueeze(0).unsqueeze(0)

        importance = base_importance * decay
        importance = importance / (importance.max(dim=-1, keepdim=True)[0] + 1e-8)

        return importance.to(dtype)


class WriteGatedKVScorer:
    def __init__(self, attention_weight=0.3):
        self.attention_weight = attention_weight

    def reset(self):
        pass

    def __call__(self, keys, values, attentions):
        batch, heads, seq_len, head_dim = keys.shape
        device = keys.device
        dtype = keys.dtype

        key_norms = keys.float().norm(dim=-1)
        value_norms = values.float().norm(dim=-1)
        
        kv_importance = (key_norms + value_norms) / 2
        kv_importance = kv_importance / (kv_importance.max(dim=-1, keepdim=True)[0] + 1e-8)

        if attentions is not None:
            if attentions.dim() == 4:
                attn = attentions[:, :, -1, :]
            else:
                attn = attentions
            
            if attn.shape[-1] >= seq_len:
                attn = attn[:, :, :seq_len]
            else:
                attn = F.pad(attn, (0, seq_len - attn.shape[-1]))
            
            attn_norm = attn / (attn.max(dim=-1, keepdim=True)[0] + 1e-8)
            importance = (1 - self.attention_weight) * kv_importance + self.attention_weight * attn_norm.to(kv_importance.dtype)
        else:
            importance = kv_importance

        return importance.to(dtype)


def create_sip_importance_scorer(sip_model):
    class SIPWrapper:
        def __init__(self, model):
            self.model = model
            self.model.eval()

        def reset(self):
            pass

        def __call__(self, keys, values, attentions):
            with torch.no_grad():
                positions = torch.arange(keys.shape[2], device=keys.device).unsqueeze(0)
                importance = self.model(keys, values, positions)
            return importance

    return SIPWrapper(sip_model)


SCORER_REGISTRY = {
    'h2o': H2OImportanceScorer,
    'snapkv': SnapKVImportanceScorer,
    'streamingllm': StreamingLLMScorer,
    'recent': RecentImportanceScorer,
    'random': RandomImportanceScorer,
    'expected_attention': ExpectedAttentionScorer,
    'trimkv': TRIMKVScorer,
    'write_gated': WriteGatedKVScorer,
}

def get_scorer(name, **kwargs):
    if name not in SCORER_REGISTRY:
        raise ValueError(f"Unknown scorer: {name}. Available: {list(SCORER_REGISTRY.keys())}")
    return SCORER_REGISTRY[name](**kwargs)


compression_config = CompressionConfig(
    attention_sink_size=4,
    recent_window_size=32,
    min_cache_size=64,
)

In [None]:
class PyramidKVScorer:
    def __init__(self, num_layers=22, base_ratio=0.5):
        self.num_layers = num_layers
        self.base_ratio = base_ratio
        self.layer_ratios = self._compute_pyramid_ratios()
    
    def _compute_pyramid_ratios(self):
        ratios = []
        for i in range(self.num_layers):
            layer_ratio = 0.3 + 0.4 * (i / max(1, self.num_layers - 1))
            ratios.append(layer_ratio * self.base_ratio)
        return ratios
    
    def reset(self):
        pass
    
    def __call__(self, keys, values, attentions):
        batch, heads, seq_len, head_dim = keys.shape
        device = keys.device
        dtype = keys.dtype
        
        if attentions is not None:
            if attentions.dim() == 4:
                importance = attentions[:, :, -1, :seq_len]
            else:
                importance = attentions[:, :, :seq_len]
        else:
            importance = torch.ones(batch, heads, seq_len, device=device)
            recency = torch.linspace(0.5, 1.0, seq_len, device=device)
            importance = importance * recency.unsqueeze(0).unsqueeze(0)
        
        importance = importance / (importance.max(dim=-1, keepdim=True)[0] + 1e-8)
        return importance.to(dtype)


class AdaKVScorer:
    def __init__(self, min_ratio=0.25, max_ratio=0.9):
        self.min_ratio = min_ratio
        self.max_ratio = max_ratio
    
    def reset(self):
        pass
    
    def __call__(self, keys, values, attentions):
        batch, heads, seq_len, head_dim = keys.shape
        device = keys.device
        dtype = keys.dtype
        
        if attentions is not None:
            if attentions.dim() == 4:
                attn = attentions[:, :, -1, :seq_len]
            else:
                attn = attentions[:, :, :seq_len]
            
            attn_probs = attn / (attn.sum(dim=-1, keepdim=True) + 1e-8)
            entropy = -(attn_probs * (attn_probs + 1e-8).log()).sum(dim=-1)
            max_entropy = np.log(seq_len)
            normalized_entropy = entropy / max_entropy
            
            difficulty = normalized_entropy.mean()
            adaptive_weight = self.min_ratio + (self.max_ratio - self.min_ratio) * difficulty
            
            importance = attn * adaptive_weight
        else:
            importance = torch.ones(batch, heads, seq_len, device=device)
        
        importance = importance / (importance.max(dim=-1, keepdim=True)[0] + 1e-8)
        return importance.to(dtype)


class DMCScorer:
    def __init__(self, num_heads=32, num_kv_heads=4):
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.head_weights = torch.ones(num_kv_heads)
    
    def reset(self):
        pass
    
    def __call__(self, keys, values, attentions):
        batch, heads, seq_len, head_dim = keys.shape
        device = keys.device
        dtype = keys.dtype
        
        key_norms = keys.float().norm(dim=-1)
        value_norms = values.float().norm(dim=-1)
        
        head_weights = self.head_weights.to(device).unsqueeze(0).unsqueeze(-1)
        importance = (key_norms + value_norms) / 2 * head_weights
        
        if attentions is not None:
            if attentions.dim() == 4:
                attn = attentions[:, :, -1, :seq_len]
            else:
                attn = attentions[:, :, :seq_len]
            importance = 0.5 * importance + 0.5 * attn.to(importance.dtype)
        
        importance = importance / (importance.max(dim=-1, keepdim=True)[0] + 1e-8)
        return importance.to(dtype)


class KIVIScorer:
    def __init__(self):
        pass
    
    def reset(self):
        pass
    
    def __call__(self, keys, values, attentions):
        batch, heads, seq_len, head_dim = keys.shape
        device = keys.device
        dtype = keys.dtype
        
        keys_float = keys.float()
        values_float = values.float()
        
        key_var = keys_float.var(dim=-1)
        value_var = values_float.var(dim=-1)
        
        key_mag = keys_float.norm(dim=-1)
        value_mag = values_float.norm(dim=-1)
        
        importance = (key_var + value_var) * (key_mag + value_mag)
        importance = importance / (importance.max(dim=-1, keepdim=True)[0] + 1e-8)
        
        return importance.to(dtype)


class TOVAScorer:
    def __init__(self, window_size=64):
        self.window_size = window_size
        self.attention_history = None
    
    def reset(self):
        self.attention_history = None
    
    def __call__(self, keys, values, attentions):
        batch, heads, seq_len, head_dim = keys.shape
        device = keys.device
        dtype = keys.dtype
        
        if attentions is not None:
            if attentions.dim() == 4:
                query_len = attentions.shape[2]
                window_start = max(0, query_len - self.window_size)
                attn_window = attentions[:, :, window_start:, :seq_len]
                importance = attn_window.sum(dim=2)
            else:
                importance = attentions[:, :, :seq_len]
            
            if self.attention_history is None:
                self.attention_history = importance
            else:
                if self.attention_history.shape[-1] < seq_len:
                    pad = seq_len - self.attention_history.shape[-1]
                    self.attention_history = F.pad(self.attention_history, (0, pad))
                self.attention_history = 0.9 * self.attention_history + 0.1 * importance
                importance = self.attention_history
        else:
            importance = torch.ones(batch, heads, seq_len, device=device)
        
        importance = importance / (importance.max(dim=-1, keepdim=True)[0] + 1e-8)
        return importance.to(dtype)


SCORER_REGISTRY.update({
    'pyramidkv': PyramidKVScorer,
    'adakv': AdaKVScorer,
    'dmc': DMCScorer,
    'kivi': KIVIScorer,
    'tova': TOVAScorer,
})

## 5. Dataset with GQA Support

In [None]:
class RolloutDataset(Dataset):
    def __init__(self, texts, model, tokenizer, config, max_seq_length=512, rollout_length=32, device='cuda', cache_dir=None):
        self.config = config
        self.device = device
        self.cache_dir = cache_dir
        
        self.num_attention_heads = model.config.num_attention_heads
        self.num_kv_heads = getattr(model.config, 'num_key_value_heads', model.config.num_attention_heads)
        self.head_group_size = self.num_attention_heads // self.num_kv_heads

        self.data = []
        
        if cache_dir and os.path.exists(os.path.join(cache_dir, 'rollout_data.pt')):
            self.data = torch.load(os.path.join(cache_dir, 'rollout_data.pt'))
            print(f"Loaded {len(self.data)} cached samples")
            return
        
        for text in tqdm(texts, desc="Generating rollouts"):
            tokens = tokenizer(text, return_tensors='pt', max_length=max_seq_length, truncation=True)
            if tokens.input_ids.shape[1] < 128:
                continue
            
            item = self._generate_rollout(tokens.input_ids.squeeze(0), model, rollout_length)
            if item is not None:
                self.data.append(item)
        
        print(f"Dataset: {len(self.data)} sequences")
        
        if cache_dir:
            os.makedirs(cache_dir, exist_ok=True)
            torch.save(self.data, os.path.join(cache_dir, 'rollout_data.pt'))

    def __len__(self):
        return len(self.data)

    def _group_attention_heads(self, attn):
        if self.head_group_size == 1:
            return attn
        if attn.dim() == 2:
            h, s = attn.shape
            return attn.view(self.num_kv_heads, self.head_group_size, s).mean(dim=1)
        else:
            b, h, s = attn.shape
            return attn.view(b, self.num_kv_heads, self.head_group_size, s).mean(dim=2)

    @torch.no_grad()
    def _generate_rollout(self, input_ids, model, rollout_length):
        try:
            input_ids = input_ids.unsqueeze(0).to(self.device)
            outputs = model(input_ids, use_cache=True, output_attentions=True, return_dict=True)
            past_kv = outputs.past_key_values
            initial_seq_len = input_ids.shape[1]

            all_future_attention = []
            for step in range(min(rollout_length, self.config.lookahead_steps * 2)):
                next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)
                outputs = model(next_token, past_key_values=past_kv, use_cache=True, output_attentions=True, return_dict=True)
                past_kv = outputs.past_key_values
                layer_attns = [a.squeeze(2) for a in outputs.attentions]
                avg_attn = torch.stack(layer_attns).mean(dim=0).squeeze(0)
                grouped = self._group_attention_heads(avg_attn)
                all_future_attention.append(grouped)

            if all_future_attention:
                max_seq = max(a.shape[-1] for a in all_future_attention)
                padded = [F.pad(a, (0, max_seq - a.shape[-1])) for a in all_future_attention]
                future_attention = torch.stack(padded, dim=1)[:, :, :initial_seq_len]
            else:
                future_attention = torch.zeros(self.num_kv_heads, 1, initial_seq_len, device=self.device)

            keys = past_kv[0][0][:, :, :initial_seq_len, :].squeeze(0)
            values = past_kv[0][1][:, :, :initial_seq_len, :].squeeze(0)
            positions = torch.arange(initial_seq_len, device=self.device)

            return {
                'keys': keys.cpu(),
                'values': values.cpu(),
                'positions': positions.cpu(),
                'future_attention': future_attention.cpu(),
            }
        except Exception as e:
            return None

    def __getitem__(self, idx):
        return self.data[idx]


def collate_fn(batch):
    max_seq = max(item['keys'].shape[1] for item in batch)
    max_rollout = max(item['future_attention'].shape[1] for item in batch)
    keys_list, values_list, positions_list, future_attention_list, masks = [], [], [], [], []
    for item in batch:
        seq_len = item['keys'].shape[1]
        rollout_len = item['future_attention'].shape[1]
        keys_list.append(F.pad(item['keys'], (0, 0, 0, max_seq - seq_len)))
        values_list.append(F.pad(item['values'], (0, 0, 0, max_seq - seq_len)))
        positions_list.append(F.pad(item['positions'], (0, max_seq - seq_len)))
        future_attention_list.append(F.pad(item['future_attention'], (0, max_seq - seq_len, 0, max_rollout - rollout_len)))
        mask = torch.zeros(max_seq)
        mask[:seq_len] = 1.0
        masks.append(mask)
    return {
        'keys': torch.stack(keys_list),
        'values': torch.stack(values_list),
        'positions': torch.stack(positions_list),
        'future_attention': torch.stack(future_attention_list),
        'mask': torch.stack(masks),
    }

## 6. Curriculum Learning Scheduler

In [None]:
class CurriculumScheduler:
    def __init__(self, initial_lookahead=1, max_lookahead=8, warmup_epochs=3):
        self.initial = initial_lookahead
        self.max = max_lookahead
        self.warmup = warmup_epochs

    def get_lookahead(self, epoch):
        if epoch < self.warmup:
            return self.initial
        progress = min(1.0, (epoch - self.warmup) / max(1, self.warmup * 2))
        return min(self.max, self.initial + int(progress * (self.max - self.initial)))


class TemperatureScaling(nn.Module):
    def __init__(self):
        super().__init__()
        self.temperature = nn.Parameter(torch.ones(1))

    def forward(self, logits):
        return logits / self.temperature

    @torch.no_grad()
    def fit(self, confidences, accuracies, num_iters=50):
        optimizer = torch.optim.LBFGS([self.temperature], lr=0.01, max_iter=num_iters)
        def closure():
            optimizer.zero_grad()
            scaled = torch.sigmoid(confidences / self.temperature)
            loss = F.binary_cross_entropy(scaled, accuracies)
            loss.backward()
            return loss
        optimizer.step(closure)
        return self.temperature.item()

## 7. Load Model and Data

In [None]:
print("Loading TinyLlama...")
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = AutoModelForCausalLM.from_pretrained(
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    torch_dtype=torch.float16,
    device_map="auto",
    attn_implementation="eager",
)
model.eval()
tokenizer.pad_token = tokenizer.eos_token

In [None]:
dataset = load_dataset("wikitext", "wikitext-103-raw-v1")
train_texts = [t for t in dataset["train"]["text"] if len(t.strip()) > 200][:TRAINING_CONFIG["num_train_samples"]]
val_texts = [t for t in dataset["validation"]["text"] if len(t.strip()) > 200][:TRAINING_CONFIG["num_val_samples"]]
test_texts = [t for t in dataset["test"]["text"] if len(t.strip()) > 200][:200]
print(f"Train: {len(train_texts)}, Val: {len(val_texts)}, Test: {len(test_texts)}")

In [None]:
sip_config = SIPConfig(
    hidden_dim=model.config.hidden_size,
    num_heads=model.config.num_attention_heads,
    num_kv_heads=getattr(model.config, 'num_key_value_heads', model.config.num_attention_heads),
    head_dim=model.config.hidden_size // model.config.num_attention_heads,
    lookahead_steps=TRAINING_CONFIG["max_lookahead"],
)

sip_scorer = SpeculativeImportanceScorer(sip_config).to(device)
print(f"SIP Parameters: {sum(p.numel() for p in sip_scorer.parameters()):,}")

In [None]:
train_dataset = RolloutDataset(
    train_texts, model, tokenizer, sip_config, 
    device=device, 
    cache_dir='./cache/train'
)
val_dataset = RolloutDataset(
    val_texts, model, tokenizer, sip_config, 
    device=device,
    cache_dir='./cache/val'
)

train_loader = DataLoader(train_dataset, batch_size=TRAINING_CONFIG["batch_size"], shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=TRAINING_CONFIG["batch_size"], shuffle=False, collate_fn=collate_fn)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

## 8. Training with Curriculum Learning

In [None]:
loss_fn = ImprovedSIPLoss(
    kl_weight=TRAINING_CONFIG["kl_weight"],
    ranking_weight=TRAINING_CONFIG["ranking_weight"],
    confidence_weight=TRAINING_CONFIG["confidence_weight"],
    temporal_weight=TRAINING_CONFIG["temporal_weight"],
)

optimizer = AdamW(sip_scorer.parameters(), lr=TRAINING_CONFIG["lr"], weight_decay=TRAINING_CONFIG["weight_decay"])
total_steps = len(train_loader) * TRAINING_CONFIG["epochs"]
warmup_scheduler = LinearLR(optimizer, start_factor=0.1, total_iters=TRAINING_CONFIG["warmup_steps"])
main_scheduler = CosineAnnealingLR(optimizer, T_max=total_steps - TRAINING_CONFIG["warmup_steps"])
scheduler = SequentialLR(optimizer, [warmup_scheduler, main_scheduler], milestones=[TRAINING_CONFIG["warmup_steps"]])

curriculum = CurriculumScheduler(
    initial_lookahead=TRAINING_CONFIG["initial_lookahead"],
    max_lookahead=TRAINING_CONFIG["max_lookahead"],
    warmup_epochs=TRAINING_CONFIG["curriculum_warmup"],
)

In [None]:
best_val_loss = float('inf')
history = {'train_loss': [], 'val_loss': [], 'val_corr': [], 'lookahead': []}
accumulation_steps = TRAINING_CONFIG["accumulation_steps"]

for epoch in range(TRAINING_CONFIG["epochs"]):
    current_lookahead = curriculum.get_lookahead(epoch)
    history['lookahead'].append(current_lookahead)
    
    print(f"Epoch {epoch+1}/{TRAINING_CONFIG['epochs']} | Lookahead: {current_lookahead}")

    sip_scorer.train()
    train_losses = []
    optimizer.zero_grad()
    
    pbar = tqdm(train_loader, desc="Training")
    for batch_idx, batch in enumerate(pbar):
        keys = batch['keys'].to(device).float()
        values = batch['values'].to(device).float()
        positions = batch['positions'].to(device)
        future_attention = batch['future_attention'].to(device).float()

        kv_features = sip_scorer.kv_encoder(keys, values)
        pred, conf = sip_scorer.predictor(kv_features, positions, keys.shape[2], current_lookahead)
        
        losses = loss_fn(pred.float(), conf.float(), future_attention[:, :, :current_lookahead, :])
        loss = losses['total'] / accumulation_steps
        loss.backward()

        if (batch_idx + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(sip_scorer.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()

        train_losses.append(losses['total'].item())
        pbar.set_postfix({'loss': f"{losses['total'].item():.4f}"})

    sip_scorer.eval()
    val_losses, val_corrs = [], []
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            keys = batch['keys'].to(device).float()
            values = batch['values'].to(device).float()
            positions = batch['positions'].to(device)
            future_attention = batch['future_attention'].to(device).float()

            kv_features = sip_scorer.kv_encoder(keys, values)
            pred, conf = sip_scorer.predictor(kv_features, positions, keys.shape[2], current_lookahead)
            losses = loss_fn(pred.float(), conf.float(), future_attention[:, :, :current_lookahead, :])
            val_losses.append(losses['total'].item())

            p = pred[:, :, :, 0].flatten()
            t = future_attention[:, :, 0, :]
            t = (t / (t.sum(dim=-1, keepdim=True) + 1e-8)).flatten()
            if p.numel() > 1:
                c = torch.corrcoef(torch.stack([p, t]))[0, 1]
                if not torch.isnan(c):
                    val_corrs.append(c.item())

    avg_train = np.mean(train_losses)
    avg_val = np.mean(val_losses)
    avg_corr = np.mean(val_corrs) if val_corrs else 0
    
    history['train_loss'].append(avg_train)
    history['val_loss'].append(avg_val)
    history['val_corr'].append(avg_corr)

    print(f"Train Loss: {avg_train:.4f} | Val Loss: {avg_val:.4f} | Correlation: {avg_corr:.4f}")

    if avg_val < best_val_loss:
        best_val_loss = avg_val
        torch.save(sip_scorer.state_dict(), 'best_sip_improved.pt')
        print("Saved best model")

print(f"Training complete | Best Val Loss: {best_val_loss:.4f}")

## 9. Evaluation WITH Actual Compression

In [None]:
sip_scorer.load_state_dict(torch.load('best_sip_improved.pt', weights_only=True))
sip_scorer.eval()

In [None]:
sip_scorer.load_state_dict(torch.load('best_sip_improved.pt', weights_only=True))
sip_scorer.eval()


class PrefillAttentionScorer:
    def __init__(self):
        self.prefill_attention = None
    
    def reset(self):
        self.prefill_attention = None
    
    def __call__(self, keys, values, prefill_attention=None):
        if prefill_attention is not None:
            if prefill_attention.dim() == 4:
                self.prefill_attention = prefill_attention.mean(dim=2)
            else:
                self.prefill_attention = prefill_attention
        return self.prefill_attention


class PositionHeuristicScorer:
    def __call__(self, keys, values, attention=None):
        batch, heads, seq_len, _ = keys.shape
        importance = torch.ones(batch, heads, seq_len, device=keys.device)
        importance[:, :, :4] = 10.0
        importance[:, :, -32:] = 5.0
        return importance


class SIPPreGenerationScorer:
    def __init__(self, sip_model):
        self.model = sip_model
        self.model.eval()
    
    def reset(self):
        pass
    
    def __call__(self, keys, values, attention=None):
        with torch.no_grad():
            positions = torch.arange(keys.shape[2], device=keys.device).unsqueeze(0)
            return self.model(keys.float(), values.float(), positions)


class ExpectedAttentionPreGenScorer:
    def __init__(self):
        self.scorer = ExpectedAttentionScorer()
    
    def reset(self):
        self.scorer.reset()
    
    def __call__(self, keys, values, attention=None):
        batch, heads, seq_len, _ = keys.shape
        dummy_attn = torch.ones(batch, heads, seq_len, device=keys.device) / seq_len
        return self.scorer(keys, values, dummy_attn)


class TRIMKVPreGenScorer:
    def __init__(self):
        self.scorer = TRIMKVScorer()
    
    def reset(self):
        self.scorer.reset()
    
    def __call__(self, keys, values, attention=None):
        batch, heads, seq_len, _ = keys.shape
        dummy_attn = torch.ones(batch, heads, seq_len, device=keys.device) / seq_len
        return self.scorer(keys, values, dummy_attn)


class WriteGatedPreGenScorer:
    def __init__(self):
        self.scorer = WriteGatedKVScorer()
    
    def reset(self):
        self.scorer.reset()
    
    def __call__(self, keys, values, attention=None):
        batch, heads, seq_len, _ = keys.shape
        dummy_attn = torch.ones(batch, heads, seq_len, device=keys.device) / seq_len
        return self.scorer(keys, values, dummy_attn)

@torch.no_grad()
def compute_perplexity_with_compression_fixed(
    model, tokenizer, texts, scorer, scorer_name,
    retention_ratio=0.5, min_cache=48, max_samples=50, device='cuda',
):
    model.eval()
    
    num_attention_heads = model.config.num_attention_heads
    num_kv_heads = getattr(model.config, 'num_key_value_heads', num_attention_heads)
    head_group_size = num_attention_heads // num_kv_heads

    def group_attention_heads(attn):
        if head_group_size == 1:
            return attn
        batch, heads, *rest = attn.shape
        return attn.view(batch, num_kv_heads, head_group_size, *rest).mean(dim=2)

    total_nll = 0.0
    total_tokens = 0

    for text in tqdm(texts[:max_samples], desc=f"{scorer_name} @ {retention_ratio:.0%}"):
        if hasattr(scorer, 'reset'):
            scorer.reset()

        tokens = tokenizer(text, return_tensors='pt', max_length=512, truncation=True).input_ids.to(device)
        seq_len = tokens.shape[1]
        if seq_len < 128:
            continue

        split_point = seq_len * 3 // 4
        context_ids = tokens[:, :split_point]
        continuation_ids = tokens[:, split_point:]
        continuation_len = continuation_ids.shape[1]

        outputs = model(context_ids, use_cache=True, output_attentions=True, return_dict=True)
        past_kv = outputs.past_key_values
        
        keys, values = get_cache_keys_values(past_kv, 0)
        prefill_attention = outputs.attentions[-1]
        prefill_attention_grouped = group_attention_heads(prefill_attention)

        if isinstance(scorer, PrefillAttentionScorer):
            importance = scorer(keys, values, prefill_attention_grouped)
        else:
            importance = scorer(keys, values, None)

        if importance.dim() == 3:
            avg_importance = importance.mean(dim=(0, 1))
        elif importance.dim() == 2:
            avg_importance = importance.mean(dim=0)
        else:
            avg_importance = importance.flatten()
        
        budget = max(int(split_point * retention_ratio), min_cache)
        budget = min(budget, split_point)
        
        keep_mask = torch.zeros(split_point, dtype=torch.bool, device=device)
        sink_size = 4
        recent_size = min(16, split_point - sink_size)
        keep_mask[:sink_size] = True
        keep_mask[-recent_size:] = True
        
        reserved = keep_mask.sum().item()
        remaining_budget = max(0, budget - reserved)
        
        if remaining_budget > 0:
            importance_masked = avg_importance.clone()
            importance_masked[keep_mask] = -float('inf')
            k = min(remaining_budget, (~keep_mask).sum().item())
            if k > 0:
                _, top_indices = importance_masked.topk(k)
                keep_mask[top_indices] = True
        
        keep_indices = keep_mask.nonzero(as_tuple=True)[0].sort()[0]
        
        compressed_kv = DynamicCache()
        for layer_idx in range(get_cache_length(past_kv)):
            layer_keys, layer_values = get_cache_keys_values(past_kv, layer_idx)
            compressed_kv.update(
                layer_keys.index_select(2, keep_indices),
                layer_values.index_select(2, keep_indices),
                layer_idx
            )

        current_kv = compressed_kv
        current_pos = split_point
        
        for i in range(continuation_len - 1):
            input_token = continuation_ids[:, i:i+1]
            target_token = continuation_ids[:, i+1]
            position_ids = torch.tensor([[current_pos]], device=device)
            
            out = model(
                input_token,
                past_key_values=current_kv,
                position_ids=position_ids,
                use_cache=True,
                return_dict=True,
            )
            
            logits = out.logits[:, -1, :]
            loss = F.cross_entropy(logits, target_token, reduction='sum')
            
            if not torch.isnan(loss):
                total_nll += loss.item()
                total_tokens += 1
            
            current_kv = out.past_key_values
            current_pos += 1

    if total_tokens == 0:
        return {'perplexity': float('inf'), 'total_tokens': 0}

    perplexity = np.exp(total_nll / total_tokens)
    return {'perplexity': perplexity, 'total_tokens': total_tokens}


pregeneration_scorers = {
    'SIP (Ours)': SIPPreGenerationScorer(sip_scorer),
    'Expected-Attn': ExpectedAttentionPreGenScorer(),
    'TRIM-KV': TRIMKVPreGenScorer(),                  
    'Write-Gated': WriteGatedPreGenScorer(),           
    'Prefill-Attn': PrefillAttentionScorer(),
    'Position-Heuristic': PositionHeuristicScorer(),
    'Recent-Only': RecentImportanceScorer(),
    'Random': RandomImportanceScorer(),
}

retention_ratios = [0.10, 0.25, 0.50, 0.75]
pregeneration_results = {}

print("Baseline (Full Cache)")
total_nll = 0
total_tokens = 0
for text in tqdm(test_texts[:50], desc="Baseline"):
    tokens = tokenizer(text, return_tensors='pt', max_length=512, truncation=True).input_ids.to(device)
    if tokens.shape[1] < 128:
        continue
    split_point = tokens.shape[1] * 3 // 4
    context = tokens[:, :split_point]
    continuation = tokens[:, split_point:]
    
    with torch.no_grad():
        outputs = model(context, use_cache=True, return_dict=True)
        current_kv = outputs.past_key_values
        current_pos = split_point
        
        for i in range(continuation.shape[1] - 1):
            input_token = continuation[:, i:i+1]
            target_token = continuation[:, i+1]
            position_ids = torch.tensor([[current_pos]], device=device)
            
            out = model(input_token, past_key_values=current_kv, position_ids=position_ids, 
                       use_cache=True, return_dict=True)
            logits = out.logits[:, -1, :]
            loss = F.cross_entropy(logits, target_token, reduction='sum')
            
            if not torch.isnan(loss):
                total_nll += loss.item()
                total_tokens += 1
            
            current_kv = out.past_key_values
            current_pos += 1

baseline_ppl = np.exp(total_nll / total_tokens)
pregeneration_results['Full Cache'] = {1.0: baseline_ppl}
print(f"Baseline Perplexity: {baseline_ppl:.2f}")

for scorer_name, scorer in pregeneration_scorers.items():
    pregeneration_results[scorer_name] = {}
    
    for ratio in retention_ratios:
        result = compute_perplexity_with_compression_fixed(
            model=model,
            tokenizer=tokenizer,
            texts=test_texts,
            scorer=scorer,
            scorer_name=scorer_name,
            retention_ratio=ratio,
            min_cache=48,
            max_samples=50,
            device=str(device),
        )
        
        pregeneration_results[scorer_name][ratio] = result['perplexity']
        print(f"{scorer_name} @ {ratio:.0%}: {result['perplexity']:.2f}")

In [None]:
print(f"Baseline (Full Cache): {baseline_ppl:.2f}")

print("\nPerplexity by Method")
header = f"{'Method':<20} |"
for ratio in retention_ratios:
    header += f" {int(ratio*100):>6}% |"
print(header)
print("-"*70)

method_order = ['SIP (Ours)', 'Expected-Attn', 'TRIM-KV', 'Write-Gated', 
                'Prefill-Attn', 'Position-Heuristic', 'Recent-Only', 'Random']

for method in method_order:
    if method in pregeneration_results:
        row = f"{method:<20} |"
        for ratio in retention_ratios:
            if ratio in pregeneration_results[method]:
                ppl = pregeneration_results[method][ratio]
                row += f" {ppl:>7.2f} |"
            else:
                row += f" {'--':>7} |"
        print(row)

print("\nDegradation from Full Cache (% increase in PPL)")
header = f"{'Method':<20} |"
for ratio in retention_ratios:
    header += f" {int(ratio*100):>6}% |"
print(header)
print("-"*70)

for method in method_order:
    if method in pregeneration_results:
        row = f"{method:<20} |"
        for ratio in retention_ratios:
            if ratio in pregeneration_results[method]:
                ppl = pregeneration_results[method][ratio]
                degradation = ((ppl - baseline_ppl) / baseline_ppl) * 100
                row += f" {degradation:>+6.1f}% |"
            else:
                row += f" {'--':>7} |"
        print(row)

print("\nBest Method at Each Retention Level")
for ratio in retention_ratios:
    best_method = None
    best_ppl = float('inf')
    for method in method_order:
        if method in pregeneration_results and ratio in pregeneration_results[method]:
            ppl = pregeneration_results[method][ratio]
            if ppl < best_ppl:
                best_ppl = ppl
                best_method = method
    if best_method:
        degradation = ((best_ppl - baseline_ppl) / baseline_ppl) * 100
        print(f"  {int(ratio*100)}%: {best_method} (PPL: {best_ppl:.2f}, {degradation:+.1f}% from baseline)")

In [None]:
def run_multiseed_evaluation(
    model, tokenizer, texts, scorers, retention_ratios, 
    num_seeds=5, max_samples=30, device='cuda',
):
    all_results = {name: {r: [] for r in retention_ratios} for name in scorers.keys()}

    for seed in range(num_seeds):
        print(f"Seed {seed + 1}/{num_seeds}")
        torch.manual_seed(42 + seed)
        np.random.seed(42 + seed)

        shuffled_texts = texts.copy()
        np.random.shuffle(shuffled_texts)

        for scorer_name, scorer in scorers.items():
            for ratio in retention_ratios:
                if hasattr(scorer, 'reset'):
                    scorer.reset()

                result = compute_perplexity_with_compression_fixed(
                    model=model,
                    tokenizer=tokenizer,
                    texts=shuffled_texts,
                    scorer=scorer,
                    scorer_name=f"{scorer_name} (seed {seed})",
                    retention_ratio=ratio,
                    min_cache=48,
                    max_samples=max_samples,
                    device=device,
                )

                all_results[scorer_name][ratio].append(result['perplexity'])

    stats_results = {}
    for scorer_name in scorers.keys():
        stats_results[scorer_name] = {}
        for ratio in retention_ratios:
            values = all_results[scorer_name][ratio]
            mean = np.mean(values)
            std = np.std(values, ddof=1)
            n = len(values)
            ci_95 = 1.96 * std / np.sqrt(n)

            stats_results[scorer_name][ratio] = {
                'mean': mean,
                'std': std,
                'ci_95': ci_95,
                'values': values,
            }

    return stats_results


def compute_significance(results, baseline_name, method_name, ratio):
    baseline_values = results[baseline_name][ratio]['values']
    method_values = results[method_name][ratio]['values']

    if len(baseline_values) < 2 or len(method_values) < 2:
        return None

    t_stat, p_value = stats.ttest_rel(baseline_values, method_values)
    return {'t_stat': t_stat, 'p_value': p_value, 'significant': p_value < 0.05}


multiseed_scorers = {
    'SIP (Ours)': SIPPreGenerationScorer(sip_scorer),
    'Prefill-Attn': PrefillAttentionScorer(),
    'Position-Heuristic': PositionHeuristicScorer(),
    'Expected-Attn': ExpectedAttentionPreGenScorer(),
    'TRIM-KV': TRIMKVPreGenScorer(),
    'Random': RandomImportanceScorer(),
}

multiseed_ratios = [0.10, 0.25, 0.50, 0.75]
NUM_SEEDS = 5

print(f"Running {NUM_SEEDS}-seed evaluation across {len(multiseed_scorers)} methods...")

multiseed_results = run_multiseed_evaluation(
    model=model,
    tokenizer=tokenizer,
    texts=test_texts,
    scorers=multiseed_scorers,
    retention_ratios=multiseed_ratios,
    num_seeds=NUM_SEEDS,
    max_samples=30,
    device=str(device),
)

print("\nMulti-Seed Results (Mean +/- 95% CI)")

header = f"{'Method':<20} |"
for ratio in multiseed_ratios:
    header += f" {int(ratio*100):>12}% |"
print(header)
print("-"*75)

method_order = ['SIP (Ours)', 'Prefill-Attn', 'Position-Heuristic', 'Expected-Attn', 'TRIM-KV', 'Random']
for method in method_order:
    if method in multiseed_results:
        row = f"{method:<20} |"
        for ratio in multiseed_ratios:
            if ratio in multiseed_results[method]:
                stats_data = multiseed_results[method][ratio]
                row += f" {stats_data['mean']:.2f}+/-{stats_data['ci_95']:.2f} |"
            else:
                row += f" {'--':>12} |"
        print(row)

print("\nStatistical Significance (vs Random baseline)")

for method in ['SIP (Ours)', 'Prefill-Attn', 'Position-Heuristic', 'Expected-Attn', 'TRIM-KV']:
    if method in multiseed_results:
        row = f"{method:<20} |"
        for ratio in multiseed_ratios:
            sig = compute_significance(multiseed_results, 'Random', method, ratio)
            if sig:
                if sig['p_value'] < 0.001:
                    marker = "***"
                elif sig['p_value'] < 0.01:
                    marker = "**"
                elif sig['p_value'] < 0.05:
                    marker = "*"
                else:
                    marker = "ns"
                row += f" {marker:>12} |"
            else:
                row += f" {'--':>12} |"
        print(row)

print("\nSIP vs Prefill-Attn:")
for ratio in multiseed_ratios:
    sig = compute_significance(multiseed_results, 'Prefill-Attn', 'SIP (Ours)', ratio)
    if sig:
        sip_mean = multiseed_results['SIP (Ours)'][ratio]['mean']
        pfa_mean = multiseed_results['Prefill-Attn'][ratio]['mean']
        diff = sip_mean - pfa_mean
        sig_str = f"p={sig['p_value']:.4f}" + (" *" if sig['significant'] else " ns")
        print(f"  {int(ratio*100):>3}%: SIP {sip_mean:.2f} vs Prefill-Attn {pfa_mean:.2f} (d={diff:+.2f}) {sig_str}")

print("\nSIP vs Position-Heuristic:")
for ratio in multiseed_ratios:
    sig = compute_significance(multiseed_results, 'Position-Heuristic', 'SIP (Ours)', ratio)
    if sig:
        sip_mean = multiseed_results['SIP (Ours)'][ratio]['mean']
        pos_mean = multiseed_results['Position-Heuristic'][ratio]['mean']
        diff = sip_mean - pos_mean
        sig_str = f"p={sig['p_value']:.4f}" + (" *" if sig['significant'] else " ns")
        print(f"  {int(ratio*100):>3}%: SIP {sip_mean:.2f} vs Position-Heuristic {pos_mean:.2f} (d={diff:+.2f}) {sig_str}")

print("\nBest method at each retention level (by mean PPL):")
for ratio in multiseed_ratios:
    best_method = None
    best_mean = float('inf')
    for method in method_order:
        if method in multiseed_results and ratio in multiseed_results[method]:
            mean = multiseed_results[method][ratio]['mean']
            if mean < best_mean:
                best_mean = mean
                best_method = method
    if best_method:
        ci = multiseed_results[best_method][ratio]['ci_95']
        print(f"  {int(ratio*100):>3}%: {best_method} ({best_mean:.2f} +/- {ci:.2f})")

## 9.5. Task-Based Evaluation (QA Accuracy)

This evaluates how well the model can answer questions AFTER KV cache compression.
More practical than perplexity - measures actual task performance.

In [None]:
qa_examples = [
    {"context": "The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France. It is named after the engineer Gustave Eiffel, whose company designed and built the tower. Constructed from 1887 to 1889, it was initially criticized by some of France's leading artists and intellectuals for its design, but it has become a global cultural icon of France and one of the most recognizable structures in the world. The tower is 330 metres (1,083 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris.",
     "question": "Who designed the Eiffel Tower?", "answer": "Gustave Eiffel", "answer_variations": ["gustave eiffel", "eiffel", "gustave"]},
    {"context": "The Amazon rainforest, also known as Amazonia, is a moist broadleaf tropical rainforest in the Amazon biome that covers most of the Amazon basin of South America. This basin encompasses 7,000,000 km2 (2,700,000 sq mi), of which 5,500,000 km2 (2,100,000 sq mi) are covered by the rainforest. The majority of the forest is contained within Brazil, with 60% of the rainforest, followed by Peru with 13%, Colombia with 10%, and with minor amounts in Bolivia, Ecuador, French Guiana, Guyana, Suriname, and Venezuela.",
     "question": "Which country contains the majority of the Amazon rainforest?", "answer": "Brazil", "answer_variations": ["brazil", "brasil"]},
    {"context": "Albert Einstein was a German-born theoretical physicist who is widely held to be one of the greatest and most influential scientists of all time. Einstein is best known for developing the theory of relativity, but he also made important contributions to quantum mechanics. His mass-energy equivalence formula E = mc^2, which arises from relativity theory, has been called the world's most famous equation. He received the Nobel Prize in Physics in 1921 for his services to theoretical physics, and especially for his discovery of the law of the photoelectric effect.",
     "question": "What year did Einstein receive the Nobel Prize?", "answer": "1921", "answer_variations": ["1921", "in 1921"]},
    {"context": "The Great Wall of China is a series of fortifications that were built across the historical northern borders of ancient Chinese states and Imperial China as protection against various nomadic groups from the Eurasian Steppe. Several walls were built from as early as the 7th century BC, with selective stretches later joined together by Qin Shi Huang, the first emperor of China. The most well-known sections of the wall were built by the Ming dynasty from 1368 to 1644. The wall spans approximately 21,196 kilometers (13,171 miles).",
     "question": "Who first joined the stretches of the wall together?", "answer": "Qin Shi Huang", "answer_variations": ["qin shi huang", "qin", "emperor qin"]},
    {"context": "William Shakespeare was an English playwright, poet, and actor, widely regarded as the greatest writer in the English language and the world's greatest dramatist. He is often called England's national poet and the 'Bard of Avon'. His extant works, including collaborations, consist of some 39 plays, 154 sonnets, three long narrative poems, and a few other verses. His plays have been translated into every major living language and are performed more often than those of any other playwright.",
     "question": "How many plays did Shakespeare write?", "answer": "39", "answer_variations": ["39", "39 plays", "some 39", "thirty-nine"]},
    {"context": "The Mona Lisa is a half-length portrait painting by Italian artist Leonardo da Vinci. Considered an archetypal masterpiece of the Italian Renaissance, it has been described as the best known, the most visited, the most written about, the most sung about, and the most parodied work of art in the world. The painting's novel qualities include the subject's enigmatic expression, monumentality of the composition, the subtle modelling of forms, and the atmospheric illusionism. The painting is held in the Louvre Museum in Paris since 1797.",
     "question": "Who painted the Mona Lisa?", "answer": "Leonardo da Vinci", "answer_variations": ["leonardo da vinci", "da vinci", "leonardo"]},
    {"context": "Mount Everest is Earth's highest mountain above sea level, located in the Mahalangur Himal sub-range of the Himalayas. The China-Nepal border runs across its summit point. Its elevation of 8,848.86 m (29,031 ft) was most recently established in 2020 by the Chinese and Nepali authorities. Mount Everest attracts many climbers, including highly experienced mountaineers. There are two main climbing routes, one approaching the summit from the southeast in Nepal and the other from the north in Tibet.",
     "question": "What is the height of Mount Everest in meters?", "answer": "8848", "answer_variations": ["8848", "8,848", "8848.86"]},
    {"context": "The Pacific Ocean is the largest and deepest of Earth's five oceanic divisions. It extends from the Arctic Ocean in the north to the Southern Ocean in the south, and is bounded by the continents of Asia and Australia in the west and the Americas in the east. At 165,250,000 square kilometers in area, this largest division of the World Ocean covers about 46% of Earth's water surface and about 32% of its total surface area.",
     "question": "What percentage of Earth's water surface does the Pacific Ocean cover?", "answer": "46%", "answer_variations": ["46", "46%", "forty-six"]},
    {"context": "The human brain is the central organ of the human nervous system, and with the spinal cord makes up the central nervous system. The brain consists of the cerebrum, the brainstem and the cerebellum. It controls most of the activities of the body, processing, integrating, and coordinating the information it receives from the sense organs. The average adult human brain weighs approximately 1.4 kilograms and contains about 86 billion neurons.",
     "question": "How many neurons does the human brain contain?", "answer": "86 billion", "answer_variations": ["86 billion", "86billion", "eighty-six billion", "86"]},
    {"context": "The Nile is a major north-flowing river in northeastern Africa. It flows into the Mediterranean Sea. The Nile is the longest river in Africa and has historically been considered the longest river in the world, though this has been contested by research suggesting that the Amazon River is slightly longer. The Nile's length is approximately 6,650 km (4,130 miles). It passes through eleven countries: Tanzania, Uganda, Rwanda, Burundi, the Democratic Republic of the Congo, Kenya, Ethiopia, Eritrea, South Sudan, Sudan, and Egypt.",
     "question": "How many countries does the Nile pass through?", "answer": "eleven", "answer_variations": ["eleven", "11", "11 countries"]},
    {"context": "Marie Curie was a Polish and naturalized-French physicist and chemist who conducted pioneering research on radioactivity. She was the first woman to win a Nobel Prize, the first person to win a Nobel Prize twice, and the only person to win a Nobel Prize in two scientific fields. She was born Maria Sklodowska in Warsaw in 1867 and later moved to Paris to study.",
     "question": "What city was Marie Curie born in?", "answer": "Warsaw", "answer_variations": ["warsaw"]},
    {"context": "The Taj Mahal is an ivory-white marble mausoleum on the southern bank of the river Yamuna in the Indian city of Agra. It was commissioned in 1632 by the Mughal emperor Shah Jahan to house the tomb of his favourite wife, Mumtaz Mahal. It also houses the tomb of Shah Jahan himself. The tomb is the centrepiece of a 17-hectare complex, which includes a mosque and a guest house.",
     "question": "Who commissioned the construction of the Taj Mahal?", "answer": "Shah Jahan", "answer_variations": ["shah jahan", "shahjahan"]},
    {"context": "The speed of light in vacuum, commonly denoted c, is a universal physical constant that is important in many areas of physics. The speed of light c is exactly equal to 299,792,458 metres per second (approximately 300,000 kilometres per second; 186,000 miles per second). According to the special theory of relativity, c is the upper limit for the speed at which conventional matter or energy can travel through space.",
     "question": "What is the speed of light in kilometers per second?", "answer": "300,000", "answer_variations": ["300000", "300,000", "299792", "300"]},
    {"context": "The Industrial Revolution, which took place from the 18th to 19th centuries, was a period during which predominantly agrarian, rural societies in Europe and America became industrial and urban. Prior to the Industrial Revolution, manufacturing was often done in people's homes, using hand tools or basic machines. Industrialization marked a shift to powered, special-purpose machinery, factories and mass production. The textile industry was the first to use modern production methods.",
     "question": "Which industry was the first to use modern production methods?", "answer": "textile", "answer_variations": ["textile", "textiles", "the textile industry"]},
    {"context": "Photosynthesis is a process used by plants and other organisms to convert light energy into chemical energy that, through cellular respiration, can later be released to fuel the organism's activities. This chemical energy is stored in carbohydrate molecules, such as sugars and starches, which are synthesized from carbon dioxide and water. In most cases, oxygen is released as a waste product. The process occurs primarily in the chloroplasts of plant cells.",
     "question": "Where in plant cells does photosynthesis primarily occur?", "answer": "chloroplasts", "answer_variations": ["chloroplasts", "chloroplast", "the chloroplasts"]},
    {"context": "The Wright brothers, Orville and Wilbur, were two American aviation pioneers generally credited with inventing, building, and flying the world's first successful motor-operated airplane. They made the first controlled, sustained flight of a powered, heavier-than-air aircraft with the Wright Flyer on December 17, 1903, four miles south of Kitty Hawk, North Carolina.",
     "question": "In what year did the Wright brothers make their first flight?", "answer": "1903", "answer_variations": ["1903", "in 1903"]},
    {"context": "The Amazon River is the largest river by discharge volume of water in the world, and the disputed longest. The headwaters of the Apurimac River on Nevado Mismi had been considered for nearly a century as the Amazon's most distant source. The Amazon basin is the largest drainage basin in the world, with an area of approximately 7,000,000 square kilometres.",
     "question": "What is the area of the Amazon basin in square kilometres?", "answer": "7,000,000", "answer_variations": ["7000000", "7,000,000", "7 million"]},
    {"context": "Vincent van Gogh was a Dutch Post-Impressionist painter who posthumously became one of the most famous and influential figures in Western art history. In just over a decade, he created about 2,100 artworks, including around 860 oil paintings. He was born in Groot-Zundert in the Netherlands in 1853 and died in France in 1890 at the age of 37.",
     "question": "How many oil paintings did Van Gogh create?", "answer": "860", "answer_variations": ["860", "around 860", "about 860"]},
    {"context": "DNA, or deoxyribonucleic acid, is a molecule composed of two polynucleotide chains that coil around each other to form a double helix. DNA carries genetic instructions for the development, functioning, growth and reproduction of all known organisms and many viruses. The structure of DNA was first described by James Watson and Francis Crick in 1953, based on X-ray diffraction data collected by Rosalind Franklin.",
     "question": "Who first described the structure of DNA?", "answer": "Watson and Crick", "answer_variations": ["watson and crick", "james watson", "francis crick", "watson", "crick"]},
    {"context": "The Colosseum, also known as the Flavian Amphitheatre, is an oval amphitheatre in the centre of the city of Rome, Italy. Built of travertine limestone, tuff, and brick-faced concrete, it was the largest amphitheatre ever built at the time and held 50,000 to 80,000 spectators. Construction began under Emperor Vespasian in AD 72 and was completed in AD 80 under his successor Titus.",
     "question": "Under which emperor did construction of the Colosseum begin?", "answer": "Vespasian", "answer_variations": ["vespasian", "emperor vespasian"]},
]


@torch.no_grad()
def evaluate_qa_accuracy_with_compression(
    model, tokenizer, qa_examples, scorer, scorer_name,
    retention_ratio=0.5, min_cache=48, device='cuda', use_compression=True,
):
    model.eval()

    num_attention_heads = model.config.num_attention_heads
    num_kv_heads = getattr(model.config, 'num_key_value_heads', num_attention_heads)
    head_group_size = num_attention_heads // num_kv_heads

    def group_attention_heads(attn):
        if head_group_size == 1:
            return attn
        batch, heads, *rest = attn.shape
        return attn.view(batch, num_kv_heads, head_group_size, *rest).mean(dim=2)

    correct = 0
    total = 0
    results_detail = []

    for qa in qa_examples:
        if hasattr(scorer, 'reset'):
            scorer.reset()

        prompt = f"""Context: {qa['context']}

Question: {qa['question']}
Answer (one word or phrase):"""

        tokens = tokenizer(prompt, return_tensors='pt', max_length=512, truncation=True).input_ids.to(device)
        context_len = tokens.shape[1]

        outputs = model(tokens, use_cache=True, output_attentions=True, return_dict=True)
        past_kv = outputs.past_key_values

        first_token_id = outputs.logits[:, -1, :].argmax(dim=-1).item()

        if use_compression:
            keys, values = get_cache_keys_values(past_kv, 0)
            prefill_attention = outputs.attentions[-1]
            prefill_attention_grouped = group_attention_heads(prefill_attention)

            if isinstance(scorer, PrefillAttentionScorer):
                importance = scorer(keys, values, prefill_attention_grouped)
            else:
                importance = scorer(keys, values, None)

            if importance.dim() == 3:
                avg_importance = importance.mean(dim=(0, 1))
            elif importance.dim() == 2:
                avg_importance = importance.mean(dim=0)
            else:
                avg_importance = importance.flatten()

            budget = max(int(context_len * retention_ratio), min_cache)
            budget = min(budget, context_len)

            keep_mask = torch.zeros(context_len, dtype=torch.bool, device=device)
            sink_size = 4
            recent_size = min(16, context_len - sink_size)
            keep_mask[:sink_size] = True
            keep_mask[-recent_size:] = True

            reserved = keep_mask.sum().item()
            remaining_budget = max(0, budget - reserved)

            if remaining_budget > 0:
                importance_masked = avg_importance.clone()
                importance_masked[keep_mask] = -float('inf')
                k = min(remaining_budget, (~keep_mask).sum().item())
                if k > 0:
                    _, top_indices = importance_masked.topk(k)
                    keep_mask[top_indices] = True

            keep_indices = keep_mask.nonzero(as_tuple=True)[0].sort()[0]

            compressed_kv = DynamicCache()
            for layer_idx in range(get_cache_length(past_kv)):
                layer_keys, layer_values = get_cache_keys_values(past_kv, layer_idx)
                compressed_kv.update(
                    layer_keys.index_select(2, keep_indices),
                    layer_values.index_select(2, keep_indices),
                    layer_idx
                )
            current_kv = compressed_kv
        else:
            current_kv = past_kv

        generated_tokens = [first_token_id]
        current_pos = context_len

        for _ in range(19):
            if generated_tokens[-1] == tokenizer.eos_token_id:
                break

            position_ids = torch.tensor([[current_pos]], device=device)
            input_token = torch.tensor([[generated_tokens[-1]]], device=device)

            out = model(
                input_token,
                past_key_values=current_kv,
                position_ids=position_ids,
                use_cache=True,
                return_dict=True,
            )

            next_token = out.logits[:, -1, :].argmax(dim=-1).item()
            generated_tokens.append(next_token)
            current_kv = out.past_key_values
            current_pos += 1

            decoded = tokenizer.decode(generated_tokens, skip_special_tokens=True)
            if '\n' in decoded or (len(decoded) > 5 and decoded.endswith('.')):
                break

        generated_answer = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
        generated_lower = generated_answer.lower()

        is_correct = any(var in generated_lower for var in qa['answer_variations'])

        if is_correct:
            correct += 1
        total += 1

        results_detail.append({
            'question': qa['question'],
            'expected': qa['answer'],
            'generated': generated_answer,
            'correct': is_correct,
        })

    accuracy = correct / total if total > 0 else 0
    return {'accuracy': accuracy, 'correct': correct, 'total': total, 'details': results_detail}


qa_scorers = {
    'SIP (Ours)': SIPPreGenerationScorer(sip_scorer),
    'Prefill-Attn': PrefillAttentionScorer(),
    'Position-Heuristic': PositionHeuristicScorer(),
    'Expected-Attn': ExpectedAttentionPreGenScorer(),
    'TRIM-KV': TRIMKVPreGenScorer(),
    'Random': RandomImportanceScorer(),
}

qa_retention_ratios = [0.10, 0.25, 0.50, 0.75]
qa_results = {}

print("Baseline (Full Cache)")
baseline_result = evaluate_qa_accuracy_with_compression(
    model=model, tokenizer=tokenizer, qa_examples=qa_examples,
    scorer=RandomImportanceScorer(), scorer_name="Baseline",
    retention_ratio=1.0, min_cache=48, device=str(device), use_compression=False,
)
qa_results['Full Cache'] = {1.0: baseline_result['accuracy']}
print(f"Full Cache: {baseline_result['accuracy']*100:.1f}% ({baseline_result['correct']}/{baseline_result['total']})")

for scorer_name, scorer in qa_scorers.items():
    qa_results[scorer_name] = {}

    for ratio in qa_retention_ratios:
        result = evaluate_qa_accuracy_with_compression(
            model=model, tokenizer=tokenizer, qa_examples=qa_examples,
            scorer=scorer, scorer_name=scorer_name,
            retention_ratio=ratio, min_cache=48, device=str(device), use_compression=True,
        )

        qa_results[scorer_name][ratio] = result['accuracy']
        print(f"{scorer_name:<20} @ {int(ratio*100):>2}%: {result['accuracy']*100:.1f}% ({result['correct']}/{result['total']})")

print("\nQA Accuracy Summary")

header = f"{'Method':<20} |"
for ratio in qa_retention_ratios:
    header += f" {int(ratio*100):>6}% |"
header += " Full |"
print(header)
print("-"*75)

method_order = ['SIP (Ours)', 'Prefill-Attn', 'Position-Heuristic', 'Expected-Attn', 'TRIM-KV', 'Random']
for method in method_order:
    if method in qa_results:
        row = f"{method:<20} |"
        for ratio in qa_retention_ratios:
            if ratio in qa_results[method]:
                acc = qa_results[method][ratio] * 100
                row += f" {acc:>5.0f}% |"
            else:
                row += f" {'--':>6} |"
        row += f" {baseline_result['accuracy']*100:>4.0f}% |"
        print(row)

print("\nBest Method at Each Retention Level")
for ratio in qa_retention_ratios:
    best_method = None
    best_acc = -1
    for method in method_order:
        if method in qa_results and ratio in qa_results[method]:
            acc = qa_results[method][ratio]
            if acc > best_acc:
                best_acc = acc
                best_method = method
    if best_method:
        print(f"  {int(ratio*100):>2}%: {best_method} ({best_acc*100:.0f}%)")

In [None]:
def generate_haystack(length_tokens=500, tokenizer=None):
    filler_sentences = [
        "The weather has been quite pleasant this season with moderate temperatures.",
        "Many scientists are working on new discoveries in various fields of study.",
        "Technology continues to advance at a rapid pace in modern society.",
        "Education remains an important aspect of human development worldwide.",
        "Various animals inhabit different ecosystems across the planet.",
        "Music and art have been part of human culture for thousands of years.",
        "Agriculture provides the foundation for food production globally.",
        "Transportation systems connect cities and countries together.",
        "Healthcare research aims to improve quality of life for everyone.",
        "Architecture reflects the cultural values of different societies.",
    ]

    haystack = ""
    while len(tokenizer.encode(haystack)) < length_tokens:
        haystack += np.random.choice(filler_sentences) + " "

    return haystack.strip()


def create_needle_test(needle_fact, needle_question, haystack_length=500, needle_position=0.5, tokenizer=None):
    part1_len = int(haystack_length * needle_position)
    part2_len = haystack_length - part1_len

    part1 = generate_haystack(part1_len, tokenizer) if part1_len > 0 else ""
    part2 = generate_haystack(part2_len, tokenizer) if part2_len > 0 else ""

    if part1:
        context = part1 + " " + needle_fact + " " + part2
    else:
        context = needle_fact + " " + part2

    return {
        'context': context.strip(),
        'question': needle_question,
        'answer': needle_fact,
    }


needle_tests = [
    {
        'needle': "The secret code for the vault is XRAY-42-BETA.",
        'question': "What is the secret code for the vault?",
        'answer_key': "XRAY-42-BETA",
    },
    {
        'needle': "Dr. Alexandra Chen discovered the new element on March 15th, 2024.",
        'question': "Who discovered the new element?",
        'answer_key': "Alexandra Chen",
    },
    {
        'needle': "The ancient artifact was hidden in the Temple of the Golden Sun.",
        'question': "Where was the ancient artifact hidden?",
        'answer_key': "Temple of the Golden Sun",
    },
    {
        'needle': "The winning lottery numbers for the grand prize were 7, 23, 45, 62, 89.",
        'question': "What were the winning lottery numbers?",
        'answer_key': "7, 23, 45, 62, 89",
    },
    {
        'needle': "Captain Marcus Webb commanded the first Mars expedition in 2031.",
        'question': "Who commanded the first Mars expedition?",
        'answer_key': "Marcus Webb",
    },
    {
        'needle': "The password to access the mainframe is OMEGA-SEVEN-DELTA.",
        'question': "What is the password to access the mainframe?",
        'answer_key': "OMEGA-SEVEN-DELTA",
    },
    {
        'needle': "Professor Elena Vasquez invented the quantum processor in Helsinki.",
        'question': "Who invented the quantum processor?",
        'answer_key': "Elena Vasquez",
    },
    {
        'needle': "The treasure map coordinates are 47.3N, 122.5W near Seattle.",
        'question': "What are the treasure map coordinates?",
        'answer_key': "47.3N, 122.5W",
    },
]


@torch.no_grad()
def evaluate_needle_in_haystack(
    model, tokenizer, needle_tests, scorer, scorer_name,
    haystack_length=400, retention_ratio=0.5, 
    needle_positions=[0.2, 0.5, 0.8], device='cuda',
    use_compression=True,
):
    model.eval()

    num_attention_heads = model.config.num_attention_heads
    num_kv_heads = getattr(model.config, 'num_key_value_heads', num_attention_heads)
    head_group_size = num_attention_heads // num_kv_heads

    def group_attention_heads(attn):
        if head_group_size == 1:
            return attn
        batch, heads, *rest = attn.shape
        return attn.view(batch, num_kv_heads, head_group_size, *rest).mean(dim=2)

    results_by_position = {pos: {'correct': 0, 'total': 0} for pos in needle_positions}

    for needle_test in needle_tests:
        for position in needle_positions:
            if hasattr(scorer, 'reset'):
                scorer.reset()

            test_case = create_needle_test(
                needle_fact=needle_test['needle'],
                needle_question=needle_test['question'],
                haystack_length=haystack_length,
                needle_position=position,
                tokenizer=tokenizer,
            )

            prompt = f"""Context: {test_case['context']}

Question: {test_case['question']}
Answer:"""

            tokens = tokenizer(prompt, return_tensors='pt', max_length=512, truncation=True).input_ids.to(device)
            context_len = tokens.shape[1]

            if context_len < 100:
                continue

            outputs = model(tokens, use_cache=True, output_attentions=True, return_dict=True)
            past_kv = outputs.past_key_values

            first_token_id = outputs.logits[:, -1, :].argmax(dim=-1).item()

            if use_compression:
                keys, values = get_cache_keys_values(past_kv, 0)
                prefill_attention = outputs.attentions[-1]
                prefill_attention_grouped = group_attention_heads(prefill_attention)

                if isinstance(scorer, PrefillAttentionScorer):
                    importance = scorer(keys, values, prefill_attention_grouped)
                else:
                    importance = scorer(keys, values, None)

                if importance.dim() == 3:
                    avg_importance = importance.mean(dim=(0, 1))
                elif importance.dim() == 2:
                    avg_importance = importance.mean(dim=0)
                else:
                    avg_importance = importance.flatten()

                budget = max(int(context_len * retention_ratio), 48)
                keep_mask = torch.zeros(context_len, dtype=torch.bool, device=device)
                keep_mask[:4] = True
                keep_mask[-16:] = True

                remaining = max(0, budget - keep_mask.sum().item())
                if remaining > 0:
                    importance_masked = avg_importance.clone()
                    importance_masked[keep_mask] = -float('inf')
                    k = min(remaining, (~keep_mask).sum().item())
                    if k > 0:
                        _, top_idx = importance_masked.topk(k)
                        keep_mask[top_idx] = True

                keep_indices = keep_mask.nonzero(as_tuple=True)[0].sort()[0]

                compressed_kv = DynamicCache()
                for layer_idx in range(get_cache_length(past_kv)):
                    lk, lv = get_cache_keys_values(past_kv, layer_idx)
                    compressed_kv.update(lk.index_select(2, keep_indices), lv.index_select(2, keep_indices), layer_idx)

                current_kv = compressed_kv
            else:
                current_kv = past_kv

            generated_tokens = [first_token_id]
            current_pos = context_len

            for _ in range(29):
                if generated_tokens[-1] == tokenizer.eos_token_id:
                    break

                position_ids = torch.tensor([[current_pos]], device=device)
                input_token = torch.tensor([[generated_tokens[-1]]], device=device)

                out = model(input_token, past_key_values=current_kv, position_ids=position_ids, use_cache=True, return_dict=True)
                next_token = out.logits[:, -1, :].argmax(dim=-1).item()

                generated_tokens.append(next_token)
                current_kv = out.past_key_values
                current_pos += 1

                decoded = tokenizer.decode(generated_tokens, skip_special_tokens=True)
                if '\n' in decoded or len(decoded) > 50:
                    break

            generated = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip().lower()
            answer_key = needle_test['answer_key'].lower()

            is_correct = answer_key in generated or any(word in generated for word in answer_key.split()[:2])

            results_by_position[position]['total'] += 1
            if is_correct:
                results_by_position[position]['correct'] += 1

    return {
        pos: data['correct'] / max(1, data['total'])
        for pos, data in results_by_position.items()
    }


needle_scorers = {
    'SIP (Ours)': SIPPreGenerationScorer(sip_scorer),
    'Prefill-Attn': PrefillAttentionScorer(),
    'Position-Heuristic': PositionHeuristicScorer(),
    'Expected-Attn': ExpectedAttentionPreGenScorer(),
    'TRIM-KV': TRIMKVPreGenScorer(),
    'Random': RandomImportanceScorer(),
}

needle_retention_ratios = [0.25, 0.50, 0.75]
needle_positions = [0.2, 0.5, 0.8]
needle_results = {}

print("Baseline (Full Cache)")
baseline_needle = evaluate_needle_in_haystack(
    model=model, tokenizer=tokenizer, needle_tests=needle_tests,
    scorer=RandomImportanceScorer(), scorer_name="Baseline",
    haystack_length=350, retention_ratio=1.0,
    needle_positions=needle_positions, device=str(device),
    use_compression=False,
)
needle_results['Full Cache'] = {1.0: baseline_needle}
print(f"Full Cache: Start={baseline_needle[0.2]*100:.0f}% Middle={baseline_needle[0.5]*100:.0f}% End={baseline_needle[0.8]*100:.0f}%")

for scorer_name, scorer in needle_scorers.items():
    needle_results[scorer_name] = {}

    for ratio in needle_retention_ratios:
        result = evaluate_needle_in_haystack(
            model=model, tokenizer=tokenizer, needle_tests=needle_tests,
            scorer=scorer, scorer_name=scorer_name,
            haystack_length=350, retention_ratio=ratio,
            needle_positions=needle_positions, device=str(device),
            use_compression=True,
        )

        needle_results[scorer_name][ratio] = result

print("\nNeedle-in-Haystack Results")
print(f"Baseline (Full Cache): Start={baseline_needle[0.2]*100:.0f}% Middle={baseline_needle[0.5]*100:.0f}% End={baseline_needle[0.8]*100:.0f}%")

method_order = ['SIP (Ours)', 'Prefill-Attn', 'Position-Heuristic', 'Expected-Attn', 'TRIM-KV', 'Random']

for ratio in needle_retention_ratios:
    print(f"\n@ {int(ratio*100)}% Retention")
    print(f"{'Method':<20} | {'Start':>8} | {'Middle':>8} | {'End':>8} | {'Avg':>8}")
    print("-"*60)

    for method in method_order:
        if method in needle_results and ratio in needle_results[method]:
            res = needle_results[method][ratio]
            avg = np.mean([res[0.2], res[0.5], res[0.8]])
            print(f"{method:<20} | {res[0.2]*100:>6.0f}% | {res[0.5]*100:>6.0f}% | {res[0.8]*100:>6.0f}% | {avg*100:>6.0f}%")

print("\nBest Method by Retention Level (Average across positions)")

for ratio in needle_retention_ratios:
    best_method = None
    best_avg = -1
    for method in method_order:
        if method in needle_results and ratio in needle_results[method]:
            res = needle_results[method][ratio]
            avg = np.mean([res[0.2], res[0.5], res[0.8]])
            if avg > best_avg:
                best_avg = avg
                best_method = method
    if best_method:
        print(f"  {int(ratio*100):>2}%: {best_method} ({best_avg*100:.0f}% avg)")

In [None]:
class KeysOnlyScorer:
    def __init__(self):
        pass

    def reset(self):
        pass

    def __call__(self, keys, values, attention=None):
        batch, heads, seq_len, head_dim = keys.shape
        key_norms = keys.float().norm(dim=-1)
        importance = key_norms / (key_norms.max(dim=-1, keepdim=True)[0] + 1e-8)
        return importance


class ValuesOnlyScorer:
    def __init__(self):
        pass

    def reset(self):
        pass

    def __call__(self, keys, values, attention=None):
        batch, heads, seq_len, head_dim = values.shape
        value_norms = values.float().norm(dim=-1)
        importance = value_norms / (value_norms.max(dim=-1, keepdim=True)[0] + 1e-8)
        return importance


class KeysValuesScorer:
    def __init__(self):
        pass

    def reset(self):
        pass

    def __call__(self, keys, values, attention=None):
        batch, heads, seq_len, head_dim = keys.shape
        key_norms = keys.float().norm(dim=-1)
        value_norms = values.float().norm(dim=-1)

        recent_keys = keys[:, :, -32:, :].float()
        recent_mean = recent_keys.mean(dim=2)

        similarity = torch.einsum('bhsd,bhd->bhs', keys.float(), recent_mean)
        similarity = similarity / (head_dim ** 0.5)

        importance = (key_norms + value_norms) / 2 + 0.5 * torch.sigmoid(similarity)
        importance = importance / (importance.max(dim=-1, keepdim=True)[0] + 1e-8)
        return importance


class KVPositionScorer:
    def __init__(self):
        pass

    def reset(self):
        pass

    def __call__(self, keys, values, attention=None):
        batch, heads, seq_len, head_dim = keys.shape
        device = keys.device

        key_norms = keys.float().norm(dim=-1)
        value_norms = values.float().norm(dim=-1)
        kv_score = (key_norms + value_norms) / 2

        positions = torch.arange(seq_len, device=device, dtype=torch.float32)
        recency = positions / seq_len

        sink_bonus = torch.zeros(seq_len, device=device)
        sink_bonus[:4] = 1.0

        position_score = 0.3 * recency + 0.2 * sink_bonus
        position_score = position_score.unsqueeze(0).unsqueeze(0)

        importance = kv_score + position_score
        importance = importance / (importance.max(dim=-1, keepdim=True)[0] + 1e-8)
        return importance


class AttentionOnlyScorer:
    def __init__(self):
        self.prefill_attention = None

    def reset(self):
        self.prefill_attention = None

    def __call__(self, keys, values, attention=None):
        batch, heads, seq_len, _ = keys.shape
        device = keys.device

        if attention is not None:
            if attention.dim() == 4:
                self.prefill_attention = attention.mean(dim=2)
            else:
                self.prefill_attention = attention

        if self.prefill_attention is not None:
            importance = self.prefill_attention
            importance = importance / (importance.max(dim=-1, keepdim=True)[0] + 1e-8)
            return importance
        else:
            return torch.ones(batch, heads, seq_len, device=device)


class SingleStepSIPScorer:
    def __init__(self, sip_model):
        self.model = sip_model
        self.model.eval()

    def reset(self):
        pass

    def __call__(self, keys, values, attention=None):
        with torch.no_grad():
            positions = torch.arange(keys.shape[2], device=keys.device).unsqueeze(0)
            kv_features = self.model.kv_encoder(keys.float(), values.float())
            pred, conf = self.model.predictor(kv_features, positions, keys.shape[2], lookahead_steps=1)
            return pred[:, :, :, 0]


ablation_scorers = {
    'Random': RandomImportanceScorer(),
    'Attention-Only': AttentionOnlyScorer(),
    'Prefill-Attn': PrefillAttentionScorer(),
    'Position-Heuristic': PositionHeuristicScorer(),
    'Keys-Only': KeysOnlyScorer(),
    'Values-Only': ValuesOnlyScorer(),
    'Keys+Values': KeysValuesScorer(),
    'K+V+Position': KVPositionScorer(),
    'SIP-SingleStep': SingleStepSIPScorer(sip_scorer),
    'SIP-Full': SIPPreGenerationScorer(sip_scorer),
}

ablation_ratios = [0.25, 0.50, 0.75]
ablation_results = {name: {} for name in ablation_scorers.keys()}

num_attention_heads = model.config.num_attention_heads
num_kv_heads = getattr(model.config, 'num_key_value_heads', num_attention_heads)
head_group_size = num_attention_heads // num_kv_heads

def group_attention_heads(attn):
    if head_group_size == 1:
        return attn
    batch, heads, *rest = attn.shape
    return attn.view(batch, num_kv_heads, head_group_size, *rest).mean(dim=2)

print("Computing Full Cache Baseline")
baseline_nll = 0.0
baseline_tokens = 0

for text in tqdm(test_texts[:30], desc="Baseline"):
    tokens = tokenizer(text, return_tensors='pt', max_length=512, truncation=True).input_ids.to(device)
    seq_len = tokens.shape[1]
    if seq_len < 128:
        continue

    split_point = seq_len * 3 // 4
    context_ids = tokens[:, :split_point]
    continuation_ids = tokens[:, split_point:]

    with torch.no_grad():
        outputs = model(context_ids, use_cache=True, return_dict=True)
        current_kv = outputs.past_key_values
        current_pos = split_point

        for i in range(continuation_ids.shape[1] - 1):
            input_token = continuation_ids[:, i:i+1]
            target_token = continuation_ids[:, i+1]
            position_ids = torch.tensor([[current_pos]], device=device)

            out = model(input_token, past_key_values=current_kv, position_ids=position_ids,
                        use_cache=True, return_dict=True)
            loss = F.cross_entropy(out.logits[:, -1, :], target_token, reduction='sum')

            if not torch.isnan(loss):
                baseline_nll += loss.item()
                baseline_tokens += 1

            current_kv = out.past_key_values
            current_pos += 1

baseline_ppl = np.exp(baseline_nll / baseline_tokens) if baseline_tokens > 0 else float('inf')
print(f"Full Cache Baseline PPL: {baseline_ppl:.2f}")

for scorer_name, scorer in ablation_scorers.items():
    for ratio in ablation_ratios:
        total_nll = 0.0
        total_tokens = 0

        for text in test_texts[:30]:
            if hasattr(scorer, 'reset'):
                scorer.reset()

            tokens = tokenizer(text, return_tensors='pt', max_length=512, truncation=True).input_ids.to(device)
            seq_len = tokens.shape[1]
            if seq_len < 128:
                continue

            split_point = seq_len * 3 // 4
            context_ids = tokens[:, :split_point]
            continuation_ids = tokens[:, split_point:]

            with torch.no_grad():
                outputs = model(context_ids, use_cache=True, output_attentions=True, return_dict=True)
                past_kv = outputs.past_key_values

                keys, values = get_cache_keys_values(past_kv, 0)
                prefill_attention = outputs.attentions[-1]
                prefill_attention_grouped = group_attention_heads(prefill_attention)

                if isinstance(scorer, (AttentionOnlyScorer, PrefillAttentionScorer)):
                    importance = scorer(keys, values, prefill_attention_grouped)
                else:
                    importance = scorer(keys, values, None)

                if importance.dim() == 3:
                    avg_importance = importance.mean(dim=(0, 1))
                elif importance.dim() == 2:
                    avg_importance = importance.mean(dim=0)
                else:
                    avg_importance = importance.flatten()

                budget = max(int(split_point * ratio), 48)
                keep_mask = torch.zeros(split_point, dtype=torch.bool, device=device)
                keep_mask[:4] = True
                keep_mask[-16:] = True

                remaining = max(0, budget - keep_mask.sum().item())
                if remaining > 0:
                    importance_masked = avg_importance.clone()
                    importance_masked[keep_mask] = -float('inf')
                    k = min(remaining, (~keep_mask).sum().item())
                    if k > 0:
                        _, top_idx = importance_masked.topk(k)
                        keep_mask[top_idx] = True

                keep_indices = keep_mask.nonzero(as_tuple=True)[0].sort()[0]

                compressed_kv = DynamicCache()
                for layer_idx in range(get_cache_length(past_kv)):
                    lk, lv = get_cache_keys_values(past_kv, layer_idx)
                    compressed_kv.update(lk.index_select(2, keep_indices), lv.index_select(2, keep_indices), layer_idx)

                current_kv = compressed_kv
                current_pos = split_point

                for i in range(continuation_ids.shape[1] - 1):
                    input_token = continuation_ids[:, i:i+1]
                    target_token = continuation_ids[:, i+1]
                    position_ids = torch.tensor([[current_pos]], device=device)

                    out = model(input_token, past_key_values=current_kv, position_ids=position_ids,
                                use_cache=True, return_dict=True)
                    logits = out.logits[:, -1, :]
                    loss = F.cross_entropy(logits, target_token, reduction='sum')

                    if not torch.isnan(loss):
                        total_nll += loss.item()
                        total_tokens += 1

                    current_kv = out.past_key_values
                    current_pos += 1

        if total_tokens > 0:
            ppl = np.exp(total_nll / total_tokens)
            ablation_results[scorer_name][ratio] = ppl

print("\nAblation Study Results")
print(f"Full Cache Baseline: {baseline_ppl:.2f}")

print(f"\n{'Component':<20} |", end="")
for ratio in ablation_ratios:
    print(f" {int(ratio*100):>6}% |", end="")
print("")
print("-"*55)

method_order = ['Random', 'Attention-Only', 'Prefill-Attn', 'Position-Heuristic',
                'Keys-Only', 'Values-Only', 'Keys+Values', 'K+V+Position',
                'SIP-SingleStep', 'SIP-Full']

for scorer_name in method_order:
    if scorer_name in ablation_results:
        row = f"{scorer_name:<20} |"
        for ratio in ablation_ratios:
            if ratio in ablation_results[scorer_name]:
                ppl = ablation_results[scorer_name][ratio]
                row += f" {ppl:>6.2f} |"
            else:
                row += f" {'--':>6} |"
        print(row)

print("\nComparison @ 50% Retention (vs Random and vs Attention-Only)")

random_ppl_50 = ablation_results.get('Random', {}).get(0.50, float('inf'))
attn_only_ppl_50 = ablation_results.get('Attention-Only', {}).get(0.50, float('inf'))

print(f"\n{'Component':<20} | {'PPL':>8} | {'vs Random':>12} | {'vs Attn-Only':>14}")
print("-"*65)

for scorer_name in method_order:
    if scorer_name in ablation_results and 0.50 in ablation_results[scorer_name]:
        ppl = ablation_results[scorer_name][0.50]
        vs_random = ((random_ppl_50 - ppl) / random_ppl_50) * 100 if random_ppl_50 != float('inf') else 0
        vs_attn = ((attn_only_ppl_50 - ppl) / attn_only_ppl_50) * 100 if attn_only_ppl_50 != float('inf') else 0
        print(f"{scorer_name:<20} | {ppl:>8.2f} | {vs_random:>+10.1f}% | {vs_attn:>+12.1f}%")

sip_full_ppl = ablation_results.get('SIP-Full', {}).get(0.50, float('inf'))
prefill_attn_ppl = ablation_results.get('Prefill-Attn', {}).get(0.50, float('inf'))
pos_heur_ppl = ablation_results.get('Position-Heuristic', {}).get(0.50, float('inf'))

print("\nCritical Comparison: SIP vs Simple Baselines @ 50%")
print(f"  SIP-Full:           {sip_full_ppl:.2f}")
print(f"  Prefill-Attn:       {prefill_attn_ppl:.2f}")
print(f"  Position-Heuristic: {pos_heur_ppl:.2f}")
print(f"  Attention-Only:     {attn_only_ppl_50:.2f}")
print(f"  Random:             {random_ppl_50:.2f}")

print("\nMulti-Horizon Speculation Value @ 50%")
single_ppl = ablation_results.get('SIP-SingleStep', {}).get(0.50, float('inf'))
full_ppl = ablation_results.get('SIP-Full', {}).get(0.50, float('inf'))

if single_ppl != float('inf') and full_ppl != float('inf'):
    multi_horizon_gain = ((single_ppl - full_ppl) / single_ppl) * 100
    print(f"  Single-step SIP: {single_ppl:.2f}")
    print(f"  Full SIP (8-step): {full_ppl:.2f}")
    print(f"  Multi-horizon gain: {multi_horizon_gain:+.1f}%")

In [None]:
RUN_TRANSFER_EXPERIMENT = False

class TransferableSIPScorer:
    def __init__(self, sip_model, source_num_kv_heads=4, target_num_kv_heads=None):
        self.model = sip_model
        self.model.eval()
        self.source_num_kv_heads = source_num_kv_heads
        self.target_num_kv_heads = target_num_kv_heads or source_num_kv_heads
    
    def reset(self):
        pass
    
    def _adapt_heads(self, tensor, source_heads, target_heads):
        if source_heads == target_heads:
            return tensor
        
        batch, heads, seq_len = tensor.shape
        
        if target_heads > source_heads:
            repeat_factor = target_heads // source_heads
            tensor = tensor.repeat_interleave(repeat_factor, dim=1)
            remainder = target_heads % source_heads
            if remainder > 0:
                tensor = torch.cat([tensor, tensor[:, :remainder, :]], dim=1)
        else:
            group_size = source_heads // target_heads
            tensor = tensor.view(batch, target_heads, group_size, seq_len).mean(dim=2)
        
        return tensor[:, :target_heads, :]
    
    def __call__(self, keys, values, attention=None):
        batch, heads, seq_len, head_dim = keys.shape
        device = keys.device
        
        with torch.no_grad():
            expected_head_dim = self.model.config.head_dim
            if head_dim != expected_head_dim:
                importance = keys.float().norm(dim=-1) + values.float().norm(dim=-1)
                importance = importance / (importance.max(dim=-1, keepdim=True)[0] + 1e-8)
                return importance
            
            if heads != self.source_num_kv_heads:
                group_size = heads // self.source_num_kv_heads
                if group_size > 0:
                    keys_adapted = keys.view(batch, self.source_num_kv_heads, group_size, seq_len, head_dim).mean(dim=2)
                    values_adapted = values.view(batch, self.source_num_kv_heads, group_size, seq_len, head_dim).mean(dim=2)
                else:
                    keys_adapted = keys.repeat_interleave(self.source_num_kv_heads // heads + 1, dim=1)[:, :self.source_num_kv_heads, :, :]
                    values_adapted = values.repeat_interleave(self.source_num_kv_heads // heads + 1, dim=1)[:, :self.source_num_kv_heads, :, :]
            else:
                keys_adapted = keys
                values_adapted = values
            
            positions = torch.arange(seq_len, device=device).unsqueeze(0)
            importance = self.model(keys_adapted.float(), values_adapted.float(), positions)
            
            importance = self._adapt_heads(importance, self.source_num_kv_heads, heads)
        
        return importance


@torch.no_grad()
def evaluate_transfer(
    source_sip_model,
    target_model,
    target_tokenizer,
    texts,
    source_config,
    retention_ratio=0.5,
    max_samples=20,
    device='cuda',
):
    target_model.eval()
    
    target_num_heads = target_model.config.num_attention_heads
    target_num_kv_heads = getattr(target_model.config, 'num_key_value_heads', target_num_heads)
    target_head_dim = target_model.config.hidden_size // target_num_heads
    
    transfer_scorer = TransferableSIPScorer(
        source_sip_model,
        source_num_kv_heads=source_config.num_kv_heads,
        target_num_kv_heads=target_num_kv_heads,
    )
    
    random_scorer = RandomImportanceScorer()
    
    results = {'transfer': {'nll': 0, 'tokens': 0}, 'random': {'nll': 0, 'tokens': 0}}
    
    for text in tqdm(texts[:max_samples], desc="Transfer eval"):
        tokens = target_tokenizer(text, return_tensors='pt', max_length=512, truncation=True).input_ids.to(device)
        seq_len = tokens.shape[1]
        if seq_len < 100:
            continue
        
        split_point = seq_len * 3 // 4
        context = tokens[:, :split_point]
        continuation = tokens[:, split_point:]
        
        for scorer_name, scorer in [('transfer', transfer_scorer), ('random', random_scorer)]:
            if hasattr(scorer, 'reset'):
                scorer.reset()
            
            outputs = target_model(context, use_cache=True, return_dict=True)
            past_kv = outputs.past_key_values
            
            keys, values = get_cache_keys_values(past_kv, 0)
            importance = scorer(keys, values, None)
            
            if importance.dim() == 3:
                avg_importance = importance.mean(dim=(0, 1))
            else:
                avg_importance = importance.flatten()
            
            budget = max(int(split_point * retention_ratio), 48)
            keep_mask = torch.zeros(split_point, dtype=torch.bool, device=device)
            keep_mask[:4] = True
            keep_mask[-16:] = True
            
            remaining = max(0, budget - keep_mask.sum().item())
            if remaining > 0:
                importance_masked = avg_importance.clone()
                importance_masked[keep_mask] = -float('inf')
                k = min(remaining, (~keep_mask).sum().item())
                if k > 0:
                    _, top_idx = importance_masked.topk(k)
                    keep_mask[top_idx] = True
            
            keep_indices = keep_mask.nonzero(as_tuple=True)[0].sort()[0]
            
            compressed_kv = DynamicCache()
            for layer_idx in range(get_cache_length(past_kv)):
                lk, lv = get_cache_keys_values(past_kv, layer_idx)
                compressed_kv.update(lk.index_select(2, keep_indices), lv.index_select(2, keep_indices), layer_idx)
            
            current_kv = compressed_kv
            current_pos = split_point
            
            for i in range(continuation.shape[1] - 1):
                input_token = continuation[:, i:i+1]
                target_token = continuation[:, i+1]
                position_ids = torch.tensor([[current_pos]], device=device)
                
                out = target_model(input_token, past_key_values=current_kv, position_ids=position_ids,
                                   use_cache=True, return_dict=True)
                loss = F.cross_entropy(out.logits[:, -1, :], target_token, reduction='sum')
                
                if not torch.isnan(loss):
                    results[scorer_name]['nll'] += loss.item()
                    results[scorer_name]['tokens'] += 1
                
                current_kv = out.past_key_values
                current_pos += 1
    
    for name in results:
        if results[name]['tokens'] > 0:
            results[name]['ppl'] = np.exp(results[name]['nll'] / results[name]['tokens'])
        else:
            results[name]['ppl'] = float('inf')
    
    return results


if RUN_TRANSFER_EXPERIMENT:
    print("Running Cross-Model Transfer Experiment")
    
    try:
        from transformers import AutoModelForCausalLM, AutoTokenizer
        
        target_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
        target_tokenizer = AutoTokenizer.from_pretrained(target_model_name)
        target_model = AutoModelForCausalLM.from_pretrained(
            target_model_name,
            torch_dtype=torch.float16,
            device_map="auto",
            attn_implementation="eager",
        )
        target_model.eval()
        target_tokenizer.pad_token = target_tokenizer.eos_token
        
        transfer_results = evaluate_transfer(
            source_sip_model=sip_scorer,
            target_model=target_model,
            target_tokenizer=target_tokenizer,
            texts=test_texts,
            source_config=sip_config,
            retention_ratio=0.5,
            max_samples=20,
            device=str(device),
        )
        
        print(f"\nTarget Model: {target_model_name}")
        print(f"SIP Transfer PPL: {transfer_results['transfer']['ppl']:.2f}")
        print(f"Random PPL: {transfer_results['random']['ppl']:.2f}")
        
        improvement = (transfer_results['random']['ppl'] - transfer_results['transfer']['ppl']) / transfer_results['random']['ppl'] * 100
        print(f"Transfer improvement over random: {improvement:+.1f}%")
        
        del target_model
        torch.cuda.empty_cache()
        
    except Exception as e:
        print(f"Transfer experiment failed: {e}")
else:
    print("Transfer experiment skipped (RUN_TRANSFER_EXPERIMENT = False)")

print("\nSame-Architecture Transfer Test")

transfer_test_texts = val_texts[-50:]

same_arch_results = {}
for scorer_name, scorer in [('SIP', SIPPreGenerationScorer(sip_scorer)), ('Random', RandomImportanceScorer())]:
    total_nll = 0
    total_tokens = 0
    
    for text in tqdm(transfer_test_texts[:20], desc=f"Same-arch transfer ({scorer_name})"):
        if hasattr(scorer, 'reset'):
            scorer.reset()
        
        tokens = tokenizer(text, return_tensors='pt', max_length=512, truncation=True).input_ids.to(device)
        seq_len = tokens.shape[1]
        if seq_len < 100:
            continue
        
        split_point = seq_len * 3 // 4
        context = tokens[:, :split_point]
        continuation = tokens[:, split_point:]
        
        with torch.no_grad():
            outputs = model(context, use_cache=True, return_dict=True)
            past_kv = outputs.past_key_values
            
            keys, values = get_cache_keys_values(past_kv, 0)
            importance = scorer(keys, values, None)
            
            if importance.dim() == 3:
                avg_importance = importance.mean(dim=(0, 1))
            else:
                avg_importance = importance.flatten()
            
            budget = max(int(split_point * 0.5), 48)
            keep_mask = torch.zeros(split_point, dtype=torch.bool, device=device)
            keep_mask[:4] = True
            keep_mask[-16:] = True
            
            remaining = max(0, budget - keep_mask.sum().item())
            if remaining > 0:
                importance_masked = avg_importance.clone()
                importance_masked[keep_mask] = -float('inf')
                k = min(remaining, (~keep_mask).sum().item())
                if k > 0:
                    _, top_idx = importance_masked.topk(k)
                    keep_mask[top_idx] = True
            
            keep_indices = keep_mask.nonzero(as_tuple=True)[0].sort()[0]
            
            compressed_kv = DynamicCache()
            for layer_idx in range(get_cache_length(past_kv)):
                lk, lv = get_cache_keys_values(past_kv, layer_idx)
                compressed_kv.update(lk.index_select(2, keep_indices), lv.index_select(2, keep_indices), layer_idx)
            
            current_kv = compressed_kv
            current_pos = split_point
            
            for i in range(continuation.shape[1] - 1):
                input_token = continuation[:, i:i+1]
                target_token = continuation[:, i+1]
                position_ids = torch.tensor([[current_pos]], device=device)
                
                out = model(input_token, past_key_values=current_kv, position_ids=position_ids,
                           use_cache=True, return_dict=True)
                loss = F.cross_entropy(out.logits[:, -1, :], target_token, reduction='sum')
                
                if not torch.isnan(loss):
                    total_nll += loss.item()
                    total_tokens += 1
                
                current_kv = out.past_key_values
                current_pos += 1
    
    if total_tokens > 0:
        same_arch_results[scorer_name] = np.exp(total_nll / total_tokens)

print("\nSame-Architecture Transfer Results")
print(f"{'Method':<15} | {'PPL':>10}")
print("-"*30)
for name, ppl in same_arch_results.items():
    print(f"{name:<15} | {ppl:>10.2f}")

if 'SIP' in same_arch_results and 'Random' in same_arch_results:
    improvement = (same_arch_results['Random'] - same_arch_results['SIP']) / same_arch_results['Random'] * 100
    print(f"\nSIP improvement on held-out data: {improvement:+.1f}%")

## 10. Lookahead Accuracy with GQA Fix

In [None]:
@torch.no_grad()
def evaluate_lookahead_accuracy(
    model, tokenizer, sip_scorer, texts, 
    lookahead_steps=[1, 4, 8], num_samples=50, device='cuda'
):
    num_attention_heads = model.config.num_attention_heads
    num_kv_heads = getattr(model.config, 'num_key_value_heads', num_attention_heads)
    head_group_size = num_attention_heads // num_kv_heads

    def group_heads(attn):
        if head_group_size == 1:
            return attn
        if attn.dim() == 3:
            b, h, s = attn.shape
            return attn.view(b, num_kv_heads, head_group_size, s).mean(dim=2)
        elif attn.dim() == 4:
            b, h, q, k = attn.shape
            return attn.view(b, num_kv_heads, head_group_size, q, k).mean(dim=2)
        return attn

    results = {s: {'pearson': [], 'spearman': [], 'topk_recall': []} for s in lookahead_steps}

    for text in tqdm(texts[:num_samples], desc="Lookahead Accuracy"):
        tokens = tokenizer(text, return_tensors='pt', max_length=512, truncation=True).input_ids.to(device)
        if tokens.shape[1] < 128:
            continue

        outputs = model(tokens, use_cache=True, output_attentions=True, return_dict=True)
        past_kv = outputs.past_key_values
        seq_len = tokens.shape[1]

        keys, values = get_cache_keys_values(past_kv, 0)
        positions = torch.arange(seq_len, device=device).unsqueeze(0)

        pred_importance = sip_scorer(keys, values, positions, return_all_lookahead=True)

        actual_attns = []
        current_kv = past_kv
        current_pos = seq_len

        next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)

        for step in range(max(lookahead_steps)):
            position_ids = torch.tensor([[current_pos]], device=device)

            outputs = model(
                next_token,
                past_key_values=current_kv,
                position_ids=position_ids,
                use_cache=True,
                output_attentions=True,
                return_dict=True
            )
            current_kv = outputs.past_key_values
            current_pos += 1

            attn = outputs.attentions[-1]

            if attn.dim() == 4:
                attn = attn[:, :, -1, :]

            attn = attn[:, :, :seq_len]
            attn = attn / (attn.sum(dim=-1, keepdim=True) + 1e-8)
            actual_attns.append(group_heads(attn))

            next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)

        for step in lookahead_steps:
            if step <= len(actual_attns) and step <= pred_importance.shape[-1]:
                pred = pred_importance[:, :, :, step-1].flatten().cpu().numpy()
                actual = actual_attns[step-1].flatten().cpu().numpy()

                if len(pred) == len(actual) and len(pred) > 1:
                    pearson_corr, _ = stats.pearsonr(pred, actual)
                    if not np.isnan(pearson_corr):
                        results[step]['pearson'].append(pearson_corr)

                    spearman_corr, _ = stats.spearmanr(pred, actual)
                    if not np.isnan(spearman_corr):
                        results[step]['spearman'].append(spearman_corr)

                    k = max(1, len(pred) // 4)
                    pred_topk = set(np.argsort(pred)[-k:])
                    actual_topk = set(np.argsort(actual)[-k:])
                    recall = len(pred_topk & actual_topk) / k
                    results[step]['topk_recall'].append(recall)

    final_results = {}
    for step in lookahead_steps:
        if results[step]['pearson']:
            final_results[step] = {
                'pearson_mean': np.mean(results[step]['pearson']),
                'pearson_std': np.std(results[step]['pearson']),
                'spearman_mean': np.mean(results[step]['spearman']),
                'spearman_std': np.std(results[step]['spearman']),
                'topk_recall_mean': np.mean(results[step]['topk_recall']),
                'topk_recall_std': np.std(results[step]['topk_recall']),
                'n_samples': len(results[step]['pearson']),
            }

    return final_results


@torch.no_grad()
def evaluate_baseline_lookahead(
    model, tokenizer, texts, 
    lookahead_steps=[1, 4, 8], num_samples=30, device='cuda'
):
    num_attention_heads = model.config.num_attention_heads
    num_kv_heads = getattr(model.config, 'num_key_value_heads', num_attention_heads)
    head_group_size = num_attention_heads // num_kv_heads

    def group_heads(attn):
        if head_group_size == 1:
            return attn
        if attn.dim() == 3:
            b, h, s = attn.shape
            return attn.view(b, num_kv_heads, head_group_size, s).mean(dim=2)
        elif attn.dim() == 4:
            b, h, q, k = attn.shape
            return attn.view(b, num_kv_heads, head_group_size, q, k).mean(dim=2)
        return attn

    baselines = {
        'Prefill-Attn': [],
        'Recency': [],
        'Random': [],
    }

    results = {name: {s: [] for s in lookahead_steps} for name in baselines.keys()}

    for text in tqdm(texts[:num_samples], desc="Baseline Lookahead"):
        tokens = tokenizer(text, return_tensors='pt', max_length=512, truncation=True).input_ids.to(device)
        if tokens.shape[1] < 128:
            continue

        outputs = model(tokens, use_cache=True, output_attentions=True, return_dict=True)
        past_kv = outputs.past_key_values
        seq_len = tokens.shape[1]

        prefill_attn = outputs.attentions[-1]
        if prefill_attn.dim() == 4:
            prefill_attn = prefill_attn.mean(dim=2)
        prefill_attn = group_heads(prefill_attn)
        prefill_pred = prefill_attn.flatten().cpu().numpy()

        recency = torch.arange(seq_len, device=device, dtype=torch.float32) / seq_len
        recency = recency.unsqueeze(0).unsqueeze(0).expand(1, num_kv_heads, -1)
        recency_pred = recency.flatten().cpu().numpy()

        random_pred = np.random.rand(num_kv_heads * seq_len)

        actual_attns = []
        current_kv = past_kv
        current_pos = seq_len
        next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)

        for step in range(max(lookahead_steps)):
            position_ids = torch.tensor([[current_pos]], device=device)
            outputs = model(next_token, past_key_values=current_kv, position_ids=position_ids,
                            use_cache=True, output_attentions=True, return_dict=True)
            current_kv = outputs.past_key_values
            current_pos += 1

            attn = outputs.attentions[-1]
            if attn.dim() == 4:
                attn = attn[:, :, -1, :]
            attn = attn[:, :, :seq_len]
            attn = attn / (attn.sum(dim=-1, keepdim=True) + 1e-8)
            actual_attns.append(group_heads(attn))

            next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)

        for step in lookahead_steps:
            if step <= len(actual_attns):
                actual = actual_attns[step-1].flatten().cpu().numpy()

                for name, pred in [('Prefill-Attn', prefill_pred),
                                    ('Recency', recency_pred),
                                    ('Random', random_pred)]:
                    if len(pred) == len(actual):
                        corr, _ = stats.spearmanr(pred, actual)
                        if not np.isnan(corr):
                            results[name][step].append(corr)

    final = {}
    for name in baselines.keys():
        final[name] = {}
        for step in lookahead_steps:
            if results[name][step]:
                final[name][step] = {
                    'mean': np.mean(results[name][step]),
                    'std': np.std(results[name][step]),
                }

    return final


print("SIP Lookahead Prediction Accuracy")
lookahead_results = evaluate_lookahead_accuracy(
    model, tokenizer, sip_scorer, test_texts,
    lookahead_steps=[1, 4, 8], num_samples=50, device=str(device)
)

print(f"\n{'Step':<8} | {'Pearson':>12} | {'Spearman':>12} | {'Top-25% Recall':>15}")
print("-"*55)
for step, metrics in lookahead_results.items():
    print(f"{step}-step   | {metrics['pearson_mean']:.3f} +/- {metrics['pearson_std']:.3f} | "
        f"{metrics['spearman_mean']:.3f} +/- {metrics['spearman_std']:.3f} | "
        f"{metrics['topk_recall_mean']:.3f} +/- {metrics['topk_recall_std']:.3f}")

print("\nBaseline Comparison (Spearman Correlation)")
baseline_results = evaluate_baseline_lookahead(
    model, tokenizer, test_texts,
    lookahead_steps=[1, 4, 8], num_samples=30, device=str(device)
)

print(f"\n{'Method':<15} | {'1-step':>12} | {'4-step':>12} | {'8-step':>12}")
print("-"*55)

for name in ['Prefill-Attn', 'Recency', 'Random']:
    row = f"{name:<15} |"
    for step in [1, 4, 8]:
        if step in baseline_results.get(name, {}):
            m = baseline_results[name][step]
            row += f" {m['mean']:.3f} +/- {m['std']:.3f} |"
        else:
            row += f" {'--':>12} |"
    print(row)

row = f"{'SIP (Ours)':<15} |"
for step in [1, 4, 8]:
    if step in lookahead_results:
        m = lookahead_results[step]
        row += f" {m['spearman_mean']:.3f} +/- {m['spearman_std']:.3f} |"
    else:
        row += f" {'--':>12} |"
print(row)

print("\nAnalysis")
for step in [1, 4, 8]:
    if step in lookahead_results and 'Prefill-Attn' in baseline_results and step in baseline_results['Prefill-Attn']:
        sip_corr = lookahead_results[step]['spearman_mean']
        prefill_corr = baseline_results['Prefill-Attn'][step]['mean']
        diff = sip_corr - prefill_corr
        print(f"  {step}-step: SIP {sip_corr:.3f} vs Prefill-Attn {prefill_corr:.3f} (d={diff:+.3f})")

## 11. Confidence Calibration with Temperature Scaling

In [None]:
@torch.no_grad()
def evaluate_calibration(
    model, tokenizer, sip_scorer, texts, 
    num_samples=50, num_bins=10, device='cuda'
):
    num_attention_heads = model.config.num_attention_heads
    num_kv_heads = getattr(model.config, 'num_key_value_heads', num_attention_heads)
    head_group_size = num_attention_heads // num_kv_heads

    def group_heads(attn):
        if head_group_size == 1:
            return attn
        if attn.dim() == 3:
            b, h, s = attn.shape
            return attn.view(b, num_kv_heads, head_group_size, s).mean(dim=2)
        elif attn.dim() == 4:
            b, h, q, k = attn.shape
            return attn.view(b, num_kv_heads, head_group_size, q, k).mean(dim=2)
        return attn

    all_conf = []
    all_correct = []
    all_rank_error = []

    for text in tqdm(texts[:num_samples], desc="Calibration"):
        tokens = tokenizer(text, return_tensors='pt', max_length=512, truncation=True).input_ids.to(device)
        seq_len = tokens.shape[1]
        if seq_len < 128:
            continue

        outputs = model(tokens, use_cache=True, output_attentions=True, return_dict=True)
        past_kv = outputs.past_key_values

        keys, values = get_cache_keys_values(past_kv, 0)
        positions = torch.arange(seq_len, device=device).unsqueeze(0)

        kv_features = sip_scorer.kv_encoder(keys.float(), values.float())
        pred, conf = sip_scorer.predictor(kv_features, positions, seq_len)

        next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)
        position_ids = torch.tensor([[seq_len]], device=device)

        outputs = model(
            next_token,
            past_key_values=past_kv,
            position_ids=position_ids,
            output_attentions=True,
            return_dict=True
        )

        actual = outputs.attentions[-1]
        if actual.dim() == 4:
            actual = actual[:, :, -1, :]
        actual = actual[:, :, :seq_len]
        actual = actual / (actual.sum(dim=-1, keepdim=True) + 1e-8)
        actual = group_heads(actual)

        pred_step0 = pred[:, :, :, 0]
        conf_step0 = conf[:, :, :, 0]

        k = max(1, seq_len // 4)

        for h in range(pred_step0.shape[1]):
            pred_h = pred_step0[0, h, :].cpu().numpy()
            actual_h = actual[0, h, :].cpu().numpy()
            conf_h = conf_step0[0, h, :].cpu().numpy()

            pred_topk = set(np.argsort(pred_h)[-k:])
            actual_topk = set(np.argsort(actual_h)[-k:])

            for i in range(seq_len):
                pred_is_topk = i in pred_topk
                actual_is_topk = i in actual_topk
                is_correct = (pred_is_topk == actual_is_topk)

                all_conf.append(conf_h[i])
                all_correct.append(float(is_correct))

                pred_rank = np.argsort(np.argsort(pred_h))[i] / seq_len
                actual_rank = np.argsort(np.argsort(actual_h))[i] / seq_len
                all_rank_error.append(abs(pred_rank - actual_rank))

    return {
        'confidence': np.array(all_conf),
        'correct': np.array(all_correct),
        'rank_error': np.array(all_rank_error),
    }


def compute_ece(confidence, accuracy, num_bins=10):
    bins = np.linspace(0, 1, num_bins + 1)
    ece = 0.0
    bin_stats = []

    for i in range(num_bins):
        in_bin = (confidence > bins[i]) & (confidence <= bins[i+1])
        bin_size = in_bin.sum()

        if bin_size > 0:
            bin_acc = accuracy[in_bin].mean()
            bin_conf = confidence[in_bin].mean()
            bin_error = abs(bin_acc - bin_conf)
            ece += (bin_size / len(confidence)) * bin_error
            bin_stats.append({
                'bin': f"({bins[i]:.1f}, {bins[i+1]:.1f}]",
                'count': bin_size,
                'avg_conf': bin_conf,
                'avg_acc': bin_acc,
                'error': bin_error,
            })
        else:
            bin_stats.append({
                'bin': f"({bins[i]:.1f}, {bins[i+1]:.1f}]",
                'count': 0,
                'avg_conf': 0,
                'avg_acc': 0,
                'error': 0,
            })

    return ece, bin_stats


def apply_temperature_scaling(confidence, correct, num_iters=50):
    best_ece = float('inf')
    best_temp = 1.0

    for temp in np.logspace(-1, 1, 50):
        scaled_conf = 1 / (1 + np.exp(-np.log(confidence / (1 - confidence + 1e-8)) / temp))
        scaled_conf = np.clip(scaled_conf, 0.01, 0.99)
        ece, _ = compute_ece(scaled_conf, correct)
        if ece < best_ece:
            best_ece = ece
            best_temp = temp

    return best_temp, best_ece


print("Collecting Calibration Data")
calib_data = evaluate_calibration(
    model, tokenizer, sip_scorer, test_texts,
    num_samples=50, device=str(device)
)

confidence = calib_data['confidence']
correct = calib_data['correct']
rank_error = calib_data['rank_error']

print(f"Collected {len(confidence)} confidence-accuracy pairs")
print(f"Mean confidence: {confidence.mean():.3f}")
print(f"Mean accuracy: {correct.mean():.3f}")
print(f"Mean rank error: {rank_error.mean():.3f}")

print("\nECE Before Temperature Scaling")
ece_before, bins_before = compute_ece(confidence, correct)
print(f"ECE: {ece_before:.4f}")

print(f"\n{'Bin':<15} | {'Count':>8} | {'Avg Conf':>10} | {'Avg Acc':>10} | {'Error':>8}")
print("-"*60)
for b in bins_before:
    if b['count'] > 0:
        print(f"{b['bin']:<15} | {b['count']:>8} | {b['avg_conf']:>10.3f} | {b['avg_acc']:>10.3f} | {b['error']:>8.3f}")

print("\nTemperature Scaling")
optimal_temp, ece_after = apply_temperature_scaling(confidence, correct)
print(f"Optimal temperature: {optimal_temp:.4f}")
print(f"ECE after scaling: {ece_after:.4f}")
print(f"ECE improvement: {((ece_before - ece_after) / ece_before) * 100:.1f}%")

scaled_conf = 1 / (1 + np.exp(-np.log(confidence / (1 - confidence + 1e-8)) / optimal_temp))
scaled_conf = np.clip(scaled_conf, 0.01, 0.99)
_, bins_after = compute_ece(scaled_conf, correct)

print(f"\n{'Bin':<15} | {'Count':>8} | {'Avg Conf':>10} | {'Avg Acc':>10} | {'Error':>8}")
print("-"*60)
for b in bins_after:
    if b['count'] > 0:
        print(f"{b['bin']:<15} | {b['count']:>8} | {b['avg_conf']:>10.3f} | {b['avg_acc']:>10.3f} | {b['error']:>8.3f}")

print("\nCalibration Summary")
print(f"ECE (before): {ece_before:.4f}")
print(f"ECE (after):  {ece_after:.4f}")
print(f"Optimal T:    {optimal_temp:.4f}")

calib_results = {
    'ece_before': ece_before,
    'ece_after': ece_after,
    'optimal_temperature': optimal_temp,
    'mean_confidence': float(confidence.mean()),
    'mean_accuracy': float(correct.mean()),
}

## 12. Final Results Summary

In [None]:
print("Final Results Summary")
print(f"Model: TinyLlama-1.1B")
print(f"SIP Parameters: {sum(p.numel() for p in sip_scorer.parameters()):,}")

if 'pregeneration_results' in dir():
    print(f"\nBaseline (Full Cache): {baseline_ppl:.2f}")
    print(f"\n{'Method':<20} |", end="")
    for ratio in [0.10, 0.25, 0.50, 0.75]:
        print(f" {int(ratio*100):>5}% |", end="")
    print("")
    print("-"*55)

    method_order = ['SIP (Ours)', 'Prefill-Attn', 'Position-Heuristic',
                    'Expected-Attn', 'TRIM-KV', 'Random']
    for method in method_order:
        if method in pregeneration_results:
            row = f"{method:<20} |"
            for ratio in [0.10, 0.25, 0.50, 0.75]:
                if ratio in pregeneration_results[method]:
                    ppl = pregeneration_results[method][ratio]
                    row += f" {ppl:>5.2f} |"
                else:
                    row += f" {'--':>5} |"
            print(row)

if 'multiseed_results' in dir():
    print(f"\nMulti-Seed Results (5 seeds, 95% CI)")
    print(f"\n{'Method':<20} |", end="")
    for ratio in [0.10, 0.25, 0.50, 0.75]:
        print(f" {int(ratio*100):>12}% |", end="")
    print("")
    print("-"*70)

    for method in ['SIP (Ours)', 'Prefill-Attn', 'Position-Heuristic', 'Random']:
        if method in multiseed_results:
            row = f"{method:<20} |"
            for ratio in [0.10, 0.25, 0.50, 0.75]:
                if ratio in multiseed_results[method]:
                    m = multiseed_results[method][ratio]
                    row += f" {m['mean']:.2f}+/-{m['ci_95']:.2f} |"
                else:
                    row += f" {'--':>12} |"
            print(row)

if 'lookahead_results' in dir():
    print(f"\nLookahead Prediction Accuracy")
    print(f"\n{'Step':<8} | {'Spearman':>12} | {'Top-25% Recall':>15}")
    print("-"*45)
    for step in [1, 4, 8]:
        if step in lookahead_results:
            m = lookahead_results[step]
            spearman = m.get('spearman_mean', m.get('mean', 0))
            spearman_std = m.get('spearman_std', m.get('std', 0))
            recall = m.get('topk_recall_mean', 0)
            recall_std = m.get('topk_recall_std', 0)
            print(f"{step}-step   | {spearman:.3f} +/- {spearman_std:.3f} | {recall:.3f} +/- {recall_std:.3f}")

if 'calib_results' in dir():
    print(f"\nConfidence Calibration")
    print(f"  ECE (before): {calib_results['ece_before']:.4f}")
    print(f"  ECE (after):  {calib_results['ece_after']:.4f}")
    print(f"  Optimal T:    {calib_results['optimal_temperature']:.4f}")
elif 'ece_before' in dir():
    print(f"\nConfidence Calibration")
    print(f"  ECE (before): {ece_before:.4f}")
    print(f"  ECE (after):  {ece_after:.4f}")
    print(f"  Optimal T:    {optimal_temp:.4f}")

def to_serializable(obj):
    if isinstance(obj, dict):
        return {str(k): to_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [to_serializable(v) for v in obj]
    elif isinstance(obj, (np.floating, np.integer)):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif hasattr(obj, 'item'):
        return obj.item()
    else:
        return obj

final_results = {
    'model': 'TinyLlama-1.1B',
    'sip_parameters': sum(p.numel() for p in sip_scorer.parameters()),
    'training_config': TRAINING_CONFIG,
}

if 'baseline_ppl' in dir():
    final_results['baseline_perplexity'] = baseline_ppl

if 'pregeneration_results' in dir():
    final_results['perplexity_results'] = to_serializable(pregeneration_results)

if 'multiseed_results' in dir():
    final_results['multiseed_results'] = to_serializable(multiseed_results)

if 'qa_results' in dir():
    final_results['qa_accuracy'] = to_serializable(qa_results)

if 'needle_results' in dir():
    final_results['needle_in_haystack'] = to_serializable(needle_results)

if 'ablation_results' in dir():
    final_results['ablation_study'] = to_serializable(ablation_results)

if 'lookahead_results' in dir():
    final_results['lookahead_accuracy'] = to_serializable(lookahead_results)

if 'calib_results' in dir():
    final_results['calibration'] = to_serializable(calib_results)
elif 'ece_before' in dir():
    final_results['calibration'] = {
        'ece_before': float(ece_before),
        'ece_after': float(ece_after),
        'optimal_temperature': float(optimal_temp),
    }

if 'history' in dir():
    final_results['training_history'] = to_serializable(history)

with open('sip_comprehensive_results.json', 'w') as f:
    json.dump(final_results, f, indent=2)
print("\nResults saved to sip_comprehensive_results.json")

## Conclusions

This notebook provides comprehensive experiments for our negative results paper on learned KV cache compression.

### Key Findings

1. **SIP does not outperform simple heuristics** across any retention level or task
2. **Position-based methods** (sinks + recent) win at aggressive compression (10-25%)
3. **Prefill attention** wins at moderate compression (50-75%)
4. **SIP ≈ Random** — no statistically significant difference in multi-seed evaluation

### Practical Recommendations

For practitioners deploying KV cache compression:
- **Aggressive compression (10-25%):** Use Position-Heuristic (keep sinks + recent tokens)
- **Moderate compression (50-75%):** Use Prefill-Attn (attention from prompt processing)
- **Learned methods provide no benefit** in our evaluation setting

### Implications for Research

Our results suggest that for **non-query-aware** importance scoring under fixed-budget compression:
- Simple heuristics capture most predictable signal
- Marginal information in KV representations beyond position appears limited
- The circular dependence between future queries and generation may be a fundamental barrier

See the paper for full analysis and discussion of the query-aware vs non-query-aware distinction.