In [None]:
!pip install -q "datasets<=3.0" rouge_score transformers accelerate
!huggingface-cli login --token 

In [None]:
import torch, torch.nn.functional as F, math, json, re, string
from collections import Counter
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from tqdm import tqdm
import types

MODEL = 'meta-llama/Meta-Llama-3-8B-Instruct'
BUDGET, WINDOW, SINK = 2048, 64, 4
SAMPLES = 20

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL)
# Force eager attention (disable FlashAttention) to ensure consistent behavior
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL, 
    torch_dtype=torch.float16, 
    device_map='auto',
    attn_implementation='eager'  # Force standard attention
)
base_model.eval()
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token

# Verify the attention class
print(f"Attention class: {type(base_model.model.layers[0].self_attn)}")
print(f"Model loaded successfully")

In [None]:
def atom_h2o(attn, window, sink):
    # attn: [bsz, heads, seq_len, seq_len] - FULL attention matrix with ALL queries
    # Sum across ALL query positions (dim=-2), average across batch and heads
    # This matches KVCache-Factory H2O exactly (line 554 in pyramidkv_utils.py)
    return attn[:, :, :, :-window].sum(dim=-2).mean(dim=(0,1))

def atom_spectral(attn, window, sink):
    # attn: [bsz, heads, seq_len, seq_len] - FULL attention matrix
    # Compute dominant eigenvector of attention matrix
    A = attn.mean(dim=(0,1))[:, :-window]
    v = torch.ones(A.shape[1], device=A.device, dtype=A.dtype)
    for _ in range(10):
        v = A.T @ (A @ v)
        v = v / (v.norm() + 1e-8)
    return v

def atom_sharpness(attn, window, sink):
    # attn: [bsz, heads, seq_len, seq_len] - FULL attention matrix
    # Measure entropy - low entropy = sharp/focused attention
    a = attn.mean(dim=(0,1))[:, :-window]
    a = a / (a.sum(dim=0, keepdim=True) + 1e-8)
    ent = -(a * (a + 1e-8).log()).sum(dim=0)
    return math.log(window) - ent

def atom_combined(attn, window, sink):
    h = atom_h2o(attn, window, sink)
    s = atom_sharpness(attn, window, sink)
    h = h / (h.max() + 1e-8)
    s = s / (s.max() + 1e-8)
    return torch.maximum(h, s)

def bond_walker(attn, window, sink, walkers=512, steps=100):
    # attn: [bsz, heads, seq_len, seq_len] - FULL attention matrix
    # Random walk from last position backward following attention edges
    A = attn.mean(dim=(0,1))  # [seq_len, seq_len]
    seq_len = A.shape[0]
    counts = torch.zeros(seq_len, device=A.device)
    
    for _ in range(walkers):
        pos = seq_len - 1  # Start at last query position
        for _ in range(steps):
            counts[pos] += 1
            if pos < sink: break  # Absorbed at sink
            
            # Sample next position based on attention from current query
            probs = A[pos, :pos+1]  # Attention to keys 0..pos (causal)
            if probs.sum() < 1e-8: break
            pos = torch.multinomial(probs / probs.sum(), 1).item()
    
    return counts[:-window]

