From 0a3134098b4d1a2486cbd7bfd180aa64f94a3961 Mon Sep 17 00:00:00 2001 From: Max Jeblick Date: Tue, 7 Oct 2025 21:15:41 +0200 Subject: [PATCH 01/20] Add decoding press functionality This commit adds comprehensive decoding press functionality including: - New DecodingPress and PrefillDecodingPress classes - Updates to pipeline for decoding support - Modified existing presses to support decoding compression - New test files and documentation - Enhanced base press with decoding capabilities Signed-off-by: Max Jeblick Signed-off-by: Max Jeblick --- README.md | 130 ++++++- kvpress/__init__.py | 142 ++++---- kvpress/pipeline.py | 78 ++-- kvpress/presses/base_press.py | 16 +- kvpress/presses/decoding_press.py | 224 ++++++++++++ kvpress/presses/duo_attention_press.py | 2 +- kvpress/presses/finch_press.py | 2 +- kvpress/presses/key_rerotation_press.py | 2 +- kvpress/presses/kvzip_press.py | 15 +- kvpress/presses/prefill_decoding_press.py | 86 +++++ kvpress/presses/pyramidkv_press.py | 2 +- kvpress/presses/scorer_press.py | 2 +- kvpress/presses/simlayerkv_press.py | 2 +- kvpress/presses/snapkv_press.py | 11 +- kvpress/presses/streaming_llm_press.py | 2 +- kvpress/presses/utils.py | 20 ++ notebooks/kvpress_decoding_aime25.ipynb | 417 ++++++++++++++++++++++ pyproject.toml | 2 +- tests/fixtures.py | 4 +- tests/test_decoding_compression.py | 305 ++++++++++++++++ tests/test_pipeline.py | 2 + 21 files changed, 1321 insertions(+), 145 deletions(-) create mode 100644 kvpress/presses/decoding_press.py create mode 100644 kvpress/presses/prefill_decoding_press.py create mode 100644 notebooks/kvpress_decoding_aime25.ipynb create mode 100644 tests/test_decoding_compression.py diff --git a/README.md b/README.md index d47c63a2..b7d53e75 100644 --- a/README.md +++ b/README.md @@ -67,15 +67,7 @@ answer = pipe(context, question=question, press=press)["answer"] In the snippet above, the compression is only applied on the context tokens so that you can evaluate the compression for different questions. Check the [Wikipedia notebook demo](notebooks/wikipedia_demo.ipynb) for a more detailed example (also available on Colab [here](https://colab.research.google.com/drive/1JNvaTKuuAHrl49dYB9-mdEH_y52Ib-NP)). -> [!IMPORTANT] -> We focus on compression during the pre-filling phase as the KV cache becomes a bottleneck for long-context sequence (100k - 1M tokens) which are essentially long context prompts. This would typically apply to improving prompt caching systems. - -> [!NOTE] -> Use `model_kwargs={"attn_implementation":"flash_attention_2"}` to enable flash attention. To use the press `ObservedAttentionPress`, you need to specify `model_kwargs={"attn_implementation":"eager"}` as this press requires to materialize the attention weights - -## Contributing -We welcome contributions! To add a new press, simply open an issue or submit a pull request. Check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide. ## Available presses @@ -131,6 +123,48 @@ Below we report the average performance on the RULER dataset with 4k context len Leaderboard

+
+## Decoding Compression + +By default, KVPress applies compression during the pre-filling phase. As a new (experimental) feature, we now support decoding compression via the `DecodingPress` wrapper. `DecodingPress` compresses the KV cache periodically during token generation, optionally maintaining a buffer of recent hidden states. `DecodingPress` supports the following parameters: + +- `base_press`: Any ScorerPress (e.g., `KNormPress`, `CriticalKVPress`) +- `compression_interval`: Steps between compressions (default: 10) +- `target_size`: Target cache size of the cache after compression (default: 1024) +- `hidden_states_buffer_size`: Number of hidden states to buffer before compression (default: 128). Some presses don't need buffered hidden states and can set this to 0. + +Unlike a compression ratio, decoding press uses a `target_size` to compress the cache. This means that the cache is compressed every `compression_interval` steps, and the compression ratio is automatically computed such that the size of the cache after compression equals `target_size`. + +An example for decoding compression: + +```python +from transformers import pipeline +from kvpress import KnormPress +from kvpress import DecodingPress + +# Initialize the pipeline +device = "cuda:0" +model = "meta-llama/Llama-3.1-8B-Instruct" +model_kwargs = {"attn_implementation": "flash_attention_2"} +pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs) + +# Create a decoding press that compresses every 10 steps to 512 tokens +decoding_press = DecodingPress( + base_press=KnormPress(), + compression_steps=10, + token_buffer_size=512 +) + +# Use with pipeline +context = "A very long text you want to compress during generation" +question = "Tell me a long story about this context" +response = pipe(context, question=question, press=decoding_press)["answer"] +``` + +> Not all existing presses are fully compatible with DecodingPress due to fundamental differences in how compression works during decoding versus prefilling. + +
+ ## Quantization We support KV cache quantization through the transformers `QuantizedCache` class (see [HF blog post](https://huggingface.co/blog/kv-cache-quantization#how-to-use-quantized-kv-cache-in-%F0%9F%A4%97-transformers)). To use it, simply pass a cache object to your pipeline: @@ -149,6 +183,10 @@ By default, the `DynamicCache` is used (no quantization). > [!IMPORTANT] > To use the `QuantizedCache`, you need to install additional dependencies (_e.g._ `pip install optimum-quanto`). +## Contributing + +We welcome contributions! To add a new press, simply open an issue or submit a pull request. Check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide. + ## FAQ
@@ -224,4 +262,80 @@ with press(model): However, the `generate` method does not allow to exclude the question from the compression, which would artificially favors methods such as SnapKV. Ideally, we want a compression method that works whatever comes after the context (_e.g._ for use cases such as chat or document question answering). Finally the `generate` method does not allow to provide generation for multiple questions at once. +
+ + + +
+ +### Can I combine compression during prefilling and decoding ? + + + +Combines separate presses for prefilling and decoding phases. + +**Parameters:** +- `prefilling_press`: Press used during prefill phase +- `decoding_press`: Press used during decoding phase + +## Usage Examples + +### Basic Decoding Compression + +```python +from transformers import pipeline +from kvpress import KnormPress +from kvpress import DecodingPress + +# Initialize the pipeline +device = "cuda:0" +model = "meta-llama/Llama-3.1-8B-Instruct" +model_kwargs = {"attn_implementation": "flash_attention_2"} +pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs) + +# Create a decoding press that compresses every 10 steps to 512 tokens +decoding_press = DecodingPress( + base_press=KnormPress(), + compression_steps=10, + token_buffer_size=512 +) + +# Use with pipeline +context = "A very long text you want to compress during generation" +question = "Tell me a long story about this context" +response = pipe(context, question=question, press=decoding_press)["answer"] +``` + +### Combined Prefill + Decoding Compression + +```python +from transformers import pipeline +from kvpress import CriticalKVPress, KnormPress +from kvpress import DecodingPress, PrefillDecodingPress + +# Initialize the pipeline +device = "cuda:0" +model = "meta-llama/Llama-3.1-8B-Instruct" +model_kwargs = {"attn_implementation": "flash_attention_2"} +pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs) + +# Different strategies for prefill vs decoding +prefill_press = CriticalKVPress(KnormPress()) +decoding_press = DecodingPress( + base_press=KnormPress(compression_ratio=0.2), + compression_steps=5, + token_buffer_size=256 +) + +# Combine them +combined_press = PrefillDecodingPress( + prefilling_press=prefill_press, + decoding_press=decoding_press +) + +context = "A very long context that will be compressed during prefill" +question = "Generate a detailed analysis that will be compressed during decoding" +response = pipe(context, question=question, press=combined_press)["answer"] +``` +
\ No newline at end of file diff --git a/kvpress/__init__.py b/kvpress/__init__.py index cccd1f02..19165465 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -1,68 +1,74 @@ -# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - - -from kvpress.attention_patch import patch_attention_functions -from kvpress.pipeline import KVPressTextGenerationPipeline -from kvpress.presses.adakv_press import AdaKVPress -from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress -from kvpress.presses.block_press import BlockPress -from kvpress.presses.chunk_press import ChunkPress -from kvpress.presses.chunkkv_press import ChunkKVPress -from kvpress.presses.composed_press import ComposedPress -from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress -from kvpress.presses.duo_attention_press import DuoAttentionPress -from kvpress.presses.expected_attention_press import ExpectedAttentionPress -from kvpress.presses.expected_attention_with_stats import ExpectedAttentionStatsPress -from kvpress.presses.finch_press import FinchPress -from kvpress.presses.key_rerotation_press import KeyRerotationPress -from kvpress.presses.keydiff_press import KeyDiffPress -from kvpress.presses.knorm_press import KnormPress -from kvpress.presses.kvzip_press import KVzipPress -from kvpress.presses.lagkv_press import LagKVPress -from kvpress.presses.observed_attention_press import ObservedAttentionPress -from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress -from kvpress.presses.pyramidkv_press import PyramidKVPress -from kvpress.presses.qfilter_press import QFilterPress -from kvpress.presses.random_press import RandomPress -from kvpress.presses.scorer_press import ScorerPress -from kvpress.presses.simlayerkv_press import SimLayerKVPress -from kvpress.presses.snapkv_press import SnapKVPress -from kvpress.presses.streaming_llm_press import StreamingLLMPress -from kvpress.presses.think_press import ThinKPress -from kvpress.presses.tova_press import TOVAPress - -# Patch the attention functions to support head-wise compression -patch_attention_functions() - -__all__ = [ - "CriticalAdaKVPress", - "CriticalKVPress", - "AdaKVPress", - "BasePress", - "ComposedPress", - "ScorerPress", - "ExpectedAttentionPress", - "KnormPress", - "ObservedAttentionPress", - "RandomPress", - "SimLayerKVPress", - "SnapKVPress", - "StreamingLLMPress", - "ThinKPress", - "TOVAPress", - "KVPressTextGenerationPipeline", - "PerLayerCompressionPress", - "KeyRerotationPress", - "ChunkPress", - "DuoAttentionPress", - "ChunkKVPress", - "QFilterPress", - "PyramidKVPress", - "FinchPress", - "LagKVPress", - "BlockPress", - "KeyDiffPress", - "KVzipPress", - "ExpectedAttentionStatsPress", -] +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from kvpress.attention_patch import patch_attention_functions +from kvpress.pipeline import KVPressTextGenerationPipeline +from kvpress.presses.adakv_press import AdaKVPress +from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress +from kvpress.presses.block_press import BlockPress +from kvpress.presses.chunk_press import ChunkPress +from kvpress.presses.chunkkv_press import ChunkKVPress +from kvpress.presses.composed_press import ComposedPress +from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress +from kvpress.presses.decoding_press import DecodingPress +from kvpress.presses.duo_attention_press import DuoAttentionPress +from kvpress.presses.expected_attention_press import ExpectedAttentionPress +from kvpress.presses.expected_attention_with_stats import ExpectedAttentionStatsPress +from kvpress.presses.finch_press import FinchPress +from kvpress.presses.key_rerotation_press import KeyRerotationPress +from kvpress.presses.keydiff_press import KeyDiffPress +from kvpress.presses.knorm_press import KnormPress +from kvpress.presses.kvzip_press import KVzipPress +from kvpress.presses.lagkv_press import LagKVPress +from kvpress.presses.observed_attention_press import ObservedAttentionPress +from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress +from kvpress.presses.prefill_decoding_press import PrefillDecodingPress +from kvpress.presses.pyramidkv_press import PyramidKVPress +from kvpress.presses.qfilter_press import QFilterPress +from kvpress.presses.random_press import RandomPress +from kvpress.presses.scorer_press import ScorerPress +from kvpress.presses.simlayerkv_press import SimLayerKVPress +from kvpress.presses.snapkv_press import SnapKVPress +from kvpress.presses.streaming_llm_press import StreamingLLMPress +from kvpress.presses.think_press import ThinKPress +from kvpress.presses.tova_press import TOVAPress + +# Patch the attention functions to support head-wise compression +patch_attention_functions() + +__all__ = [ + "CriticalAdaKVPress", + "CriticalKVPress", + "AdaKVPress", + "BasePress", + "ComposedPress", + "ScorerPress", + "ExpectedAttentionPress", + "KnormPress", + "ObservedAttentionPress", + "RandomPress", + "SimLayerKVPress", + "SnapKVPress", + "StreamingLLMPress", + "ThinKPress", + "TOVAPress", + "KVPressTextGenerationPipeline", + "PerLayerCompressionPress", + "KeyRerotationPress", + "ChunkPress", + "DuoAttentionPress", + "ChunkKVPress", + "QFilterPress", + "PyramidKVPress", + "FinchPress", + "LagKVPress", + "BlockPress", + "KeyDiffPress", + "KVzipPress", + "DecodingPress", + "PrefillDecodingPress", + "ExpectedAttentionStatsPress", + "DecodingPress", + "PrefillDecodingPress", +] diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index 55eb506a..76f345d3 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -12,9 +12,11 @@ from transformers.pipelines.base import GenericTensor from kvpress.presses.base_press import BasePress +from kvpress.presses.decoding_press import DecodingPress from kvpress.presses.finch_press import FinchPress from kvpress.presses.key_rerotation_press import KeyRerotationPress from kvpress.presses.observed_attention_press import ObservedAttentionPress +from kvpress.presses.prefill_decoding_press import PrefillDecodingPress logger = logging.getLogger(__name__) @@ -189,6 +191,11 @@ def _forward( list[str] Generated answers for each input question. """ + if isinstance(press, (DecodingPress, PrefillDecodingPress)) and len(input_tensors["questions_ids"]) > 1: + raise ValueError( + "DecodingPress is not compatible with multiple questions. Please specify a single question." + ) + context_ids = input_tensors["context_ids"].to(self.model.device) context_length = context_ids.shape[1] @@ -196,7 +203,9 @@ def _forward( if cache is None: cache = DynamicCache() - with press(self.model) if press is not None else contextlib.nullcontext(): + # We only perform prefill compression if the press is not a decoding or prefill decoding press + perform_prefill_compression = press is not None and not isinstance(press, DecodingPress) + with press(self.model) if perform_prefill_compression else contextlib.nullcontext(): # We run the model without the lm head for pre-filling. self.model.model( input_ids=context_ids, @@ -204,24 +213,46 @@ def _forward( output_attentions=self.output_attentions(press), ) - logger.debug(f"Context Length: {context_length}") - logger.debug(f"Compressed Context Length: {cache.get_seq_length()}") + logger.debug(f"Context Length: {context_length}") + logger.debug(f"Compressed Context Length: {cache.get_seq_length()}") + + # We only perform decoding compression if the press is a decoding or prefill decoding press + perform_decoding_compression = press is not None and isinstance(press, (DecodingPress, PrefillDecodingPress)) + with press(self.model) if perform_decoding_compression else contextlib.nullcontext(): + # Greedy decoding for each question + answers = [] + for question_ids in input_tensors["questions_ids"]: + if isinstance(press, KeyRerotationPress) or (isinstance(press, FinchPress) and press.rerotate_keys): + context_length = cache.get_seq_length() + + cache_seq_lengths = [cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))] + answer = self.generate_answer( + question_ids=question_ids.to(self.model.device), + cache=cache, + context_length=context_length, + max_new_tokens=max_new_tokens, + ) + if len(input_tensors["questions_ids"]) > 1: + print(f"Removing answer from cache: {cache_seq_lengths}") + self._remove_answer_from_cache(cache, cache_seq_lengths) + + answers.append(answer) + return answers - # Greedy decoding for each question - answers = [] - for question_ids in input_tensors["questions_ids"]: - if isinstance(press, KeyRerotationPress) or (isinstance(press, FinchPress) and press.rerotate_keys): - context_length = cache.get_seq_length() + def _remove_answer_from_cache(self, cache: Cache, cache_seq_lengths: list[int]): - answer = self.generate_answer( - question_ids=question_ids.to(self.model.device), - cache=cache, - context_length=context_length, - max_new_tokens=max_new_tokens, - ) - answers.append(answer) + for layer_idx, sequence_length in enumerate(cache_seq_lengths): + cache.layers[layer_idx].keys = cache.layers[layer_idx].keys[:, :, :sequence_length] + cache.layers[layer_idx].values = cache.layers[layer_idx].values[:, :, :sequence_length] - return answers + if isinstance(cache, QuantizedCache): + for layer_idx, sequence_length in enumerate(cache_seq_lengths): + cache.layers[layer_idx]._quantized_keys = cache.layers[layer_idx]._quantized_keys[ + :, :, :sequence_length + ] + cache.layers[layer_idx]._quantized_values = cache.layers[layer_idx]._quantized_values[ + :, :, :sequence_length + ] def generate_answer( self, question_ids: torch.Tensor, cache: Cache, context_length: int, max_new_tokens: int @@ -245,7 +276,6 @@ def generate_answer( str The generated answer. """ - cache_seq_lengths = [cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))] position_ids = torch.arange( context_length, context_length + question_ids.shape[1], device=self.model.device ).unsqueeze(0) @@ -276,20 +306,6 @@ def generate_answer( if new_id.item() in should_stop_token_ids: break answer = self.tokenizer.decode(torch.stack(generated_ids), skip_special_tokens=True) - # Remove the generated tokens from the cache - for layer_idx, sequence_length in enumerate(cache_seq_lengths): - cache.layers[layer_idx].keys = cache.layers[layer_idx].keys[:, :, :sequence_length] - cache.layers[layer_idx].values = cache.layers[layer_idx].values[:, :, :sequence_length] - - if isinstance(cache, QuantizedCache): - for layer_idx, sequence_length in enumerate(cache_seq_lengths): - cache.layers[layer_idx]._quantized_keys = cache.layers[layer_idx]._quantized_keys[ - :, :, :sequence_length - ] - cache.layers[layer_idx]._quantized_values = cache.layers[layer_idx]._quantized_values[ - :, :, :sequence_length - ] - return answer def output_attentions(self, press: BasePress): diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index 7d40fcd7..8fb4acf2 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -20,6 +20,8 @@ Qwen3ForCausalLM, ) +from kvpress.presses.utils import extract_keys_and_values + logger = logging.getLogger(__name__) SUPPORTED_MODELS = ( @@ -124,24 +126,14 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic hidden_states = kwargs["hidden_states"] cache = kwargs["past_key_values"] + cache_layer = cache.layers[module.layer_idx] q_len = hidden_states.shape[1] # Don't compress after pre-filling if kwargs["cache_position"][-1] > q_len: return output - cache_layer = cache.layers[module.layer_idx] - if isinstance(cache, QuantizedCache): - keys = cache_layer._dequantize( # type: ignore[index] - cache_layer._quantized_keys # type: ignore[index] - ) - values = cache_layer._dequantize( # type: ignore[index] - cache_layer._quantized_values # type: ignore[index] - ) - - else: - keys = cache_layer.keys - values = cache_layer.values + keys, values = extract_keys_and_values(cache, module.layer_idx) keys, values = self.compress(module, hidden_states, keys, values, output[1], kwargs) diff --git a/kvpress/presses/decoding_press.py b/kvpress/presses/decoding_press.py new file mode 100644 index 00000000..9e5b1a6c --- /dev/null +++ b/kvpress/presses/decoding_press.py @@ -0,0 +1,224 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections import defaultdict +from dataclasses import dataclass + +import torch +import torch.nn as nn +from transformers.cache_utils import QuantizedCache + +from kvpress.presses.base_press import BasePress +from kvpress.presses.scorer_press import ScorerPress +from kvpress.presses.utils import extract_keys_and_values + +logger = logging.getLogger(__name__) + + +@dataclass +class DecodingPress(BasePress): + """ + A press that only operates during decoding phase and maintains a running buffer of hidden states. + + This press accumulates hidden states during decoding and applies compression every N steps + using a scorer press to determine which tokens to keep. + + + Parameters + ---------- + base_press : ScorerPress + The scorer press used to compute importance scores for tokens. + compression_interval : int, default=10 + Number of decoding steps between compression, i.e. compression will be applied every compression_interval steps. + target_size : int, default=1024 + Target number of tokens to keep after compression. + hidden_states_buffer_size : int, default=128 + Maximum number of hidden states to keep before compression. Larger values use more GPU memory. + Note: Some presses don't need buffered hidden states and can set this to 0 to use only the + current hidden state for compression scoring. + """ + + base_press: ScorerPress + compression_interval: int = 128 + target_size: int = 1024 + hidden_states_buffer_size: int = 128 + + def __post_init__(self): + # Buffer to store hidden states during decoding (per layer) + assert isinstance(self.base_press, ScorerPress), "DecodingPress requires a ScorerPress as input" + self.hidden_states_buffer = defaultdict(list) # Per-layer buffer + self.layer_step_counts = defaultdict(int) # Track step count per layer + + # Warn if compression happens before buffer is fully utilized + # TODO: would it make sense to not reset the buffer? + if self.hidden_states_buffer_size > 0 and self.compression_interval < self.hidden_states_buffer_size: + logger.warning( + f"compression_interval ({self.compression_interval}) < hidden_states_buffer_size ({self.hidden_states_buffer_size}). " # noqa: E501 + f"Buffer will be reset before reaching full capacity, potentially reducing compression quality." + ) + + assert self.compression_interval > 0, "compression_interval must be greater than 0" + assert self.target_size > 0, "target_size must be greater than 0" + + if self.base_press.compression_ratio: + logger.warning( + f"compression_ratio is set for base press ({self.base_press.compression_ratio}). " + f"This will be overridden by the decoding press." + ) + + def compress( + self, + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attentions: torch.Tensor, + kwargs: dict, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Delegate compression to the base press during decoding phase. + + Args: + module: The transformer module being compressed + hidden_states: Buffered hidden states from recent decoding steps (shape: [batch, buffer_len, hidden_dim]) + keys: Key cache from all previous steps including current (shape: [batch, n_heads, seq_len, head_dim]) + values: Value cache from all previous steps including current (shape: [batch, n_heads, seq_len, head_dim]) + attentions: Attention weights (shape varies by implementation) + kwargs: Additional keyword arguments + + Returns: + tuple[torch.Tensor, torch.Tensor]: Compressed (keys, values) tensors + + Note: + **Sequence length alignment**: During decoding compression, `hidden_states` contains the + buffered hidden states from recent decoding steps (buffer_len tokens), while `keys` and + `values` contain the full sequence history (seq_len tokens). The base press implementation + should use keys.shape[2] for full sequence length calculations. The buffered hidden_states + provide context for the most recent tokens when computing compression scores. + + Performance Note: + It would be possible to speed up compression during decoding for certain scorer presses by + storing existing scores in a buffer (e.g. KNormPress) and reusing them in subsequent compressions. + """ + q_len = keys.shape[2] + target_compression_ratio = self._find_target_compression_ratio(q_len, self.target_size) + logger.debug(f"Compressing {q_len} to {self.target_size} with ratio {target_compression_ratio}") + + original_compression_ratio = self.base_press.compression_ratio + self.base_press.compression_ratio = target_compression_ratio + result = self.base_press.compress(module, hidden_states, keys, values, attentions, kwargs) + self.base_press.compression_ratio = original_compression_ratio + return result + + def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): + """ + Forward hook that manages decoding-specific compression logic. + + This hook: + 1. Detects when we're in decoding phase (not prefilling) + 2. Accumulates hidden states in a buffer + 3. Applies compression every N steps + 4. Clears the buffer after compression + """ + hidden_states = kwargs["hidden_states"] + cache = kwargs["past_key_values"] + q_len = hidden_states.shape[1] + layer_idx = module.layer_idx + + # Only operate during decoding phase (after prefilling) + if kwargs["cache_position"][-1] <= q_len: + # We're still in prefilling phase, don't do anything + return output + # print(f"Adding hidden states to buffer: {hidden_states.shape}") + # Add current hidden states to buffer for this layer + self.hidden_states_buffer[layer_idx].append(hidden_states.detach().clone()) + + # print(f"Layer step counts: {self.layer_step_counts[layer_idx]}") + self.layer_step_counts[layer_idx] += 1 + + # Apply compression if we've reached the compression step threshold + if (self.layer_step_counts[layer_idx] >= self.compression_interval) or (q_len >= self.target_size): + logger.debug( + f"Applying decoding compression: layer_step_count ({self.layer_step_counts[layer_idx]}) >= compression_steps ({self.compression_interval})" # noqa: E501 + ) + + cache_layer = cache.layers[module.layer_idx] + keys, values = extract_keys_and_values(cache, module.layer_idx) + + # Get attention weights from output + attentions = output[1] if len(output) > 1 and output[1] is not None else None + + # Apply compression using buffered hidden states for this layer + buffered_hidden_states = torch.cat(self.hidden_states_buffer[layer_idx], dim=1) + keys, values = self.compress(module, buffered_hidden_states, keys, values, attentions, kwargs) + logger.debug(f"Applied decoding compression: " f"keys.shape: {keys.shape}, values.shape: {values.shape}") + + # Update cache with compressed keys and values + if isinstance(cache, QuantizedCache): + cache_layer._quantized_keys = cache_layer._quantize(keys, axis=cache_layer.axis_key) + cache_layer._quantized_values = cache_layer._quantize(values, axis=cache_layer.axis_value) + cache_layer.keys = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index] + cache_layer.values = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index] + cache_layer.cumulative_length = keys.shape[2] + else: + cache_layer.keys = keys + cache_layer.values = values + + # Reset step count and clear buffer for this layer + self.layer_step_counts[layer_idx] = 0 + # Always clear the buffer after compression - otherwise there's a mismatch between + # hidden states buffer and kv cache + self.hidden_states_buffer[layer_idx] = [] + + self.hidden_states_buffer[layer_idx] = self.hidden_states_buffer[layer_idx][-self.hidden_states_buffer_size :] + return output + + def reset(self): + """Reset the decoding press state.""" + self.hidden_states_buffer = defaultdict(list) + self.layer_step_counts = defaultdict(int) + + def _find_target_compression_ratio(self, q_len: int, target_tokens: int) -> float: + """ + Find the compression ratio that results in exactly target_tokens after int() rounding. + + Args: + q_len: Current sequence length + target_tokens: Desired number of tokens after compression + + Returns: + Compression ratio that gives exactly target_tokens + """ + if q_len <= target_tokens: + return 0.0 + + # Start with theoretical ratio + ratio = 1.0 - (target_tokens / q_len) + + # Binary search to handle int() rounding + low, high = 0.0, 1.0 + max_iterations = 20 + iteration = 0 + + while iteration < max_iterations: + n_kept = int(q_len * (1 - ratio)) + if n_kept == target_tokens: + break + elif n_kept > target_tokens: + # Need more compression + low = ratio + ratio = (ratio + high) / 2 + else: + # Need less compression + high = ratio + ratio = (low + ratio) / 2 + iteration += 1 + + final_n_kept = int(q_len * (1 - ratio)) + if final_n_kept != target_tokens: + logger.warning( + f"Binary search failed: q_len={q_len}, target={target_tokens}, got={final_n_kept}, ratio={ratio}" + ) + + return ratio diff --git a/kvpress/presses/duo_attention_press.py b/kvpress/presses/duo_attention_press.py index b88d2f03..24c5b17a 100644 --- a/kvpress/presses/duo_attention_press.py +++ b/kvpress/presses/duo_attention_press.py @@ -103,7 +103,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): raise ValueError( "Streaming mask not initialized. Make sure to call __post_init_from_model__ to initialize this press." ) - q_len = hidden_states.shape[1] + q_len = keys.shape[2] if (self.head_compression_ratio > 0) or (q_len > (self.sink_size + self.recent_size)): diff --git a/kvpress/presses/finch_press.py b/kvpress/presses/finch_press.py index e22952c3..39a8cd48 100644 --- a/kvpress/presses/finch_press.py +++ b/kvpress/presses/finch_press.py @@ -95,7 +95,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): scores = self.score(module, hidden_states, keys, values, attentions, kwargs) # Compute indices to keep (optionally by chunks) - q_len = hidden_states.shape[1] + q_len = keys.shape[2] # Use actual sequence length from keys instead of hidden_states if self.chunk_length is None: n_kept = int(q_len * (1 - self.compression_ratio)) indices = scores.topk(n_kept, dim=-1).indices diff --git a/kvpress/presses/key_rerotation_press.py b/kvpress/presses/key_rerotation_press.py index 129072f8..a264b3d4 100644 --- a/kvpress/presses/key_rerotation_press.py +++ b/kvpress/presses/key_rerotation_press.py @@ -137,7 +137,7 @@ def compress( scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs) # Get indices of KV pairs with the lowest scores - q_len = hidden_states.shape[1] + q_len = keys.shape[2] n_kept = int(q_len * (1 - self.press.compression_ratio)) indices = scores.topk(n_kept, dim=-1).indices indices = torch.sort(indices, dim=2).values diff --git a/kvpress/presses/kvzip_press.py b/kvpress/presses/kvzip_press.py index 59b565e7..0f13e68a 100644 --- a/kvpress/presses/kvzip_press.py +++ b/kvpress/presses/kvzip_press.py @@ -14,7 +14,7 @@ from transformers.models.llama.modeling_llama import rotate_half from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress -from kvpress.presses.utils import get_query_states +from kvpress.presses.utils import extract_keys_and_values, get_query_states logger = logging.getLogger(__name__) @@ -153,25 +153,16 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic hidden_states = kwargs["hidden_states"] cache = kwargs.get("past_key_values", None) or kwargs.get("past_key_value", None) - cache_layer = cache.layers[module.layer_idx] - if isinstance(cache, QuantizedCache): - keys = cache_layer._dequantize( # type: ignore[index] - cache_layer._quantized_keys # type: ignore[index] - ) - values = cache_layer._dequantize( # type: ignore[index] - cache_layer._quantized_values # type: ignore[index] - ) - else: - keys = cache_layer.keys - values = cache_layer.values + keys, values = extract_keys_and_values(cache, module.layer_idx) # Compute importance scores for KV pairs in the prefilled context, # retaining only the originally prefilled KV pairs. keys, values = self.score_kvzip(module, hidden_states, keys, values, output[1], kwargs) if isinstance(cache, QuantizedCache): + # Update cache with compressed keys and values cache_layer._quantized_keys = cache_layer._quantize(keys, axis=cache_layer.axis_key) cache_layer._quantized_values = cache_layer._quantize(values, axis=cache_layer.axis_value) cache_layer.keys = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index] diff --git a/kvpress/presses/prefill_decoding_press.py b/kvpress/presses/prefill_decoding_press.py new file mode 100644 index 00000000..85b5d465 --- /dev/null +++ b/kvpress/presses/prefill_decoding_press.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PreTrainedModel + +from kvpress.presses.base_press import BasePress + +from .decoding_press import DecodingPress + +logger = logging.getLogger(__name__) + + +@dataclass +class PrefillDecodingPress(BasePress): + """ + A wrapper press that combines separate prefilling and decoding compression strategies. + + This press acts as a single press interface but internally delegates to different + presses based on the current phase (prefilling vs decoding). During prefilling, + it uses the prefilling_press. During decoding, it uses the decoding_press. + + Parameters + ---------- + prefilling_press : BasePress, optional + Press to use during the prefilling phase. If None, no compression is applied during prefilling. + decoding_press : DecodingPress, optional + Press to use during the decoding phase. If None, no compression is applied during decoding. + """ + + prefilling_press: Optional[BasePress] = None + decoding_press: Optional[DecodingPress] = None + + def compress( + self, + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attentions: torch.Tensor, + kwargs: dict, + ) -> tuple[torch.Tensor, torch.Tensor]: + q_len = hidden_states.shape[1] + + # Determine if we're in prefilling or decoding phase + if kwargs["cache_position"][-1] <= q_len and self.prefilling_press is not None: + return self.prefilling_press.compress(module, hidden_states, keys, values, attentions, kwargs) + elif self.decoding_press is not None: + return self.decoding_press.compress(module, hidden_states, keys, values, attentions, kwargs) + + # No compression applied + logger.warning("No compression applied during prefill or decoding phase") + + return keys, values + + def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): + """ + Forward hook that delegates to the appropriate press based on current phase. + """ + hidden_states = kwargs["hidden_states"] + q_len = hidden_states.shape[1] + + # Determine if we're in prefilling or decoding phase + if kwargs["cache_position"][-1] <= q_len and self.prefilling_press is not None: + return self.prefilling_press.forward_hook(module, input, kwargs, output) + elif self.decoding_press is not None: + return self.decoding_press.forward_hook(module, input, kwargs, output) + + # No hook applied + return output + + @contextmanager + def __call__(self, model: PreTrainedModel): + try: + with super().__call__(model): + yield + finally: + # Reset decoding press if it exists + if self.decoding_press is not None: + self.decoding_press.reset() diff --git a/kvpress/presses/pyramidkv_press.py b/kvpress/presses/pyramidkv_press.py index 14b6c03b..cc0a0106 100644 --- a/kvpress/presses/pyramidkv_press.py +++ b/kvpress/presses/pyramidkv_press.py @@ -100,7 +100,7 @@ def compress( scores = self.score(module, hidden_states, keys, values, attentions, kwargs) # Get indices of KV pairs with the lowest scores - q_len = hidden_states.shape[1] + q_len = keys.shape[2] n_kept = self.get_layer_budget(module, q_len) indices = scores.topk(n_kept, dim=-1).indices indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) diff --git a/kvpress/presses/scorer_press.py b/kvpress/presses/scorer_press.py index 5c84b379..1cbcb748 100644 --- a/kvpress/presses/scorer_press.py +++ b/kvpress/presses/scorer_press.py @@ -90,7 +90,7 @@ def compress( scores = self.score(module, hidden_states, keys, values, attentions, kwargs) # Get indices of KV pairs with the lowest scores - q_len = hidden_states.shape[1] + q_len = keys.shape[2] n_kept = int(q_len * (1 - self.compression_ratio)) indices = scores.topk(n_kept, dim=-1).indices indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) diff --git a/kvpress/presses/simlayerkv_press.py b/kvpress/presses/simlayerkv_press.py index 86086e7c..48af6bc4 100644 --- a/kvpress/presses/simlayerkv_press.py +++ b/kvpress/presses/simlayerkv_press.py @@ -94,7 +94,7 @@ def compress( self.compression_ratios = [] # Check if compression is needed - q_len = hidden_states.shape[1] + q_len = keys.shape[2] min_length = self.n_initial + self.n_recent + self.n_last if q_len <= min_length: diff --git a/kvpress/presses/snapkv_press.py b/kvpress/presses/snapkv_press.py index 68f92af9..549f3865 100644 --- a/kvpress/presses/snapkv_press.py +++ b/kvpress/presses/snapkv_press.py @@ -44,7 +44,7 @@ def compute_window_attention(module, hidden_states, keys, window_size, position_ Compute the last window_size queries and associated attention weights for the first q_len - window_size keys. """ - bsz, q_len, _ = hidden_states.shape + bsz, _, q_len, _ = keys.shape num_heads = module.config.num_attention_heads head_dim = module.head_dim num_key_value_groups = num_heads // module.config.num_key_value_heads @@ -78,10 +78,13 @@ def score( kwargs, ) -> torch.Tensor: - bsz, num_key_value_heads, q_len, _ = keys.shape + bsz, num_key_value_heads, k_len, _ = keys.shape num_key_value_groups = module.config.num_attention_heads // num_key_value_heads - assert q_len > self.window_size, "Query length should be greater than the window size" + q_len = hidden_states.shape[1] + assert ( + q_len > self.window_size + ), f"Query length {q_len} should be greater than the window size {self.window_size}" if attentions is not None: attn_weights = attentions[..., -self.window_size :, : -self.window_size] @@ -94,7 +97,7 @@ def score( scores = F.avg_pool1d(scores, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1) # Average per group (https://github.com/FasterDecoding/SnapKV/issues/22) - scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len - self.window_size) + scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, k_len - self.window_size) scores = scores.mean(2) # Add back the observation window. Use max score to make sure the window is not pruned. diff --git a/kvpress/presses/streaming_llm_press.py b/kvpress/presses/streaming_llm_press.py index 4c0ac218..66c9ecf8 100644 --- a/kvpress/presses/streaming_llm_press.py +++ b/kvpress/presses/streaming_llm_press.py @@ -44,7 +44,7 @@ def score( kwargs, ) -> torch.Tensor: - q_len = hidden_states.shape[1] + q_len = keys.shape[2] assert q_len > self.n_sink, f"Input should contain more tokens than n_sink={self.n_sink}" n_pruned = q_len - int(q_len * (1 - self.compression_ratio)) scores = torch.ones_like(keys[..., 0]) diff --git a/kvpress/presses/utils.py b/kvpress/presses/utils.py index 938c15ba..3614fa5b 100644 --- a/kvpress/presses/utils.py +++ b/kvpress/presses/utils.py @@ -3,6 +3,7 @@ import torch from torch import nn +from transformers import Cache, QuantizedCache from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention from transformers.models.phi3.modeling_phi3 import Phi3Attention from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention @@ -50,3 +51,22 @@ def get_query_states(module: nn.Module, hidden_states: torch.Tensor) -> torch.Te query_states = module.q_norm(query_states) return query_states + + +def dequantize_layer(cache_layer) -> tuple[torch.Tensor, torch.Tensor]: + keys = cache_layer._dequantize(cache_layer._quantized_keys) + values = cache_layer._dequantize(cache_layer._quantized_values) + return keys, values + + +def extract_keys_and_values(cache: Cache, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Extracts the keys and values from a given cache layer, + handling both quantized and unquantized caches. + """ + if isinstance(cache, QuantizedCache): + keys, values = dequantize_layer(cache.layers[layer_idx]) + else: + keys = cache.layers[layer_idx].keys + values = cache.layers[layer_idx].values + return keys, values diff --git a/notebooks/kvpress_decoding_aime25.ipynb b/notebooks/kvpress_decoding_aime25.ipynb new file mode 100644 index 00000000..626c06da --- /dev/null +++ b/notebooks/kvpress_decoding_aime25.ipynb @@ -0,0 +1,417 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 6, + "id": "5705c4fb-d665-48e8-a4dd-a43725e5f7f4", + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "\n", + "import torch\n", + "from transformers import pipeline\n", + "\n", + "from kvpress import (\n", + " ExpectedAttentionPress,\n", + " KnormPress,\n", + " ObservedAttentionPress,\n", + " RandomPress,\n", + " SnapKVPress,\n", + " StreamingLLMPress,\n", + " TOVAPress,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "553f7d7f-7a2f-456d-a4b1-f38f3ead8767", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "dataset = load_dataset(\"math-ai/aime25\")\n", + "sample = dataset[\"test\"][0]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2e0398e8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'problem': 'Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$',\n", + " 'answer': '70',\n", + " 'id': '0'}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sample " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "dd926e5b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fetching 0 files: 0it [00:00, ?it/s]\n", + "Fetching 1 files: 100%|██████████| 1/1 [00:00<00:00, 18641.35it/s]\n", + "Fetching 0 files: 0it [00:00, ?it/s]\n", + "Device set to use cuda:0\n" + ] + } + ], + "source": [ + "device = \"cuda:0\"\n", + "ckpt = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n", + "attn_implementation = \"flash_attention_2\" # use \"eager\" for ObservedAttentionPress and \"sdpa\" if you can't use \"flash_attention_2\"\n", + "pipe = pipeline(\"kv-press-text-generation\", model=ckpt, device=device, model_kwargs={\"attn_implementation\":attn_implementation, \"dtype\": torch.bfloat16})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc932bbc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Question: Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$\n", + "Answer: 70\n", + "Prediction: Okay, so I have this problem here: I need to find the sum of all integer bases \\( b > 9 \\) for which \\( 17_b \\) is a divisor of \\( 97_b \\). Hmm, okay. Let me try to break this down step by step.\n", + "\n", + "First, I know that when a number is written in base \\( b \\), each digit represents a power of \\( b \\). So, for example, \\( 17_b \\) would be equal to \\( 1 \\times b + 7 \\times 1 \\) in decimal, right? Similarly, \\( 97_b \\) would be \\( 9 \\times b + 7 \\times 1 \\) in decimal. So, I can convert both of these numbers to base 10 and then set up the division condition.\n", + "\n", + "Let me write that out:\n", + "\n", + "\\( 17_b = 1 \\times b + 7 = b + 7 \\)\n", + "\n", + "\\( 97_b = 9 \\times b + 7 = 9b + 7 \\)\n", + "\n", + "So, the problem is asking for bases \\( b > 9 \\) such that \\( b + 7 \\) divides \\( 9b + 7 \\). In other words, \\( b + 7 \\) is a divisor of \\( 9b + 7 \\).\n", + "\n", + "I remember that if \\( a \\) divides \\( b \\), then \\( b = k \\times a \\) for some integer \\( k \\). So, in this case, \\( 9b + 7 = k \\times (b + 7) \\) for some integer \\( k \\).\n", + "\n", + "Let me write that equation:\n", + "\n", + "\\( 9b + 7 = k(b + 7) \\)\n", + "\n", + "Now, I can expand the right-hand side:\n", + "\n", + "\\( 9b + 7 = kb + 7k \\)\n", + "\n", + "Now, let's bring all terms to one side to see if I can solve for \\( k \\):\n", + "\n", + "\\( 9b + 7 - kb - 7k = 0 \\)\n", + "\n", + "Factor out \\( b \\) and \\( k \\):\n", + "\n", + "\\( b(9 - k) + 7(1 - k) = 0 \\)\n", + "\n", + "Hmm, that's a bit messy. Maybe I can rearrange the equation differently. Let me subtract \\( k(b + 7) \\) from both sides:\n", + "\n", + "\\( 9b + 7 - k(b + 7) = 0 \\)\n", + "\n", + "Which gives:\n", + "\n", + "\\( (9 - k)b + (7 - 7k) = 0 \\)\n", + "\n", + "So, this is a linear equation in terms of \\( b \\). Let me solve for \\( b \\):\n", + "\n", + "\\( (9 - k)b = 7k - 7 \\)\n", + "\n", + "So,\n", + "\n", + "\\( b = \\frac{7k - 7}{9 - k} \\)\n", + "\n", + "Hmm, okay. Since \\( b \\) must be an integer greater than 9, the right-hand side must also be an integer. So, \\( 7k - 7 \\) must be divisible by \\( 9 - k \\). Let me write that as:\n", + "\n", + "\\( 7(k - 1) \\) is divisible by \\( 9 - k \\)\n", + "\n", + "Which can be written as:\n", + "\n", + "\\( 7(k - 1) \\) is divisible by \\( -(k - 9) \\)\n", + "\n", + "But since divisibility is unaffected by the sign, I can say that \\( 7(k - 1) \\) is divisible by \\( k - 9 \\).\n", + "\n", + "So, \\( k - 9 \\) divides \\( 7(k - 1) \\). Let me denote \\( d = k - 9 \\), so \\( d \\) divides \\( 7(k - 1) \\). But \\( k = d + 9 \\), so substituting back:\n", + "\n", + "\\( d \\) divides \\( 7(d + 9 - 1) = 7(d + 8) \\)\n", + "\n", + "So, \\( d \\) divides \\( 7(d + 8) \\). Which means \\( d \\) divides \\( 7 \\times 8 = 56 \\). Because \\( d \\) divides \\( 7d + 56 \\), so \\( d \\) must divide 56.\n", + "\n", + "Therefore, \\( d \\) is a divisor of 56. So, the possible values of \\( d \\) are the divisors of 56. Let me list them:\n", + "\n", + "Positive divisors: 1, 2, 4, 7, 8, 14, 28, 56\n", + "\n", + "Negative divisors: -1, -2, -4, -7, -8, -14, -28, -56\n", + "\n", + "So, \\( d \\) can be any of these. But remember that \\( d = k - 9 \\), and \\( k \\) is an integer because \\( k \\) is the quotient when \\( 9b + 7 \\) is divided by \\( b + 7 \\). So, \\( k \\) must be an integer.\n", + "\n", + "But let's see, since \\( b > 9 \\), and \\( b + 7 \\) is a divisor of \\( 9b + 7 \\), we can also think about the possible values of \\( k \\). Let me see.\n", + "\n", + "From the equation \\( b = \\frac{7(k - 1)}{9 - k} \\), since \\( b > 9 \\), the numerator and denominator must have the same sign. So, either both numerator and denominator are positive or both are negative.\n", + "\n", + "Case 1: Both numerator and denominator are positive.\n", + "\n", + "So, \\( 7(k - 1) > 0 \\) implies \\( k > 1 \\)\n", + "\n", + "And \\( 9 - k > 0 \\) implies \\( k < 9 \\)\n", + "\n", + "So, \\( 1 < k < 9 \\)\n", + "\n", + "Case 2: Both numerator and denominator are negative.\n", + "\n", + "So, \\( 7(k - 1) < 0 \\) implies \\( k < 1 \\)\n", + "\n", + "And \\( 9 - k < 0 \\) implies \\( k > 9 \\)\n", + "\n", + "But \\( k < 1 \\) and \\( k > 9 \\) can't happen at the same time. So, only Case 1 is possible.\n", + "\n", + "Therefore, \\( k \\) must be an integer between 2 and 8 inclusive.\n", + "\n", + "So, \\( k \\in \\{2, 3, 4, 5, 6, 7, 8\\} \\)\n", + "\n", + "Now, let's compute \\( d = k - 9 \\) for each \\( k \\):\n", + "\n", + "For \\( k = 2 \\): \\( d = -7 \\)\n", + "\n", + "For \\( k = 3 \\): \\( d = -6 \\)\n", + "\n", + "For \\( k = 4 \\): \\( d = -5 \\)\n", + "\n", + "For \\( k = 5 \\): \\( d = -4 \\)\n", + "\n", + "For \\( k = 6 \\): \\( d = -3 \\)\n", + "\n", + "For \\( k = 7 \\): \\( d = -2 \\)\n", + "\n", + "For \\( k = 8 \\): \\( d = -1 \\)\n", + "\n", + "So, \\( d \\) can be -7, -6, -5, -4, -3, -2, -1.\n", + "\n", + "But earlier, we said that \\( d \\) must divide 56. Let's check if each of these \\( d \\) values divides 56.\n", + "\n", + "- \\( d = -7 \\): 56 ÷ (-7) = -8, which is integer. So, yes.\n", + "\n", + "- \\( d = -6 \\): 56 ÷ (-6) ≈ -9.333... Not integer. So, no.\n", + "\n", + "- \\( d = -5 \\): 56 ÷ (-5) = -11.2. Not integer. No.\n", + "\n", + "- \\( d = -4 \\): 56 ÷ (-4) = -14. Integer. Yes.\n", + "\n", + "- \\( d = -3 \\): 56 ÷ (-3) ≈ -18.666... Not integer. No.\n", + "\n", + "- \\( d = -2 \\): 56 ÷ (-2) = -28. Integer. Yes.\n", + "\n", + "- \\( d = -1 \\): 56 ÷ (-1) = -56. Integer. Yes.\n", + "\n", + "So, the valid \\( d \\) values are -7, -4, -2, -1.\n", + "\n", + "Therefore, for each of these \\( d \\), we can find \\( k \\):\n", + "\n", + "- \\( d = -7 \\): \\( k = d + 9 = 2 \\)\n", + "\n", + "- \\( d = -4 \\): \\( k = 5 \\)\n", + "\n", + "- \\( d = -2 \\): \\( k = 7 \\)\n", + "\n", + "- \\( d = -1 \\): \\( k = 8 \\)\n", + "\n", + "So, now, let's compute \\( b \\) for each \\( k \\):\n", + "\n", + "Recall \\( b = \\frac{7(k - 1)}{9 - k} \\)\n", + "\n", + "Let's compute for each \\( k \\):\n", + "\n", + "1. \\( k = 2 \\):\n", + "\n", + "\\( b = \\frac{7(2 - 1)}{9 - 2} = \\frac{7(1)}{7} = 1 \\)\n", + "\n", + "But \\( b > 9 \\), so 1 is invalid.\n", + "\n", + "2. \\( k = 5 \\):\n", + "\n", + "\\( b = \\frac{7(5 - 1)}{9 - 5} = \\frac{7(4)}{4} = 7 \\)\n", + "\n", + "Again, \\( b = 7 \\) is less than 9, so invalid.\n", + "\n", + "3. \\( k = 7 \\):\n", + "\n", + "\\( b = \\frac{7(7 - 1)}{9 - 7} = \\frac{7(6)}{\n", + "Cache size: torch.Size([1, 2, 2081, 128])\n" + ] + } + ], + "source": [ + "%%time\n", + "cache = DynamicCache()\n", + "question = sample[\"problem\"]\n", + "true_answer = sample[\"answer\"]\n", + "pred_answer = pipe(\" \", question=question, press=None, cache=cache, max_new_tokens=2048)[\"answer\"]\n", + "\n", + "print(f\"Question: {question}\")\n", + "print(f\"Answer: {true_answer}\")\n", + "print(f\"Prediction: {pred_answer}\")\n", + "print(f\"Cache size: {cache.layers[0].keys.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "ff177e00-cabf-4680-9cfc-b6cb00770527", + "metadata": {}, + "outputs": [], + "source": [ + "from kvpress.presses.decoding_press import DecodingPress" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "7f4fa366-7e62-443a-b5c8-274128fe6237", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Question: Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$\n", + "Answer: 70\n", + "Prediction: Okay, so I have this problem here: I need to find the sum of all integer bases \\( b > 9 \\) for which \\( 17_b \\) is a divisor of \\( 97_b \\). Hmm, okay. Let me try to break this down step by step.\n", + "\n", + "First, I know that when a number is written in base \\( b \\), each digit represents a power of \\( b \\). So, for example, \\( 17_b \\) would be equal to \\( 1 \\times b + 7 \\times 1 \\) in decimal, right? Similarly, \\( 97_b \\) would be \\( 9 \\times b + 7 \\times 1 \\) in decimal. So, I can convert both of these numbers to base 10 and then set up the division condition.\n", + "\n", + "Let me write that out:\n", + "\n", + "\\( 17_b = 1 \\times b + 7 = b + 7 \\)\n", + "\n", + "\\( 97_b = 9 \\times b + 7 = 9b + 7 \\)\n", + "\n", + "So, the problem is asking for bases \\( b > 9 \\) such that \\( b + 7 \\) divides \\( 9b + 7 \\). In other words, \\( b + 7 \\) is a divisor of \\( 9b + 7 \\).\n", + "\n", + "I remember that if \\( a \\) divides \\( b \\), then \\( b = k \\times a \\) for some integer \\( k \\). So, in this case, \\( 9b + 7 = k \\times (b + 7) \\) for some integer \\( k \\).\n", + "\n", + "Let me write that equation:\n", + "\n", + "\\( 9b + 7 = k(b + 7) \\)\n", + "\n", + "Now, I can expand the right-hand side:\n", + "\n", + "\\( 9b + 7 = kb + 7k \\)\n", + "\n", + "Now, let's bring all terms to one side to see if I can solve for \\( k \\):\n", + "\n", + "\\( 9b + 7 - kb - 7k = 0 \\)\n", + "\n", + "Factor out \\( b \\) and \\( k \\):\n", + "\n", + "\\( b(9 - k) + 7(1 - k) = 0 \\)\n", + "\n", + "Hmm, that's a bit messy. Maybe I can rearrange the equation differently. Let me subtract \\( k(b + 7) \\) from both sides:\n", + "\n", + "\\( 9b + 7 - k(b + 7) = 0 \\)\n", + "\n", + "Expanding the \\( k(b + 7) \\):\n", + "\n", + "\\( 9b + 7 - kb - 7k = 0 \\)\n", + "\n", + "Now, let's collect like terms:\n", + "\n", + "\\( (9 - k)b + (7 - 7k) = 0 \\)\n", + "\n", + "So, this is a linear equation in terms of \\( b \\). Let me write it as:\n", + "\n", + "\\( (9 - k)b = 7k - 7 \\)\n", + "\n", + "Then, solving for \\( b \\):\n", + "\n", + "\\( b = \\frac{7k - 7}{9 - k} \\)\n", + "\n", + "Hmm, okay. So, \\( b \\) must be an integer greater than 9. So, \\( b \\) is an integer, and \\( k \\) is also an integer because \\( k \\) is the quotient when \\( 9b + 7 \\) is divided by \\( b + 7 \\).\n", + "\n", + "So, \\( b \\) is expressed in terms of \\( k \\). Let me see if I can find integer values of \\( k \\) such that \\( b \\) is an integer greater than 9.\n", + "\n", + "First, let's note that \\( 9 - k \\) cannot be zero because that would make the denominator zero, which is undefined. So, \\( k \\neq 9 \\).\n", + "\n", + "Also, since \\( b > 9 \\), let's see what constraints that imposes on \\( k \\).\n", + "\n", + "Looking at the expression \\( b = \\frac{7k - 7}{9 - k} \\), let's see when this is positive and greater than 9.\n", + "\n", + "First, let's note that \\( 7k - 7 \\) and \\( 9 - k \\) must have the same sign because \\( b \\) is positive.\n", + "\n", + "So, either both numerator and denominator are positive or both are negative.\n", + "\n", + "Case 1: Both numerator and denominator are positive.\n", + "\n", + "So, \\( 7k - 7 > 0 \\) implies \\( k > 1 \\).\n", + "\n", + "And \\( 9 - k > 0 \\) implies \\( k < 9 \\).\n", + "\n", + "So, \\( k \\) must satisfy \\( 1 < k < 9 \\). Since \\( k \\) is an integer, \\( k \\) can be 2, 3, 4, 5, 6, 7, or 8.\n", + "\n", + "Case 2: Both numerator and\n", + "Cache size: torch.Size([1, 2, 500, 128])\n", + "CPU times: user 36.7 s, sys: 26.8 ms, total: 36.8 s\n", + "Wall time: 36.8 s\n" + ] + } + ], + "source": [ + "%%time\n", + "compression_interval = 500 # compress every compression_steps\n", + "target_size = 500 # number of tokens to keep after compression\n", + "\n", + "\n", + "press = DecodingPress(base_press=KnormPress(), compression_interval=compression_interval, target_size=target_size, hidden_states_buffer_size=0)\n", + "\n", + "cache = DynamicCache()\n", + "question = sample[\"problem\"]\n", + "true_answer = sample[\"answer\"]\n", + "pred_answer = pipe(\" \", question=question, press=press, cache=cache, max_new_tokens=1000)[\"answer\"]\n", + "\n", + "print(f\"Question: {question}\")\n", + "print(f\"Answer: {true_answer}\")\n", + "print(f\"Prediction: {pred_answer}\")\n", + "print(f\"Cache size: {cache.layers[0].keys.shape}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index 7bc1e993..935cf50f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,4 +91,4 @@ disable_error_code = ["attr-defined"] [[tool.mypy.overrides]] module = "kvpress.pipeline" -disable_error_code = ["attr-defined", "assignment", "override"] \ No newline at end of file +disable_error_code = ["attr-defined", "assignment", "override"] diff --git a/tests/fixtures.py b/tests/fixtures.py index c446db4d..ae855a0a 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -51,7 +51,7 @@ def kv_press_llama3_2_flash_attn_pipeline(): "kv-press-text-generation", model=ckpt, device=device, - model_kwargs={"attn_implementation": attn_implementation, "torch_dtype": torch.bfloat16}, + model_kwargs={"attn_implementation": attn_implementation, "dtype": torch.bfloat16}, ) return pipe @@ -65,6 +65,6 @@ def kv_press_llama3_1_flash_attn_pipeline(): "kv-press-text-generation", model=ckpt, device=device, - model_kwargs={"attn_implementation": attn_implementation, "torch_dtype": torch.bfloat16}, + model_kwargs={"attn_implementation": attn_implementation, "dtype": torch.bfloat16}, ) return pipe diff --git a/tests/test_decoding_compression.py b/tests/test_decoding_compression.py new file mode 100644 index 00000000..90acb8e0 --- /dev/null +++ b/tests/test_decoding_compression.py @@ -0,0 +1,305 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Test script to verify that DecodingPress actually compresses during decoding. +""" +import logging + +import pytest +import torch +from transformers import DynamicCache, pipeline + +from kvpress import PyramidKVPress, ScorerPress +from kvpress.presses.decoding_press import DecodingPress +from kvpress.presses.knorm_press import KnormPress +from kvpress.presses.prefill_decoding_press import PrefillDecodingPress +from tests.default_presses import default_presses + +logger = logging.getLogger(__name__) + + +@pytest.mark.parametrize("token_buffer_size", [32, 64, 128]) +def test_decoding_compression(token_buffer_size): + """Test that DecodingPress compresses the cache during decoding.""" + + # Initialize pipeline with a small model + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + # Create a DecodingPress with KnormPress + press = DecodingPress( + base_press=KnormPress(compression_ratio=0.5), # Remove 50% of tokens + compression_interval=4, # Compress every 4 tokens + target_size=token_buffer_size, + ) + + # Create cache + cache = DynamicCache() + + # Test context and question + context = "The quick brown fox jumps over the lazy dog. " * 10 # Repeat for longer context + question = "What animal jumps over the dog?" + + # Run pipeline + pipe(context, question=question, press=press, cache=cache, max_new_tokens=20) + + # Assert that all layers have the expected cache size + for layer_idx, cache_layer in enumerate(cache.layers): + layer_seq_len = cache_layer.keys.shape[2] + # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger + max_expected_size = token_buffer_size + press.compression_interval - 1 + assert token_buffer_size <= layer_seq_len <= max_expected_size, ( + f"Layer {layer_idx}: Expected cache sequence length to be between {token_buffer_size} " + f"and {max_expected_size}, but got {layer_seq_len}" + ) + + +def test_prefill_decoding_press_calls_both_phases(): + """Test that PrefillDecodingPress calls both prefilling and decoding presses.""" + + # Initialize pipeline + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + # Create PrefillDecodingPress with both presses + combined_press = PrefillDecodingPress( + prefilling_press=KnormPress(compression_ratio=0.6), # Compress to 60% during prefill + decoding_press=DecodingPress(base_press=KnormPress(), compression_interval=3, target_size=48), + ) + + # Test context and question + context = "The quick brown fox jumps over the lazy dog. " * 12 # Longer context + question = "What animal jumps over the dog?" + + # Run pipeline + cache = DynamicCache() + pipe(context, question=question, press=combined_press, cache=cache, max_new_tokens=15) + + # Check that cache was compressed during both phases + # Final cache should be compressed to decoding press target size + for layer_idx, cache_layer in enumerate(cache.layers): + layer_seq_len = cache_layer.keys.shape[2] + # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger + target_size = 48 # token_buffer_size from decoding press + compression_steps = 3 # from the decoding press configuration + max_expected_size = target_size + compression_steps - 1 + assert target_size <= layer_seq_len <= max_expected_size, ( + f"Layer {layer_idx}: Expected final cache size to be between {target_size} " + f"and {max_expected_size} (decoding target), but got {layer_seq_len}" + ) + + +def test_decoding_press_without_prefill(): + """Test that DecodingPress works correctly when used standalone (no prefill compression).""" + + # Initialize pipeline + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + # Create DecodingPress only + decoding_press = DecodingPress(base_press=KnormPress(compression_ratio=0.4), compression_interval=5, target_size=64) + + # Test context and question + context = "The quick brown fox jumps over the lazy dog. " * 8 + question = "What animal jumps over the dog?" + + # Run pipeline + cache = DynamicCache() + pipe(context, question=question, press=decoding_press, cache=cache, max_new_tokens=25) + + # Check that cache was compressed during decoding + for layer_idx, cache_layer in enumerate(cache.layers): + layer_seq_len = cache_layer.keys.shape[2] + # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger + target_size = 64 + compression_steps = 5 # from the decoding press configuration + max_expected_size = target_size + compression_steps - 1 + assert target_size <= layer_seq_len <= max_expected_size, ( + f"Layer {layer_idx}: Expected cache size to be between {target_size} " + f"and {max_expected_size}, but got {layer_seq_len}" + ) + + +def test_prefill_decoding_press_decoding_only(): + """Test PrefillDecodingPress with only decoding press (no prefill compression).""" + + # Initialize pipeline + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + # Create PrefillDecodingPress with only decoding press + combined_press = PrefillDecodingPress( + prefilling_press=None, + decoding_press=DecodingPress( + base_press=KnormPress(compression_ratio=0.6), compression_interval=4, target_size=56 + ), + ) + + # Test context and question + context = "The quick brown fox jumps over the lazy dog. " * 9 + question = "What animal jumps over the dog?" + + # Run pipeline + cache = DynamicCache() + pipe(context, question=question, press=combined_press, cache=cache, max_new_tokens=12) + + # Check that only decoding compression was applied + for layer_idx, cache_layer in enumerate(cache.layers): + layer_seq_len = cache_layer.keys.shape[2] + # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger + target_size = 56 + compression_steps = 4 # from the decoding press configuration + max_expected_size = target_size + compression_steps - 1 + assert target_size <= layer_seq_len <= max_expected_size, ( + f"Layer {layer_idx}: Expected cache size to be between {target_size} " + f"and {max_expected_size}, but got {layer_seq_len}" + ) + + +def test_decoding_press_equivalence(): + """Test that DecodingPress standalone yields same result as PrefillDecodingPress with decoding only.""" + + # Set random seed for reproducibility + torch.manual_seed(42) + + # Initialize pipeline + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + # Create standalone decoding press + decoding_press = DecodingPress(base_press=KnormPress(compression_ratio=0.5), compression_interval=3, target_size=52) + + # Create PrefillDecodingPress with only decoding press + combined_press = PrefillDecodingPress( + prefilling_press=None, + decoding_press=DecodingPress( + base_press=KnormPress(compression_ratio=0.5), compression_interval=3, target_size=52 + ), + ) + + # Test context and question + context = "The quick brown fox jumps over the lazy dog. " * 7 + question = "What animal jumps over the dog?" + + # Run with standalone decoding press + cache1 = DynamicCache() + result1 = pipe(context, question=question, press=decoding_press, cache=cache1, max_new_tokens=10) + + # Run with combined press (decoding only) + cache2 = DynamicCache() + result2 = pipe(context, question=question, press=combined_press, cache=cache2, max_new_tokens=10) + + # Compare cache sizes (should be identical) + for layer_idx in range(len(cache1.layers)): + cache1_size = cache1.layers[layer_idx].keys.shape[2] + cache2_size = cache2.layers[layer_idx].keys.shape[2] + assert cache1_size == cache2_size, ( + f"Layer {layer_idx}: Standalone decoding cache size {cache1_size} != " + f"combined press cache size {cache2_size}" + ) + + # Compare generated text results (should be identical) + assert result1["answer"] == result2["answer"], ( + f"Generated answers differ:\n" + f"Standalone decoding: '{result1['answer']}'\n" + f"Combined press: '{result2['answer']}'" + ) + + +""" +E AttributeError: 'QFilterPress' object has no attribute 'q_filters' +E Failed: DecodingPress failed with SnapKVPress: shape '[1, 2, 2, 6]' is invalid for input of size 12 +> query_states = query_states.view(bsz, window_size, num_heads, head_dim).transpose(1, 2) +E RuntimeError: shape '[1, 2, 2, 6]' is invalid for input of size 12 +""" + + +@pytest.mark.parametrize("press_config", default_presses) +def test_all_presses_work_with_decoding_press(press_config): + """Test that all default presses work as base presses for DecodingPress.""" + + # Initialize pipeline + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + # Get press class and use the first (easier) configuration + press_cls = press_config["cls"] + press_kwargs = press_config["kwargs"][0] # Use easier compression settings + + base_press = press_cls(**press_kwargs) + if not isinstance(base_press, ScorerPress): + logger.info(f"Press {press_cls.__name__} is not a ScorerPress, skipping test") + return + if isinstance(base_press, (PyramidKVPress)): + # PyramidKVPress -> Pyramid shape, not compatible with token_buffer_size=48 + logger.info(f"Press {press_cls.__name__} is not supported, skipping test") + return + if hasattr(base_press, "__post_init_from_model__"): + base_press.__post_init_from_model__(pipe.model) + + # Create DecodingPress with this base press + decoding_press = DecodingPress(base_press=base_press, compression_interval=3, target_size=48) + + # Test context and question + context = "The quick brown fox jumps over the lazy dog. " * 8 + question = "What animal jumps over the dog?" + + # Run pipeline + cache = DynamicCache() + try: + result = pipe(context, question=question, press=decoding_press, cache=cache, max_new_tokens=15) + + # Verify compression worked + assert len(result["answer"]) > 0, f"No answer generated with {press_cls.__name__}" + + # Check that cache was compressed (allow some tolerance for rounding) + for layer_idx, cache_layer in enumerate(cache.layers): + layer_seq_len = cache_layer.keys.shape[2] + # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger + target_size = 48 + compression_steps = 3 # from the decoding press configuration + max_expected_size = target_size + compression_steps - 1 + assert ( + target_size <= layer_seq_len <= max_expected_size + ), f"{press_cls.__name__}: Layer {layer_idx} cache size {layer_seq_len} not in expected range [{target_size}-{max_expected_size}]" # noqa: E501 + + except Exception as e: + pytest.fail(f"DecodingPress failed with {press_cls.__name__}: {e}") + + +def test_compression_actually_reduces_memory(): + """Test that compression actually reduces memory usage compared to no compression.""" + + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + context = "The quick brown fox jumps over the lazy dog. " * 15 # Long context + question = "What animal jumps over the dog?" + + # Run without compression + cache_uncompressed = DynamicCache() + result_uncompressed = pipe(context, question=question, cache=cache_uncompressed, max_new_tokens=25) + + # Run with compression + press = DecodingPress( + base_press=KnormPress(compression_ratio=0.3), # Aggressive compression + compression_interval=3, + target_size=40, + ) + cache_compressed = DynamicCache() + result_compressed = pipe(context, question=question, press=press, cache=cache_compressed, max_new_tokens=25) + + # Calculate memory usage (approximate) + uncompressed_memory = sum( + (cache_layer.values.numel() + cache_layer.keys.numel()) * cache_layer.keys.element_size() + for cache_layer in cache_uncompressed.layers + ) + compressed_memory = sum( + (cache_layer.values.numel() + cache_layer.keys.numel()) * cache_layer.keys.element_size() + for cache_layer in cache_compressed.layers + ) + + # Compression should significantly reduce memory usage + compression_ratio = compressed_memory / uncompressed_memory + assert compression_ratio < 0.6, ( + f"Expected compression ratio < 0.6, but got {compression_ratio:.3f} " + f"(compressed: {compressed_memory} bytes, uncompressed: {uncompressed_memory} bytes)" + ) + + # Both should still generate reasonable answers + assert len(result_uncompressed["answer"]) > 0 + assert len(result_compressed["answer"]) > 0 diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 030d83cb..26f28f86 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -155,7 +155,9 @@ def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811 keys = [layer.keys.clone() for layer in past_key_values.layers] values = [layer.values.clone() for layer in past_key_values.layers] + cache_seq_lengths = [past_key_values.get_seq_length(layer_idx) for layer_idx in range(len(past_key_values))] compression_pipeline.generate_answer(input_ids_question, past_key_values, context_length=22, max_new_tokens=10) + compression_pipeline._remove_answer_from_cache(past_key_values, cache_seq_lengths) assert past_key_values.get_seq_length() == seq_len assert all([torch.allclose(key, layer.keys) for key, layer in zip(keys, past_key_values.layers)]) assert all([torch.allclose(value, layer.values) for value, layer in zip(values, past_key_values.layers)]) From b16417bb5d51127da095f55e4ce54d986bf0f4eb Mon Sep 17 00:00:00 2001 From: Max Jeblick Date: Tue, 7 Oct 2025 21:15:41 +0200 Subject: [PATCH 02/20] Add decoding press functionality This commit adds comprehensive decoding press functionality including: - New DecodingPress and PrefillDecodingPress classes - Updates to pipeline for decoding support - Modified existing presses to support decoding compression - New test files and documentation - Enhanced base press with decoding capabilities Co-authored-by: alessiodevoto Signed-off-by: Max Jeblick Signed-off-by: Max Jeblick --- README.md | 130 ++++++- kvpress/__init__.py | 142 ++++---- kvpress/pipeline.py | 78 ++-- kvpress/presses/base_press.py | 16 +- kvpress/presses/decoding_press.py | 224 ++++++++++++ kvpress/presses/duo_attention_press.py | 2 +- kvpress/presses/finch_press.py | 2 +- kvpress/presses/key_rerotation_press.py | 2 +- kvpress/presses/kvzip_press.py | 15 +- kvpress/presses/prefill_decoding_press.py | 86 +++++ kvpress/presses/pyramidkv_press.py | 2 +- kvpress/presses/scorer_press.py | 2 +- kvpress/presses/simlayerkv_press.py | 2 +- kvpress/presses/snapkv_press.py | 11 +- kvpress/presses/streaming_llm_press.py | 2 +- kvpress/presses/utils.py | 20 ++ notebooks/kvpress_decoding_aime25.ipynb | 417 ++++++++++++++++++++++ pyproject.toml | 2 +- tests/fixtures.py | 4 +- tests/test_decoding_compression.py | 305 ++++++++++++++++ tests/test_pipeline.py | 2 + 21 files changed, 1321 insertions(+), 145 deletions(-) create mode 100644 kvpress/presses/decoding_press.py create mode 100644 kvpress/presses/prefill_decoding_press.py create mode 100644 notebooks/kvpress_decoding_aime25.ipynb create mode 100644 tests/test_decoding_compression.py diff --git a/README.md b/README.md index d47c63a2..b7d53e75 100644 --- a/README.md +++ b/README.md @@ -67,15 +67,7 @@ answer = pipe(context, question=question, press=press)["answer"] In the snippet above, the compression is only applied on the context tokens so that you can evaluate the compression for different questions. Check the [Wikipedia notebook demo](notebooks/wikipedia_demo.ipynb) for a more detailed example (also available on Colab [here](https://colab.research.google.com/drive/1JNvaTKuuAHrl49dYB9-mdEH_y52Ib-NP)). -> [!IMPORTANT] -> We focus on compression during the pre-filling phase as the KV cache becomes a bottleneck for long-context sequence (100k - 1M tokens) which are essentially long context prompts. This would typically apply to improving prompt caching systems. - -> [!NOTE] -> Use `model_kwargs={"attn_implementation":"flash_attention_2"}` to enable flash attention. To use the press `ObservedAttentionPress`, you need to specify `model_kwargs={"attn_implementation":"eager"}` as this press requires to materialize the attention weights - -## Contributing -We welcome contributions! To add a new press, simply open an issue or submit a pull request. Check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide. ## Available presses @@ -131,6 +123,48 @@ Below we report the average performance on the RULER dataset with 4k context len Leaderboard

+
+## Decoding Compression + +By default, KVPress applies compression during the pre-filling phase. As a new (experimental) feature, we now support decoding compression via the `DecodingPress` wrapper. `DecodingPress` compresses the KV cache periodically during token generation, optionally maintaining a buffer of recent hidden states. `DecodingPress` supports the following parameters: + +- `base_press`: Any ScorerPress (e.g., `KNormPress`, `CriticalKVPress`) +- `compression_interval`: Steps between compressions (default: 10) +- `target_size`: Target cache size of the cache after compression (default: 1024) +- `hidden_states_buffer_size`: Number of hidden states to buffer before compression (default: 128). Some presses don't need buffered hidden states and can set this to 0. + +Unlike a compression ratio, decoding press uses a `target_size` to compress the cache. This means that the cache is compressed every `compression_interval` steps, and the compression ratio is automatically computed such that the size of the cache after compression equals `target_size`. + +An example for decoding compression: + +```python +from transformers import pipeline +from kvpress import KnormPress +from kvpress import DecodingPress + +# Initialize the pipeline +device = "cuda:0" +model = "meta-llama/Llama-3.1-8B-Instruct" +model_kwargs = {"attn_implementation": "flash_attention_2"} +pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs) + +# Create a decoding press that compresses every 10 steps to 512 tokens +decoding_press = DecodingPress( + base_press=KnormPress(), + compression_steps=10, + token_buffer_size=512 +) + +# Use with pipeline +context = "A very long text you want to compress during generation" +question = "Tell me a long story about this context" +response = pipe(context, question=question, press=decoding_press)["answer"] +``` + +> Not all existing presses are fully compatible with DecodingPress due to fundamental differences in how compression works during decoding versus prefilling. + +
+ ## Quantization We support KV cache quantization through the transformers `QuantizedCache` class (see [HF blog post](https://huggingface.co/blog/kv-cache-quantization#how-to-use-quantized-kv-cache-in-%F0%9F%A4%97-transformers)). To use it, simply pass a cache object to your pipeline: @@ -149,6 +183,10 @@ By default, the `DynamicCache` is used (no quantization). > [!IMPORTANT] > To use the `QuantizedCache`, you need to install additional dependencies (_e.g._ `pip install optimum-quanto`). +## Contributing + +We welcome contributions! To add a new press, simply open an issue or submit a pull request. Check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide. + ## FAQ
@@ -224,4 +262,80 @@ with press(model): However, the `generate` method does not allow to exclude the question from the compression, which would artificially favors methods such as SnapKV. Ideally, we want a compression method that works whatever comes after the context (_e.g._ for use cases such as chat or document question answering). Finally the `generate` method does not allow to provide generation for multiple questions at once. +
+ + + +
+ +### Can I combine compression during prefilling and decoding ? + + + +Combines separate presses for prefilling and decoding phases. + +**Parameters:** +- `prefilling_press`: Press used during prefill phase +- `decoding_press`: Press used during decoding phase + +## Usage Examples + +### Basic Decoding Compression + +```python +from transformers import pipeline +from kvpress import KnormPress +from kvpress import DecodingPress + +# Initialize the pipeline +device = "cuda:0" +model = "meta-llama/Llama-3.1-8B-Instruct" +model_kwargs = {"attn_implementation": "flash_attention_2"} +pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs) + +# Create a decoding press that compresses every 10 steps to 512 tokens +decoding_press = DecodingPress( + base_press=KnormPress(), + compression_steps=10, + token_buffer_size=512 +) + +# Use with pipeline +context = "A very long text you want to compress during generation" +question = "Tell me a long story about this context" +response = pipe(context, question=question, press=decoding_press)["answer"] +``` + +### Combined Prefill + Decoding Compression + +```python +from transformers import pipeline +from kvpress import CriticalKVPress, KnormPress +from kvpress import DecodingPress, PrefillDecodingPress + +# Initialize the pipeline +device = "cuda:0" +model = "meta-llama/Llama-3.1-8B-Instruct" +model_kwargs = {"attn_implementation": "flash_attention_2"} +pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs) + +# Different strategies for prefill vs decoding +prefill_press = CriticalKVPress(KnormPress()) +decoding_press = DecodingPress( + base_press=KnormPress(compression_ratio=0.2), + compression_steps=5, + token_buffer_size=256 +) + +# Combine them +combined_press = PrefillDecodingPress( + prefilling_press=prefill_press, + decoding_press=decoding_press +) + +context = "A very long context that will be compressed during prefill" +question = "Generate a detailed analysis that will be compressed during decoding" +response = pipe(context, question=question, press=combined_press)["answer"] +``` +
\ No newline at end of file diff --git a/kvpress/__init__.py b/kvpress/__init__.py index cccd1f02..19165465 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -1,68 +1,74 @@ -# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - - -from kvpress.attention_patch import patch_attention_functions -from kvpress.pipeline import KVPressTextGenerationPipeline -from kvpress.presses.adakv_press import AdaKVPress -from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress -from kvpress.presses.block_press import BlockPress -from kvpress.presses.chunk_press import ChunkPress -from kvpress.presses.chunkkv_press import ChunkKVPress -from kvpress.presses.composed_press import ComposedPress -from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress -from kvpress.presses.duo_attention_press import DuoAttentionPress -from kvpress.presses.expected_attention_press import ExpectedAttentionPress -from kvpress.presses.expected_attention_with_stats import ExpectedAttentionStatsPress -from kvpress.presses.finch_press import FinchPress -from kvpress.presses.key_rerotation_press import KeyRerotationPress -from kvpress.presses.keydiff_press import KeyDiffPress -from kvpress.presses.knorm_press import KnormPress -from kvpress.presses.kvzip_press import KVzipPress -from kvpress.presses.lagkv_press import LagKVPress -from kvpress.presses.observed_attention_press import ObservedAttentionPress -from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress -from kvpress.presses.pyramidkv_press import PyramidKVPress -from kvpress.presses.qfilter_press import QFilterPress -from kvpress.presses.random_press import RandomPress -from kvpress.presses.scorer_press import ScorerPress -from kvpress.presses.simlayerkv_press import SimLayerKVPress -from kvpress.presses.snapkv_press import SnapKVPress -from kvpress.presses.streaming_llm_press import StreamingLLMPress -from kvpress.presses.think_press import ThinKPress -from kvpress.presses.tova_press import TOVAPress - -# Patch the attention functions to support head-wise compression -patch_attention_functions() - -__all__ = [ - "CriticalAdaKVPress", - "CriticalKVPress", - "AdaKVPress", - "BasePress", - "ComposedPress", - "ScorerPress", - "ExpectedAttentionPress", - "KnormPress", - "ObservedAttentionPress", - "RandomPress", - "SimLayerKVPress", - "SnapKVPress", - "StreamingLLMPress", - "ThinKPress", - "TOVAPress", - "KVPressTextGenerationPipeline", - "PerLayerCompressionPress", - "KeyRerotationPress", - "ChunkPress", - "DuoAttentionPress", - "ChunkKVPress", - "QFilterPress", - "PyramidKVPress", - "FinchPress", - "LagKVPress", - "BlockPress", - "KeyDiffPress", - "KVzipPress", - "ExpectedAttentionStatsPress", -] +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from kvpress.attention_patch import patch_attention_functions +from kvpress.pipeline import KVPressTextGenerationPipeline +from kvpress.presses.adakv_press import AdaKVPress +from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress +from kvpress.presses.block_press import BlockPress +from kvpress.presses.chunk_press import ChunkPress +from kvpress.presses.chunkkv_press import ChunkKVPress +from kvpress.presses.composed_press import ComposedPress +from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress +from kvpress.presses.decoding_press import DecodingPress +from kvpress.presses.duo_attention_press import DuoAttentionPress +from kvpress.presses.expected_attention_press import ExpectedAttentionPress +from kvpress.presses.expected_attention_with_stats import ExpectedAttentionStatsPress +from kvpress.presses.finch_press import FinchPress +from kvpress.presses.key_rerotation_press import KeyRerotationPress +from kvpress.presses.keydiff_press import KeyDiffPress +from kvpress.presses.knorm_press import KnormPress +from kvpress.presses.kvzip_press import KVzipPress +from kvpress.presses.lagkv_press import LagKVPress +from kvpress.presses.observed_attention_press import ObservedAttentionPress +from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress +from kvpress.presses.prefill_decoding_press import PrefillDecodingPress +from kvpress.presses.pyramidkv_press import PyramidKVPress +from kvpress.presses.qfilter_press import QFilterPress +from kvpress.presses.random_press import RandomPress +from kvpress.presses.scorer_press import ScorerPress +from kvpress.presses.simlayerkv_press import SimLayerKVPress +from kvpress.presses.snapkv_press import SnapKVPress +from kvpress.presses.streaming_llm_press import StreamingLLMPress +from kvpress.presses.think_press import ThinKPress +from kvpress.presses.tova_press import TOVAPress + +# Patch the attention functions to support head-wise compression +patch_attention_functions() + +__all__ = [ + "CriticalAdaKVPress", + "CriticalKVPress", + "AdaKVPress", + "BasePress", + "ComposedPress", + "ScorerPress", + "ExpectedAttentionPress", + "KnormPress", + "ObservedAttentionPress", + "RandomPress", + "SimLayerKVPress", + "SnapKVPress", + "StreamingLLMPress", + "ThinKPress", + "TOVAPress", + "KVPressTextGenerationPipeline", + "PerLayerCompressionPress", + "KeyRerotationPress", + "ChunkPress", + "DuoAttentionPress", + "ChunkKVPress", + "QFilterPress", + "PyramidKVPress", + "FinchPress", + "LagKVPress", + "BlockPress", + "KeyDiffPress", + "KVzipPress", + "DecodingPress", + "PrefillDecodingPress", + "ExpectedAttentionStatsPress", + "DecodingPress", + "PrefillDecodingPress", +] diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index 55eb506a..76f345d3 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -12,9 +12,11 @@ from transformers.pipelines.base import GenericTensor from kvpress.presses.base_press import BasePress +from kvpress.presses.decoding_press import DecodingPress from kvpress.presses.finch_press import FinchPress from kvpress.presses.key_rerotation_press import KeyRerotationPress from kvpress.presses.observed_attention_press import ObservedAttentionPress +from kvpress.presses.prefill_decoding_press import PrefillDecodingPress logger = logging.getLogger(__name__) @@ -189,6 +191,11 @@ def _forward( list[str] Generated answers for each input question. """ + if isinstance(press, (DecodingPress, PrefillDecodingPress)) and len(input_tensors["questions_ids"]) > 1: + raise ValueError( + "DecodingPress is not compatible with multiple questions. Please specify a single question." + ) + context_ids = input_tensors["context_ids"].to(self.model.device) context_length = context_ids.shape[1] @@ -196,7 +203,9 @@ def _forward( if cache is None: cache = DynamicCache() - with press(self.model) if press is not None else contextlib.nullcontext(): + # We only perform prefill compression if the press is not a decoding or prefill decoding press + perform_prefill_compression = press is not None and not isinstance(press, DecodingPress) + with press(self.model) if perform_prefill_compression else contextlib.nullcontext(): # We run the model without the lm head for pre-filling. self.model.model( input_ids=context_ids, @@ -204,24 +213,46 @@ def _forward( output_attentions=self.output_attentions(press), ) - logger.debug(f"Context Length: {context_length}") - logger.debug(f"Compressed Context Length: {cache.get_seq_length()}") + logger.debug(f"Context Length: {context_length}") + logger.debug(f"Compressed Context Length: {cache.get_seq_length()}") + + # We only perform decoding compression if the press is a decoding or prefill decoding press + perform_decoding_compression = press is not None and isinstance(press, (DecodingPress, PrefillDecodingPress)) + with press(self.model) if perform_decoding_compression else contextlib.nullcontext(): + # Greedy decoding for each question + answers = [] + for question_ids in input_tensors["questions_ids"]: + if isinstance(press, KeyRerotationPress) or (isinstance(press, FinchPress) and press.rerotate_keys): + context_length = cache.get_seq_length() + + cache_seq_lengths = [cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))] + answer = self.generate_answer( + question_ids=question_ids.to(self.model.device), + cache=cache, + context_length=context_length, + max_new_tokens=max_new_tokens, + ) + if len(input_tensors["questions_ids"]) > 1: + print(f"Removing answer from cache: {cache_seq_lengths}") + self._remove_answer_from_cache(cache, cache_seq_lengths) + + answers.append(answer) + return answers - # Greedy decoding for each question - answers = [] - for question_ids in input_tensors["questions_ids"]: - if isinstance(press, KeyRerotationPress) or (isinstance(press, FinchPress) and press.rerotate_keys): - context_length = cache.get_seq_length() + def _remove_answer_from_cache(self, cache: Cache, cache_seq_lengths: list[int]): - answer = self.generate_answer( - question_ids=question_ids.to(self.model.device), - cache=cache, - context_length=context_length, - max_new_tokens=max_new_tokens, - ) - answers.append(answer) + for layer_idx, sequence_length in enumerate(cache_seq_lengths): + cache.layers[layer_idx].keys = cache.layers[layer_idx].keys[:, :, :sequence_length] + cache.layers[layer_idx].values = cache.layers[layer_idx].values[:, :, :sequence_length] - return answers + if isinstance(cache, QuantizedCache): + for layer_idx, sequence_length in enumerate(cache_seq_lengths): + cache.layers[layer_idx]._quantized_keys = cache.layers[layer_idx]._quantized_keys[ + :, :, :sequence_length + ] + cache.layers[layer_idx]._quantized_values = cache.layers[layer_idx]._quantized_values[ + :, :, :sequence_length + ] def generate_answer( self, question_ids: torch.Tensor, cache: Cache, context_length: int, max_new_tokens: int @@ -245,7 +276,6 @@ def generate_answer( str The generated answer. """ - cache_seq_lengths = [cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))] position_ids = torch.arange( context_length, context_length + question_ids.shape[1], device=self.model.device ).unsqueeze(0) @@ -276,20 +306,6 @@ def generate_answer( if new_id.item() in should_stop_token_ids: break answer = self.tokenizer.decode(torch.stack(generated_ids), skip_special_tokens=True) - # Remove the generated tokens from the cache - for layer_idx, sequence_length in enumerate(cache_seq_lengths): - cache.layers[layer_idx].keys = cache.layers[layer_idx].keys[:, :, :sequence_length] - cache.layers[layer_idx].values = cache.layers[layer_idx].values[:, :, :sequence_length] - - if isinstance(cache, QuantizedCache): - for layer_idx, sequence_length in enumerate(cache_seq_lengths): - cache.layers[layer_idx]._quantized_keys = cache.layers[layer_idx]._quantized_keys[ - :, :, :sequence_length - ] - cache.layers[layer_idx]._quantized_values = cache.layers[layer_idx]._quantized_values[ - :, :, :sequence_length - ] - return answer def output_attentions(self, press: BasePress): diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index 7d40fcd7..8fb4acf2 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -20,6 +20,8 @@ Qwen3ForCausalLM, ) +from kvpress.presses.utils import extract_keys_and_values + logger = logging.getLogger(__name__) SUPPORTED_MODELS = ( @@ -124,24 +126,14 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic hidden_states = kwargs["hidden_states"] cache = kwargs["past_key_values"] + cache_layer = cache.layers[module.layer_idx] q_len = hidden_states.shape[1] # Don't compress after pre-filling if kwargs["cache_position"][-1] > q_len: return output - cache_layer = cache.layers[module.layer_idx] - if isinstance(cache, QuantizedCache): - keys = cache_layer._dequantize( # type: ignore[index] - cache_layer._quantized_keys # type: ignore[index] - ) - values = cache_layer._dequantize( # type: ignore[index] - cache_layer._quantized_values # type: ignore[index] - ) - - else: - keys = cache_layer.keys - values = cache_layer.values + keys, values = extract_keys_and_values(cache, module.layer_idx) keys, values = self.compress(module, hidden_states, keys, values, output[1], kwargs) diff --git a/kvpress/presses/decoding_press.py b/kvpress/presses/decoding_press.py new file mode 100644 index 00000000..9e5b1a6c --- /dev/null +++ b/kvpress/presses/decoding_press.py @@ -0,0 +1,224 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections import defaultdict +from dataclasses import dataclass + +import torch +import torch.nn as nn +from transformers.cache_utils import QuantizedCache + +from kvpress.presses.base_press import BasePress +from kvpress.presses.scorer_press import ScorerPress +from kvpress.presses.utils import extract_keys_and_values + +logger = logging.getLogger(__name__) + + +@dataclass +class DecodingPress(BasePress): + """ + A press that only operates during decoding phase and maintains a running buffer of hidden states. + + This press accumulates hidden states during decoding and applies compression every N steps + using a scorer press to determine which tokens to keep. + + + Parameters + ---------- + base_press : ScorerPress + The scorer press used to compute importance scores for tokens. + compression_interval : int, default=10 + Number of decoding steps between compression, i.e. compression will be applied every compression_interval steps. + target_size : int, default=1024 + Target number of tokens to keep after compression. + hidden_states_buffer_size : int, default=128 + Maximum number of hidden states to keep before compression. Larger values use more GPU memory. + Note: Some presses don't need buffered hidden states and can set this to 0 to use only the + current hidden state for compression scoring. + """ + + base_press: ScorerPress + compression_interval: int = 128 + target_size: int = 1024 + hidden_states_buffer_size: int = 128 + + def __post_init__(self): + # Buffer to store hidden states during decoding (per layer) + assert isinstance(self.base_press, ScorerPress), "DecodingPress requires a ScorerPress as input" + self.hidden_states_buffer = defaultdict(list) # Per-layer buffer + self.layer_step_counts = defaultdict(int) # Track step count per layer + + # Warn if compression happens before buffer is fully utilized + # TODO: would it make sense to not reset the buffer? + if self.hidden_states_buffer_size > 0 and self.compression_interval < self.hidden_states_buffer_size: + logger.warning( + f"compression_interval ({self.compression_interval}) < hidden_states_buffer_size ({self.hidden_states_buffer_size}). " # noqa: E501 + f"Buffer will be reset before reaching full capacity, potentially reducing compression quality." + ) + + assert self.compression_interval > 0, "compression_interval must be greater than 0" + assert self.target_size > 0, "target_size must be greater than 0" + + if self.base_press.compression_ratio: + logger.warning( + f"compression_ratio is set for base press ({self.base_press.compression_ratio}). " + f"This will be overridden by the decoding press." + ) + + def compress( + self, + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attentions: torch.Tensor, + kwargs: dict, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Delegate compression to the base press during decoding phase. + + Args: + module: The transformer module being compressed + hidden_states: Buffered hidden states from recent decoding steps (shape: [batch, buffer_len, hidden_dim]) + keys: Key cache from all previous steps including current (shape: [batch, n_heads, seq_len, head_dim]) + values: Value cache from all previous steps including current (shape: [batch, n_heads, seq_len, head_dim]) + attentions: Attention weights (shape varies by implementation) + kwargs: Additional keyword arguments + + Returns: + tuple[torch.Tensor, torch.Tensor]: Compressed (keys, values) tensors + + Note: + **Sequence length alignment**: During decoding compression, `hidden_states` contains the + buffered hidden states from recent decoding steps (buffer_len tokens), while `keys` and + `values` contain the full sequence history (seq_len tokens). The base press implementation + should use keys.shape[2] for full sequence length calculations. The buffered hidden_states + provide context for the most recent tokens when computing compression scores. + + Performance Note: + It would be possible to speed up compression during decoding for certain scorer presses by + storing existing scores in a buffer (e.g. KNormPress) and reusing them in subsequent compressions. + """ + q_len = keys.shape[2] + target_compression_ratio = self._find_target_compression_ratio(q_len, self.target_size) + logger.debug(f"Compressing {q_len} to {self.target_size} with ratio {target_compression_ratio}") + + original_compression_ratio = self.base_press.compression_ratio + self.base_press.compression_ratio = target_compression_ratio + result = self.base_press.compress(module, hidden_states, keys, values, attentions, kwargs) + self.base_press.compression_ratio = original_compression_ratio + return result + + def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): + """ + Forward hook that manages decoding-specific compression logic. + + This hook: + 1. Detects when we're in decoding phase (not prefilling) + 2. Accumulates hidden states in a buffer + 3. Applies compression every N steps + 4. Clears the buffer after compression + """ + hidden_states = kwargs["hidden_states"] + cache = kwargs["past_key_values"] + q_len = hidden_states.shape[1] + layer_idx = module.layer_idx + + # Only operate during decoding phase (after prefilling) + if kwargs["cache_position"][-1] <= q_len: + # We're still in prefilling phase, don't do anything + return output + # print(f"Adding hidden states to buffer: {hidden_states.shape}") + # Add current hidden states to buffer for this layer + self.hidden_states_buffer[layer_idx].append(hidden_states.detach().clone()) + + # print(f"Layer step counts: {self.layer_step_counts[layer_idx]}") + self.layer_step_counts[layer_idx] += 1 + + # Apply compression if we've reached the compression step threshold + if (self.layer_step_counts[layer_idx] >= self.compression_interval) or (q_len >= self.target_size): + logger.debug( + f"Applying decoding compression: layer_step_count ({self.layer_step_counts[layer_idx]}) >= compression_steps ({self.compression_interval})" # noqa: E501 + ) + + cache_layer = cache.layers[module.layer_idx] + keys, values = extract_keys_and_values(cache, module.layer_idx) + + # Get attention weights from output + attentions = output[1] if len(output) > 1 and output[1] is not None else None + + # Apply compression using buffered hidden states for this layer + buffered_hidden_states = torch.cat(self.hidden_states_buffer[layer_idx], dim=1) + keys, values = self.compress(module, buffered_hidden_states, keys, values, attentions, kwargs) + logger.debug(f"Applied decoding compression: " f"keys.shape: {keys.shape}, values.shape: {values.shape}") + + # Update cache with compressed keys and values + if isinstance(cache, QuantizedCache): + cache_layer._quantized_keys = cache_layer._quantize(keys, axis=cache_layer.axis_key) + cache_layer._quantized_values = cache_layer._quantize(values, axis=cache_layer.axis_value) + cache_layer.keys = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index] + cache_layer.values = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index] + cache_layer.cumulative_length = keys.shape[2] + else: + cache_layer.keys = keys + cache_layer.values = values + + # Reset step count and clear buffer for this layer + self.layer_step_counts[layer_idx] = 0 + # Always clear the buffer after compression - otherwise there's a mismatch between + # hidden states buffer and kv cache + self.hidden_states_buffer[layer_idx] = [] + + self.hidden_states_buffer[layer_idx] = self.hidden_states_buffer[layer_idx][-self.hidden_states_buffer_size :] + return output + + def reset(self): + """Reset the decoding press state.""" + self.hidden_states_buffer = defaultdict(list) + self.layer_step_counts = defaultdict(int) + + def _find_target_compression_ratio(self, q_len: int, target_tokens: int) -> float: + """ + Find the compression ratio that results in exactly target_tokens after int() rounding. + + Args: + q_len: Current sequence length + target_tokens: Desired number of tokens after compression + + Returns: + Compression ratio that gives exactly target_tokens + """ + if q_len <= target_tokens: + return 0.0 + + # Start with theoretical ratio + ratio = 1.0 - (target_tokens / q_len) + + # Binary search to handle int() rounding + low, high = 0.0, 1.0 + max_iterations = 20 + iteration = 0 + + while iteration < max_iterations: + n_kept = int(q_len * (1 - ratio)) + if n_kept == target_tokens: + break + elif n_kept > target_tokens: + # Need more compression + low = ratio + ratio = (ratio + high) / 2 + else: + # Need less compression + high = ratio + ratio = (low + ratio) / 2 + iteration += 1 + + final_n_kept = int(q_len * (1 - ratio)) + if final_n_kept != target_tokens: + logger.warning( + f"Binary search failed: q_len={q_len}, target={target_tokens}, got={final_n_kept}, ratio={ratio}" + ) + + return ratio diff --git a/kvpress/presses/duo_attention_press.py b/kvpress/presses/duo_attention_press.py index b88d2f03..24c5b17a 100644 --- a/kvpress/presses/duo_attention_press.py +++ b/kvpress/presses/duo_attention_press.py @@ -103,7 +103,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): raise ValueError( "Streaming mask not initialized. Make sure to call __post_init_from_model__ to initialize this press." ) - q_len = hidden_states.shape[1] + q_len = keys.shape[2] if (self.head_compression_ratio > 0) or (q_len > (self.sink_size + self.recent_size)): diff --git a/kvpress/presses/finch_press.py b/kvpress/presses/finch_press.py index e22952c3..39a8cd48 100644 --- a/kvpress/presses/finch_press.py +++ b/kvpress/presses/finch_press.py @@ -95,7 +95,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): scores = self.score(module, hidden_states, keys, values, attentions, kwargs) # Compute indices to keep (optionally by chunks) - q_len = hidden_states.shape[1] + q_len = keys.shape[2] # Use actual sequence length from keys instead of hidden_states if self.chunk_length is None: n_kept = int(q_len * (1 - self.compression_ratio)) indices = scores.topk(n_kept, dim=-1).indices diff --git a/kvpress/presses/key_rerotation_press.py b/kvpress/presses/key_rerotation_press.py index 129072f8..a264b3d4 100644 --- a/kvpress/presses/key_rerotation_press.py +++ b/kvpress/presses/key_rerotation_press.py @@ -137,7 +137,7 @@ def compress( scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs) # Get indices of KV pairs with the lowest scores - q_len = hidden_states.shape[1] + q_len = keys.shape[2] n_kept = int(q_len * (1 - self.press.compression_ratio)) indices = scores.topk(n_kept, dim=-1).indices indices = torch.sort(indices, dim=2).values diff --git a/kvpress/presses/kvzip_press.py b/kvpress/presses/kvzip_press.py index 59b565e7..0f13e68a 100644 --- a/kvpress/presses/kvzip_press.py +++ b/kvpress/presses/kvzip_press.py @@ -14,7 +14,7 @@ from transformers.models.llama.modeling_llama import rotate_half from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress -from kvpress.presses.utils import get_query_states +from kvpress.presses.utils import extract_keys_and_values, get_query_states logger = logging.getLogger(__name__) @@ -153,25 +153,16 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic hidden_states = kwargs["hidden_states"] cache = kwargs.get("past_key_values", None) or kwargs.get("past_key_value", None) - cache_layer = cache.layers[module.layer_idx] - if isinstance(cache, QuantizedCache): - keys = cache_layer._dequantize( # type: ignore[index] - cache_layer._quantized_keys # type: ignore[index] - ) - values = cache_layer._dequantize( # type: ignore[index] - cache_layer._quantized_values # type: ignore[index] - ) - else: - keys = cache_layer.keys - values = cache_layer.values + keys, values = extract_keys_and_values(cache, module.layer_idx) # Compute importance scores for KV pairs in the prefilled context, # retaining only the originally prefilled KV pairs. keys, values = self.score_kvzip(module, hidden_states, keys, values, output[1], kwargs) if isinstance(cache, QuantizedCache): + # Update cache with compressed keys and values cache_layer._quantized_keys = cache_layer._quantize(keys, axis=cache_layer.axis_key) cache_layer._quantized_values = cache_layer._quantize(values, axis=cache_layer.axis_value) cache_layer.keys = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index] diff --git a/kvpress/presses/prefill_decoding_press.py b/kvpress/presses/prefill_decoding_press.py new file mode 100644 index 00000000..85b5d465 --- /dev/null +++ b/kvpress/presses/prefill_decoding_press.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PreTrainedModel + +from kvpress.presses.base_press import BasePress + +from .decoding_press import DecodingPress + +logger = logging.getLogger(__name__) + + +@dataclass +class PrefillDecodingPress(BasePress): + """ + A wrapper press that combines separate prefilling and decoding compression strategies. + + This press acts as a single press interface but internally delegates to different + presses based on the current phase (prefilling vs decoding). During prefilling, + it uses the prefilling_press. During decoding, it uses the decoding_press. + + Parameters + ---------- + prefilling_press : BasePress, optional + Press to use during the prefilling phase. If None, no compression is applied during prefilling. + decoding_press : DecodingPress, optional + Press to use during the decoding phase. If None, no compression is applied during decoding. + """ + + prefilling_press: Optional[BasePress] = None + decoding_press: Optional[DecodingPress] = None + + def compress( + self, + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attentions: torch.Tensor, + kwargs: dict, + ) -> tuple[torch.Tensor, torch.Tensor]: + q_len = hidden_states.shape[1] + + # Determine if we're in prefilling or decoding phase + if kwargs["cache_position"][-1] <= q_len and self.prefilling_press is not None: + return self.prefilling_press.compress(module, hidden_states, keys, values, attentions, kwargs) + elif self.decoding_press is not None: + return self.decoding_press.compress(module, hidden_states, keys, values, attentions, kwargs) + + # No compression applied + logger.warning("No compression applied during prefill or decoding phase") + + return keys, values + + def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): + """ + Forward hook that delegates to the appropriate press based on current phase. + """ + hidden_states = kwargs["hidden_states"] + q_len = hidden_states.shape[1] + + # Determine if we're in prefilling or decoding phase + if kwargs["cache_position"][-1] <= q_len and self.prefilling_press is not None: + return self.prefilling_press.forward_hook(module, input, kwargs, output) + elif self.decoding_press is not None: + return self.decoding_press.forward_hook(module, input, kwargs, output) + + # No hook applied + return output + + @contextmanager + def __call__(self, model: PreTrainedModel): + try: + with super().__call__(model): + yield + finally: + # Reset decoding press if it exists + if self.decoding_press is not None: + self.decoding_press.reset() diff --git a/kvpress/presses/pyramidkv_press.py b/kvpress/presses/pyramidkv_press.py index 14b6c03b..cc0a0106 100644 --- a/kvpress/presses/pyramidkv_press.py +++ b/kvpress/presses/pyramidkv_press.py @@ -100,7 +100,7 @@ def compress( scores = self.score(module, hidden_states, keys, values, attentions, kwargs) # Get indices of KV pairs with the lowest scores - q_len = hidden_states.shape[1] + q_len = keys.shape[2] n_kept = self.get_layer_budget(module, q_len) indices = scores.topk(n_kept, dim=-1).indices indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) diff --git a/kvpress/presses/scorer_press.py b/kvpress/presses/scorer_press.py index 5c84b379..1cbcb748 100644 --- a/kvpress/presses/scorer_press.py +++ b/kvpress/presses/scorer_press.py @@ -90,7 +90,7 @@ def compress( scores = self.score(module, hidden_states, keys, values, attentions, kwargs) # Get indices of KV pairs with the lowest scores - q_len = hidden_states.shape[1] + q_len = keys.shape[2] n_kept = int(q_len * (1 - self.compression_ratio)) indices = scores.topk(n_kept, dim=-1).indices indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) diff --git a/kvpress/presses/simlayerkv_press.py b/kvpress/presses/simlayerkv_press.py index 86086e7c..48af6bc4 100644 --- a/kvpress/presses/simlayerkv_press.py +++ b/kvpress/presses/simlayerkv_press.py @@ -94,7 +94,7 @@ def compress( self.compression_ratios = [] # Check if compression is needed - q_len = hidden_states.shape[1] + q_len = keys.shape[2] min_length = self.n_initial + self.n_recent + self.n_last if q_len <= min_length: diff --git a/kvpress/presses/snapkv_press.py b/kvpress/presses/snapkv_press.py index 68f92af9..549f3865 100644 --- a/kvpress/presses/snapkv_press.py +++ b/kvpress/presses/snapkv_press.py @@ -44,7 +44,7 @@ def compute_window_attention(module, hidden_states, keys, window_size, position_ Compute the last window_size queries and associated attention weights for the first q_len - window_size keys. """ - bsz, q_len, _ = hidden_states.shape + bsz, _, q_len, _ = keys.shape num_heads = module.config.num_attention_heads head_dim = module.head_dim num_key_value_groups = num_heads // module.config.num_key_value_heads @@ -78,10 +78,13 @@ def score( kwargs, ) -> torch.Tensor: - bsz, num_key_value_heads, q_len, _ = keys.shape + bsz, num_key_value_heads, k_len, _ = keys.shape num_key_value_groups = module.config.num_attention_heads // num_key_value_heads - assert q_len > self.window_size, "Query length should be greater than the window size" + q_len = hidden_states.shape[1] + assert ( + q_len > self.window_size + ), f"Query length {q_len} should be greater than the window size {self.window_size}" if attentions is not None: attn_weights = attentions[..., -self.window_size :, : -self.window_size] @@ -94,7 +97,7 @@ def score( scores = F.avg_pool1d(scores, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1) # Average per group (https://github.com/FasterDecoding/SnapKV/issues/22) - scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len - self.window_size) + scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, k_len - self.window_size) scores = scores.mean(2) # Add back the observation window. Use max score to make sure the window is not pruned. diff --git a/kvpress/presses/streaming_llm_press.py b/kvpress/presses/streaming_llm_press.py index 4c0ac218..66c9ecf8 100644 --- a/kvpress/presses/streaming_llm_press.py +++ b/kvpress/presses/streaming_llm_press.py @@ -44,7 +44,7 @@ def score( kwargs, ) -> torch.Tensor: - q_len = hidden_states.shape[1] + q_len = keys.shape[2] assert q_len > self.n_sink, f"Input should contain more tokens than n_sink={self.n_sink}" n_pruned = q_len - int(q_len * (1 - self.compression_ratio)) scores = torch.ones_like(keys[..., 0]) diff --git a/kvpress/presses/utils.py b/kvpress/presses/utils.py index 938c15ba..3614fa5b 100644 --- a/kvpress/presses/utils.py +++ b/kvpress/presses/utils.py @@ -3,6 +3,7 @@ import torch from torch import nn +from transformers import Cache, QuantizedCache from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention from transformers.models.phi3.modeling_phi3 import Phi3Attention from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention @@ -50,3 +51,22 @@ def get_query_states(module: nn.Module, hidden_states: torch.Tensor) -> torch.Te query_states = module.q_norm(query_states) return query_states + + +def dequantize_layer(cache_layer) -> tuple[torch.Tensor, torch.Tensor]: + keys = cache_layer._dequantize(cache_layer._quantized_keys) + values = cache_layer._dequantize(cache_layer._quantized_values) + return keys, values + + +def extract_keys_and_values(cache: Cache, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Extracts the keys and values from a given cache layer, + handling both quantized and unquantized caches. + """ + if isinstance(cache, QuantizedCache): + keys, values = dequantize_layer(cache.layers[layer_idx]) + else: + keys = cache.layers[layer_idx].keys + values = cache.layers[layer_idx].values + return keys, values diff --git a/notebooks/kvpress_decoding_aime25.ipynb b/notebooks/kvpress_decoding_aime25.ipynb new file mode 100644 index 00000000..626c06da --- /dev/null +++ b/notebooks/kvpress_decoding_aime25.ipynb @@ -0,0 +1,417 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 6, + "id": "5705c4fb-d665-48e8-a4dd-a43725e5f7f4", + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "\n", + "import torch\n", + "from transformers import pipeline\n", + "\n", + "from kvpress import (\n", + " ExpectedAttentionPress,\n", + " KnormPress,\n", + " ObservedAttentionPress,\n", + " RandomPress,\n", + " SnapKVPress,\n", + " StreamingLLMPress,\n", + " TOVAPress,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "553f7d7f-7a2f-456d-a4b1-f38f3ead8767", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "dataset = load_dataset(\"math-ai/aime25\")\n", + "sample = dataset[\"test\"][0]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2e0398e8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'problem': 'Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$',\n", + " 'answer': '70',\n", + " 'id': '0'}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sample " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "dd926e5b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fetching 0 files: 0it [00:00, ?it/s]\n", + "Fetching 1 files: 100%|██████████| 1/1 [00:00<00:00, 18641.35it/s]\n", + "Fetching 0 files: 0it [00:00, ?it/s]\n", + "Device set to use cuda:0\n" + ] + } + ], + "source": [ + "device = \"cuda:0\"\n", + "ckpt = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n", + "attn_implementation = \"flash_attention_2\" # use \"eager\" for ObservedAttentionPress and \"sdpa\" if you can't use \"flash_attention_2\"\n", + "pipe = pipeline(\"kv-press-text-generation\", model=ckpt, device=device, model_kwargs={\"attn_implementation\":attn_implementation, \"dtype\": torch.bfloat16})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc932bbc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Question: Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$\n", + "Answer: 70\n", + "Prediction: Okay, so I have this problem here: I need to find the sum of all integer bases \\( b > 9 \\) for which \\( 17_b \\) is a divisor of \\( 97_b \\). Hmm, okay. Let me try to break this down step by step.\n", + "\n", + "First, I know that when a number is written in base \\( b \\), each digit represents a power of \\( b \\). So, for example, \\( 17_b \\) would be equal to \\( 1 \\times b + 7 \\times 1 \\) in decimal, right? Similarly, \\( 97_b \\) would be \\( 9 \\times b + 7 \\times 1 \\) in decimal. So, I can convert both of these numbers to base 10 and then set up the division condition.\n", + "\n", + "Let me write that out:\n", + "\n", + "\\( 17_b = 1 \\times b + 7 = b + 7 \\)\n", + "\n", + "\\( 97_b = 9 \\times b + 7 = 9b + 7 \\)\n", + "\n", + "So, the problem is asking for bases \\( b > 9 \\) such that \\( b + 7 \\) divides \\( 9b + 7 \\). In other words, \\( b + 7 \\) is a divisor of \\( 9b + 7 \\).\n", + "\n", + "I remember that if \\( a \\) divides \\( b \\), then \\( b = k \\times a \\) for some integer \\( k \\). So, in this case, \\( 9b + 7 = k \\times (b + 7) \\) for some integer \\( k \\).\n", + "\n", + "Let me write that equation:\n", + "\n", + "\\( 9b + 7 = k(b + 7) \\)\n", + "\n", + "Now, I can expand the right-hand side:\n", + "\n", + "\\( 9b + 7 = kb + 7k \\)\n", + "\n", + "Now, let's bring all terms to one side to see if I can solve for \\( k \\):\n", + "\n", + "\\( 9b + 7 - kb - 7k = 0 \\)\n", + "\n", + "Factor out \\( b \\) and \\( k \\):\n", + "\n", + "\\( b(9 - k) + 7(1 - k) = 0 \\)\n", + "\n", + "Hmm, that's a bit messy. Maybe I can rearrange the equation differently. Let me subtract \\( k(b + 7) \\) from both sides:\n", + "\n", + "\\( 9b + 7 - k(b + 7) = 0 \\)\n", + "\n", + "Which gives:\n", + "\n", + "\\( (9 - k)b + (7 - 7k) = 0 \\)\n", + "\n", + "So, this is a linear equation in terms of \\( b \\). Let me solve for \\( b \\):\n", + "\n", + "\\( (9 - k)b = 7k - 7 \\)\n", + "\n", + "So,\n", + "\n", + "\\( b = \\frac{7k - 7}{9 - k} \\)\n", + "\n", + "Hmm, okay. Since \\( b \\) must be an integer greater than 9, the right-hand side must also be an integer. So, \\( 7k - 7 \\) must be divisible by \\( 9 - k \\). Let me write that as:\n", + "\n", + "\\( 7(k - 1) \\) is divisible by \\( 9 - k \\)\n", + "\n", + "Which can be written as:\n", + "\n", + "\\( 7(k - 1) \\) is divisible by \\( -(k - 9) \\)\n", + "\n", + "But since divisibility is unaffected by the sign, I can say that \\( 7(k - 1) \\) is divisible by \\( k - 9 \\).\n", + "\n", + "So, \\( k - 9 \\) divides \\( 7(k - 1) \\). Let me denote \\( d = k - 9 \\), so \\( d \\) divides \\( 7(k - 1) \\). But \\( k = d + 9 \\), so substituting back:\n", + "\n", + "\\( d \\) divides \\( 7(d + 9 - 1) = 7(d + 8) \\)\n", + "\n", + "So, \\( d \\) divides \\( 7(d + 8) \\). Which means \\( d \\) divides \\( 7 \\times 8 = 56 \\). Because \\( d \\) divides \\( 7d + 56 \\), so \\( d \\) must divide 56.\n", + "\n", + "Therefore, \\( d \\) is a divisor of 56. So, the possible values of \\( d \\) are the divisors of 56. Let me list them:\n", + "\n", + "Positive divisors: 1, 2, 4, 7, 8, 14, 28, 56\n", + "\n", + "Negative divisors: -1, -2, -4, -7, -8, -14, -28, -56\n", + "\n", + "So, \\( d \\) can be any of these. But remember that \\( d = k - 9 \\), and \\( k \\) is an integer because \\( k \\) is the quotient when \\( 9b + 7 \\) is divided by \\( b + 7 \\). So, \\( k \\) must be an integer.\n", + "\n", + "But let's see, since \\( b > 9 \\), and \\( b + 7 \\) is a divisor of \\( 9b + 7 \\), we can also think about the possible values of \\( k \\). Let me see.\n", + "\n", + "From the equation \\( b = \\frac{7(k - 1)}{9 - k} \\), since \\( b > 9 \\), the numerator and denominator must have the same sign. So, either both numerator and denominator are positive or both are negative.\n", + "\n", + "Case 1: Both numerator and denominator are positive.\n", + "\n", + "So, \\( 7(k - 1) > 0 \\) implies \\( k > 1 \\)\n", + "\n", + "And \\( 9 - k > 0 \\) implies \\( k < 9 \\)\n", + "\n", + "So, \\( 1 < k < 9 \\)\n", + "\n", + "Case 2: Both numerator and denominator are negative.\n", + "\n", + "So, \\( 7(k - 1) < 0 \\) implies \\( k < 1 \\)\n", + "\n", + "And \\( 9 - k < 0 \\) implies \\( k > 9 \\)\n", + "\n", + "But \\( k < 1 \\) and \\( k > 9 \\) can't happen at the same time. So, only Case 1 is possible.\n", + "\n", + "Therefore, \\( k \\) must be an integer between 2 and 8 inclusive.\n", + "\n", + "So, \\( k \\in \\{2, 3, 4, 5, 6, 7, 8\\} \\)\n", + "\n", + "Now, let's compute \\( d = k - 9 \\) for each \\( k \\):\n", + "\n", + "For \\( k = 2 \\): \\( d = -7 \\)\n", + "\n", + "For \\( k = 3 \\): \\( d = -6 \\)\n", + "\n", + "For \\( k = 4 \\): \\( d = -5 \\)\n", + "\n", + "For \\( k = 5 \\): \\( d = -4 \\)\n", + "\n", + "For \\( k = 6 \\): \\( d = -3 \\)\n", + "\n", + "For \\( k = 7 \\): \\( d = -2 \\)\n", + "\n", + "For \\( k = 8 \\): \\( d = -1 \\)\n", + "\n", + "So, \\( d \\) can be -7, -6, -5, -4, -3, -2, -1.\n", + "\n", + "But earlier, we said that \\( d \\) must divide 56. Let's check if each of these \\( d \\) values divides 56.\n", + "\n", + "- \\( d = -7 \\): 56 ÷ (-7) = -8, which is integer. So, yes.\n", + "\n", + "- \\( d = -6 \\): 56 ÷ (-6) ≈ -9.333... Not integer. So, no.\n", + "\n", + "- \\( d = -5 \\): 56 ÷ (-5) = -11.2. Not integer. No.\n", + "\n", + "- \\( d = -4 \\): 56 ÷ (-4) = -14. Integer. Yes.\n", + "\n", + "- \\( d = -3 \\): 56 ÷ (-3) ≈ -18.666... Not integer. No.\n", + "\n", + "- \\( d = -2 \\): 56 ÷ (-2) = -28. Integer. Yes.\n", + "\n", + "- \\( d = -1 \\): 56 ÷ (-1) = -56. Integer. Yes.\n", + "\n", + "So, the valid \\( d \\) values are -7, -4, -2, -1.\n", + "\n", + "Therefore, for each of these \\( d \\), we can find \\( k \\):\n", + "\n", + "- \\( d = -7 \\): \\( k = d + 9 = 2 \\)\n", + "\n", + "- \\( d = -4 \\): \\( k = 5 \\)\n", + "\n", + "- \\( d = -2 \\): \\( k = 7 \\)\n", + "\n", + "- \\( d = -1 \\): \\( k = 8 \\)\n", + "\n", + "So, now, let's compute \\( b \\) for each \\( k \\):\n", + "\n", + "Recall \\( b = \\frac{7(k - 1)}{9 - k} \\)\n", + "\n", + "Let's compute for each \\( k \\):\n", + "\n", + "1. \\( k = 2 \\):\n", + "\n", + "\\( b = \\frac{7(2 - 1)}{9 - 2} = \\frac{7(1)}{7} = 1 \\)\n", + "\n", + "But \\( b > 9 \\), so 1 is invalid.\n", + "\n", + "2. \\( k = 5 \\):\n", + "\n", + "\\( b = \\frac{7(5 - 1)}{9 - 5} = \\frac{7(4)}{4} = 7 \\)\n", + "\n", + "Again, \\( b = 7 \\) is less than 9, so invalid.\n", + "\n", + "3. \\( k = 7 \\):\n", + "\n", + "\\( b = \\frac{7(7 - 1)}{9 - 7} = \\frac{7(6)}{\n", + "Cache size: torch.Size([1, 2, 2081, 128])\n" + ] + } + ], + "source": [ + "%%time\n", + "cache = DynamicCache()\n", + "question = sample[\"problem\"]\n", + "true_answer = sample[\"answer\"]\n", + "pred_answer = pipe(\" \", question=question, press=None, cache=cache, max_new_tokens=2048)[\"answer\"]\n", + "\n", + "print(f\"Question: {question}\")\n", + "print(f\"Answer: {true_answer}\")\n", + "print(f\"Prediction: {pred_answer}\")\n", + "print(f\"Cache size: {cache.layers[0].keys.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "ff177e00-cabf-4680-9cfc-b6cb00770527", + "metadata": {}, + "outputs": [], + "source": [ + "from kvpress.presses.decoding_press import DecodingPress" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "7f4fa366-7e62-443a-b5c8-274128fe6237", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Question: Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$\n", + "Answer: 70\n", + "Prediction: Okay, so I have this problem here: I need to find the sum of all integer bases \\( b > 9 \\) for which \\( 17_b \\) is a divisor of \\( 97_b \\). Hmm, okay. Let me try to break this down step by step.\n", + "\n", + "First, I know that when a number is written in base \\( b \\), each digit represents a power of \\( b \\). So, for example, \\( 17_b \\) would be equal to \\( 1 \\times b + 7 \\times 1 \\) in decimal, right? Similarly, \\( 97_b \\) would be \\( 9 \\times b + 7 \\times 1 \\) in decimal. So, I can convert both of these numbers to base 10 and then set up the division condition.\n", + "\n", + "Let me write that out:\n", + "\n", + "\\( 17_b = 1 \\times b + 7 = b + 7 \\)\n", + "\n", + "\\( 97_b = 9 \\times b + 7 = 9b + 7 \\)\n", + "\n", + "So, the problem is asking for bases \\( b > 9 \\) such that \\( b + 7 \\) divides \\( 9b + 7 \\). In other words, \\( b + 7 \\) is a divisor of \\( 9b + 7 \\).\n", + "\n", + "I remember that if \\( a \\) divides \\( b \\), then \\( b = k \\times a \\) for some integer \\( k \\). So, in this case, \\( 9b + 7 = k \\times (b + 7) \\) for some integer \\( k \\).\n", + "\n", + "Let me write that equation:\n", + "\n", + "\\( 9b + 7 = k(b + 7) \\)\n", + "\n", + "Now, I can expand the right-hand side:\n", + "\n", + "\\( 9b + 7 = kb + 7k \\)\n", + "\n", + "Now, let's bring all terms to one side to see if I can solve for \\( k \\):\n", + "\n", + "\\( 9b + 7 - kb - 7k = 0 \\)\n", + "\n", + "Factor out \\( b \\) and \\( k \\):\n", + "\n", + "\\( b(9 - k) + 7(1 - k) = 0 \\)\n", + "\n", + "Hmm, that's a bit messy. Maybe I can rearrange the equation differently. Let me subtract \\( k(b + 7) \\) from both sides:\n", + "\n", + "\\( 9b + 7 - k(b + 7) = 0 \\)\n", + "\n", + "Expanding the \\( k(b + 7) \\):\n", + "\n", + "\\( 9b + 7 - kb - 7k = 0 \\)\n", + "\n", + "Now, let's collect like terms:\n", + "\n", + "\\( (9 - k)b + (7 - 7k) = 0 \\)\n", + "\n", + "So, this is a linear equation in terms of \\( b \\). Let me write it as:\n", + "\n", + "\\( (9 - k)b = 7k - 7 \\)\n", + "\n", + "Then, solving for \\( b \\):\n", + "\n", + "\\( b = \\frac{7k - 7}{9 - k} \\)\n", + "\n", + "Hmm, okay. So, \\( b \\) must be an integer greater than 9. So, \\( b \\) is an integer, and \\( k \\) is also an integer because \\( k \\) is the quotient when \\( 9b + 7 \\) is divided by \\( b + 7 \\).\n", + "\n", + "So, \\( b \\) is expressed in terms of \\( k \\). Let me see if I can find integer values of \\( k \\) such that \\( b \\) is an integer greater than 9.\n", + "\n", + "First, let's note that \\( 9 - k \\) cannot be zero because that would make the denominator zero, which is undefined. So, \\( k \\neq 9 \\).\n", + "\n", + "Also, since \\( b > 9 \\), let's see what constraints that imposes on \\( k \\).\n", + "\n", + "Looking at the expression \\( b = \\frac{7k - 7}{9 - k} \\), let's see when this is positive and greater than 9.\n", + "\n", + "First, let's note that \\( 7k - 7 \\) and \\( 9 - k \\) must have the same sign because \\( b \\) is positive.\n", + "\n", + "So, either both numerator and denominator are positive or both are negative.\n", + "\n", + "Case 1: Both numerator and denominator are positive.\n", + "\n", + "So, \\( 7k - 7 > 0 \\) implies \\( k > 1 \\).\n", + "\n", + "And \\( 9 - k > 0 \\) implies \\( k < 9 \\).\n", + "\n", + "So, \\( k \\) must satisfy \\( 1 < k < 9 \\). Since \\( k \\) is an integer, \\( k \\) can be 2, 3, 4, 5, 6, 7, or 8.\n", + "\n", + "Case 2: Both numerator and\n", + "Cache size: torch.Size([1, 2, 500, 128])\n", + "CPU times: user 36.7 s, sys: 26.8 ms, total: 36.8 s\n", + "Wall time: 36.8 s\n" + ] + } + ], + "source": [ + "%%time\n", + "compression_interval = 500 # compress every compression_steps\n", + "target_size = 500 # number of tokens to keep after compression\n", + "\n", + "\n", + "press = DecodingPress(base_press=KnormPress(), compression_interval=compression_interval, target_size=target_size, hidden_states_buffer_size=0)\n", + "\n", + "cache = DynamicCache()\n", + "question = sample[\"problem\"]\n", + "true_answer = sample[\"answer\"]\n", + "pred_answer = pipe(\" \", question=question, press=press, cache=cache, max_new_tokens=1000)[\"answer\"]\n", + "\n", + "print(f\"Question: {question}\")\n", + "print(f\"Answer: {true_answer}\")\n", + "print(f\"Prediction: {pred_answer}\")\n", + "print(f\"Cache size: {cache.layers[0].keys.shape}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index 7bc1e993..935cf50f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,4 +91,4 @@ disable_error_code = ["attr-defined"] [[tool.mypy.overrides]] module = "kvpress.pipeline" -disable_error_code = ["attr-defined", "assignment", "override"] \ No newline at end of file +disable_error_code = ["attr-defined", "assignment", "override"] diff --git a/tests/fixtures.py b/tests/fixtures.py index c446db4d..ae855a0a 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -51,7 +51,7 @@ def kv_press_llama3_2_flash_attn_pipeline(): "kv-press-text-generation", model=ckpt, device=device, - model_kwargs={"attn_implementation": attn_implementation, "torch_dtype": torch.bfloat16}, + model_kwargs={"attn_implementation": attn_implementation, "dtype": torch.bfloat16}, ) return pipe @@ -65,6 +65,6 @@ def kv_press_llama3_1_flash_attn_pipeline(): "kv-press-text-generation", model=ckpt, device=device, - model_kwargs={"attn_implementation": attn_implementation, "torch_dtype": torch.bfloat16}, + model_kwargs={"attn_implementation": attn_implementation, "dtype": torch.bfloat16}, ) return pipe diff --git a/tests/test_decoding_compression.py b/tests/test_decoding_compression.py new file mode 100644 index 00000000..90acb8e0 --- /dev/null +++ b/tests/test_decoding_compression.py @@ -0,0 +1,305 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Test script to verify that DecodingPress actually compresses during decoding. +""" +import logging + +import pytest +import torch +from transformers import DynamicCache, pipeline + +from kvpress import PyramidKVPress, ScorerPress +from kvpress.presses.decoding_press import DecodingPress +from kvpress.presses.knorm_press import KnormPress +from kvpress.presses.prefill_decoding_press import PrefillDecodingPress +from tests.default_presses import default_presses + +logger = logging.getLogger(__name__) + + +@pytest.mark.parametrize("token_buffer_size", [32, 64, 128]) +def test_decoding_compression(token_buffer_size): + """Test that DecodingPress compresses the cache during decoding.""" + + # Initialize pipeline with a small model + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + # Create a DecodingPress with KnormPress + press = DecodingPress( + base_press=KnormPress(compression_ratio=0.5), # Remove 50% of tokens + compression_interval=4, # Compress every 4 tokens + target_size=token_buffer_size, + ) + + # Create cache + cache = DynamicCache() + + # Test context and question + context = "The quick brown fox jumps over the lazy dog. " * 10 # Repeat for longer context + question = "What animal jumps over the dog?" + + # Run pipeline + pipe(context, question=question, press=press, cache=cache, max_new_tokens=20) + + # Assert that all layers have the expected cache size + for layer_idx, cache_layer in enumerate(cache.layers): + layer_seq_len = cache_layer.keys.shape[2] + # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger + max_expected_size = token_buffer_size + press.compression_interval - 1 + assert token_buffer_size <= layer_seq_len <= max_expected_size, ( + f"Layer {layer_idx}: Expected cache sequence length to be between {token_buffer_size} " + f"and {max_expected_size}, but got {layer_seq_len}" + ) + + +def test_prefill_decoding_press_calls_both_phases(): + """Test that PrefillDecodingPress calls both prefilling and decoding presses.""" + + # Initialize pipeline + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + # Create PrefillDecodingPress with both presses + combined_press = PrefillDecodingPress( + prefilling_press=KnormPress(compression_ratio=0.6), # Compress to 60% during prefill + decoding_press=DecodingPress(base_press=KnormPress(), compression_interval=3, target_size=48), + ) + + # Test context and question + context = "The quick brown fox jumps over the lazy dog. " * 12 # Longer context + question = "What animal jumps over the dog?" + + # Run pipeline + cache = DynamicCache() + pipe(context, question=question, press=combined_press, cache=cache, max_new_tokens=15) + + # Check that cache was compressed during both phases + # Final cache should be compressed to decoding press target size + for layer_idx, cache_layer in enumerate(cache.layers): + layer_seq_len = cache_layer.keys.shape[2] + # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger + target_size = 48 # token_buffer_size from decoding press + compression_steps = 3 # from the decoding press configuration + max_expected_size = target_size + compression_steps - 1 + assert target_size <= layer_seq_len <= max_expected_size, ( + f"Layer {layer_idx}: Expected final cache size to be between {target_size} " + f"and {max_expected_size} (decoding target), but got {layer_seq_len}" + ) + + +def test_decoding_press_without_prefill(): + """Test that DecodingPress works correctly when used standalone (no prefill compression).""" + + # Initialize pipeline + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + # Create DecodingPress only + decoding_press = DecodingPress(base_press=KnormPress(compression_ratio=0.4), compression_interval=5, target_size=64) + + # Test context and question + context = "The quick brown fox jumps over the lazy dog. " * 8 + question = "What animal jumps over the dog?" + + # Run pipeline + cache = DynamicCache() + pipe(context, question=question, press=decoding_press, cache=cache, max_new_tokens=25) + + # Check that cache was compressed during decoding + for layer_idx, cache_layer in enumerate(cache.layers): + layer_seq_len = cache_layer.keys.shape[2] + # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger + target_size = 64 + compression_steps = 5 # from the decoding press configuration + max_expected_size = target_size + compression_steps - 1 + assert target_size <= layer_seq_len <= max_expected_size, ( + f"Layer {layer_idx}: Expected cache size to be between {target_size} " + f"and {max_expected_size}, but got {layer_seq_len}" + ) + + +def test_prefill_decoding_press_decoding_only(): + """Test PrefillDecodingPress with only decoding press (no prefill compression).""" + + # Initialize pipeline + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + # Create PrefillDecodingPress with only decoding press + combined_press = PrefillDecodingPress( + prefilling_press=None, + decoding_press=DecodingPress( + base_press=KnormPress(compression_ratio=0.6), compression_interval=4, target_size=56 + ), + ) + + # Test context and question + context = "The quick brown fox jumps over the lazy dog. " * 9 + question = "What animal jumps over the dog?" + + # Run pipeline + cache = DynamicCache() + pipe(context, question=question, press=combined_press, cache=cache, max_new_tokens=12) + + # Check that only decoding compression was applied + for layer_idx, cache_layer in enumerate(cache.layers): + layer_seq_len = cache_layer.keys.shape[2] + # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger + target_size = 56 + compression_steps = 4 # from the decoding press configuration + max_expected_size = target_size + compression_steps - 1 + assert target_size <= layer_seq_len <= max_expected_size, ( + f"Layer {layer_idx}: Expected cache size to be between {target_size} " + f"and {max_expected_size}, but got {layer_seq_len}" + ) + + +def test_decoding_press_equivalence(): + """Test that DecodingPress standalone yields same result as PrefillDecodingPress with decoding only.""" + + # Set random seed for reproducibility + torch.manual_seed(42) + + # Initialize pipeline + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + # Create standalone decoding press + decoding_press = DecodingPress(base_press=KnormPress(compression_ratio=0.5), compression_interval=3, target_size=52) + + # Create PrefillDecodingPress with only decoding press + combined_press = PrefillDecodingPress( + prefilling_press=None, + decoding_press=DecodingPress( + base_press=KnormPress(compression_ratio=0.5), compression_interval=3, target_size=52 + ), + ) + + # Test context and question + context = "The quick brown fox jumps over the lazy dog. " * 7 + question = "What animal jumps over the dog?" + + # Run with standalone decoding press + cache1 = DynamicCache() + result1 = pipe(context, question=question, press=decoding_press, cache=cache1, max_new_tokens=10) + + # Run with combined press (decoding only) + cache2 = DynamicCache() + result2 = pipe(context, question=question, press=combined_press, cache=cache2, max_new_tokens=10) + + # Compare cache sizes (should be identical) + for layer_idx in range(len(cache1.layers)): + cache1_size = cache1.layers[layer_idx].keys.shape[2] + cache2_size = cache2.layers[layer_idx].keys.shape[2] + assert cache1_size == cache2_size, ( + f"Layer {layer_idx}: Standalone decoding cache size {cache1_size} != " + f"combined press cache size {cache2_size}" + ) + + # Compare generated text results (should be identical) + assert result1["answer"] == result2["answer"], ( + f"Generated answers differ:\n" + f"Standalone decoding: '{result1['answer']}'\n" + f"Combined press: '{result2['answer']}'" + ) + + +""" +E AttributeError: 'QFilterPress' object has no attribute 'q_filters' +E Failed: DecodingPress failed with SnapKVPress: shape '[1, 2, 2, 6]' is invalid for input of size 12 +> query_states = query_states.view(bsz, window_size, num_heads, head_dim).transpose(1, 2) +E RuntimeError: shape '[1, 2, 2, 6]' is invalid for input of size 12 +""" + + +@pytest.mark.parametrize("press_config", default_presses) +def test_all_presses_work_with_decoding_press(press_config): + """Test that all default presses work as base presses for DecodingPress.""" + + # Initialize pipeline + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + # Get press class and use the first (easier) configuration + press_cls = press_config["cls"] + press_kwargs = press_config["kwargs"][0] # Use easier compression settings + + base_press = press_cls(**press_kwargs) + if not isinstance(base_press, ScorerPress): + logger.info(f"Press {press_cls.__name__} is not a ScorerPress, skipping test") + return + if isinstance(base_press, (PyramidKVPress)): + # PyramidKVPress -> Pyramid shape, not compatible with token_buffer_size=48 + logger.info(f"Press {press_cls.__name__} is not supported, skipping test") + return + if hasattr(base_press, "__post_init_from_model__"): + base_press.__post_init_from_model__(pipe.model) + + # Create DecodingPress with this base press + decoding_press = DecodingPress(base_press=base_press, compression_interval=3, target_size=48) + + # Test context and question + context = "The quick brown fox jumps over the lazy dog. " * 8 + question = "What animal jumps over the dog?" + + # Run pipeline + cache = DynamicCache() + try: + result = pipe(context, question=question, press=decoding_press, cache=cache, max_new_tokens=15) + + # Verify compression worked + assert len(result["answer"]) > 0, f"No answer generated with {press_cls.__name__}" + + # Check that cache was compressed (allow some tolerance for rounding) + for layer_idx, cache_layer in enumerate(cache.layers): + layer_seq_len = cache_layer.keys.shape[2] + # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger + target_size = 48 + compression_steps = 3 # from the decoding press configuration + max_expected_size = target_size + compression_steps - 1 + assert ( + target_size <= layer_seq_len <= max_expected_size + ), f"{press_cls.__name__}: Layer {layer_idx} cache size {layer_seq_len} not in expected range [{target_size}-{max_expected_size}]" # noqa: E501 + + except Exception as e: + pytest.fail(f"DecodingPress failed with {press_cls.__name__}: {e}") + + +def test_compression_actually_reduces_memory(): + """Test that compression actually reduces memory usage compared to no compression.""" + + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + context = "The quick brown fox jumps over the lazy dog. " * 15 # Long context + question = "What animal jumps over the dog?" + + # Run without compression + cache_uncompressed = DynamicCache() + result_uncompressed = pipe(context, question=question, cache=cache_uncompressed, max_new_tokens=25) + + # Run with compression + press = DecodingPress( + base_press=KnormPress(compression_ratio=0.3), # Aggressive compression + compression_interval=3, + target_size=40, + ) + cache_compressed = DynamicCache() + result_compressed = pipe(context, question=question, press=press, cache=cache_compressed, max_new_tokens=25) + + # Calculate memory usage (approximate) + uncompressed_memory = sum( + (cache_layer.values.numel() + cache_layer.keys.numel()) * cache_layer.keys.element_size() + for cache_layer in cache_uncompressed.layers + ) + compressed_memory = sum( + (cache_layer.values.numel() + cache_layer.keys.numel()) * cache_layer.keys.element_size() + for cache_layer in cache_compressed.layers + ) + + # Compression should significantly reduce memory usage + compression_ratio = compressed_memory / uncompressed_memory + assert compression_ratio < 0.6, ( + f"Expected compression ratio < 0.6, but got {compression_ratio:.3f} " + f"(compressed: {compressed_memory} bytes, uncompressed: {uncompressed_memory} bytes)" + ) + + # Both should still generate reasonable answers + assert len(result_uncompressed["answer"]) > 0 + assert len(result_compressed["answer"]) > 0 diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 030d83cb..26f28f86 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -155,7 +155,9 @@ def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811 keys = [layer.keys.clone() for layer in past_key_values.layers] values = [layer.values.clone() for layer in past_key_values.layers] + cache_seq_lengths = [past_key_values.get_seq_length(layer_idx) for layer_idx in range(len(past_key_values))] compression_pipeline.generate_answer(input_ids_question, past_key_values, context_length=22, max_new_tokens=10) + compression_pipeline._remove_answer_from_cache(past_key_values, cache_seq_lengths) assert past_key_values.get_seq_length() == seq_len assert all([torch.allclose(key, layer.keys) for key, layer in zip(keys, past_key_values.layers)]) assert all([torch.allclose(value, layer.values) for value, layer in zip(values, past_key_values.layers)]) From 77c0f55f0e54a77bb03d9c77629538bc4e167069 Mon Sep 17 00:00:00 2001 From: Max Jeblick Date: Tue, 7 Oct 2025 21:24:53 +0200 Subject: [PATCH 03/20] improve readme Signed-off-by: Max Jeblick --- README.md | 81 +++++++++++++++++++++++++++---------------------------- 1 file changed, 40 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index b7d53e75..704a5220 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,47 @@ answer = pipe(context, question=question, press=press)["answer"] In the snippet above, the compression is only applied on the context tokens so that you can evaluate the compression for different questions. Check the [Wikipedia notebook demo](notebooks/wikipedia_demo.ipynb) for a more detailed example (also available on Colab [here](https://colab.research.google.com/drive/1JNvaTKuuAHrl49dYB9-mdEH_y52Ib-NP)). +
+Decoding Compression + +By default, KVPress applies compression during the pre-filling phase. As a new (experimental) feature, we now support decoding compression via the `DecodingPress` wrapper. `DecodingPress` compresses the KV cache periodically during token generation, optionally maintaining a buffer of recent hidden states. `DecodingPress` supports the following parameters: + +- `base_press`: Any ScorerPress (e.g., `KNormPress`, `CriticalKVPress`) +- `compression_interval`: Steps between compressions (default: 10) +- `target_size`: Target cache size of the cache after compression (default: 1024) +- `hidden_states_buffer_size`: Number of hidden states to buffer before compression (default: 128). Some presses don't need buffered hidden states and can set this to 0. + +Unlike a compression ratio, decoding press uses a `target_size` to compress the cache. This means that the cache is compressed every `compression_interval` steps, and the compression ratio is automatically computed such that the size of the cache after compression equals `target_size`. + +An example for decoding compression: + +```python +from transformers import pipeline +from kvpress import KnormPress +from kvpress import DecodingPress + +# Initialize the pipeline +device = "cuda:0" +model = "meta-llama/Llama-3.1-8B-Instruct" +model_kwargs = {"attn_implementation": "flash_attention_2"} +pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs) + +# Create a decoding press that compresses every 10 steps to 512 tokens +decoding_press = DecodingPress( + base_press=KnormPress(), + compression_steps=10, + token_buffer_size=512 +) +# Use with pipeline +context = "A very long text you want to compress during generation" +question = "Tell me a long story about this context" +response = pipe(context, question=question, press=decoding_press)["answer"] +``` + +> Not all existing presses are fully compatible with DecodingPress due to fundamental differences in how compression works during decoding versus prefilling. + +
## Available presses @@ -123,47 +163,6 @@ Below we report the average performance on the RULER dataset with 4k context len Leaderboard

-
-## Decoding Compression - -By default, KVPress applies compression during the pre-filling phase. As a new (experimental) feature, we now support decoding compression via the `DecodingPress` wrapper. `DecodingPress` compresses the KV cache periodically during token generation, optionally maintaining a buffer of recent hidden states. `DecodingPress` supports the following parameters: - -- `base_press`: Any ScorerPress (e.g., `KNormPress`, `CriticalKVPress`) -- `compression_interval`: Steps between compressions (default: 10) -- `target_size`: Target cache size of the cache after compression (default: 1024) -- `hidden_states_buffer_size`: Number of hidden states to buffer before compression (default: 128). Some presses don't need buffered hidden states and can set this to 0. - -Unlike a compression ratio, decoding press uses a `target_size` to compress the cache. This means that the cache is compressed every `compression_interval` steps, and the compression ratio is automatically computed such that the size of the cache after compression equals `target_size`. - -An example for decoding compression: - -```python -from transformers import pipeline -from kvpress import KnormPress -from kvpress import DecodingPress - -# Initialize the pipeline -device = "cuda:0" -model = "meta-llama/Llama-3.1-8B-Instruct" -model_kwargs = {"attn_implementation": "flash_attention_2"} -pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs) - -# Create a decoding press that compresses every 10 steps to 512 tokens -decoding_press = DecodingPress( - base_press=KnormPress(), - compression_steps=10, - token_buffer_size=512 -) - -# Use with pipeline -context = "A very long text you want to compress during generation" -question = "Tell me a long story about this context" -response = pipe(context, question=question, press=decoding_press)["answer"] -``` - -> Not all existing presses are fully compatible with DecodingPress due to fundamental differences in how compression works during decoding versus prefilling. - -
## Quantization From 8ead77334c6ce5c188a311b008d62c31f2af4bba Mon Sep 17 00:00:00 2001 From: Max Jeblick Date: Tue, 7 Oct 2025 21:27:51 +0200 Subject: [PATCH 04/20] improve readme Signed-off-by: Max Jeblick --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 704a5220..46b415b5 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,8 @@ Finally we provide wrapper presses that can be combined with other presses: - `ChunkPress` ([source](kvpress/presses/chunk_press.py), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): compress the KV cache on each sequence chunk separately. This can yield to more uniform compression across long sequences - `CriticalKVPress` and `CriticalAdaKVPress` ([source](kvpress/presses/criticalkv_press.py), [paper](https://arxiv.org/abs/2502.03805)): refine the scores using the L1 norm of Wo @ values, coupled with a two-stage selection. - `BlockPress` ([source](kvpress/presses/block_press.py), [paper](https://arxiv.org/abs/2504.15364)): segments input sequence into non-overlapping blocks and compresses iteratively. +- `DeocdingPress` ([source](kvpress/presses/decoding_press.py)): Allows for compression during decoding, see decoding section in this README. +- `BlockPress` ([source](kvpress/presses/prefill_decoding_press.py)): Allows to compress both during prefilling and during decoding. For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression) From 049bffcf2d43cbfa4b0d1302fc0caa9de479653e Mon Sep 17 00:00:00 2001 From: Max Jeblick Date: Tue, 7 Oct 2025 21:28:10 +0200 Subject: [PATCH 05/20] improve readme Signed-off-by: Max Jeblick --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 46b415b5..038deacb 100644 --- a/README.md +++ b/README.md @@ -144,7 +144,7 @@ Finally we provide wrapper presses that can be combined with other presses: - `CriticalKVPress` and `CriticalAdaKVPress` ([source](kvpress/presses/criticalkv_press.py), [paper](https://arxiv.org/abs/2502.03805)): refine the scores using the L1 norm of Wo @ values, coupled with a two-stage selection. - `BlockPress` ([source](kvpress/presses/block_press.py), [paper](https://arxiv.org/abs/2504.15364)): segments input sequence into non-overlapping blocks and compresses iteratively. - `DeocdingPress` ([source](kvpress/presses/decoding_press.py)): Allows for compression during decoding, see decoding section in this README. -- `BlockPress` ([source](kvpress/presses/prefill_decoding_press.py)): Allows to compress both during prefilling and during decoding. +- `PrefillDecodingPress` ([source](kvpress/presses/prefill_decoding_press.py)): Allows to compress both during prefilling and during decoding. For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression) From db40e4c8c38560d3fddc6f3eb839ed0007155bf8 Mon Sep 17 00:00:00 2001 From: Max Jeblick Date: Wed, 8 Oct 2025 08:41:25 +0200 Subject: [PATCH 06/20] update notebook Signed-off-by: Max Jeblick --- notebooks/kvpress_decoding_aime25.ipynb | 465 ++++++++++++++---------- 1 file changed, 263 insertions(+), 202 deletions(-) diff --git a/notebooks/kvpress_decoding_aime25.ipynb b/notebooks/kvpress_decoding_aime25.ipynb index 626c06da..b84291bb 100644 --- a/notebooks/kvpress_decoding_aime25.ipynb +++ b/notebooks/kvpress_decoding_aime25.ipynb @@ -1,11 +1,30 @@ { "cells": [ { - "cell_type": "code", - "execution_count": 6, - "id": "5705c4fb-d665-48e8-a4dd-a43725e5f7f4", "metadata": {}, + "cell_type": "raw", + "source": "", + "id": "6458dd86772cbe6a" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Using decoding compression on the AIME25 Math Dataset\n", + "\n", + "This notebook demonstrates how to compress during text generation.\n", + "We use `nvidia/OpenMath-Nemotron-7B` to solve math problems from the AIME25 dataset. For each problem, the model generates an answer in a boxed format (e.g., `\\boxed{42}`).\n", + "\n", + "To optimize memory usage during long-context generation, the notebook applies key-value cache compression during decoding.\n", + "Compression periodically reduces the cache size by keeping only the most relevant tokens, enabling efficient inference without sacrificing answer quality." + ], + "id": "d5b82a94ac3477fd" + }, + { + "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": 1, "source": [ "import requests\n", "\n", @@ -20,25 +39,37 @@ " SnapKVPress,\n", " StreamingLLMPress,\n", " TOVAPress,\n", - ")" - ] + ")\n", + "\n", + "from kvpress.presses.decoding_press import DecodingPress" + ], + "id": "4cd1c9e43c5ca1bf" }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 2, "id": "553f7d7f-7a2f-456d-a4b1-f38f3ead8767", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", - "dataset = load_dataset(\"math-ai/aime25\")\n", - "sample = dataset[\"test\"][0]\n" + "dataset = load_dataset(\"math-ai/aime25\")" ] }, { "cell_type": "code", - "execution_count": 8, - "id": "2e0398e8", + "execution_count": 4, + "id": "5e912ac7-e5d8-42ff-8f9e-855b8ae999a8", + "metadata": {}, + "outputs": [], + "source": [ + "sample = dataset[\"test\"][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f8c2689d-e461-4608-aa69-a4811f71d639", "metadata": {}, "outputs": [ { @@ -49,43 +80,64 @@ " 'id': '0'}" ] }, - "execution_count": 8, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "sample " + "sample" ] }, { "cell_type": "code", - "execution_count": 9, - "id": "dd926e5b", + "execution_count": 6, + "id": "02b0af83-e71a-4129-9692-921d2bc16cb8", "metadata": {}, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "96a774394acf47948fdc5cd642df43a7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:009$ for which $17_b$ is a divisor of $97_b.$\n", "Answer: 70\n", - "Prediction: Okay, so I have this problem here: I need to find the sum of all integer bases \\( b > 9 \\) for which \\( 17_b \\) is a divisor of \\( 97_b \\). Hmm, okay. Let me try to break this down step by step.\n", - "\n", - "First, I know that when a number is written in base \\( b \\), each digit represents a power of \\( b \\). So, for example, \\( 17_b \\) would be equal to \\( 1 \\times b + 7 \\times 1 \\) in decimal, right? Similarly, \\( 97_b \\) would be \\( 9 \\times b + 7 \\times 1 \\) in decimal. So, I can convert both of these numbers to base 10 and then set up the division condition.\n", - "\n", - "Let me write that out:\n", + "Prediction: \n", + "Okay, so I need to find all integer bases b greater than 9 where the number 17 in base b divides the number 97 in base b. Then, sum all those bases. Hmm, let me think step by step.\n", "\n", - "\\( 17_b = 1 \\times b + 7 = b + 7 \\)\n", + "First, I should recall how numbers are represented in different bases. The number 17 in base b is equal to 1 times b plus 7 times 1, right? So that's 1*b + 7. Similarly, 97 in base b would be 9*b + 7. So, translating both numbers to base 10, we have:\n", "\n", - "\\( 97_b = 9 \\times b + 7 = 9b + 7 \\)\n", + "17_b = b + 7\n", "\n", - "So, the problem is asking for bases \\( b > 9 \\) such that \\( b + 7 \\) divides \\( 9b + 7 \\). In other words, \\( b + 7 \\) is a divisor of \\( 9b + 7 \\).\n", + "97_b = 9b + 7\n", "\n", - "I remember that if \\( a \\) divides \\( b \\), then \\( b = k \\times a \\) for some integer \\( k \\). So, in this case, \\( 9b + 7 = k \\times (b + 7) \\) for some integer \\( k \\).\n", + "The problem states that 17_b must divide 97_b. So, in mathematical terms, this means that (9b + 7) divided by (b + 7) should result in an integer. So, (9b + 7) must be divisible by (b + 7). \n", "\n", - "Let me write that equation:\n", + "Let me write that as an equation:\n", "\n", - "\\( 9b + 7 = k(b + 7) \\)\n", + "(9b + 7) ÷ (b + 7) = integer.\n", "\n", - "Now, I can expand the right-hand side:\n", + "To find when this division results in an integer, maybe I can perform the division and see what the remainder is. If the remainder is zero, then it's divisible. Let's try polynomial division or maybe manipulate the expression.\n", "\n", - "\\( 9b + 7 = kb + 7k \\)\n", + "Let me rewrite 9b + 7. Let's see, if I write 9b + 7 as 9*(b + 7) minus something. Let's compute:\n", "\n", - "Now, let's bring all terms to one side to see if I can solve for \\( k \\):\n", + "9*(b + 7) = 9b + 63\n", "\n", - "\\( 9b + 7 - kb - 7k = 0 \\)\n", + "But 9b + 7 is the original numerator. So, subtracting these:\n", "\n", - "Factor out \\( b \\) and \\( k \\):\n", + "9b + 7 = 9*(b + 7) - 56\n", "\n", - "\\( b(9 - k) + 7(1 - k) = 0 \\)\n", + "So, 9b + 7 = 9*(b + 7) - 56\n", "\n", - "Hmm, that's a bit messy. Maybe I can rearrange the equation differently. Let me subtract \\( k(b + 7) \\) from both sides:\n", + "Therefore, when we divide 9b + 7 by (b + 7), we get:\n", "\n", - "\\( 9b + 7 - k(b + 7) = 0 \\)\n", + "(9*(b + 7) - 56) / (b + 7) = 9 - 56/(b + 7)\n", "\n", - "Which gives:\n", + "So, for this to be an integer, 56/(b + 7) must be an integer. That means that (b + 7) must be a divisor of 56. \n", "\n", - "\\( (9 - k)b + (7 - 7k) = 0 \\)\n", + "So, (b + 7) divides 56. Therefore, (b + 7) is a positive divisor of 56. But since b > 9, then b + 7 > 16. So, we need to find all divisors of 56 that are greater than 16, and then subtract 7 to find the corresponding b.\n", "\n", - "So, this is a linear equation in terms of \\( b \\). Let me solve for \\( b \\):\n", + "First, let's list all positive divisors of 56. The divisors of 56 are:\n", "\n", - "\\( (9 - k)b = 7k - 7 \\)\n", + "1, 2, 4, 7, 8, 14, 28, 56.\n", "\n", - "So,\n", + "Now, from these, the divisors greater than 16 are 28 and 56. So, (b + 7) can be 28 or 56. Therefore, solving for b:\n", "\n", - "\\( b = \\frac{7k - 7}{9 - k} \\)\n", + "If b + 7 = 28, then b = 28 - 7 = 21.\n", "\n", - "Hmm, okay. Since \\( b \\) must be an integer greater than 9, the right-hand side must also be an integer. So, \\( 7k - 7 \\) must be divisible by \\( 9 - k \\). Let me write that as:\n", + "If b + 7 = 56, then b = 56 - 7 = 49.\n", "\n", - "\\( 7(k - 1) \\) is divisible by \\( 9 - k \\)\n", + "So, the possible bases are 21 and 49. Therefore, the sum of all such bases is 21 + 49 = 70.\n", "\n", - "Which can be written as:\n", + "Wait, but let me double-check. Let's verify for b = 21 and b = 49 whether 17_b divides 97_b.\n", "\n", - "\\( 7(k - 1) \\) is divisible by \\( -(k - 9) \\)\n", + "First, for b = 21:\n", "\n", - "But since divisibility is unaffected by the sign, I can say that \\( 7(k - 1) \\) is divisible by \\( k - 9 \\).\n", + "17_21 = 1*21 + 7 = 28\n", "\n", - "So, \\( k - 9 \\) divides \\( 7(k - 1) \\). Let me denote \\( d = k - 9 \\), so \\( d \\) divides \\( 7(k - 1) \\). But \\( k = d + 9 \\), so substituting back:\n", + "97_21 = 9*21 + 7 = 189 + 7 = 196\n", "\n", - "\\( d \\) divides \\( 7(d + 9 - 1) = 7(d + 8) \\)\n", + "Now, 196 divided by 28 is 7, which is an integer. So that works.\n", "\n", - "So, \\( d \\) divides \\( 7(d + 8) \\). Which means \\( d \\) divides \\( 7 \\times 8 = 56 \\). Because \\( d \\) divides \\( 7d + 56 \\), so \\( d \\) must divide 56.\n", + "For b = 49:\n", "\n", - "Therefore, \\( d \\) is a divisor of 56. So, the possible values of \\( d \\) are the divisors of 56. Let me list them:\n", + "17_49 = 1*49 + 7 = 56\n", "\n", - "Positive divisors: 1, 2, 4, 7, 8, 14, 28, 56\n", + "97_49 = 9*49 + 7 = 441 + 7 = 448\n", "\n", - "Negative divisors: -1, -2, -4, -7, -8, -14, -28, -56\n", + "448 divided by 56 is 8, which is also an integer. So that works too.\n", "\n", - "So, \\( d \\) can be any of these. But remember that \\( d = k - 9 \\), and \\( k \\) is an integer because \\( k \\) is the quotient when \\( 9b + 7 \\) is divided by \\( b + 7 \\). So, \\( k \\) must be an integer.\n", + "Therefore, the bases are indeed 21 and 49, and their sum is 70. So the answer should be 70.\n", "\n", - "But let's see, since \\( b > 9 \\), and \\( b + 7 \\) is a divisor of \\( 9b + 7 \\), we can also think about the possible values of \\( k \\). Let me see.\n", + "But wait, let me check if there are any other divisors of 56 that I might have missed. The divisors of 56 are 1, 2, 4, 7, 8, 14, 28, 56. So, the ones greater than 16 are 28 and 56. So, that's correct. So, no other divisors. Therefore, the sum is 21 + 49 = 70.\n", "\n", - "From the equation \\( b = \\frac{7(k - 1)}{9 - k} \\), since \\( b > 9 \\), the numerator and denominator must have the same sign. So, either both numerator and denominator are positive or both are negative.\n", + "But just to make sure, let's check if there's a possibility of negative divisors. But since the base b must be greater than 9, and b + 7 must be a positive divisor (since base is a positive integer greater than the digits used, which in 17_b, the digits are 1 and 7, so base must be at least 8. But the problem states b > 9, so base is at least 10. So, b + 7 is at least 17. So, negative divisors are irrelevant here. So, we don't need to consider them.\n", "\n", - "Case 1: Both numerator and denominator are positive.\n", + "Therefore, the answer is 70.\n", "\n", - "So, \\( 7(k - 1) > 0 \\) implies \\( k > 1 \\)\n", + "**Final Answer**\n", + "\\boxed{70}\n", + "To solve the problem, we need to find all integer bases \\( b > 9 \\) for which \\( 17_b \\) is a divisor of \\( 97_b \\).\n", "\n", - "And \\( 9 - k > 0 \\) implies \\( k < 9 \\)\n", + "First, we convert the numbers from base \\( b \\) to base 10:\n", + "- \\( 17_b = 1 \\cdot b + 7 = b + 7 \\)\n", + "- \\( 97_b = 9 \\cdot b + 7 = 9b + 7 \\)\n", "\n", - "So, \\( 1 < k < 9 \\)\n", + "We need \\( 9b + 7 \\) to be divisible by \\( b + 7 \\). This can be expressed as:\n", + "\\[\n", + "\\frac{9b + 7}{b + 7} = 9 - \\frac{56}{b + 7}\n", + "\\]\n", + "For this to be an integer, \\( \\frac{56}{b + 7} \\) must be an integer. Therefore, \\( b + 7 \\) must be a divisor of 56.\n", "\n", - "Case 2: Both numerator and denominator are negative.\n", + "The positive divisors of 56 are:\n", + "\\[\n", + "1, 2, 4, 7, 8, 14, 28, 56\n", + "\\]\n", + "Since \\( b > 9 \\), we need \\( b + 7 > 16 \\). The divisors of 56 that are greater than 16 are 28 and 56. Thus, we have:\n", + "\\[\n", + "b + 7 = 28 \\quad \\text{or} \\quad b + 7 = 56\n", + "\\]\n", + "Solving for \\( b \\):\n", + "\\[\n", + "b = 28 - 7 = 21 \\quad \\text{or} \\quad b = 56 - 7 = 49\n", + "\\]\n", "\n", - "So, \\( 7(k - 1) < 0 \\) implies \\( k < 1 \\)\n", + "We verify these bases:\n", + "- For \\( b = 21 \\):\n", + " \\[\n", + " 17_{21} = 1 \\cdot 21 + 7 = 28\n", + " \\]\n", + " \\[\n", + " 97_{21} = 9 \\cdot 21 + 7 = 196\n", + " \\]\n", + " \\[\n", + " 196 \\div 28 = 7 \\quad \\text{(an integer)}\n", + " \\]\n", "\n", - "And \\( 9 - k < 0 \\) implies \\( k > 9 \\)\n", + "- For \\( b = 49 \\):\n", + " \\[\n", + " 17_{49} = 1 \\cdot 49 + 7 = 56\n", + " \\]\n", + " \\[\n", + " 97_{49} = 9 \\cdot 49 + 7 = 448\n", + " \\]\n", + " \\[\n", + " 448 \\div 56 = 8 \\quad \\text{(an integer)}\n", + " \\]\n", "\n", - "But \\( k < 1 \\) and \\( k > 9 \\) can't happen at the same time. So, only Case 1 is possible.\n", + "Both bases 21 and 49 satisfy the condition. The sum of these bases is:\n", + "\\[\n", + "21 + 49 = 70\n", + "\\]\n", "\n", - "Therefore, \\( k \\) must be an integer between 2 and 8 inclusive.\n", - "\n", - "So, \\( k \\in \\{2, 3, 4, 5, 6, 7, 8\\} \\)\n", - "\n", - "Now, let's compute \\( d = k - 9 \\) for each \\( k \\):\n", - "\n", - "For \\( k = 2 \\): \\( d = -7 \\)\n", - "\n", - "For \\( k = 3 \\): \\( d = -6 \\)\n", - "\n", - "For \\( k = 4 \\): \\( d = -5 \\)\n", - "\n", - "For \\( k = 5 \\): \\( d = -4 \\)\n", - "\n", - "For \\( k = 6 \\): \\( d = -3 \\)\n", - "\n", - "For \\( k = 7 \\): \\( d = -2 \\)\n", - "\n", - "For \\( k = 8 \\): \\( d = -1 \\)\n", - "\n", - "So, \\( d \\) can be -7, -6, -5, -4, -3, -2, -1.\n", - "\n", - "But earlier, we said that \\( d \\) must divide 56. Let's check if each of these \\( d \\) values divides 56.\n", - "\n", - "- \\( d = -7 \\): 56 ÷ (-7) = -8, which is integer. So, yes.\n", - "\n", - "- \\( d = -6 \\): 56 ÷ (-6) ≈ -9.333... Not integer. So, no.\n", - "\n", - "- \\( d = -5 \\): 56 ÷ (-5) = -11.2. Not integer. No.\n", - "\n", - "- \\( d = -4 \\): 56 ÷ (-4) = -14. Integer. Yes.\n", - "\n", - "- \\( d = -3 \\): 56 ÷ (-3) ≈ -18.666... Not integer. No.\n", - "\n", - "- \\( d = -2 \\): 56 ÷ (-2) = -28. Integer. Yes.\n", - "\n", - "- \\( d = -1 \\): 56 ÷ (-1) = -56. Integer. Yes.\n", - "\n", - "So, the valid \\( d \\) values are -7, -4, -2, -1.\n", - "\n", - "Therefore, for each of these \\( d \\), we can find \\( k \\):\n", - "\n", - "- \\( d = -7 \\): \\( k = d + 9 = 2 \\)\n", - "\n", - "- \\( d = -4 \\): \\( k = 5 \\)\n", - "\n", - "- \\( d = -2 \\): \\( k = 7 \\)\n", - "\n", - "- \\( d = -1 \\): \\( k = 8 \\)\n", - "\n", - "So, now, let's compute \\( b \\) for each \\( k \\):\n", - "\n", - "Recall \\( b = \\frac{7(k - 1)}{9 - k} \\)\n", - "\n", - "Let's compute for each \\( k \\):\n", - "\n", - "1. \\( k = 2 \\):\n", - "\n", - "\\( b = \\frac{7(2 - 1)}{9 - 2} = \\frac{7(1)}{7} = 1 \\)\n", - "\n", - "But \\( b > 9 \\), so 1 is invalid.\n", - "\n", - "2. \\( k = 5 \\):\n", - "\n", - "\\( b = \\frac{7(5 - 1)}{9 - 5} = \\frac{7(4)}{4} = 7 \\)\n", - "\n", - "Again, \\( b = 7 \\) is less than 9, so invalid.\n", - "\n", - "3. \\( k = 7 \\):\n", - "\n", - "\\( b = \\frac{7(7 - 1)}{9 - 7} = \\frac{7(6)}{\n", - "Cache size: torch.Size([1, 2, 2081, 128])\n" + "Thus, the final answer is:\n", + "\\[\n", + "\\boxed{70}\n", + "\\]\n", + "Cache size: torch.Size([1, 4, 1899, 128])\n", + "CPU times: user 45.5 s, sys: 126 ms, total: 45.6 s\n", + "Wall time: 45.6 s\n" ] } ], "source": [ - "%%time\n", "cache = DynamicCache()\n", "question = sample[\"problem\"]\n", "true_answer = sample[\"answer\"]\n", @@ -269,22 +293,12 @@ "print(f\"Question: {question}\")\n", "print(f\"Answer: {true_answer}\")\n", "print(f\"Prediction: {pred_answer}\")\n", - "print(f\"Cache size: {cache.layers[0].keys.shape}\")" + "print(f\"Cache size: {cache.key_cache[0].shape}\")" ] }, { "cell_type": "code", - "execution_count": 12, - "id": "ff177e00-cabf-4680-9cfc-b6cb00770527", - "metadata": {}, - "outputs": [], - "source": [ - "from kvpress.presses.decoding_press import DecodingPress" - ] - }, - { - "cell_type": "code", - "execution_count": 13, + "execution_count": 9, "id": "7f4fa366-7e62-443a-b5c8-274128fe6237", "metadata": {}, "outputs": [ @@ -294,108 +308,155 @@ "text": [ "Question: Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$\n", "Answer: 70\n", - "Prediction: Okay, so I have this problem here: I need to find the sum of all integer bases \\( b > 9 \\) for which \\( 17_b \\) is a divisor of \\( 97_b \\). Hmm, okay. Let me try to break this down step by step.\n", + "Prediction: \n", + "Okay, so I need to find all integer bases b greater than 9 where the number 17 in base b divides the number 97 in base b. Then, sum all those bases. Hmm, let me think step by step.\n", + "\n", + "First, I should recall how numbers are represented in different bases. The number 17 in base b is equal to 1 times b plus 7 times 1, right? So that's 1*b + 7. Similarly, 97 in base b would be 9*b + 7. So, translating both numbers to base 10, we have:\n", "\n", - "First, I know that when a number is written in base \\( b \\), each digit represents a power of \\( b \\). So, for example, \\( 17_b \\) would be equal to \\( 1 \\times b + 7 \\times 1 \\) in decimal, right? Similarly, \\( 97_b \\) would be \\( 9 \\times b + 7 \\times 1 \\) in decimal. So, I can convert both of these numbers to base 10 and then set up the division condition.\n", + "17_b = b + 7\n", "\n", - "Let me write that out:\n", + "97_b = 9b + 7\n", "\n", - "\\( 17_b = 1 \\times b + 7 = b + 7 \\)\n", + "The problem states that 17_b must divide 97_b. So, in mathematical terms, this means that (9b + 7) divided by (b + 7) should result in an integer. So, (9b + 7) must be divisible by (b + 7). \n", "\n", - "\\( 97_b = 9 \\times b + 7 = 9b + 7 \\)\n", + "Let me write that as an equation:\n", "\n", - "So, the problem is asking for bases \\( b > 9 \\) such that \\( b + 7 \\) divides \\( 9b + 7 \\). In other words, \\( b + 7 \\) is a divisor of \\( 9b + 7 \\).\n", + "(9b + 7) ÷ (b + 7) = integer.\n", "\n", - "I remember that if \\( a \\) divides \\( b \\), then \\( b = k \\times a \\) for some integer \\( k \\). So, in this case, \\( 9b + 7 = k \\times (b + 7) \\) for some integer \\( k \\).\n", + "To find when this division results in an integer, maybe I can perform the division and see what the remainder is. If the remainder is zero, then it's divisible. Let's try polynomial division or maybe manipulate the expression.\n", "\n", - "Let me write that equation:\n", + "Let me rewrite 9b + 7. Let's see, if I write 9b + 7 as 9*(b + 7) minus something. Let's compute:\n", "\n", - "\\( 9b + 7 = k(b + 7) \\)\n", + "9*(b + 7) = 9b + 63\n", "\n", - "Now, I can expand the right-hand side:\n", + "But 9b + 7 is the original numerator. So, subtracting these:\n", "\n", - "\\( 9b + 7 = kb + 7k \\)\n", + "9b + 7 = 9*(b + 7) - 56\n", "\n", - "Now, let's bring all terms to one side to see if I can solve for \\( k \\):\n", + "So, 9b + 7 = 9*(b + 7) - 56\n", "\n", - "\\( 9b + 7 - kb - 7k = 0 \\)\n", + "Therefore, when we divide 9b + 7 by (b + 7), we get:\n", "\n", - "Factor out \\( b \\) and \\( k \\):\n", + "(9*(b + 7) - 56) / (b + 7) = 9 - 56/(b + 7)\n", "\n", - "\\( b(9 - k) + 7(1 - k) = 0 \\)\n", + "So, for this to be an integer, 56/(b + 7) must be an integer. That means that (b + 7) must be a divisor of 56. \n", "\n", - "Hmm, that's a bit messy. Maybe I can rearrange the equation differently. Let me subtract \\( k(b + 7) \\) from both sides:\n", + "So, (b + 7) divides 56. Therefore, b + 7 is a positive divisor of 56. But since b > 9, then b + 7 > 16. So, we need to find all divisors of 56 that are greater than 16, then subtract 7 to find the corresponding b.\n", "\n", - "\\( 9b + 7 - k(b + 7) = 0 \\)\n", + "First, let's list all positive divisors of 56. The divisors of 56 are:\n", "\n", - "Expanding the \\( k(b + 7) \\):\n", + "1, 2, 4, 7, 8, 14, 28, 56.\n", "\n", - "\\( 9b + 7 - kb - 7k = 0 \\)\n", + "Now, from these, the divisors greater than 16 are 28 and 56. So, b + 7 can be 28 or 56. Therefore, solving for b:\n", "\n", - "Now, let's collect like terms:\n", + "If b + 7 = 28, then b = 28 - 7 = 21.\n", "\n", - "\\( (9 - k)b + (7 - 7k) = 0 \\)\n", + "If b + 7 = 56, then b = 56 - 7 = 49.\n", "\n", - "So, this is a linear equation in terms of \\( b \\). Let me write it as:\n", + "So, the possible bases are 21 and 49. Therefore, the sum of these bases is 21 + 49 = 70.\n", "\n", - "\\( (9 - k)b = 7k - 7 \\)\n", + "Wait, but let me double-check. The problem says \"integer bases b > 9\". So, 21 and 49 are both greater than 9, so they are valid. Let me confirm that these bases actually work.\n", "\n", - "Then, solving for \\( b \\):\n", + "First, check base 21:\n", "\n", - "\\( b = \\frac{7k - 7}{9 - k} \\)\n", + "17 in base 21 is 1*21 + 7 = 28.\n", "\n", - "Hmm, okay. So, \\( b \\) must be an integer greater than 9. So, \\( b \\) is an integer, and \\( k \\) is also an integer because \\( k \\) is the quotient when \\( 9b + 7 \\) is divided by \\( b + 7 \\).\n", + "97 in base 21 is 9*21 + 7 = 189 + 7 = 196.\n", "\n", - "So, \\( b \\) is expressed in terms of \\( k \\). Let me see if I can find integer values of \\( k \\) such that \\( b \\) is an integer greater than 9.\n", + "Now, check if 28 divides 196. 196 ÷ 28 = 7, which is an integer. So that works.\n", "\n", - "First, let's note that \\( 9 - k \\) cannot be zero because that would make the denominator zero, which is undefined. So, \\( k \\neq 9 \\).\n", + "Now check base 49:\n", "\n", - "Also, since \\( b > 9 \\), let's see what constraints that imposes on \\( k \\).\n", + "17 in base 49 is 1*49 + 7 = 56.\n", "\n", - "Looking at the expression \\( b = \\frac{7k - 7}{9 - k} \\), let's see when this is positive and greater than 9.\n", + "97 in base 49 is 9*49 + 7 = 441 + 7 = 448.\n", "\n", - "First, let's note that \\( 7k - 7 \\) and \\( 9 - k \\) must have the same sign because \\( b \\) is positive.\n", + "Check if 56 divides 448. 448 ÷ 56 = 8, which is an integer. So that works too.\n", "\n", - "So, either both numerator and denominator are positive or both are negative.\n", + "Therefore, the valid bases are 21 and 49, and their sum is 70. So the answer should be 70.\n", "\n", - "Case 1: Both numerator and denominator are positive.\n", + "Wait, but let me check if there are any other divisors of 56 that I might have missed. The divisors of 56 are 1, 2, 4, 7, 8, 14, 28, 56. So the ones greater than 16 are 28 and 56. So yes, only those two. So 21 and 49. So sum is 70. That seems correct.\n", "\n", - "So, \\( 7k - 7 > 0 \\) implies \\( k > 1 \\).\n", + "But just to be thorough, let me check if there's a possibility that b + 7 could be a negative divisor of 56. But since b is a base, it must be an integer greater than 9. So b + 7 is at least 17, so positive. Therefore, negative divisors don't apply here. So no other possibilities.\n", "\n", - "And \\( 9 - k > 0 \\) implies \\( k < 9 \\).\n", + "Therefore, the answer is 70.\n", + "To solve the problem, we need to find all integer bases \\( b > 9 \\) such that the number \\( 17_b \\) (which is \\( b + 7 \\) in decimal) divides \\( 97_b \\) (which is \\( 9b + 7 \\) in decimal). We then sum these valid bases.\n", "\n", - "So, \\( k \\) must satisfy \\( 1 < k < 9 \\). Since \\( k \\) is an integer, \\( k \\) can be 2, 3, 4, 5, 6, 7, or 8.\n", + "First, we express the condition mathematically:\n", + "\\[\n", + "\\frac{9b + 7}{b + 7} \\text{ must be an integer.}\n", + "\\]\n", + "This can be rewritten as:\n", + "\\[\n", + "\\frac{9b + 7}{b + 7} = 9 - \\frac{56}{b + 7}.\n", + "\\]\n", + "For this to be an integer, \\( \\frac{56}{b + 7} \\) must be an integer, meaning \\( b + 7 \\) must be a divisor of 56. The divisors of 56 are:\n", + "\\[\n", + "1, 2, 4, 7, 8, 14, 28, 56.\n", + "\\]\n", + "Since \\( b > 9 \\), we need \\( b + 7 > 16 \\). Therefore, the valid divisors of 56 that are greater than 16 are 28 and 56. This gives us:\n", + "\\[\n", + "b + 7 = 28 \\implies b = 21,\n", + "\\]\n", + "\\[\n", + "b + 7 = 56 \\implies b = 49.\n", + "\\]\n", + "We now verify that these bases satisfy the original condition:\n", + "- For \\( b = 21 \\):\n", + " \\[\n", + " 17_{21} = 21 + 7 = 28, \\quad 97_{21} = 9 \\cdot 21 + 7 = 196.\n", + " \\]\n", + " We check if 196 is divisible by 28:\n", + " \\[\n", + " 196 \\div 28 = 7,\n", + " \\]\n", + " which is an integer. So, \\( b = 21 \\) is valid.\n", + "- For \\( b = 49 \\):\n", + " \\[\n", + " 17_{49} = 49 + 7 = 56, \\quad 97_{49} = 9 \\cdot 49 + 7 = 448.\n", + " \\]\n", + " We check if 448 is divisible by 56:\n", + " \\[\n", + " 448 \\div 56 = 8,\n", + " \\]\n", + " which is an integer. So, \\( b = 49 \\) is also valid.\n", "\n", - "Case 2: Both numerator and\n", - "Cache size: torch.Size([1, 2, 500, 128])\n", - "CPU times: user 36.7 s, sys: 26.8 ms, total: 36.8 s\n", - "Wall time: 36.8 s\n" + "The valid bases are 21 and 49. Summing these bases, we get:\n", + "\\[\n", + "21 + 49 = 70.\n", + "\\]\n", + "Thus, the sum of the valid bases is:\n", + "\\[\n", + "\\boxed{70}.\n", + "\\]\n", + "Cache size: torch.Size([1, 4, 515, 128])\n", + "CPU times: user 50.5 s, sys: 25.2 ms, total: 50.5 s\n", + "Wall time: 50.5 s\n" ] } ], "source": [ - "%%time\n", - "compression_interval = 500 # compress every compression_steps\n", - "target_size = 500 # number of tokens to keep after compression\n", + "compression_steps = 48 # compress every compression_steps\n", + "token_buffer_size = 512 # number of tokens to keep after compression\n", "\n", "\n", - "press = DecodingPress(base_press=KnormPress(), compression_interval=compression_interval, target_size=target_size, hidden_states_buffer_size=0)\n", + "press = DecodingPress(base_press=ExpectedAttentionPress(), compression_steps=compression_steps, token_buffer_size=token_buffer_size)\n", "\n", "cache = DynamicCache()\n", "question = sample[\"problem\"]\n", "true_answer = sample[\"answer\"]\n", - "pred_answer = pipe(\" \", question=question, press=press, cache=cache, max_new_tokens=1000)[\"answer\"]\n", + "pred_answer = pipe(\" \", question=question, press=press, cache=cache, max_new_tokens=2048)[\"answer\"]\n", "\n", "print(f\"Question: {question}\")\n", "print(f\"Answer: {true_answer}\")\n", "print(f\"Prediction: {pred_answer}\")\n", - "print(f\"Cache size: {cache.layers[0].keys.shape}\")" + "print(f\"Cache size: {cache.key_cache[0].shape}\")" ] } ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -409,7 +470,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.10.14" } }, "nbformat": 4, From cce1196ac50a9d6cb5801722c87c6fcbeac32448 Mon Sep 17 00:00:00 2001 From: Max Jeblick Date: Wed, 8 Oct 2025 16:16:58 +0200 Subject: [PATCH 07/20] update notebook Signed-off-by: Max Jeblick --- kvpress/presses/decoding_press.py | 8 - notebooks/kvpress_decoding_aime25.ipynb | 435 +++++++++++++++++------- 2 files changed, 314 insertions(+), 129 deletions(-) diff --git a/kvpress/presses/decoding_press.py b/kvpress/presses/decoding_press.py index 9e5b1a6c..8b96fc18 100644 --- a/kvpress/presses/decoding_press.py +++ b/kvpress/presses/decoding_press.py @@ -50,14 +50,6 @@ def __post_init__(self): self.hidden_states_buffer = defaultdict(list) # Per-layer buffer self.layer_step_counts = defaultdict(int) # Track step count per layer - # Warn if compression happens before buffer is fully utilized - # TODO: would it make sense to not reset the buffer? - if self.hidden_states_buffer_size > 0 and self.compression_interval < self.hidden_states_buffer_size: - logger.warning( - f"compression_interval ({self.compression_interval}) < hidden_states_buffer_size ({self.hidden_states_buffer_size}). " # noqa: E501 - f"Buffer will be reset before reaching full capacity, potentially reducing compression quality." - ) - assert self.compression_interval > 0, "compression_interval must be greater than 0" assert self.target_size > 0, "target_size must be greater than 0" diff --git a/notebooks/kvpress_decoding_aime25.ipynb b/notebooks/kvpress_decoding_aime25.ipynb index b84291bb..55ccd8bd 100644 --- a/notebooks/kvpress_decoding_aime25.ipynb +++ b/notebooks/kvpress_decoding_aime25.ipynb @@ -1,14 +1,9 @@ { "cells": [ { - "metadata": {}, - "cell_type": "raw", - "source": "", - "id": "6458dd86772cbe6a" - }, - { - "metadata": {}, "cell_type": "markdown", + "id": "d5b82a94ac3477fd", + "metadata": {}, "source": [ "# Using decoding compression on the AIME25 Math Dataset\n", "\n", @@ -17,14 +12,14 @@ "\n", "To optimize memory usage during long-context generation, the notebook applies key-value cache compression during decoding.\n", "Compression periodically reduces the cache size by keeping only the most relevant tokens, enabling efficient inference without sacrificing answer quality." - ], - "id": "d5b82a94ac3477fd" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": 1, + "id": "4cd1c9e43c5ca1bf", + "metadata": {}, + "outputs": [], "source": [ "import requests\n", "\n", @@ -42,8 +37,7 @@ ")\n", "\n", "from kvpress.presses.decoding_press import DecodingPress" - ], - "id": "4cd1c9e43c5ca1bf" + ] }, { "cell_type": "code", @@ -58,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "5e912ac7-e5d8-42ff-8f9e-855b8ae999a8", "metadata": {}, "outputs": [], @@ -68,7 +62,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "f8c2689d-e461-4608-aa69-a4811f71d639", "metadata": {}, "outputs": [ @@ -80,7 +74,7 @@ " 'id': '0'}" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -91,14 +85,14 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "02b0af83-e71a-4129-9692-921d2bc16cb8", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "96a774394acf47948fdc5cd642df43a7", + "model_id": "bd9b941a02394e7d936e018a25c4628b", "version_major": 2, "version_minor": 0 }, @@ -121,12 +115,12 @@ "device = \"cuda:0\"\n", "ckpt = \"nvidia/OpenMath-Nemotron-7B\"\n", "attn_implementation = \"flash_attention_2\"\n", - "pipe = pipeline(\"kv-press-text-generation\", model=ckpt, device=device, torch_dtype=\"auto\", model_kwargs={\"attn_implementation\":attn_implementation})" + "pipe = pipeline(\"kv-press-text-generation\", model=ckpt, device=device, dtype=\"auto\", model_kwargs={\"attn_implementation\":attn_implementation})" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "ccb1ff5e-5e16-4779-bbae-5781e8255345", "metadata": {}, "outputs": [], @@ -136,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "id": "96e1f430-4ece-49dd-a5fb-68ae3c2ee8b8", "metadata": {}, "outputs": [ @@ -167,57 +161,197 @@ "\n", "9*(b + 7) = 9b + 63\n", "\n", - "But 9b + 7 is the original numerator. So, subtracting these:\n", + "But 9b + 7 is the original numerator. So, subtracting 9*(b + 7) from 9b +7 gives:\n", "\n", - "9b + 7 = 9*(b + 7) - 56\n", + "9b +7 - (9b +63) = 7 - 63 = -56\n", "\n", - "So, 9b + 7 = 9*(b + 7) - 56\n", + "Therefore, 9b +7 = 9*(b +7) -56\n", "\n", - "Therefore, when we divide 9b + 7 by (b + 7), we get:\n", + "So, (9b +7)/(b +7) = 9 - 56/(b +7)\n", "\n", - "(9*(b + 7) - 56) / (b + 7) = 9 - 56/(b + 7)\n", + "For this to be an integer, 56/(b +7) must be an integer. Therefore, (b +7) must be a divisor of 56. \n", "\n", - "So, for this to be an integer, 56/(b + 7) must be an integer. That means that (b + 7) must be a divisor of 56. \n", + "But since b is a base greater than 9, the digits in the numbers 17_b and 97_b must be valid. In base b, the digits can go from 0 to b-1. So, in 17_b, the digits are 1 and 7. Since the base is greater than 9, 7 is a valid digit. Similarly, in 97_b, the digits are 9 and 7. So, 9 must be less than b. Therefore, the base b must be greater than 9, which is already given. So, the constraints are:\n", "\n", - "So, (b + 7) divides 56. Therefore, (b + 7) is a positive divisor of 56. But since b > 9, then b + 7 > 16. So, we need to find all divisors of 56 that are greater than 16, and then subtract 7 to find the corresponding b.\n", + "1. b > 9\n", "\n", - "First, let's list all positive divisors of 56. The divisors of 56 are:\n", + "2. (b +7) divides 56.\n", + "\n", + "So, first, let's find all divisors of 56. The positive divisors of 56 are:\n", "\n", "1, 2, 4, 7, 8, 14, 28, 56.\n", "\n", - "Now, from these, the divisors greater than 16 are 28 and 56. So, (b + 7) can be 28 or 56. Therefore, solving for b:\n", + "But since b +7 must be one of these divisors, and b >9, then b +7 must be a divisor of 56 greater than 9 +7 =16. Wait, because b >9, so b >=10, so b +7 >=17. Therefore, the divisors of 56 that are greater than or equal to 17 are 28 and 56. Wait, let's check:\n", + "\n", + "Divisors of 56: 1, 2, 4, 7, 8, 14, 28, 56.\n", + "\n", + "So, divisors >=17 are 28 and 56. Therefore, b +7 can be 28 or 56. Therefore, solving for b:\n", + "\n", + "If b +7 =28, then b=21\n", + "\n", + "If b +7=56, then b=49\n", + "\n", + "So, the possible bases are 21 and 49. Then, the sum of these bases is 21 +49=70.\n", + "\n", + "Wait, but let me check if there are any other divisors. Wait, 14 is a divisor, but 14 is less than 17 (since b >=10, so b +7 >=17). So 14 is too small. Similarly, 8,7, etc. So only 28 and 56. Therefore, the possible bases are 21 and 49. So their sum is 70.\n", + "\n", + "But wait, let me verify this. Let's check for b=21:\n", + "\n", + "17 in base 21 is 1*21 +7=28\n", + "\n", + "97 in base21 is 9*21 +7=196\n", + "\n", + "Check if 28 divides 196. 196 divided by 28 is 7. Yes, that's an integer. So that works.\n", + "\n", + "For b=49:\n", + "\n", + "17 in base49 is 1*49 +7=56\n", + "\n", + "97 in base49 is 9*49 +7=448\n", + "\n", + "448 divided by 56 is 8. That's also an integer. So that works too.\n", + "\n", + "So, both bases are valid. Therefore, the sum is 21 +49=70.\n", + "\n", + "Wait, but hold on. Let me check if there are any negative divisors. But since b is a base, it must be a positive integer greater than 9, so b +7 is positive. Therefore, we don't need to consider negative divisors. So, yes, only 28 and 56. Therefore, the answer is 70.\n", + "\n", + "But let me just make sure I didn't miss any divisors. Let's list all divisors again:\n", + "\n", + "Positive divisors of 56: 1, 2, 4, 7, 8, 14, 28, 56.\n", + "\n", + "Negative divisors would be -1, -2, -4, -7, -8, -14, -28, -56. But since b +7 must be positive (as b>9), we can ignore the negative ones. So, only 28 and 56. Therefore, the answer is 70.\n", + "\n", + "But wait, another thought: when we say that (9b +7) is divisible by (b +7), we can also approach this by setting up the equation (9b +7) = k*(b +7), where k is an integer. Then, solving for b:\n", + "\n", + "9b +7 = k*b +7k\n", + "\n", + "Bring terms with b to one side:\n", + "\n", + "9b -k*b = 7k -7\n", + "\n", + "b*(9 -k) = 7(k -1)\n", + "\n", + "Therefore, b = [7(k -1)] / (9 -k)\n", + "\n", + "Since b must be an integer greater than 9, we can look for integer values of k such that 9 -k divides 7(k -1), and the result is b>9.\n", + "\n", + "Let me see. Let's solve for k:\n", + "\n", + "b = 7(k -1)/(9 -k)\n", + "\n", + "We need 9 -k to divide 7(k -1). Let's denote d = 9 -k, so k =9 -d. Then:\n", + "\n", + "b =7( (9 -d) -1 ) / d =7(8 -d)/d =7*(8/d -1)\n", + "\n", + "Since b must be a positive integer greater than 9, d must be a positive divisor of 56 (since 7*(8/d -1) must be integer). Wait, maybe this approach complicates things. Let's try substituting possible integer values of k.\n", + "\n", + "Alternatively, since k must be an integer such that 9 -k divides 7(k -1). Let's note that 9 -k and k -1 are related. Let's set m =9 -k, so k =9 -m. Then:\n", + "\n", + "b =7*( (9 -m) -1 ) / m =7*(8 -m)/m =7*(8/m -1)\n", + "\n", + "So, 8/m must be an integer because b must be integer. Therefore, m must be a positive divisor of 8. But m =9 -k, and since k is an integer, m can be positive or negative. However, since b must be positive, let's see:\n", + "\n", + "If m is positive, then 8/m -1 must be positive as well because 7*(8/m -1) =b>9. So:\n", + "\n", + "8/m -1 >0 =>8/m >1 =>m <8. But m is a positive divisor of 8. The positive divisors of 8 are 1,2,4,8. So m can be 1,2,4,8. But m must be less than 8 (from 8/m >1). So m=1,2,4.\n", + "\n", + "Wait, if m=8, then 8/m -1=1 -1=0, so b=0, which is invalid. So m=1,2,4.\n", + "\n", + "Let's check each:\n", + "\n", + "Case 1: m=1\n", + "\n", + "Then, b=7*(8/1 -1)=7*(7)=49. So b=49. Which is valid, as before.\n", + "\n", + "Case 2: m=2\n", + "\n", + "b=7*(8/2 -1)=7*(4 -1)=7*3=21. So b=21. Also valid.\n", + "\n", + "Case3: m=4\n", + "\n", + "b=7*(8/4 -1)=7*(2 -1)=7*1=7. But b=7 is not greater than 9, so invalid.\n", "\n", - "If b + 7 = 28, then b = 28 - 7 = 21.\n", + "So only m=1 and m=2 give valid bases, which are 49 and 21. So sum is 70. So same result as before.\n", "\n", - "If b + 7 = 56, then b = 56 - 7 = 49.\n", + "Alternatively, if m is negative, then m divides 8, but m negative. Let's see:\n", "\n", - "So, the possible bases are 21 and 49. Therefore, the sum of all such bases is 21 + 49 = 70.\n", + "If m is a negative divisor of 8, then m can be -1,-2,-4,-8.\n", "\n", - "Wait, but let me double-check. Let's verify for b = 21 and b = 49 whether 17_b divides 97_b.\n", + "Then, let's check:\n", "\n", - "First, for b = 21:\n", + "Case m=-1:\n", "\n", - "17_21 = 1*21 + 7 = 28\n", + "b=7*(8/(-1) -1)=7*(-8 -1)=7*(-9)=-63. Negative base is invalid.\n", "\n", - "97_21 = 9*21 + 7 = 189 + 7 = 196\n", + "Similarly, m=-2:\n", "\n", - "Now, 196 divided by 28 is 7, which is an integer. So that works.\n", + "b=7*(8/(-2) -1)=7*(-4 -1)=7*(-5)=-35. Invalid.\n", "\n", - "For b = 49:\n", + "m=-4:\n", "\n", - "17_49 = 1*49 + 7 = 56\n", + "b=7*(8/(-4) -1)=7*(-2 -1)=7*(-3)=-21. Invalid.\n", "\n", - "97_49 = 9*49 + 7 = 441 + 7 = 448\n", + "m=-8:\n", "\n", - "448 divided by 56 is 8, which is also an integer. So that works too.\n", + "b=7*(8/(-8) -1)=7*(-1 -1)=7*(-2)=-14. Invalid.\n", "\n", - "Therefore, the bases are indeed 21 and 49, and their sum is 70. So the answer should be 70.\n", + "So all negative m give invalid bases. Therefore, only m=1 and m=2, leading to b=49 and 21. So sum is 70.\n", "\n", - "But wait, let me check if there are any other divisors of 56 that I might have missed. The divisors of 56 are 1, 2, 4, 7, 8, 14, 28, 56. So, the ones greater than 16 are 28 and 56. So, that's correct. So, no other divisors. Therefore, the sum is 21 + 49 = 70.\n", + "Therefore, the answer is 70. So I think that's correct. Let me just check once more.\n", "\n", - "But just to make sure, let's check if there's a possibility of negative divisors. But since the base b must be greater than 9, and b + 7 must be a positive divisor (since base is a positive integer greater than the digits used, which in 17_b, the digits are 1 and 7, so base must be at least 8. But the problem states b > 9, so base is at least 10. So, b + 7 is at least 17. So, negative divisors are irrelevant here. So, we don't need to consider them.\n", + "Alternatively, maybe there's another approach. Let's consider that (9b +7)/(b +7) must be integer. Let's call this integer k. So:\n", "\n", - "Therefore, the answer is 70.\n", + "(9b +7) =k*(b +7)\n", + "\n", + "So, 9b +7 =k*b +7k\n", + "\n", + "Rearranged:\n", + "\n", + "(9 -k)*b =7k -7\n", + "\n", + "So, b=(7k -7)/(9 -k)\n", + "\n", + "We need b to be an integer greater than 9. So, (7k -7) must be divisible by (9 -k). Let's factor 7:\n", + "\n", + "7(k -1)/(9 -k) = -7(k -1)/(k -9)\n", + "\n", + "So, b= -7(k -1)/(k -9)\n", + "\n", + "We can write this as 7(1 -k)/(k -9). Hmm, not sure if that helps. Let's look for integer k such that (9 -k) divides 7(k -1). Let's denote d=9 -k, so k=9 -d. Then:\n", + "\n", + "b=7(k -1)/d=7(8 -d)/d=7*(8/d -1)\n", + "\n", + "So, 8/d must be integer, so d must divide 8. So d is a divisor of 8. Since d=9 -k, and k is an integer, d can be any integer (positive or negative) that divides 8. But since b must be positive and greater than 9, let's see:\n", + "\n", + "If d is positive, then 8/d must be integer, so d is a positive divisor of 8: 1,2,4,8.\n", + "\n", + "Then, for each d:\n", + "\n", + "d=1:\n", + "\n", + "b=7*(8/1 -1)=7*7=49. Valid.\n", + "\n", + "d=2:\n", + "\n", + "b=7*(8/2 -1)=7*(4 -1)=21. Valid.\n", + "\n", + "d=4:\n", + "\n", + "b=7*(8/4 -1)=7*(2 -1)=7. Not valid (b>9 required).\n", + "\n", + "d=8:\n", + "\n", + "b=7*(8/8 -1)=7*(1 -1)=0. Invalid.\n", + "\n", + "If d is negative, then d divides 8, so d=-1,-2,-4,-8.\n", + "\n", + "For d=-1:\n", + "\n", + "b=7*(8/(-1) -1)=7*(-8 -1)=7*(-9)=-63. Invalid.\n", + "\n", + "Similarly, other negative d's give negative b's. So only d=1 and d=2 give valid bases. So again, 49 and 21. Sum is 70.\n", + "\n", + "Therefore, I think the answer is 70. So I can be confident that the sum is 70.\n", "\n", "**Final Answer**\n", "\\boxed{70}\n", @@ -227,29 +361,50 @@ "- \\( 17_b = 1 \\cdot b + 7 = b + 7 \\)\n", "- \\( 97_b = 9 \\cdot b + 7 = 9b + 7 \\)\n", "\n", - "We need \\( 9b + 7 \\) to be divisible by \\( b + 7 \\). This can be expressed as:\n", + "We need \\( b + 7 \\) to divide \\( 9b + 7 \\). This can be expressed as:\n", "\\[\n", - "\\frac{9b + 7}{b + 7} = 9 - \\frac{56}{b + 7}\n", + "\\frac{9b + 7}{b + 7} = k \\quad \\text{for some integer } k\n", "\\]\n", - "For this to be an integer, \\( \\frac{56}{b + 7} \\) must be an integer. Therefore, \\( b + 7 \\) must be a divisor of 56.\n", "\n", - "The positive divisors of 56 are:\n", + "Rewriting the equation, we get:\n", + "\\[\n", + "9b + 7 = k(b + 7)\n", + "\\]\n", + "\\[\n", + "9b + 7 = kb + 7k\n", + "\\]\n", "\\[\n", - "1, 2, 4, 7, 8, 14, 28, 56\n", + "9b - kb = 7k - 7\n", "\\]\n", - "Since \\( b > 9 \\), we need \\( b + 7 > 16 \\). The divisors of 56 that are greater than 16 are 28 and 56. Thus, we have:\n", "\\[\n", - "b + 7 = 28 \\quad \\text{or} \\quad b + 7 = 56\n", + "b(9 - k) = 7(k - 1)\n", "\\]\n", - "Solving for \\( b \\):\n", "\\[\n", - "b = 28 - 7 = 21 \\quad \\text{or} \\quad b = 56 - 7 = 49\n", + "b = \\frac{7(k - 1)}{9 - k}\n", "\\]\n", "\n", - "We verify these bases:\n", + "For \\( b \\) to be an integer, \\( 9 - k \\) must be a divisor of \\( 7(k - 1) \\). We need to find the values of \\( k \\) such that \\( b > 9 \\).\n", + "\n", + "Let's consider the possible values of \\( k \\) by checking the divisors of 56 (since \\( 9b + 7 = 9(b + 7) - 56 \\)):\n", + "\n", + "The positive divisors of 56 are: 1, 2, 4, 7, 8, 14, 28, 56.\n", + "\n", + "Since \\( b > 9 \\), \\( b + 7 \\) must be greater than 16. Therefore, the valid divisors are 28 and 56.\n", + "\n", + "1. If \\( b + 7 = 28 \\):\n", + " \\[\n", + " b = 28 - 7 = 21\n", + " \\]\n", + "\n", + "2. If \\( b + 7 = 56 \\):\n", + " \\[\n", + " b = 56 - 7 = 49\n", + " \\]\n", + "\n", + "We verify these values:\n", "- For \\( b = 21 \\):\n", " \\[\n", - " 17_{21} = 1 \\cdot 21 + 7 = 28\n", + " 17_{21} = 21 + 7 = 28\n", " \\]\n", " \\[\n", " 97_{21} = 9 \\cdot 21 + 7 = 196\n", @@ -260,7 +415,7 @@ "\n", "- For \\( b = 49 \\):\n", " \\[\n", - " 17_{49} = 1 \\cdot 49 + 7 = 56\n", + " 17_{49} = 49 + 7 = 56\n", " \\]\n", " \\[\n", " 97_{49} = 9 \\cdot 49 + 7 = 448\n", @@ -269,18 +424,16 @@ " 448 \\div 56 = 8 \\quad \\text{(an integer)}\n", " \\]\n", "\n", - "Both bases 21 and 49 satisfy the condition. The sum of these bases is:\n", + "Both values of \\( b \\) are valid. The sum of these bases is:\n", "\\[\n", "21 + 49 = 70\n", "\\]\n", "\n", - "Thus, the final answer is:\n", + "Thus, the sum of all integer bases \\( b > 9 \\) for which \\( 17_b \\) is a divisor of \\( 97_b \\) is:\n", "\\[\n", "\\boxed{70}\n", "\\]\n", - "Cache size: torch.Size([1, 4, 1899, 128])\n", - "CPU times: user 45.5 s, sys: 126 ms, total: 45.6 s\n", - "Wall time: 45.6 s\n" + "Cache size: torch.Size([1, 4, 3781, 128])\n" ] } ], @@ -288,17 +441,17 @@ "cache = DynamicCache()\n", "question = sample[\"problem\"]\n", "true_answer = sample[\"answer\"]\n", - "pred_answer = pipe(\" \", question=question, press=None, cache=cache, max_new_tokens=2048)[\"answer\"]\n", + "pred_answer = pipe(\" \", question=question, press=None, cache=cache, max_new_tokens=16_000)[\"answer\"]\n", "\n", "print(f\"Question: {question}\")\n", "print(f\"Answer: {true_answer}\")\n", "print(f\"Prediction: {pred_answer}\")\n", - "print(f\"Cache size: {cache.key_cache[0].shape}\")" + "print(f\"Cache size: {cache.layers[0].keys.shape}\")" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "id": "7f4fa366-7e62-443a-b5c8-274128fe6237", "metadata": {}, "outputs": [ @@ -329,136 +482,176 @@ "\n", "9*(b + 7) = 9b + 63\n", "\n", - "But 9b + 7 is the original numerator. So, subtracting these:\n", + "But 9b + 7 is the original numerator. So, subtracting 9*(b + 7) from 9b +7 gives:\n", + "\n", + "9b +7 - (9b +63) = 7 - 63 = -56\n", "\n", - "9b + 7 = 9*(b + 7) - 56\n", + "Therefore, 9b +7 = 9*(b +7) -56\n", "\n", - "So, 9b + 7 = 9*(b + 7) - 56\n", + "So, (9b +7)/(b +7) = 9 - 56/(b +7)\n", "\n", - "Therefore, when we divide 9b + 7 by (b + 7), we get:\n", + "For this to be an integer, 56/(b +7) must be an integer. Therefore, (b +7) must be a divisor of 56. \n", "\n", - "(9*(b + 7) - 56) / (b + 7) = 9 - 56/(b + 7)\n", + "But since b is a base greater than 9, the digits in the numbers 17_b and 97_b must be valid. In base b, the digits can go from 0 to b-1. So, in 17_b, the digits are 1 and 7. Since the base is greater than 9, 7 is a valid digit. Similarly, in 97_b, the digits are 9 and 7. So, 9 must be less than b. Therefore, the base b must be greater than 9, which is already given. So, the constraints are:\n", "\n", - "So, for this to be an integer, 56/(b + 7) must be an integer. That means that (b + 7) must be a divisor of 56. \n", + "1. b > 9\n", "\n", - "So, (b + 7) divides 56. Therefore, b + 7 is a positive divisor of 56. But since b > 9, then b + 7 > 16. So, we need to find all divisors of 56 that are greater than 16, then subtract 7 to find the corresponding b.\n", + "2. (b +7) divides 56.\n", "\n", - "First, let's list all positive divisors of 56. The divisors of 56 are:\n", + "So, first, let's find all divisors of 56. The positive divisors of 56 are:\n", "\n", "1, 2, 4, 7, 8, 14, 28, 56.\n", "\n", - "Now, from these, the divisors greater than 16 are 28 and 56. So, b + 7 can be 28 or 56. Therefore, solving for b:\n", + "But since b +7 must be one of these divisors, and b >9, then b +7 must be a divisor of 56 greater than 9 +7 =16. Wait, because b >9, so b >=10, so b +7 >=17. Therefore, the divisors of 56 that are greater than or equal to 17 are 28 and 56. Wait, let's check:\n", + "\n", + "Divisors of 56: 1, 2, 4, 7, 8, 14, 28, 56.\n", + "\n", + "So, divisors >=17 are 28 and 56. Therefore, b +7 can be 28 or 56. Therefore, solving for b:\n", + "\n", + "If b +7 =28, then b=21\n", + "\n", + "If b +7=56, then b=49\n", + "\n", + "So, the possible bases are 21 and 49. Then, the sum of these bases is 21 +49=70.\n", + "\n", + "Wait, but let me check if there are any other divisors. Wait, 14 is a divisor, but 14 is less than 17 (since b >=10, so b +7 >=17). So 14 is too small. Similarly, 8,7, etc. So only 28 and 56. Therefore, the possible bases are 21 and 49. So their sum is 70.\n", + "\n", + "But wait, let me verify this. Let's check for b=21:\n", + "\n", + "17 in base 21 is 21 +7=28\n", + "\n", + "97 in base21 is 9*21 +7=189 +7=196\n", + "\n", + "196 divided by28 is 7, which is an integer. So that works.\n", + "\n", + "For b=49:\n", + "\n", + "17 in base49 is 49 +7=56\n", "\n", - "If b + 7 = 28, then b = 28 - 7 = 21.\n", + "97 in base49 is 9*49 +7=441 +7=448\n", "\n", - "If b + 7 = 56, then b = 56 - 7 = 49.\n", + "448 divided by56 is 8, which is an integer. So that works too.\n", "\n", - "So, the possible bases are 21 and 49. Therefore, the sum of these bases is 21 + 49 = 70.\n", + "So both bases are valid. Therefore, the sum is 21 +49=70.\n", "\n", - "Wait, but let me double-check. The problem says \"integer bases b > 9\". So, 21 and 49 are both greater than 9, so they are valid. Let me confirm that these bases actually work.\n", + "Wait, but let me check if there are any other divisors. Wait, 56 is a divisor of 56, but 56 is 56. So, if b +7=56, then b=49. Which we already have. Similarly, 28 gives b=21. Are there any negative divisors? But since b is a base, it must be a positive integer greater than 9, so negative divisors don't make sense here. So, yes, only 28 and 56. Therefore, the answer is 70.\n", "\n", - "First, check base 21:\n", + "But wait, let me check if there's a mistake here. Let me think again. The problem says \"all bases b>9\". So, the possible divisors of 56 are 1,2,4,7,8,14,28,56. But since b +7 must be one of these, and b>9, so b +7 must be at least 17. So, the possible divisors are 28 and 56. Therefore, b=21 and 49. So sum is 70. That seems correct.\n", "\n", - "17 in base 21 is 1*21 + 7 = 28.\n", + "But let me check if there's another way. Let me think. Suppose we have (9b +7)/(b +7) must be integer. Let me compute this fraction:\n", "\n", - "97 in base 21 is 9*21 + 7 = 189 + 7 = 196.\n", + "(9b +7)/(b +7) = 9 - 56/(b +7). Wait, how?\n", "\n", - "Now, check if 28 divides 196. 196 ÷ 28 = 7, which is an integer. So that works.\n", + "Let me do the division:\n", "\n", - "Now check base 49:\n", + "Divide 9b +7 by b +7.\n", "\n", - "17 in base 49 is 1*49 + 7 = 56.\n", + "Dividing 9b by b gives 9. Multiply (b +7) by 9: 9b +63. Subtract that from 9b +7: (9b +7) - (9b +63) = -56. So, the division gives 9 with a remainder of -56. Therefore, (9b +7) = 9*(b +7) -56. Therefore, (9b +7)/(b +7) = 9 -56/(b +7). For this to be an integer, 56/(b +7) must be an integer. Therefore, (b +7) must be a divisor of 56. Which is the same conclusion as before.\n", "\n", - "97 in base 49 is 9*49 + 7 = 441 + 7 = 448.\n", + "Therefore, the possible values of (b +7) are the positive divisors of 56 that are greater than or equal to 17 (since b>9 implies b +7>16). The divisors of 56 are 1,2,4,7,8,14,28,56. So, the ones greater than 16 are 28 and 56. Therefore, b +7=28 gives b=21, and b +7=56 gives b=49. So, sum is 21 +49=70.\n", "\n", - "Check if 56 divides 448. 448 ÷ 56 = 8, which is an integer. So that works too.\n", + "Therefore, the answer is 70. Let me check once more with another example. Suppose b=21:\n", "\n", - "Therefore, the valid bases are 21 and 49, and their sum is 70. So the answer should be 70.\n", + "17 in base21 is 2*10 +7=27? Wait, no. Wait, in base b, 17_b is 1*b +7. So, 1*21 +7=28. 97_b is 9*21 +7=189 +7=196. 196 divided by28 is 7. Correct.\n", "\n", - "Wait, but let me check if there are any other divisors of 56 that I might have missed. The divisors of 56 are 1, 2, 4, 7, 8, 14, 28, 56. So the ones greater than 16 are 28 and 56. So yes, only those two. So 21 and 49. So sum is 70. That seems correct.\n", + "For b=49:\n", "\n", - "But just to be thorough, let me check if there's a possibility that b + 7 could be a negative divisor of 56. But since b is a base, it must be an integer greater than 9. So b + 7 is at least 17, so positive. Therefore, negative divisors don't apply here. So no other possibilities.\n", + "17_b is 1*49 +7=56. 97_b is 9*49 +7=441 +7=448. 448 divided by56 is 8. Correct.\n", "\n", - "Therefore, the answer is 70.\n", - "To solve the problem, we need to find all integer bases \\( b > 9 \\) such that the number \\( 17_b \\) (which is \\( b + 7 \\) in decimal) divides \\( 97_b \\) (which is \\( 9b + 7 \\) in decimal). We then sum these valid bases.\n", + "So, both bases work. Therefore, the sum is 21 +49=70. So, the answer is 70.\n", "\n", - "First, we express the condition mathematically:\n", + "**Final Answer**\n", + "\\boxed{70}\n", + "To solve the problem, we need to find all bases \\( b > 9 \\) such that \\( 17_b \\) divides \\( 97_b \\).\n", + "\n", + "First, we convert the numbers from base \\( b \\) to base 10:\n", + "- \\( 17_b \\) in base 10 is \\( 1 \\cdot b + 7 = b + 7 \\).\n", + "- \\( 97_b \\) in base 10 is \\( 9 \\cdot b + 7 = 9b + 7 \\).\n", + "\n", + "We need \\( 9b + 7 \\) to be divisible by \\( b + 7 \\). This can be expressed as:\n", "\\[\n", "\\frac{9b + 7}{b + 7} \\text{ must be an integer.}\n", "\\]\n", - "This can be rewritten as:\n", + "\n", + "Rewriting the fraction:\n", "\\[\n", - "\\frac{9b + 7}{b + 7} = 9 - \\frac{56}{b + 7}.\n", + "\\frac{9b + 7}{b + 7} = \\frac{9(b + 7) - 56}{b + 7} = 9 - \\frac{56}{b + 7}.\n", "\\]\n", - "For this to be an integer, \\( \\frac{56}{b + 7} \\) must be an integer, meaning \\( b + 7 \\) must be a divisor of 56. The divisors of 56 are:\n", + "\n", + "For this to be an integer, \\( \\frac{56}{b + 7} \\) must be an integer. Therefore, \\( b + 7 \\) must be a divisor of 56. The divisors of 56 are:\n", "\\[\n", "1, 2, 4, 7, 8, 14, 28, 56.\n", "\\]\n", - "Since \\( b > 9 \\), we need \\( b + 7 > 16 \\). Therefore, the valid divisors of 56 that are greater than 16 are 28 and 56. This gives us:\n", + "\n", + "Since \\( b > 9 \\), we need \\( b + 7 > 16 \\). The valid divisors of 56 that are greater than 16 are:\n", + "\\[\n", + "28 \\quad \\text{and} \\quad 56.\n", + "\\]\n", + "\n", + "Thus, we have:\n", "\\[\n", - "b + 7 = 28 \\implies b = 21,\n", + "b + 7 = 28 \\quad \\Rightarrow \\quad b = 21,\n", "\\]\n", "\\[\n", - "b + 7 = 56 \\implies b = 49.\n", + "b + 7 = 56 \\quad \\Rightarrow \\quad b = 49.\n", "\\]\n", - "We now verify that these bases satisfy the original condition:\n", + "\n", + "We verify these bases:\n", "- For \\( b = 21 \\):\n", " \\[\n", - " 17_{21} = 21 + 7 = 28, \\quad 97_{21} = 9 \\cdot 21 + 7 = 196.\n", + " 17_{21} = 21 + 7 = 28,\n", " \\]\n", - " We check if 196 is divisible by 28:\n", " \\[\n", - " 196 \\div 28 = 7,\n", + " 97_{21} = 9 \\cdot 21 + 7 = 196.\n", " \\]\n", - " which is an integer. So, \\( b = 21 \\) is valid.\n", + " Since \\( 196 \\div 28 = 7 \\), \\( 17_{21} \\) divides \\( 97_{21} \\).\n", + "\n", "- For \\( b = 49 \\):\n", " \\[\n", - " 17_{49} = 49 + 7 = 56, \\quad 97_{49} = 9 \\cdot 49 + 7 = 448.\n", + " 17_{49} = 49 + 7 = 56,\n", " \\]\n", - " We check if 448 is divisible by 56:\n", " \\[\n", - " 448 \\div 56 = 8,\n", + " 97_{49} = 9 \\cdot 49 + 7 = 448.\n", " \\]\n", - " which is an integer. So, \\( b = 49 \\) is also valid.\n", + " Since \\( 448 \\div 56 = 8 \\), \\( 17_{49} \\) divides \\( 97_{49} \\).\n", "\n", - "The valid bases are 21 and 49. Summing these bases, we get:\n", + "The valid bases are \\( b = 21 \\) and \\( b = 49 \\). Summing these bases:\n", "\\[\n", "21 + 49 = 70.\n", "\\]\n", - "Thus, the sum of the valid bases is:\n", + "\n", + "Thus, the sum of all bases \\( b > 9 \\) where \\( 17_b \\) divides \\( 97_b \\) is:\n", "\\[\n", "\\boxed{70}.\n", "\\]\n", - "Cache size: torch.Size([1, 4, 515, 128])\n", - "CPU times: user 50.5 s, sys: 25.2 ms, total: 50.5 s\n", - "Wall time: 50.5 s\n" + "Cache size: torch.Size([1, 4, 1182, 128])\n" ] } ], "source": [ - "compression_steps = 48 # compress every compression_steps\n", - "token_buffer_size = 512 # number of tokens to keep after compression\n", + "compression_interval = 1024 # compress every compression_steps\n", + "target_size = 512 # number of tokens to keep after compression. Note that actual cache size lies in [target_size, compression_interval]\n", "\n", "\n", - "press = DecodingPress(base_press=ExpectedAttentionPress(), compression_steps=compression_steps, token_buffer_size=token_buffer_size)\n", + "press = DecodingPress(base_press=ExpectedAttentionPress(), compression_interval=compression_interval, target_size=target_size)\n", "\n", "cache = DynamicCache()\n", "question = sample[\"problem\"]\n", "true_answer = sample[\"answer\"]\n", - "pred_answer = pipe(\" \", question=question, press=press, cache=cache, max_new_tokens=2048)[\"answer\"]\n", + "pred_answer = pipe(\" \", question=question, press=press, cache=cache, max_new_tokens=16_000)[\"answer\"]\n", "\n", "print(f\"Question: {question}\")\n", "print(f\"Answer: {true_answer}\")\n", "print(f\"Prediction: {pred_answer}\")\n", - "print(f\"Cache size: {cache.key_cache[0].shape}\")" + "print(f\"Cache size: {cache.layers[0].keys.shape}\")" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "kvpress", "language": "python", - "name": "python3" + "name": "kvpress" }, "language_info": { "codemirror_mode": { @@ -470,7 +663,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.3" } }, "nbformat": 4, From d27e2a3b42a363a4ad0b6ae2ae1a007040d2d9dc Mon Sep 17 00:00:00 2001 From: Max Jeblick Date: Fri, 10 Oct 2025 15:22:14 +0200 Subject: [PATCH 08/20] fix 0 buffer size Signed-off-by: Max Jeblick --- kvpress/presses/decoding_press.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/kvpress/presses/decoding_press.py b/kvpress/presses/decoding_press.py index 8b96fc18..8271dfe5 100644 --- a/kvpress/presses/decoding_press.py +++ b/kvpress/presses/decoding_press.py @@ -163,7 +163,11 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic # hidden states buffer and kv cache self.hidden_states_buffer[layer_idx] = [] - self.hidden_states_buffer[layer_idx] = self.hidden_states_buffer[layer_idx][-self.hidden_states_buffer_size :] + self.hidden_states_buffer[layer_idx] = ( + self.hidden_states_buffer[layer_idx][-self.hidden_states_buffer_size :] + if self.hidden_states_buffer_size > 0 + else [] + ) return output def reset(self): From a01946f0553e5543040deae6cf14c4df5e17124c Mon Sep 17 00:00:00 2001 From: Max Jeblick Date: Fri, 10 Oct 2025 15:34:37 +0200 Subject: [PATCH 09/20] be explcit about k_len Signed-off-by: Max Jeblick --- kvpress/presses/adakv_press.py | 10 +++++----- kvpress/presses/block_press.py | 12 ++++++------ kvpress/presses/criticalkv_press.py | 22 +++++++++++----------- kvpress/presses/decoding_press.py | 6 +++--- kvpress/presses/duo_attention_press.py | 6 +++--- kvpress/presses/finch_press.py | 12 ++++++------ kvpress/presses/pyramidkv_press.py | 4 ++-- kvpress/presses/scorer_press.py | 4 ++-- kvpress/presses/simlayerkv_press.py | 8 ++++---- kvpress/presses/snapkv_press.py | 9 ++++----- kvpress/presses/streaming_llm_press.py | 6 +++--- kvpress/presses/think_press.py | 4 ++-- 12 files changed, 51 insertions(+), 52 deletions(-) diff --git a/kvpress/presses/adakv_press.py b/kvpress/presses/adakv_press.py index 22393678..815c777a 100644 --- a/kvpress/presses/adakv_press.py +++ b/kvpress/presses/adakv_press.py @@ -55,21 +55,21 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): # Compute scores scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs) - bsz, num_key_value_heads, q_len = scores.shape + bsz, num_key_value_heads, k_len = scores.shape # Make sure to keep at least alpha * (1 - compression_ratio) KV pairs per head - n_kept = int(q_len * (1 - self.compression_ratio)) # ScorerPress definition + n_kept = int(k_len * (1 - self.compression_ratio)) # ScorerPress definition n_safe = int(n_kept * self.alpha_safeguard) top_indices = torch.topk(scores, n_safe, dim=-1).indices scores.scatter_(-1, top_indices, torch.finfo(scores.dtype).max) # Compute bottom-k across heads - n_pruned = num_key_value_heads * (q_len - n_kept) + n_pruned = num_key_value_heads * (k_len - n_kept) indices = torch.topk(-scores.reshape(bsz, -1), n_pruned, dim=1).indices.flatten() # Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details batch_indices = torch.arange(bsz).repeat_interleave(n_pruned) - head_indices = indices // q_len - seq_indices = indices % q_len + head_indices = indices // k_len + seq_indices = indices % k_len module.masked_key_indices = (batch_indices, head_indices, seq_indices) return keys, values diff --git a/kvpress/presses/block_press.py b/kvpress/presses/block_press.py index 6d8da788..b4edafdb 100644 --- a/kvpress/presses/block_press.py +++ b/kvpress/presses/block_press.py @@ -58,18 +58,18 @@ def compress( assert attentions is None, "BlockPress does not support attentions." - bsz, num_key_value_heads, q_len, head_dim = keys.shape + bsz, num_key_value_heads, k_len, head_dim = keys.shape - block_size = self.block_size if self.block_size < q_len else q_len - n_kept = int(q_len * (1 - self.compression_ratio)) + block_size = self.block_size if self.block_size < k_len else k_len + n_kept = int(k_len * (1 - self.compression_ratio)) kept_indices = torch.arange(n_kept, device=keys.device).expand(bsz, num_key_value_heads, -1) # Reshape hidden states to match the kept_indices - states = hidden_states.view(bsz, q_len, num_key_value_heads, -1).transpose(1, 2) + states = hidden_states.view(bsz, k_len, num_key_value_heads, -1).transpose(1, 2) - for i in range(n_kept, q_len, block_size): - end = min(i + block_size, q_len) + for i in range(n_kept, k_len, block_size): + end = min(i + block_size, k_len) current_indices = torch.arange(i, end, device=keys.device).expand(bsz, num_key_value_heads, -1) current_indices = torch.cat([kept_indices, current_indices], dim=-1) diff --git a/kvpress/presses/criticalkv_press.py b/kvpress/presses/criticalkv_press.py index 5513bc67..abcb8b83 100644 --- a/kvpress/presses/criticalkv_press.py +++ b/kvpress/presses/criticalkv_press.py @@ -53,7 +53,7 @@ def compression_ratio(self, value): @staticmethod def vwl1norm(values, module): - bsz, num_key_value_heads, q_len, _ = values.shape + bsz, num_key_value_heads, k_len, _ = values.shape num_key_value_groups = module.config.num_attention_heads // num_key_value_heads Wo = module.o_proj.weight.transpose(0, 1) Wo = Wo.view(module.config.num_attention_heads, module.config.head_dim, module.config.hidden_size) @@ -67,16 +67,16 @@ def vwl1norm(values, module): head_WoV_norm = torch.norm(head_WoV, p=1, dim=-1) head_WoV_norm_list.append(head_WoV_norm) - # b_size, num_heads, q_len , k_len + # b_size, num_heads, k_len , k_len WoV_norm = torch.stack(head_WoV_norm_list, dim=1) - WoV_norm = WoV_norm.view(bsz, num_key_value_heads, module.num_key_value_groups, q_len).mean(dim=2) + WoV_norm = WoV_norm.view(bsz, num_key_value_heads, module.num_key_value_groups, k_len).mean(dim=2) return WoV_norm def score(self, module, hidden_states, keys, values, attentions, kwargs): # Stage 1 scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs) - q_len = keys.shape[2] - selection_budget = int((1 - self.compression_ratio) * q_len * self.first_stage_ratio) + k_len = keys.shape[2] + selection_budget = int((1 - self.compression_ratio) * k_len * self.first_stage_ratio) top_k_index = torch.topk(scores, selection_budget, sorted=True, dim=-1).indices # Stage 2 @@ -140,10 +140,10 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): # Compute scores scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs) - bsz, num_key_value_heads, q_len = scores.shape + bsz, num_key_value_heads, k_len = scores.shape # Make sure to keep at least alpha * (1 - compression_ratio) KV pairs per head - n_kept = int(q_len * (1 - self.compression_ratio)) # ScorerPress definition + n_kept = int(k_len * (1 - self.compression_ratio)) # ScorerPress definition n_safe = int(n_kept * self.alpha_safeguard) top_indices = torch.topk(scores, n_safe, dim=-1).indices scores.scatter_(-1, top_indices, torch.finfo(scores.dtype).max) @@ -156,7 +156,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): budget_scores = scores.scatter(-1, top_indices, torch.finfo(scores.dtype).max) budget_scores = budget_scores.reshape(bsz, -1) top_indices = torch.topk(budget_scores, n_kept * num_key_value_heads, dim=-1).indices - top_indices_head_idx = top_indices // q_len + top_indices_head_idx = top_indices // k_len head_budgets = torch.zeros(num_key_value_heads, device=keys.device, dtype=torch.int64) head_budgets.scatter_add_(0, top_indices_head_idx.flatten(), torch.ones_like(top_indices_head_idx.flatten())) @@ -180,12 +180,12 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): ########################## # Compute bottom-k across heads - n_pruned = num_key_value_heads * (q_len - n_kept) + n_pruned = num_key_value_heads * (k_len - n_kept) indices = torch.topk(-scores.reshape(bsz, -1), n_pruned, dim=1).indices.flatten() # Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details batch_indices = torch.arange(bsz).repeat_interleave(n_pruned) - head_indices = indices // q_len - seq_indices = indices % q_len + head_indices = indices // k_len + seq_indices = indices % k_len module.masked_key_indices = (batch_indices, head_indices, seq_indices) return keys, values diff --git a/kvpress/presses/decoding_press.py b/kvpress/presses/decoding_press.py index 8271dfe5..ec583ab6 100644 --- a/kvpress/presses/decoding_press.py +++ b/kvpress/presses/decoding_press.py @@ -93,9 +93,9 @@ def compress( It would be possible to speed up compression during decoding for certain scorer presses by storing existing scores in a buffer (e.g. KNormPress) and reusing them in subsequent compressions. """ - q_len = keys.shape[2] - target_compression_ratio = self._find_target_compression_ratio(q_len, self.target_size) - logger.debug(f"Compressing {q_len} to {self.target_size} with ratio {target_compression_ratio}") + k_len = keys.shape[2] + target_compression_ratio = self._find_target_compression_ratio(k_len, self.target_size) + logger.debug(f"Compressing {k_len} to {self.target_size} with ratio {target_compression_ratio}") original_compression_ratio = self.base_press.compression_ratio self.base_press.compression_ratio = target_compression_ratio diff --git a/kvpress/presses/duo_attention_press.py b/kvpress/presses/duo_attention_press.py index 24c5b17a..2843f8dd 100644 --- a/kvpress/presses/duo_attention_press.py +++ b/kvpress/presses/duo_attention_press.py @@ -103,9 +103,9 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): raise ValueError( "Streaming mask not initialized. Make sure to call __post_init_from_model__ to initialize this press." ) - q_len = keys.shape[2] + k_len = keys.shape[2] - if (self.head_compression_ratio > 0) or (q_len > (self.sink_size + self.recent_size)): + if (self.head_compression_ratio > 0) or (k_len > (self.sink_size + self.recent_size)): # Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details masked_keys = torch.zeros_like(keys[..., 0], dtype=torch.bool) @@ -114,7 +114,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): # Compute the compression ratio self.compression_ratio_ = self.streaming_mask.float().mean().item() - self.compression_ratio_ *= 1 - (self.sink_size + self.recent_size) / q_len + self.compression_ratio_ *= 1 - (self.sink_size + self.recent_size) / k_len return keys, values diff --git a/kvpress/presses/finch_press.py b/kvpress/presses/finch_press.py index 39a8cd48..29051b92 100644 --- a/kvpress/presses/finch_press.py +++ b/kvpress/presses/finch_press.py @@ -58,7 +58,7 @@ def score(self, module, hidden_states, keys, values, attentions, kwargs): Similar to SnapKVPress except it adds a normalization step before averaging on the context window. """ - bsz, num_key_value_heads, q_len, _ = keys.shape + bsz, num_key_value_heads, k_len, _ = keys.shape num_key_value_groups = module.config.num_attention_heads // num_key_value_heads if attentions is not None: @@ -69,13 +69,13 @@ def score(self, module, hidden_states, keys, values, attentions, kwargs): ) if self.normalize_scores: - non_zero_counts = torch.arange(q_len - self.window_size, q_len)[None, None, :, None] + non_zero_counts = torch.arange(k_len - self.window_size, k_len)[None, None, :, None] non_zero_counts = non_zero_counts.to(attn_weights.device) attn_weights = attn_weights * non_zero_counts # Average per group scores = attn_weights.mean(dim=-2) - scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len - self.window_size) + scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, k_len - self.window_size) scores = scores.mean(dim=2) # Add back the observation window. Use max score to make sure the window is not pruned. @@ -95,14 +95,14 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): scores = self.score(module, hidden_states, keys, values, attentions, kwargs) # Compute indices to keep (optionally by chunks) - q_len = keys.shape[2] # Use actual sequence length from keys instead of hidden_states + k_len = keys.shape[2] # Use actual sequence length from keys instead of hidden_states if self.chunk_length is None: - n_kept = int(q_len * (1 - self.compression_ratio)) + n_kept = int(k_len * (1 - self.compression_ratio)) indices = scores.topk(n_kept, dim=-1).indices else: assert self.chunk_length > self.window_size / (1 - self.compression_ratio) indices = [] - for i in range(0, q_len, self.chunk_length): + for i in range(0, k_len, self.chunk_length): chunk_scores = scores[:, :, i : i + self.chunk_length] n_kept = max(1, int(chunk_scores.shape[2] * (1 - self.compression_ratio))) chunk_indices = i + chunk_scores.topk(n_kept, dim=-1).indices diff --git a/kvpress/presses/pyramidkv_press.py b/kvpress/presses/pyramidkv_press.py index cc0a0106..006f83ce 100644 --- a/kvpress/presses/pyramidkv_press.py +++ b/kvpress/presses/pyramidkv_press.py @@ -100,8 +100,8 @@ def compress( scores = self.score(module, hidden_states, keys, values, attentions, kwargs) # Get indices of KV pairs with the lowest scores - q_len = keys.shape[2] - n_kept = self.get_layer_budget(module, q_len) + k_len = keys.shape[2] + n_kept = self.get_layer_budget(module, k_len) indices = scores.topk(n_kept, dim=-1).indices indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) diff --git a/kvpress/presses/scorer_press.py b/kvpress/presses/scorer_press.py index 1cbcb748..f7f33d16 100644 --- a/kvpress/presses/scorer_press.py +++ b/kvpress/presses/scorer_press.py @@ -90,8 +90,8 @@ def compress( scores = self.score(module, hidden_states, keys, values, attentions, kwargs) # Get indices of KV pairs with the lowest scores - q_len = keys.shape[2] - n_kept = int(q_len * (1 - self.compression_ratio)) + k_len = keys.shape[2] + n_kept = int(k_len * (1 - self.compression_ratio)) indices = scores.topk(n_kept, dim=-1).indices indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) diff --git a/kvpress/presses/simlayerkv_press.py b/kvpress/presses/simlayerkv_press.py index 48af6bc4..3c114814 100644 --- a/kvpress/presses/simlayerkv_press.py +++ b/kvpress/presses/simlayerkv_press.py @@ -94,13 +94,13 @@ def compress( self.compression_ratios = [] # Check if compression is needed - q_len = keys.shape[2] + k_len = keys.shape[2] min_length = self.n_initial + self.n_recent + self.n_last - if q_len <= min_length: + if k_len <= min_length: logger.warning(f"Sequence length is shorter than {min_length}: no compression applied") - if (self.lazy_threshold == 1.0) or (q_len <= min_length): + if (self.lazy_threshold == 1.0) or (k_len <= min_length): self.compression_ratios.append(0.0) return keys, values @@ -109,7 +109,7 @@ def compress( # If layer is lazy, only keep the initial and recent KV pairs keys = torch.cat([keys[:, :, : self.n_initial], keys[:, :, -self.n_recent + self.n_last :]], dim=2) values = torch.cat([values[:, :, : self.n_initial], values[:, :, -self.n_recent + self.n_last :]], dim=2) - self.compression_ratios.append((q_len - self.n_initial - self.n_recent + 1) / q_len) + self.compression_ratios.append((k_len - self.n_initial - self.n_recent + 1) / k_len) else: self.compression_ratios.append(0.0) diff --git a/kvpress/presses/snapkv_press.py b/kvpress/presses/snapkv_press.py index 549f3865..684d40f4 100644 --- a/kvpress/presses/snapkv_press.py +++ b/kvpress/presses/snapkv_press.py @@ -44,7 +44,7 @@ def compute_window_attention(module, hidden_states, keys, window_size, position_ Compute the last window_size queries and associated attention weights for the first q_len - window_size keys. """ - bsz, _, q_len, _ = keys.shape + bsz, _, k_len, _ = keys.shape num_heads = module.config.num_attention_heads head_dim = module.head_dim num_key_value_groups = num_heads // module.config.num_key_value_heads @@ -61,7 +61,7 @@ def compute_window_attention(module, hidden_states, keys, window_size, position_ key_states = repeat_kv(keys, num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim) attention_mask = torch.ones_like(attn_weights) * float("-inf") - attention_mask = torch.triu(attention_mask, diagonal=q_len - window_size + 1) + attention_mask = torch.triu(attention_mask, diagonal=k_len - window_size + 1) attn_weights += attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = attn_weights[..., :-window_size] @@ -81,10 +81,9 @@ def score( bsz, num_key_value_heads, k_len, _ = keys.shape num_key_value_groups = module.config.num_attention_heads // num_key_value_heads - q_len = hidden_states.shape[1] assert ( - q_len > self.window_size - ), f"Query length {q_len} should be greater than the window size {self.window_size}" + hidden_states.shape[1] > self.window_size + ), f"Query length {hidden_states.shape[1]} should be greater than the window size {self.window_size}" if attentions is not None: attn_weights = attentions[..., -self.window_size :, : -self.window_size] diff --git a/kvpress/presses/streaming_llm_press.py b/kvpress/presses/streaming_llm_press.py index 66c9ecf8..6d657389 100644 --- a/kvpress/presses/streaming_llm_press.py +++ b/kvpress/presses/streaming_llm_press.py @@ -44,9 +44,9 @@ def score( kwargs, ) -> torch.Tensor: - q_len = keys.shape[2] - assert q_len > self.n_sink, f"Input should contain more tokens than n_sink={self.n_sink}" - n_pruned = q_len - int(q_len * (1 - self.compression_ratio)) + k_len = keys.shape[2] + assert k_len > self.n_sink, f"Input should contain more tokens than n_sink={self.n_sink}" + n_pruned = k_len - int(k_len * (1 - self.compression_ratio)) scores = torch.ones_like(keys[..., 0]) scores[:, :, self.n_sink : self.n_sink + n_pruned] = 0 diff --git a/kvpress/presses/think_press.py b/kvpress/presses/think_press.py index a19379c2..f277bb5c 100644 --- a/kvpress/presses/think_press.py +++ b/kvpress/presses/think_press.py @@ -72,7 +72,7 @@ def compress( return keys, values # Compute scores per dimension - bsz, num_key_value_heads, q_len, head_dim = keys.shape + bsz, num_key_value_heads, k_len, head_dim = keys.shape num_key_value_groups = module.config.num_attention_heads // num_key_value_heads queries = self.compute_window_queries(module, kwargs["hidden_states"], kwargs["position_embeddings"]) @@ -84,7 +84,7 @@ def compress( # Prune dimensions with the lowest scores by setting them to 0 n_pruned = int(head_dim * self.key_channel_compression_ratio) indices = key_scores.topk(n_pruned, dim=-1, largest=False).indices - indices = indices.unsqueeze(2).expand(-1, -1, q_len, -1) + indices = indices.unsqueeze(2).expand(-1, -1, k_len, -1) keys = keys.scatter_(-1, indices, 0) return keys, values From 70128414633260388f0edf27c50acaee5c1a7cae Mon Sep 17 00:00:00 2001 From: Max Jeblick Date: Fri, 10 Oct 2025 15:39:36 +0200 Subject: [PATCH 10/20] address pr feedback Signed-off-by: Max Jeblick --- kvpress/pipeline.py | 6 ++---- kvpress/presses/snapkv_press.py | 2 +- tests/test_decoding_compression.py | 34 +++++++++++++----------------- 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index 76f345d3..a8765cab 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -203,7 +203,7 @@ def _forward( if cache is None: cache = DynamicCache() - # We only perform prefill compression if the press is not a decoding or prefill decoding press + # We only perform prefill compression if the press is a prefill press perform_prefill_compression = press is not None and not isinstance(press, DecodingPress) with press(self.model) if perform_prefill_compression else contextlib.nullcontext(): # We run the model without the lm head for pre-filling. @@ -232,9 +232,7 @@ def _forward( context_length=context_length, max_new_tokens=max_new_tokens, ) - if len(input_tensors["questions_ids"]) > 1: - print(f"Removing answer from cache: {cache_seq_lengths}") - self._remove_answer_from_cache(cache, cache_seq_lengths) + self._remove_answer_from_cache(cache, cache_seq_lengths) answers.append(answer) return answers diff --git a/kvpress/presses/snapkv_press.py b/kvpress/presses/snapkv_press.py index 684d40f4..1f3a46fd 100644 --- a/kvpress/presses/snapkv_press.py +++ b/kvpress/presses/snapkv_press.py @@ -82,7 +82,7 @@ def score( num_key_value_groups = module.config.num_attention_heads // num_key_value_heads assert ( - hidden_states.shape[1] > self.window_size + hidden_states.shape[1] > self.window_size ), f"Query length {hidden_states.shape[1]} should be greater than the window size {self.window_size}" if attentions is not None: diff --git a/tests/test_decoding_compression.py b/tests/test_decoding_compression.py index 90acb8e0..9fc2adc4 100644 --- a/tests/test_decoding_compression.py +++ b/tests/test_decoding_compression.py @@ -241,25 +241,21 @@ def test_all_presses_work_with_decoding_press(press_config): # Run pipeline cache = DynamicCache() - try: - result = pipe(context, question=question, press=decoding_press, cache=cache, max_new_tokens=15) - - # Verify compression worked - assert len(result["answer"]) > 0, f"No answer generated with {press_cls.__name__}" - - # Check that cache was compressed (allow some tolerance for rounding) - for layer_idx, cache_layer in enumerate(cache.layers): - layer_seq_len = cache_layer.keys.shape[2] - # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger - target_size = 48 - compression_steps = 3 # from the decoding press configuration - max_expected_size = target_size + compression_steps - 1 - assert ( - target_size <= layer_seq_len <= max_expected_size - ), f"{press_cls.__name__}: Layer {layer_idx} cache size {layer_seq_len} not in expected range [{target_size}-{max_expected_size}]" # noqa: E501 - - except Exception as e: - pytest.fail(f"DecodingPress failed with {press_cls.__name__}: {e}") + result = pipe(context, question=question, press=decoding_press, cache=cache, max_new_tokens=15) + + # Verify compression worked + assert len(result["answer"]) > 0, f"No answer generated with {press_cls.__name__}" + + # Check that cache was compressed (allow some tolerance for rounding) + for layer_idx, cache_layer in enumerate(cache.layers): + layer_seq_len = cache_layer.keys.shape[2] + # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger + target_size = 48 + compression_steps = 3 # from the decoding press configuration + max_expected_size = target_size + compression_steps - 1 + assert ( + target_size <= layer_seq_len <= max_expected_size + ), f"{press_cls.__name__}: Layer {layer_idx} cache size {layer_seq_len} not in expected range [{target_size}-{max_expected_size}]" # noqa: E501 def test_compression_actually_reduces_memory(): From 5c1687305eb605db269a0aa3ffbbdc76d17dd342 Mon Sep 17 00:00:00 2001 From: Max Jeblick Date: Fri, 10 Oct 2025 15:51:27 +0200 Subject: [PATCH 11/20] address pr feedback Signed-off-by: Max Jeblick --- tests/test_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 26f28f86..c298b6f8 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -157,7 +157,6 @@ def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811 values = [layer.values.clone() for layer in past_key_values.layers] cache_seq_lengths = [past_key_values.get_seq_length(layer_idx) for layer_idx in range(len(past_key_values))] compression_pipeline.generate_answer(input_ids_question, past_key_values, context_length=22, max_new_tokens=10) - compression_pipeline._remove_answer_from_cache(past_key_values, cache_seq_lengths) assert past_key_values.get_seq_length() == seq_len assert all([torch.allclose(key, layer.keys) for key, layer in zip(keys, past_key_values.layers)]) assert all([torch.allclose(value, layer.values) for value, layer in zip(values, past_key_values.layers)]) From 77b4352e8580801006d433fe4d2101f31bb765a7 Mon Sep 17 00:00:00 2001 From: Max Jeblick Date: Fri, 10 Oct 2025 15:53:51 +0200 Subject: [PATCH 12/20] address pr feedback Signed-off-by: Max Jeblick --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 038deacb..a16bfdc6 100644 --- a/README.md +++ b/README.md @@ -105,7 +105,7 @@ question = "Tell me a long story about this context" response = pipe(context, question=question, press=decoding_press)["answer"] ``` -> Not all existing presses are fully compatible with DecodingPress due to fundamental differences in how compression works during decoding versus prefilling. +> Not all existing presses are fully compatible with DecodingPress due to fundamental differences in how compression works during decoding versus prefilling. in particular, we only support ScorerPresses as base presses. From 2ca1a337de73ddf1e1851a12d41d61387a7b24d9 Mon Sep 17 00:00:00 2001 From: Max Jeblick Date: Fri, 10 Oct 2025 15:57:08 +0200 Subject: [PATCH 13/20] fix style Signed-off-by: Max Jeblick --- tests/integration/test_ruler.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/tests/integration/test_ruler.py b/tests/integration/test_ruler.py index 43cfe290..155cf6dd 100644 --- a/tests/integration/test_ruler.py +++ b/tests/integration/test_ruler.py @@ -6,6 +6,7 @@ import torch from transformers import DynamicCache, QuantoQuantizedCache from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available + from kvpress import QFilterPress from tests.default_presses import default_presses from tests.fixtures import kv_press_llama3_2_flash_attn_pipeline, kv_press_qwen3_flash_attn_pipeline # noqa: F401 @@ -58,12 +59,9 @@ def test_ruler_is_correct( # QFilterPress doesn't support Qwen3 4B. Will be tested in the next test class. return else: - pred_answer = kv_press_qwen3_flash_attn_pipeline( - context, - question=question, - press=press, - cache=cache - )["answer"] + pred_answer = kv_press_qwen3_flash_attn_pipeline(context, question=question, press=press, cache=cache)[ + "answer" + ] assert true_answer in pred_answer @@ -102,10 +100,7 @@ def test_ruler_is_correct_for_qfilter( question = df_ruler.iloc[idx]["question"] true_answer = df_ruler.iloc[idx]["answer"][0] - pred_answer = kv_press_llama3_2_flash_attn_pipeline( - context, - question=question, - press=press, - cache=cache - )["answer"] + pred_answer = kv_press_llama3_2_flash_attn_pipeline(context, question=question, press=press, cache=cache)[ + "answer" + ] assert true_answer in pred_answer From 537e4a26bd013ea5b21221ecc1f8a03cfa2d8d16 Mon Sep 17 00:00:00 2001 From: Max Jeblick Date: Fri, 10 Oct 2025 15:59:59 +0200 Subject: [PATCH 14/20] fix style Signed-off-by: Max Jeblick --- tests/test_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 473e174e..c3e6f0b4 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -155,7 +155,6 @@ def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811 keys = [layer.keys.clone() for layer in past_key_values.layers] values = [layer.values.clone() for layer in past_key_values.layers] - cache_seq_lengths = [past_key_values.get_seq_length(layer_idx) for layer_idx in range(len(past_key_values))] compression_pipeline.generate_answer(input_ids_question, past_key_values, context_length=22, max_new_tokens=10) assert past_key_values.get_seq_length() == seq_len assert all([torch.allclose(key, layer.keys) for key, layer in zip(keys, past_key_values.layers)]) From 1199bb75043f4bce845e784e8a1bb269c9b0888d Mon Sep 17 00:00:00 2001 From: Max Jeblick Date: Fri, 10 Oct 2025 16:31:49 +0200 Subject: [PATCH 15/20] fix broken test Signed-off-by: Max Jeblick --- tests/test_pipeline.py | 66 ++++++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index c3e6f0b4..26f28f86 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -43,34 +43,33 @@ def test_pipeline_with_cache(kv_press_unit_test_pipeline): # noqa: F811 assert isinstance(answers[0], str) -class TestPipelineFA2: - @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") - @pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") - @pytest.mark.parametrize("compression_ratio", [0.0, 0.2]) - def test_pipeline_fa2(self, kv_press_llama3_2_flash_attn_pipeline, compression_ratio): # noqa: F811 - context = "This is a test article. It was written on 2022-01-01." - questions = ["Repeat the last sentence"] - press = ExpectedAttentionPress(compression_ratio=compression_ratio) - cache = DynamicCache() - answers = kv_press_llama3_2_flash_attn_pipeline( - context, questions=questions, press=press, cache=cache, max_new_tokens=6 - )["answers"] - - assert len(answers) == 1 - assert isinstance(answers[0], str) - - kv_press_llama3_2_flash_attn_pipeline.model.set_attn_implementation("sdpa") - press = ExpectedAttentionPress(compression_ratio=compression_ratio) - cache = DynamicCache() - answers_sdpa = kv_press_llama3_2_flash_attn_pipeline( - context, questions=questions, press=press, cache=cache, max_new_tokens=6 - )["answers"] - kv_press_llama3_2_flash_attn_pipeline.model.set_attn_implementation("flash_attention_2") - - assert ( - answers_sdpa[0] == answers[0] - ), f"Answers from SDPA and Flash Attention 2 should be the same. \n{answers_sdpa[0]}\n{answers[0]}" - assert "This is a test" in answers[0], f"The answer should contain the context sentence, but got {answers[0]}." +@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") +@pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") +@pytest.mark.parametrize("compression_ratio", [0.0, 0.2]) +def test_pipeline_fa2(compression_ratio, kv_press_llama3_2_flash_attn_pipeline): # noqa: F811 + context = "This is a test article. It was written on 2022-01-01." + questions = ["Repeat the last sentence"] + press = ExpectedAttentionPress(compression_ratio=compression_ratio) + cache = DynamicCache() + answers = kv_press_llama3_2_flash_attn_pipeline( + context, questions=questions, press=press, cache=cache, max_new_tokens=6 + )["answers"] + + assert len(answers) == 1 + assert isinstance(answers[0], str) + + kv_press_llama3_2_flash_attn_pipeline.model.set_attn_implementation("sdpa") + press = ExpectedAttentionPress(compression_ratio=compression_ratio) + cache = DynamicCache() + answers_sdpa = kv_press_llama3_2_flash_attn_pipeline( + context, questions=questions, press=press, cache=cache, max_new_tokens=6 + )["answers"] + kv_press_llama3_2_flash_attn_pipeline.model.set_attn_implementation("flash_attention_2") + + assert ( + answers_sdpa[0] == answers[0] + ), f"Answers from SDPA and Flash Attention 2 should be the same. \n{answers_sdpa[0]}\n{answers[0]}" + assert "This is a test" in answers[0], f"The answer should contain the context sentence, but got {answers[0]}." @pytest.mark.parametrize("question", ["When was this article written?", ""]) @@ -93,6 +92,7 @@ def test_pipeline_no_press_works(kv_press_unit_test_pipeline, caplog): # noqa: kv_press_unit_test_pipeline(context, question=question) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") def test_pipeline_answer_is_correct(danube_500m_model, caplog): # noqa: F811 with caplog.at_level(logging.DEBUG): answers = generate_answer(danube_500m_model) @@ -142,7 +142,7 @@ def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811 model = unit_test_model questions = ["When was this article written?"] tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path) - device = model.device + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") compression_pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer, device=device) input_ids_question = tokenizer(questions[0], return_tensors="pt", add_special_tokens=False)["input_ids"].to(device) @@ -155,19 +155,23 @@ def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811 keys = [layer.keys.clone() for layer in past_key_values.layers] values = [layer.values.clone() for layer in past_key_values.layers] + cache_seq_lengths = [past_key_values.get_seq_length(layer_idx) for layer_idx in range(len(past_key_values))] compression_pipeline.generate_answer(input_ids_question, past_key_values, context_length=22, max_new_tokens=10) + compression_pipeline._remove_answer_from_cache(past_key_values, cache_seq_lengths) assert past_key_values.get_seq_length() == seq_len assert all([torch.allclose(key, layer.keys) for key, layer in zip(keys, past_key_values.layers)]) assert all([torch.allclose(value, layer.values) for value, layer in zip(values, past_key_values.layers)]) def generate_answer(model): - device = model.device + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model.to(device) context = "This is a test article. It was written on 2022-01-01." questions = ["When was this article written?", "When was this article written?"] press = ExpectedAttentionPress(compression_ratio=0.4) tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path) - answers = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer, device=device)( + answers = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer)( context, questions=questions, press=press )["answers"] + model.to(torch.device("cpu")) return answers From 28f07e50ff5e3675476c2ae91593f98f2837c3d4 Mon Sep 17 00:00:00 2001 From: Max Jeblick Date: Fri, 10 Oct 2025 16:51:03 +0200 Subject: [PATCH 16/20] fix broken test Signed-off-by: Max Jeblick --- tests/test_decoding_compression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_decoding_compression.py b/tests/test_decoding_compression.py index 9fc2adc4..b36dae4c 100644 --- a/tests/test_decoding_compression.py +++ b/tests/test_decoding_compression.py @@ -48,7 +48,7 @@ def test_decoding_compression(token_buffer_size): layer_seq_len = cache_layer.keys.shape[2] # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger max_expected_size = token_buffer_size + press.compression_interval - 1 - assert token_buffer_size <= layer_seq_len <= max_expected_size, ( + assert layer_seq_len <= max_expected_size, ( f"Layer {layer_idx}: Expected cache sequence length to be between {token_buffer_size} " f"and {max_expected_size}, but got {layer_seq_len}" ) From de51e03fe9f6d4bd47729f5767821c2922de6f4f Mon Sep 17 00:00:00 2001 From: Max Jeblick Date: Mon, 13 Oct 2025 10:16:29 +0200 Subject: [PATCH 17/20] wrap test in a class Signed-off-by: Max Jeblick --- tests/test_pipeline.py | 68 ++++++++++++++++++++---------------------- 1 file changed, 32 insertions(+), 36 deletions(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 26f28f86..0cbac965 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -43,33 +43,34 @@ def test_pipeline_with_cache(kv_press_unit_test_pipeline): # noqa: F811 assert isinstance(answers[0], str) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") -@pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") -@pytest.mark.parametrize("compression_ratio", [0.0, 0.2]) -def test_pipeline_fa2(compression_ratio, kv_press_llama3_2_flash_attn_pipeline): # noqa: F811 - context = "This is a test article. It was written on 2022-01-01." - questions = ["Repeat the last sentence"] - press = ExpectedAttentionPress(compression_ratio=compression_ratio) - cache = DynamicCache() - answers = kv_press_llama3_2_flash_attn_pipeline( - context, questions=questions, press=press, cache=cache, max_new_tokens=6 - )["answers"] - - assert len(answers) == 1 - assert isinstance(answers[0], str) - - kv_press_llama3_2_flash_attn_pipeline.model.set_attn_implementation("sdpa") - press = ExpectedAttentionPress(compression_ratio=compression_ratio) - cache = DynamicCache() - answers_sdpa = kv_press_llama3_2_flash_attn_pipeline( - context, questions=questions, press=press, cache=cache, max_new_tokens=6 - )["answers"] - kv_press_llama3_2_flash_attn_pipeline.model.set_attn_implementation("flash_attention_2") - - assert ( - answers_sdpa[0] == answers[0] - ), f"Answers from SDPA and Flash Attention 2 should be the same. \n{answers_sdpa[0]}\n{answers[0]}" - assert "This is a test" in answers[0], f"The answer should contain the context sentence, but got {answers[0]}." +class TestPipelineFA2: + @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") + @pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed") + @pytest.mark.parametrize("compression_ratio", [0.0, 0.2]) + def test_pipeline_fa2(self, kv_press_llama3_2_flash_attn_pipeline, compression_ratio): # noqa: F811 + context = "This is a test article. It was written on 2022-01-01." + questions = ["Repeat the last sentence"] + press = ExpectedAttentionPress(compression_ratio=compression_ratio) + cache = DynamicCache() + answers = kv_press_llama3_2_flash_attn_pipeline( + context, questions=questions, press=press, cache=cache, max_new_tokens=6 + )["answers"] + + assert len(answers) == 1 + assert isinstance(answers[0], str) + + kv_press_llama3_2_flash_attn_pipeline.model.set_attn_implementation("sdpa") + press = ExpectedAttentionPress(compression_ratio=compression_ratio) + cache = DynamicCache() + answers_sdpa = kv_press_llama3_2_flash_attn_pipeline( + context, questions=questions, press=press, cache=cache, max_new_tokens=6 + )["answers"] + kv_press_llama3_2_flash_attn_pipeline.model.set_attn_implementation("flash_attention_2") + + assert ( + answers_sdpa[0] == answers[0] + ), f"Answers from SDPA and Flash Attention 2 should be the same. \n{answers_sdpa[0]}\n{answers[0]}" + assert "This is a test" in answers[0], f"The answer should contain the context sentence, but got {answers[0]}." @pytest.mark.parametrize("question", ["When was this article written?", ""]) @@ -92,7 +93,6 @@ def test_pipeline_no_press_works(kv_press_unit_test_pipeline, caplog): # noqa: kv_press_unit_test_pipeline(context, question=question) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available") def test_pipeline_answer_is_correct(danube_500m_model, caplog): # noqa: F811 with caplog.at_level(logging.DEBUG): answers = generate_answer(danube_500m_model) @@ -142,7 +142,7 @@ def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811 model = unit_test_model questions = ["When was this article written?"] tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + device = model.device compression_pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer, device=device) input_ids_question = tokenizer(questions[0], return_tensors="pt", add_special_tokens=False)["input_ids"].to(device) @@ -155,23 +155,19 @@ def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811 keys = [layer.keys.clone() for layer in past_key_values.layers] values = [layer.values.clone() for layer in past_key_values.layers] - cache_seq_lengths = [past_key_values.get_seq_length(layer_idx) for layer_idx in range(len(past_key_values))] compression_pipeline.generate_answer(input_ids_question, past_key_values, context_length=22, max_new_tokens=10) - compression_pipeline._remove_answer_from_cache(past_key_values, cache_seq_lengths) assert past_key_values.get_seq_length() == seq_len assert all([torch.allclose(key, layer.keys) for key, layer in zip(keys, past_key_values.layers)]) assert all([torch.allclose(value, layer.values) for value, layer in zip(values, past_key_values.layers)]) def generate_answer(model): - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - model.to(device) + device = model.device context = "This is a test article. It was written on 2022-01-01." questions = ["When was this article written?", "When was this article written?"] press = ExpectedAttentionPress(compression_ratio=0.4) tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path) - answers = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer)( + answers = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer, device=device)( context, questions=questions, press=press )["answers"] - model.to(torch.device("cpu")) - return answers + return answers \ No newline at end of file From cb94097ba1311727d0ef266250020dee5277a168 Mon Sep 17 00:00:00 2001 From: Max Jeblick Date: Mon, 13 Oct 2025 10:22:22 +0200 Subject: [PATCH 18/20] fix format Signed-off-by: Max Jeblick --- tests/test_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 0cbac965..c3e6f0b4 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -170,4 +170,4 @@ def generate_answer(model): answers = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer, device=device)( context, questions=questions, press=press )["answers"] - return answers \ No newline at end of file + return answers From f5d73cb6694692909fe059e95c02c2067b4f4800 Mon Sep 17 00:00:00 2001 From: Max Jeblick Date: Mon, 13 Oct 2025 10:48:16 +0200 Subject: [PATCH 19/20] fix falling test Signed-off-by: Max Jeblick --- tests/test_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index c3e6f0b4..89daf663 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -155,8 +155,9 @@ def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811 keys = [layer.keys.clone() for layer in past_key_values.layers] values = [layer.values.clone() for layer in past_key_values.layers] + cache_seq_lengths = [past_key_values.get_seq_length(layer_idx) for layer_idx in range(len(past_key_values))] compression_pipeline.generate_answer(input_ids_question, past_key_values, context_length=22, max_new_tokens=10) - assert past_key_values.get_seq_length() == seq_len + compression_pipeline._remove_answer_from_cache(past_key_values, cache_seq_lengths) assert past_key_values.get_seq_length() == seq_len assert all([torch.allclose(key, layer.keys) for key, layer in zip(keys, past_key_values.layers)]) assert all([torch.allclose(value, layer.values) for value, layer in zip(values, past_key_values.layers)]) From ede8c3d1c7d580388bc77fcb87f9ba9a767368de Mon Sep 17 00:00:00 2001 From: Max Jeblick Date: Mon, 13 Oct 2025 10:48:35 +0200 Subject: [PATCH 20/20] fix falling test Signed-off-by: Max Jeblick --- tests/test_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 89daf663..f0686cdd 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -157,7 +157,8 @@ def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811 values = [layer.values.clone() for layer in past_key_values.layers] cache_seq_lengths = [past_key_values.get_seq_length(layer_idx) for layer_idx in range(len(past_key_values))] compression_pipeline.generate_answer(input_ids_question, past_key_values, context_length=22, max_new_tokens=10) - compression_pipeline._remove_answer_from_cache(past_key_values, cache_seq_lengths) assert past_key_values.get_seq_length() == seq_len + compression_pipeline._remove_answer_from_cache(past_key_values, cache_seq_lengths) + assert past_key_values.get_seq_length() == seq_len assert all([torch.allclose(key, layer.keys) for key, layer in zip(keys, past_key_values.layers)]) assert all([torch.allclose(value, layer.values) for value, layer in zip(values, past_key_values.layers)])