# Using decoding compression on the AIME25 Math Dataset

This notebook demonstrates how to compress during text generation.
We use `nvidia/OpenMath-Nemotron-7B` to solve math problems from the AIME25 dataset. For each problem, the model generates an answer in a boxed format (e.g., `\boxed{42}`).

To optimize memory usage during long-context generation, the notebook applies key-value cache compression during decoding.
Compression periodically reduces the cache size by keeping only the most relevant tokens, enabling efficient inference without sacrificing answer quality.

In [1]:
import requests

import torch
from transformers import pipeline

from kvpress import (
    ExpectedAttentionPress,
    KnormPress,
    ObservedAttentionPress,
    RandomPress,
    SnapKVPress,
    StreamingLLMPress,
    TOVAPress,
)

from kvpress.presses.decoding_press import DecodingPress

In [2]:
from datasets import load_dataset
dataset = load_dataset("math-ai/aime25")

In [3]:
sample = dataset["test"][0]

In [4]:
sample

{'problem': 'Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$',
 'answer': '70',
 'id': '0'}

In [5]:
device = "cuda:0"
ckpt = "nvidia/OpenMath-Nemotron-7B"
attn_implementation = "flash_attention_2"
pipe = pipeline("kv-press-text-generation", model=ckpt, device=device, dtype="auto", model_kwargs={"attn_implementation":attn_implementation})

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0


In [6]:
from transformers import DynamicCache

In [7]:
cache = DynamicCache()
question = sample["problem"]
true_answer = sample["answer"]
pred_answer = pipe(" ", question=question, press=None, cache=cache, max_new_tokens=16_000)["answer"]

print(f"Question:   {question}")
print(f"Answer:     {true_answer}")
print(f"Prediction: {pred_answer}")
print(f"Cache size: {cache.layers[0].keys.shape}")

Question:   Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$
Answer:     70
Prediction: <think>
Okay, so I need to find all integer bases b greater than 9 where the number 17 in base b divides the number 97 in base b. Then, sum all those bases. Hmm, let me think step by step.

First, I should recall how numbers are represented in different bases. The number 17 in base b is equal to 1 times b plus 7 times 1, right? So that's 1*b + 7. Similarly, 97 in base b would be 9*b + 7. So, translating both numbers to base 10, we have:

17_b = b + 7

97_b = 9b + 7

The problem states that 17_b must divide 97_b. So, in mathematical terms, this means that (9b + 7) divided by (b + 7) should result in an integer. So, (9b + 7) must be divisible by (b + 7). 

Let me write that as an equation:

(9b + 7) รท (b + 7) = integer.

To find when this division results in an integer, maybe I can perform the division and see what the remainder is. If the remainder is zero, then it's 

In [11]:
compression_interval = 1024  # compress every compression_steps
target_size = 512 # number of tokens to keep after compression. Note that actual cache size lies in [target_size, compression_interval]


press = DecodingPress(base_press=ExpectedAttentionPress(), compression_interval=compression_interval, target_size=target_size)

cache = DynamicCache()
question = sample["problem"]
true_answer = sample["answer"]
pred_answer = pipe(" ", question=question, press=press, cache=cache, max_new_tokens=16_000)["answer"]

print(f"Question:   {question}")
print(f"Answer:     {true_answer}")
print(f"Prediction: {pred_answer}")
print(f"Cache size: {cache.layers[0].keys.shape}")

Question:   Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$
Answer:     70
Prediction: <think>
Okay, so I need to find all integer bases b greater than 9 where the number 17 in base b divides the number 97 in base b. Then, sum all those bases. Hmm, let me think step by step.

First, I should recall how numbers are represented in different bases. The number 17 in base b is equal to 1 times b plus 7 times 1, right? So that's 1*b + 7. Similarly, 97 in base b would be 9*b + 7. So, translating both numbers to base 10, we have:

17_b = b + 7

97_b = 9b + 7

The problem states that 17_b must divide 97_b. So, in mathematical terms, this means that (9b + 7) divided by (b + 7) should result in an integer. So, (9b + 7) must be divisible by (b + 7). 

Let me write that as an equation:

(9b + 7) รท (b + 7) = integer.

To find when this division results in an integer, maybe I can perform the division and see what the remainder is. If the remainder is zero, then it's 