In [None]:
import torch
from torch import nn
from transformers import pipeline

from kvpress import SimLayerKVPress

In [None]:
# Load pipeline
device = "cuda:0"
ckpt = "Qwen/Qwen2.5-1.5B-Instruct"  
pipe = pipeline(
    "kv-press-text-generation", 
    model=ckpt, 
    device=device, 
    torch_dtype="auto", 
)




In [None]:
# Test data for both prefilling and decoding
context = """SimLayerKV is a method for efficient transformer inference that identifies and optimizes 
lazy attention layers. It works in two phases: prefilling and decoding. During prefilling, it analyzes 
the last w_last tokens to identify lazy layers. During decoding, it examines the attention patterns of 
the first generated token."""

question = "\nWhat are the two phases of SimLayerKV?"

# Tokenize
tokens = pipe.tokenizer(context, return_tensors="pt").to(device)

In [None]:

# Test prefilling phase
press = SimLayerKVPress(
    initial_tokens=4,
    recent_tokens=1024,
    w_last=32,
    window_size=32,
    compression_ratio=0.85 # according to Original implmentation for qwen model 0.85 compression or threshold 
)

print("Testing Prefilling Phase:")
print("-" * 50)

with torch.no_grad():
    outputs_without_press = pipe.model(**tokens, output_hidden_states=True)

with torch.no_grad(), press(pipe.model):
    output_with_press = pipe.model(**tokens)

print(f"Original cache shape: {outputs_without_press.past_key_values[0][0].shape}")
print(f"Compressed cache shape: {output_with_press.past_key_values[0][0].shape}")



# Test decoding phase
print("\nTesting Decoding Phase:")
print("-" * 50)

# Generate with press
output = pipe(
    context, 
    question=question, 
    press=press,
    max_new_tokens=150,
  
)
print("Generated Answer:")
print(output["answer"])
