In [1]:
import requests

import torch
from transformers import pipeline

from kvpress import (
    ExpectedAttentionPress,
    KnormPress,
    ObservedAttentionPress,
    RandomPress,
    SnapKVPress,
    StreamingLLMPress,
    TOVAPress,
)

In [2]:
from datasets import load_dataset
dataset = load_dataset("math-ai/aime25")

In [3]:
from kvpress.presses.generation.decoding_press import DecodingPress


In [4]:
sample = dataset["test"][0]

In [5]:
sample

{'problem': 'Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$',
 'answer': '70',
 'id': '0'}

In [6]:
device = "cuda:0"
ckpt = "nvidia/OpenMath-Nemotron-7B"
attn_implementation = "sdpa"  # use "eager" for ObservedAttentionPress and "sdpa" if you can't use "flash_attention_2"
pipe = pipeline("kv-press-text-generation", model=ckpt, device=device, torch_dtype="auto", model_kwargs={"attn_implementation":attn_implementation})

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0


In [7]:
from transformers import DynamicCache

In [8]:
%%time
cache = DynamicCache()
question = sample["problem"]
true_answer = sample["answer"]
pred_answer = pipe(" ", question=question, press=None, cache=cache, max_new_tokens=2048)["answer"]

print(f"Question:   {question}")
print(f"Answer:     {true_answer}")
print(f"Prediction: {pred_answer}")
print(f"Cache size: {cache.key_cache[0].shape}")

Question:   Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$
Answer:     70
Prediction: <think>
Okay, so I need to find all integer bases b greater than 9 where the number 17 in base b divides the number 97 in base b. Then, sum all those bases. Hmm, let me think step by step.

First, I should recall how numbers are represented in different bases. The number 17 in base b is equal to 1 times b plus 7 times 1, right? So that's 1*b + 7. Similarly, 97 in base b would be 9*b + 7. So, translating both numbers to base 10, we have:

17_b = b + 7

97_b = 9b + 7

The problem states that 17_b must divide 97_b. So, in mathematical terms, this means that (9b + 7) divided by (b + 7) should result in an integer. So, (9b + 7) must be divisible by (b + 7). 

Let me write that as an equation:

(9b + 7) ÷ (b + 7) = integer.

To find when this division results in an integer, maybe I can perform the division and see what the remainder is. If the remainder is zero, then it's d

In [9]:
%%time
compression_steps = 48  # compress every compression_steps
token_buffer_size = 512 # number of tokens to keep after compression


press = DecodingPress(base_press=ExpectedAttentionPress(), compression_steps=compression_steps, token_buffer_size=token_buffer_size)

cache = DynamicCache()
question = sample["problem"]
true_answer = sample["answer"]
pred_answer = pipe(" ", question=question, press=press, cache=cache, max_new_tokens=2048)["answer"]

print(f"Question:   {question}")
print(f"Answer:     {true_answer}")
print(f"Prediction: {pred_answer}")
print(f"Cache size: {cache.key_cache[0].shape}")

Question:   Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$
Answer:     70
Prediction: <think>
Okay, so I need to find all integer bases b greater than 9 where the number 17 in base b divides the number 97 in base b. Then, sum all those bases. Hmm, let me think step by step.

First, I should recall how numbers are represented in different bases. The number 17 in base b is equal to 1 times b plus 7 times 1, right? So that's 1*b + 7. Similarly, 97 in base b would be 9*b + 7. So, translating both numbers to base 10, we have:

17_b = b + 7

97_b = 9b + 7

The problem states that 17_b must divide 97_b. So, in mathematical terms, this means that (9b + 7) divided by (b + 7) should result in an integer. So, (9b + 7) must be divisible by (b + 7). 

Let me write that as an equation:

(9b + 7) ÷ (b + 7) = integer.

To find when this division results in an integer, maybe I can perform the division and see what the remainder is. If the remainder is zero, then it's d