In [1]:
# log level to debug
import logging
logging.basicConfig(level=logging.DEBUG)


In [2]:
import requests

import torch
from transformers import pipeline

from kvpress import (
    ExpectedAttentionPress,
    KnormPress,
    ObservedAttentionPress,
    RandomPress,
    SnapKVPress,
    StreamingLLMPress,
    TOVAPress,
)

  from .autonotebook import tqdm as notebook_tqdm
INFO:datasets:PyTorch version 2.8.0 available.


In [3]:
from datasets import load_dataset
dataset = load_dataset("math-ai/aime25")

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /api/datasets/math-ai/aime25 HTTP/1.1" 200 860
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): s3.amazonaws.com:443
DEBUG:urllib3.connectionpool:https://s3.amazonaws.com:443 "HEAD /datasets.huggingface.co/datasets/datasets/math-ai/aime25/math-ai/aime25.py HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /api/datasets/math-ai/aime25 HTTP/1.1" 200 860
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /api/datasets/math-ai/aime25/revision/7510e41db03507cb7b25add80804e23c44f4d472 HTTP/1.1" 200 860
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /api/datasets/math-ai/aime25/tree/7510e41db03507cb7b25add80804e23c44f4d472?recursive=False&expand=False HTTP/1.1" 200 295
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): datasets-server.huggingface.co:443
DEBUG:urllib3.connec

In [4]:
from kvpress.presses.generation.decoding_press import DecodingPress

In [5]:
sample = dataset["test"][0]

In [6]:
sample

{'problem': 'Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$',
 'answer': '70',
 'id': '0'}

In [7]:
device = "cuda:0"
ckpt = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
attn_implementation = "sdpa"  # use "eager" for ObservedAttentionPress and "sdpa" if you can't use "flash_attention_2"
pipe = pipeline("kv-press-text-generation", model=ckpt, device=device, torch_dtype="auto", model_kwargs={"attn_implementation":attn_implementation})

DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/config.json HTTP/1.1" 307 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /api/resolve-cache/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B/ad9f0ae0864d7fbcd1cd905e3c6c5b069cc8b562/config.json HTTP/1.1" 200 0


DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/generation_config.json HTTP/1.1" 307 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /api/resolve-cache/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B/ad9f0ae0864d7fbcd1cd905e3c6c5b069cc8b562/generation_config.json HTTP/1.1" 200 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/custom_generate/generate.py HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /api/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B/tree/main/additional_chat_templates?recursive=False&expand=False HTTP/1.1" 404 64
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /api/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B/revision/main HTTP/1.1" 200 7786
Fetching 0 files: 0it [00:00, ?it/s]
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /api/models/deepseek-ai/Deep

In [8]:
from transformers import DynamicCache

In [None]:
%%time
cache = DynamicCache()
question = sample["problem"]
true_answer = sample["answer"]
pred_answer = pipe(" ", question=question, press=None, cache=cache, max_new_tokens=1000)["answer"]

print(f"Question:   {question}")
print(f"Answer:     {true_answer}")
print(f"Prediction: {pred_answer}")
print(f"Cache size: {cache.layers[0].keys.shape}")

DEBUG:kvpress.pipeline:Context Length: 3
DEBUG:kvpress.pipeline:Compressed Context Length: 3


Question:   Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$
Answer:     70
Prediction: Okay, so I have this problem here: I need to find the sum of all integer bases \( b > 9 \) for which \( 17_b \) is a divisor of \( 97_b \). Hmm, okay. Let me try to break this down step by step.

First, I remember that when a number is written in base \( b \), each digit represents a power of \( b \). So, for example, \( 17_b \) would
CPU times: user 1.6 s, sys: 110 ms, total: 1.71 s
Wall time: 1.62 s


In [None]:
%%time
compression_interval = 10  # compress every compression_steps
target_size = 10 # number of tokens to keep after compression


press = DecodingPress(base_press=KnormPress(), compression_interval=compression_interval, target_size=target_size, hidden_states_buffer_size=0)

cache = DynamicCache()
question = sample["problem"]
true_answer = sample["answer"]
pred_answer = pipe(" ", question=question, press=press, cache=cache, max_new_tokens=100)["answer"]

print(f"Question:   {question}")
print(f"Answer:     {true_answer}")
print(f"Prediction: {pred_answer}")
print(f"Cache size: {cache.layers[0].keys.shape}")

DEBUG:kvpress.pipeline:Context Length: 3
DEBUG:kvpress.pipeline:Compressed Context Length: 3
DEBUG:kvpress.presses.generation.decoding_press:Applying decoding compression: layer_step_count=1 >= compression_steps=10
DEBUG:kvpress.presses.generation.decoding_press:Compressing 34 to 10 with ratio 0.7058823529411764
DEBUG:kvpress.presses.generation.decoding_press:Applied decoding compression: keys.shape: torch.Size([1, 2, 10, 128]), values.shape: torch.Size([1, 2, 10, 128])
DEBUG:kvpress.presses.generation.decoding_press:Applying decoding compression: layer_step_count=1 >= compression_steps=10
DEBUG:kvpress.presses.generation.decoding_press:Compressing 34 to 10 with ratio 0.7058823529411764
DEBUG:kvpress.presses.generation.decoding_press:Applied decoding compression: keys.shape: torch.Size([1, 2, 10, 128]), values.shape: torch.Size([1, 2, 10, 128])
DEBUG:kvpress.presses.generation.decoding_press:Applying decoding compression: layer_step_count=1 >= compression_steps=10
DEBUG:kvpress.presses

Question:   Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$
Answer:     70
Prediction: Okay, so I have this problem here: Find the base \( b \) such that when you write down the number \( 11111111111111111111111111111111111111111111111111111111111111111111111111
CPU times: user 2.26 s, sys: 116 ms, total: 2.37 s
Wall time: 2.3 s


Cache size: torch.Size([1, 2, 19, 128])