def bond_multisource(attn, window, sink, walkers=256, steps=100):
    # attn: [bsz, heads, seq_len, seq_len] - FULL attention matrix
    # Random walks starting from ALL positions in the recent window
    A = attn.mean(dim=(0,1))  # [seq_len, seq_len]
    seq_len = A.shape[0]
    counts = torch.zeros(seq_len, device=A.device)
    
    for window_pos in range(window):
        start_pos = seq_len - window + window_pos  # Position in recent window
        for _ in range(walkers // window):
            pos = start_pos
            for _ in range(steps):
                counts[pos] += 1
                if pos < sink: break  # Absorbed at sink
                
                # Sample next position based on attention from current query
                probs = A[pos, :pos+1]  # Attention to keys 0..pos (causal)
                if probs.sum() < 1e-8: break
                pos = torch.multinomial(probs / probs.sum(), 1).item()
    
    return counts[:-window]

ATOMS = {'h2o': atom_h2o, 'spectral': atom_spectral, 'sharpness': atom_sharpness, 'combined': atom_combined}
BONDS = {'none': lambda a,w,s: torch.zeros_like(atom_h2o(a,w,s)), 'walker': bond_walker, 'multi': bond_multisource}

def combine_static(atom, bond, budget, ratio=0.8):
    atom = atom / (atom.max() + 1e-8)
    bond = bond / (bond.max() + 1e-8)
    n_atom = int(budget * ratio)
    n_bond = budget - n_atom
    atom_idx = atom.topk(min(n_atom, len(atom))).indices
    mask = torch.ones_like(bond, dtype=torch.bool)
    mask[atom_idx] = False
    bond_idx = bond[mask].topk(min(n_bond, mask.sum())).indices
    keep = torch.zeros(len(atom), dtype=torch.bool, device=atom.device)
    keep[atom_idx] = True
    keep[mask.nonzero(as_tuple=True)[0][bond_idx]] = True
    return keep

def combine_max(atom, bond, budget):
    atom = atom / (atom.max() + 1e-8)
    bond = bond / (bond.max() + 1e-8)
    scores = torch.maximum(atom, bond)
    idx = scores.topk(min(budget, len(scores))).indices
    keep = torch.zeros(len(atom), dtype=torch.bool, device=atom.device)
    keep[idx] = True
    return keep

def combine_dynamic(atom, bond, budget):
    atom_top = set(atom.topk(min(budget, len(atom))).indices.tolist())
    bond_top = set(bond.topk(min(budget, len(bond))).indices.tolist())
    overlap = len(atom_top & bond_top) / budget if budget > 0 else 0
    ratio = 0.6 + 0.3 * overlap
    return combine_static(atom, bond, budget, ratio)

COMBINES = {'static80': lambda a,b,n: combine_static(a,b,n,0.8), 'static70': lambda a,b,n: combine_static(a,b,n,0.7), 
            'max': combine_max, 'dynamic': combine_dynamic}

In [None]:
# Store original forward methods
_original_forwards = {}
_debug_eviction = {'count': 0}

def patch_model(model, atom_fn, bond_fn, combine_fn, budget=2048, window=64, sink=4):
    """Patch model attention layers for KV eviction. Follows KVCache-Factory structure."""
    _debug_eviction['count'] = 0

    config = model.config
    
    # Get the rotary embedding from the model (location varies by transformers version)
    if hasattr(model.model.layers[0].self_attn, 'rotary_emb'):
        rotary_emb = model.model.layers[0].self_attn.rotary_emb
    else:
        rotary_emb = model.model.rotary_emb

    # Compute attention dimensions from config
    num_heads = config.num_attention_heads
    num_key_value_heads = getattr(config, 'num_key_value_heads', num_heads)
    head_dim = config.hidden_size // num_heads
    num_key_value_groups = num_heads // num_key_value_heads
    hidden_size = config.hidden_size

    # Store config on each attention module
    for layer_idx in range(len(model.model.layers)):
        attn = model.model.layers[layer_idx].self_attn
        attn._poc_atom_fn = atom_fn
        attn._poc_bond_fn = bond_fn
        attn._poc_combine_fn = combine_fn
        attn._poc_budget = budget
        attn._poc_window = window
        attn._poc_sink = sink
        attn._poc_kv_seq_len = 0
        attn._poc_rotary_emb = rotary_emb
        # Store computed values
        attn._poc_num_heads = num_heads
        attn._poc_num_key_value_heads = num_key_value_heads
        attn._poc_head_dim = head_dim
        attn._poc_num_key_value_groups = num_key_value_groups
        attn._poc_hidden_size = hidden_size

    def patched_forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask = None,
        position_ids = None,
        past_key_value = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position = None,
        position_embeddings = None,
        **kwargs,
    ):
        bsz, q_len, _ = hidden_states.size()
        
        # Get all config from stored attributes
        atom_fn = self._poc_atom_fn
        bond_fn = self._poc_bond_fn
        combine_fn = self._poc_combine_fn
        budget = self._poc_budget
        window = self._poc_window
        sink = self._poc_sink
        rotary_emb = self._poc_rotary_emb
        num_heads = self._poc_num_heads
        num_key_value_heads = self._poc_num_key_value_heads
        head_dim = self._poc_head_dim
        num_key_value_groups = self._poc_num_key_value_groups
        hidden_size = self._poc_hidden_size

        # Project Q, K, V
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)

        # Track KV sequence length
        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            if self._poc_kv_seq_len != 0:
                kv_seq_len += self._poc_kv_seq_len
            else:
                kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

        # Apply RoPE
        if position_embeddings is None:
            cos, sin = rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
        
        # Expand KV for GQA
        key_states = repeat_kv(key_states, num_key_value_groups)
        value_states = repeat_kv(value_states, num_key_value_groups)

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            is_prefill = (key_states.shape[-2] == kv_seq_len)
            
            if is_prefill and q_len >= budget:
                # PREFILL WITH EVICTION
                self._poc_kv_seq_len = kv_seq_len
                
                if self.layer_idx == 0 and _debug_eviction['count'] == 0:
                    print(f'[EVICTION] q_len={q_len} >= budget={budget}')
                    _debug_eviction['count'] += 1
                
                # Compute attention for scoring
                attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
                
                # Causal mask for window
                mask = torch.full((window, window), torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
                mask_cond = torch.arange(window, device=attn_weights.device)
                mask.masked_fill_(mask_cond < (mask_cond + 1).view(window, 1), 0)
                attn_weights[:, :, -window:, -window:] += mask[None, None, :, :]
                
                attn_weights_for_score = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
                
                # Score tokens
                atom = atom_fn(attn_weights_for_score, window, sink)
                bond = bond_fn(attn_weights_for_score, window, sink)
                
                # Keep: sink + selected middle + window
                atom_middle = atom[sink:]
                bond_middle = bond[sink:]
                keep_n = budget - window - sink
                keep_middle = combine_fn(atom_middle, bond_middle, min(keep_n, len(atom_middle)))
                
                keep_mask = torch.zeros(q_len, dtype=torch.bool, device=key_states.device)
                keep_mask[:sink] = True
                keep_mask[sink:sink+len(keep_middle)] = keep_middle
                keep_mask[-window:] = True
                
                keep_indices = keep_mask.nonzero(as_tuple=True)[0]
                idx = keep_indices.view(1, 1, -1, 1).expand(bsz, num_heads, -1, head_dim)
                
                key_compress = key_states.gather(dim=2, index=idx)
                val_compress = value_states.gather(dim=2, index=idx)
                past_key_value.update(key_compress, val_compress, self.layer_idx, cache_kwargs)
                
            elif is_prefill:
                # PREFILL WITHOUT EVICTION
                self._poc_kv_seq_len = kv_seq_len
                past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
            else:
                # GENERATION
                self._poc_kv_seq_len += q_len
                key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
            
            past_key_value._seen_tokens = self._poc_kv_seq_len

        # Compute attention output
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask[:, :, :, :key_states.shape[-2]]
        attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_output = torch.matmul(attn_weights, value_states)

        attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, hidden_size)
        attn_output = self.o_proj(attn_output)

        return attn_output, None if not output_attentions else attn_weights, past_key_value

    # Patch all layers
    for layer_idx in range(len(model.model.layers)):
        model.model.layers[layer_idx].self_attn.forward = types.MethodType(
            patched_forward, model.model.layers[layer_idx].self_attn
        )

    return model

