In [6]:
import requests

import torch
from transformers import pipeline

from kvpress import (
    ExpectedAttentionPress,
    KnormPress,
    ObservedAttentionPress,
    RandomPress,
    SnapKVPress,
    StreamingLLMPress,
    TOVAPress,
)

In [7]:
from datasets import load_dataset
dataset = load_dataset("math-ai/aime25")
sample = dataset["test"][0]


In [8]:
sample 

{'problem': 'Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$',
 'answer': '70',
 'id': '0'}

In [9]:
device = "cuda:0"
ckpt = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
attn_implementation = "flash_attention_2"  # use "eager" for ObservedAttentionPress and "sdpa" if you can't use "flash_attention_2"
pipe = pipeline("kv-press-text-generation", model=ckpt, device=device, model_kwargs={"attn_implementation":attn_implementation, "dtype": torch.bfloat16})

Fetching 0 files: 0it [00:00, ?it/s]
Fetching 1 files: 100%|██████████| 1/1 [00:00<00:00, 18641.35it/s]
Fetching 0 files: 0it [00:00, ?it/s]
Device set to use cuda:0


In [None]:
%%time
cache = DynamicCache()
question = sample["problem"]
true_answer = sample["answer"]
pred_answer = pipe(" ", question=question, press=None, cache=cache, max_new_tokens=2048)["answer"]

print(f"Question:   {question}")
print(f"Answer:     {true_answer}")
print(f"Prediction: {pred_answer}")
print(f"Cache size: {cache.layers[0].keys.shape}")

Question:   Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$
Answer:     70
Prediction: Okay, so I have this problem here: I need to find the sum of all integer bases \( b > 9 \) for which \( 17_b \) is a divisor of \( 97_b \). Hmm, okay. Let me try to break this down step by step.

First, I know that when a number is written in base \( b \), each digit represents a power of \( b \). So, for example, \( 17_b \) would be equal to \( 1 \times b + 7 \times 1 \) in decimal, right? Similarly, \( 97_b \) would be \( 9 \times b + 7 \times 1 \) in decimal. So, I can convert both of these numbers to base 10 and then set up the division condition.

Let me write that out:

\( 17_b = 1 \times b + 7 = b + 7 \)

\( 97_b = 9 \times b + 7 = 9b + 7 \)

So, the problem is asking for bases \( b > 9 \) such that \( b + 7 \) divides \( 9b + 7 \). In other words, \( b + 7 \) is a divisor of \( 9b + 7 \).

I remember that if \( a \) divides \( b \), then \( b = k \times a \) f

In [12]:
from kvpress.presses.generation.decoding_press import DecodingPress

In [13]:
%%time
compression_interval = 500  # compress every compression_steps
target_size = 500 # number of tokens to keep after compression


press = DecodingPress(base_press=KnormPress(), compression_interval=compression_interval, target_size=target_size, hidden_states_buffer_size=0)

cache = DynamicCache()
question = sample["problem"]
true_answer = sample["answer"]
pred_answer = pipe(" ", question=question, press=press, cache=cache, max_new_tokens=1000)["answer"]

print(f"Question:   {question}")
print(f"Answer:     {true_answer}")
print(f"Prediction: {pred_answer}")
print(f"Cache size: {cache.layers[0].keys.shape}")

Question:   Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$
Answer:     70
Prediction: Okay, so I have this problem here: I need to find the sum of all integer bases \( b > 9 \) for which \( 17_b \) is a divisor of \( 97_b \). Hmm, okay. Let me try to break this down step by step.

First, I know that when a number is written in base \( b \), each digit represents a power of \( b \). So, for example, \( 17_b \) would be equal to \( 1 \times b + 7 \times 1 \) in decimal, right? Similarly, \( 97_b \) would be \( 9 \times b + 7 \times 1 \) in decimal. So, I can convert both of these numbers to base 10 and then set up the division condition.

Let me write that out:

\( 17_b = 1 \times b + 7 = b + 7 \)

\( 97_b = 9 \times b + 7 = 9b + 7 \)

So, the problem is asking for bases \( b > 9 \) such that \( b + 7 \) divides \( 9b + 7 \). In other words, \( b + 7 \) is a divisor of \( 9b + 7 \).

I remember that if \( a \) divides \( b \), then \( b = k \times a \) f