def reset_model(model):
    """Reset kv_seq_len for each layer before new generation."""
    for layer in model.model.layers:
        if hasattr(layer.self_attn, '_poc_kv_seq_len'):
            layer.self_attn._poc_kv_seq_len = 0

In [None]:
@torch.no_grad()
def generate_with_eviction(prompt, model, max_new=128):
    torch.cuda.empty_cache()
    
    # Reset KV sequence tracking before each generation
    reset_model(model)
    
    inp = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=7500).to('cuda')
    seq_len = inp.input_ids.shape[1]
    
    out = model.generate(
        **inp, 
        max_new_tokens=max_new, 
        do_sample=False, 
        pad_token_id=tokenizer.eos_token_id, 
        use_cache=True
    )
    result = tokenizer.decode(out[0][seq_len:], skip_special_tokens=True)
    torch.cuda.empty_cache()
    return result

def normalize(s):
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)
    def white_space_fix(text):
        return ' '.join(text.split())
    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)
    def lower(text):
        return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(s))))

def f1_score_tokens(prediction_tokens, ground_truth_tokens):
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def qa_f1(pred, gold):
    normalized_pred = normalize(pred)
    normalized_gold = normalize(gold)
    pred_tokens = normalized_pred.split()
    gold_tokens = normalized_gold.split()
    if not pred_tokens or not gold_tokens:
        return 0
    return f1_score_tokens(pred_tokens, gold_tokens)

def eval_task(task, atom_fn, bond_fn, combine_fn, n=SAMPLES):
    model = patch_model(base_model, atom_fn, bond_fn, combine_fn, BUDGET, WINDOW, SINK)
    
    task2maxlen = {'qasper': 128, 'narrativeqa': 128, 'hotpotqa': 32}
    max_new = task2maxlen.get(task, 64)
    
    ds = load_dataset('THUDM/LongBench', task, split='test', trust_remote_code=True)
    scores = []
    for i, item in enumerate(tqdm(ds, total=min(n, len(ds)))):
        if i >= n: break
        if task == 'qasper':
            prompt = f"You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {item['context']}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {item['input']}\n\nAnswer:"
        elif task == 'narrativeqa':
            prompt = f"You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {item['context']}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {item['input']}\n\nAnswer:"
        else:
            prompt = f"Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{item['context']}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {item['input']}\nAnswer:"
        
        pred = generate_with_eviction(prompt, model, max_new)
        score = max(qa_f1(pred, ans) for ans in item['answers'])
        scores.append(score * 100)
        
        if i == 0:
            print(f'\n[Sample 0] Pred: {pred[:200]}')
            print(f'  Gold: {item["answers"][:2]}')
            print(f'  F1: {score:.2f}')
    
    return sum(scores)/len(scores)

# ============= BASELINE TEST (NO PATCHING) =============
print("=" * 50)
print("BASELINE TEST: No patching (original model)")
print("=" * 50)

@torch.no_grad()
def baseline_test():
    """Test the base model WITHOUT patching to verify it works."""
    ds = load_dataset('THUDM/LongBench', 'qasper', split='test', trust_remote_code=True)
    item = ds[0]
    prompt = f"You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {item['context']}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {item['input']}\n\nAnswer:"
    
    inp = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=7500).to('cuda')
    seq_len = inp.input_ids.shape[1]
    print(f"Input length: {seq_len} tokens")
    
    out = base_model.generate(**inp, max_new_tokens=128, do_sample=False, pad_token_id=tokenizer.eos_token_id, use_cache=True)
    result = tokenizer.decode(out[0][seq_len:], skip_special_tokens=True)
    
    print(f"Prediction: {result[:300]}")
    print(f"Gold: {item['answers'][:2]}")
    score = max(qa_f1(result, ans) for ans in item['answers'])
    print(f"F1 Score: {score:.2%}")
    return score

baseline_score = baseline_test()
print(f"\nBaseline F1: {baseline_score:.2%}")
print("=" * 50)

In [None]:
exp1 = {}
for name, fn in ATOMS.items():
    exp1[name] = eval_task('qasper', fn, BONDS['none'], COMBINES['static80'])
    print(f'{name}: {exp1[name]:.2f}')

In [None]:
best_atom = max(exp1, key=exp1.get)
exp2 = {}
for name, fn in BONDS.items():
    exp2[name] = eval_task('qasper', ATOMS[best_atom], fn, COMBINES['max'])
    print(f'{name}: {exp2[name]:.2f}')

In [None]:
best_bond = max(exp2, key=exp2.get)
exp3 = {}
for name, fn in COMBINES.items():
    exp3[name] = eval_task('qasper', ATOMS[best_atom], BONDS[best_bond], fn)
    print(f'{name}: {exp3[name]:.2f}')

In [None]:
best_combine = max(exp3, key=exp3.get)
exp4 = {}
for task in ['qasper', 'narrativeqa', 'hotpotqa']:
    exp4[task] = eval_task(task, ATOMS[best_atom], BONDS[best_bond], COMBINES[best_combine], 30)
    print(f'{task}: {exp4[task]:.2f}')

In [None]:
print('\n=== RESULTS ===')
print('Exp1 atoms:', exp1)
print('Exp2 bonds:', exp2)
print('Exp3 combine:', exp3)
print('Exp4 tasks:', exp4)
print(f'\nBest: atom={best_atom}, bond={best_bond}, combine={best_combine}')
print(f'Avg: {sum(exp4.values())/len(exp4):.2f}')
json.dump({'exp1':exp1,'exp2':exp2,'exp3':exp3,'exp4':exp4,'best':{'atom':best_atom,'bond':best_bond,'combine':best_combine}}, 
          open('poc_results.json','w'), indent=2)