diff --git a/README.md b/README.md
index d47c63a2..a16bfdc6 100644
--- a/README.md
+++ b/README.md
@@ -67,15 +67,47 @@ answer = pipe(context, question=question, press=press)["answer"]
In the snippet above, the compression is only applied on the context tokens so that you can evaluate the compression for different questions. Check the [Wikipedia notebook demo](notebooks/wikipedia_demo.ipynb) for a more detailed example (also available on Colab [here](https://colab.research.google.com/drive/1JNvaTKuuAHrl49dYB9-mdEH_y52Ib-NP)).
-> [!IMPORTANT]
-> We focus on compression during the pre-filling phase as the KV cache becomes a bottleneck for long-context sequence (100k - 1M tokens) which are essentially long context prompts. This would typically apply to improving prompt caching systems.
+
+Decoding Compression
+
+By default, KVPress applies compression during the pre-filling phase. As a new (experimental) feature, we now support decoding compression via the `DecodingPress` wrapper. `DecodingPress` compresses the KV cache periodically during token generation, optionally maintaining a buffer of recent hidden states. `DecodingPress` supports the following parameters:
-> [!NOTE]
-> Use `model_kwargs={"attn_implementation":"flash_attention_2"}` to enable flash attention. To use the press `ObservedAttentionPress`, you need to specify `model_kwargs={"attn_implementation":"eager"}` as this press requires to materialize the attention weights
+- `base_press`: Any ScorerPress (e.g., `KNormPress`, `CriticalKVPress`)
+- `compression_interval`: Steps between compressions (default: 10)
+- `target_size`: Target cache size of the cache after compression (default: 1024)
+- `hidden_states_buffer_size`: Number of hidden states to buffer before compression (default: 128). Some presses don't need buffered hidden states and can set this to 0.
-## Contributing
+Unlike a compression ratio, decoding press uses a `target_size` to compress the cache. This means that the cache is compressed every `compression_interval` steps, and the compression ratio is automatically computed such that the size of the cache after compression equals `target_size`.
-We welcome contributions! To add a new press, simply open an issue or submit a pull request. Check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide.
+An example for decoding compression:
+
+```python
+from transformers import pipeline
+from kvpress import KnormPress
+from kvpress import DecodingPress
+
+# Initialize the pipeline
+device = "cuda:0"
+model = "meta-llama/Llama-3.1-8B-Instruct"
+model_kwargs = {"attn_implementation": "flash_attention_2"}
+pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)
+
+# Create a decoding press that compresses every 10 steps to 512 tokens
+decoding_press = DecodingPress(
+ base_press=KnormPress(),
+ compression_steps=10,
+ token_buffer_size=512
+)
+
+# Use with pipeline
+context = "A very long text you want to compress during generation"
+question = "Tell me a long story about this context"
+response = pipe(context, question=question, press=decoding_press)["answer"]
+```
+
+> Not all existing presses are fully compatible with DecodingPress due to fundamental differences in how compression works during decoding versus prefilling. in particular, we only support ScorerPresses as base presses.
+
+
## Available presses
@@ -111,6 +143,8 @@ Finally we provide wrapper presses that can be combined with other presses:
- `ChunkPress` ([source](kvpress/presses/chunk_press.py), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): compress the KV cache on each sequence chunk separately. This can yield to more uniform compression across long sequences
- `CriticalKVPress` and `CriticalAdaKVPress` ([source](kvpress/presses/criticalkv_press.py), [paper](https://arxiv.org/abs/2502.03805)): refine the scores using the L1 norm of Wo @ values, coupled with a two-stage selection.
- `BlockPress` ([source](kvpress/presses/block_press.py), [paper](https://arxiv.org/abs/2504.15364)): segments input sequence into non-overlapping blocks and compresses iteratively.
+- `DeocdingPress` ([source](kvpress/presses/decoding_press.py)): Allows for compression during decoding, see decoding section in this README.
+- `PrefillDecodingPress` ([source](kvpress/presses/prefill_decoding_press.py)): Allows to compress both during prefilling and during decoding.
For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)
@@ -131,6 +165,7 @@ Below we report the average performance on the RULER dataset with 4k context len
+
## Quantization
We support KV cache quantization through the transformers `QuantizedCache` class (see [HF blog post](https://huggingface.co/blog/kv-cache-quantization#how-to-use-quantized-kv-cache-in-%F0%9F%A4%97-transformers)). To use it, simply pass a cache object to your pipeline:
@@ -149,6 +184,10 @@ By default, the `DynamicCache` is used (no quantization).
> [!IMPORTANT]
> To use the `QuantizedCache`, you need to install additional dependencies (_e.g._ `pip install optimum-quanto`).
+## Contributing
+
+We welcome contributions! To add a new press, simply open an issue or submit a pull request. Check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide.
+
## FAQ
@@ -224,4 +263,80 @@ with press(model):
However, the `generate` method does not allow to exclude the question from the compression, which would artificially favors methods such as SnapKV. Ideally, we want a compression method that works whatever comes after the context (_e.g._ for use cases such as chat or document question answering). Finally the `generate` method does not allow to provide generation for multiple questions at once.
+
+
+
+
+
+
+### Can I combine compression during prefilling and decoding ?
+
+
+
+Combines separate presses for prefilling and decoding phases.
+
+**Parameters:**
+- `prefilling_press`: Press used during prefill phase
+- `decoding_press`: Press used during decoding phase
+
+## Usage Examples
+
+### Basic Decoding Compression
+
+```python
+from transformers import pipeline
+from kvpress import KnormPress
+from kvpress import DecodingPress
+
+# Initialize the pipeline
+device = "cuda:0"
+model = "meta-llama/Llama-3.1-8B-Instruct"
+model_kwargs = {"attn_implementation": "flash_attention_2"}
+pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)
+
+# Create a decoding press that compresses every 10 steps to 512 tokens
+decoding_press = DecodingPress(
+ base_press=KnormPress(),
+ compression_steps=10,
+ token_buffer_size=512
+)
+
+# Use with pipeline
+context = "A very long text you want to compress during generation"
+question = "Tell me a long story about this context"
+response = pipe(context, question=question, press=decoding_press)["answer"]
+```
+
+### Combined Prefill + Decoding Compression
+
+```python
+from transformers import pipeline
+from kvpress import CriticalKVPress, KnormPress
+from kvpress import DecodingPress, PrefillDecodingPress
+
+# Initialize the pipeline
+device = "cuda:0"
+model = "meta-llama/Llama-3.1-8B-Instruct"
+model_kwargs = {"attn_implementation": "flash_attention_2"}
+pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)
+
+# Different strategies for prefill vs decoding
+prefill_press = CriticalKVPress(KnormPress())
+decoding_press = DecodingPress(
+ base_press=KnormPress(compression_ratio=0.2),
+ compression_steps=5,
+ token_buffer_size=256
+)
+
+# Combine them
+combined_press = PrefillDecodingPress(
+ prefilling_press=prefill_press,
+ decoding_press=decoding_press
+)
+
+context = "A very long context that will be compressed during prefill"
+question = "Generate a detailed analysis that will be compressed during decoding"
+response = pipe(context, question=question, press=combined_press)["answer"]
+```
+
\ No newline at end of file
diff --git a/kvpress/__init__.py b/kvpress/__init__.py
index cccd1f02..19165465 100644
--- a/kvpress/__init__.py
+++ b/kvpress/__init__.py
@@ -1,68 +1,74 @@
-# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-
-
-from kvpress.attention_patch import patch_attention_functions
-from kvpress.pipeline import KVPressTextGenerationPipeline
-from kvpress.presses.adakv_press import AdaKVPress
-from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress
-from kvpress.presses.block_press import BlockPress
-from kvpress.presses.chunk_press import ChunkPress
-from kvpress.presses.chunkkv_press import ChunkKVPress
-from kvpress.presses.composed_press import ComposedPress
-from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress
-from kvpress.presses.duo_attention_press import DuoAttentionPress
-from kvpress.presses.expected_attention_press import ExpectedAttentionPress
-from kvpress.presses.expected_attention_with_stats import ExpectedAttentionStatsPress
-from kvpress.presses.finch_press import FinchPress
-from kvpress.presses.key_rerotation_press import KeyRerotationPress
-from kvpress.presses.keydiff_press import KeyDiffPress
-from kvpress.presses.knorm_press import KnormPress
-from kvpress.presses.kvzip_press import KVzipPress
-from kvpress.presses.lagkv_press import LagKVPress
-from kvpress.presses.observed_attention_press import ObservedAttentionPress
-from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
-from kvpress.presses.pyramidkv_press import PyramidKVPress
-from kvpress.presses.qfilter_press import QFilterPress
-from kvpress.presses.random_press import RandomPress
-from kvpress.presses.scorer_press import ScorerPress
-from kvpress.presses.simlayerkv_press import SimLayerKVPress
-from kvpress.presses.snapkv_press import SnapKVPress
-from kvpress.presses.streaming_llm_press import StreamingLLMPress
-from kvpress.presses.think_press import ThinKPress
-from kvpress.presses.tova_press import TOVAPress
-
-# Patch the attention functions to support head-wise compression
-patch_attention_functions()
-
-__all__ = [
- "CriticalAdaKVPress",
- "CriticalKVPress",
- "AdaKVPress",
- "BasePress",
- "ComposedPress",
- "ScorerPress",
- "ExpectedAttentionPress",
- "KnormPress",
- "ObservedAttentionPress",
- "RandomPress",
- "SimLayerKVPress",
- "SnapKVPress",
- "StreamingLLMPress",
- "ThinKPress",
- "TOVAPress",
- "KVPressTextGenerationPipeline",
- "PerLayerCompressionPress",
- "KeyRerotationPress",
- "ChunkPress",
- "DuoAttentionPress",
- "ChunkKVPress",
- "QFilterPress",
- "PyramidKVPress",
- "FinchPress",
- "LagKVPress",
- "BlockPress",
- "KeyDiffPress",
- "KVzipPress",
- "ExpectedAttentionStatsPress",
-]
+# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+
+from kvpress.attention_patch import patch_attention_functions
+from kvpress.pipeline import KVPressTextGenerationPipeline
+from kvpress.presses.adakv_press import AdaKVPress
+from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress
+from kvpress.presses.block_press import BlockPress
+from kvpress.presses.chunk_press import ChunkPress
+from kvpress.presses.chunkkv_press import ChunkKVPress
+from kvpress.presses.composed_press import ComposedPress
+from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress
+from kvpress.presses.decoding_press import DecodingPress
+from kvpress.presses.duo_attention_press import DuoAttentionPress
+from kvpress.presses.expected_attention_press import ExpectedAttentionPress
+from kvpress.presses.expected_attention_with_stats import ExpectedAttentionStatsPress
+from kvpress.presses.finch_press import FinchPress
+from kvpress.presses.key_rerotation_press import KeyRerotationPress
+from kvpress.presses.keydiff_press import KeyDiffPress
+from kvpress.presses.knorm_press import KnormPress
+from kvpress.presses.kvzip_press import KVzipPress
+from kvpress.presses.lagkv_press import LagKVPress
+from kvpress.presses.observed_attention_press import ObservedAttentionPress
+from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
+from kvpress.presses.prefill_decoding_press import PrefillDecodingPress
+from kvpress.presses.pyramidkv_press import PyramidKVPress
+from kvpress.presses.qfilter_press import QFilterPress
+from kvpress.presses.random_press import RandomPress
+from kvpress.presses.scorer_press import ScorerPress
+from kvpress.presses.simlayerkv_press import SimLayerKVPress
+from kvpress.presses.snapkv_press import SnapKVPress
+from kvpress.presses.streaming_llm_press import StreamingLLMPress
+from kvpress.presses.think_press import ThinKPress
+from kvpress.presses.tova_press import TOVAPress
+
+# Patch the attention functions to support head-wise compression
+patch_attention_functions()
+
+__all__ = [
+ "CriticalAdaKVPress",
+ "CriticalKVPress",
+ "AdaKVPress",
+ "BasePress",
+ "ComposedPress",
+ "ScorerPress",
+ "ExpectedAttentionPress",
+ "KnormPress",
+ "ObservedAttentionPress",
+ "RandomPress",
+ "SimLayerKVPress",
+ "SnapKVPress",
+ "StreamingLLMPress",
+ "ThinKPress",
+ "TOVAPress",
+ "KVPressTextGenerationPipeline",
+ "PerLayerCompressionPress",
+ "KeyRerotationPress",
+ "ChunkPress",
+ "DuoAttentionPress",
+ "ChunkKVPress",
+ "QFilterPress",
+ "PyramidKVPress",
+ "FinchPress",
+ "LagKVPress",
+ "BlockPress",
+ "KeyDiffPress",
+ "KVzipPress",
+ "DecodingPress",
+ "PrefillDecodingPress",
+ "ExpectedAttentionStatsPress",
+ "DecodingPress",
+ "PrefillDecodingPress",
+]
diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py
index 55eb506a..a8765cab 100644
--- a/kvpress/pipeline.py
+++ b/kvpress/pipeline.py
@@ -12,9 +12,11 @@
from transformers.pipelines.base import GenericTensor
from kvpress.presses.base_press import BasePress
+from kvpress.presses.decoding_press import DecodingPress
from kvpress.presses.finch_press import FinchPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
+from kvpress.presses.prefill_decoding_press import PrefillDecodingPress
logger = logging.getLogger(__name__)
@@ -189,6 +191,11 @@ def _forward(
list[str]
Generated answers for each input question.
"""
+ if isinstance(press, (DecodingPress, PrefillDecodingPress)) and len(input_tensors["questions_ids"]) > 1:
+ raise ValueError(
+ "DecodingPress is not compatible with multiple questions. Please specify a single question."
+ )
+
context_ids = input_tensors["context_ids"].to(self.model.device)
context_length = context_ids.shape[1]
@@ -196,7 +203,9 @@ def _forward(
if cache is None:
cache = DynamicCache()
- with press(self.model) if press is not None else contextlib.nullcontext():
+ # We only perform prefill compression if the press is a prefill press
+ perform_prefill_compression = press is not None and not isinstance(press, DecodingPress)
+ with press(self.model) if perform_prefill_compression else contextlib.nullcontext():
# We run the model without the lm head for pre-filling.
self.model.model(
input_ids=context_ids,
@@ -204,24 +213,44 @@ def _forward(
output_attentions=self.output_attentions(press),
)
- logger.debug(f"Context Length: {context_length}")
- logger.debug(f"Compressed Context Length: {cache.get_seq_length()}")
+ logger.debug(f"Context Length: {context_length}")
+ logger.debug(f"Compressed Context Length: {cache.get_seq_length()}")
+
+ # We only perform decoding compression if the press is a decoding or prefill decoding press
+ perform_decoding_compression = press is not None and isinstance(press, (DecodingPress, PrefillDecodingPress))
+ with press(self.model) if perform_decoding_compression else contextlib.nullcontext():
+ # Greedy decoding for each question
+ answers = []
+ for question_ids in input_tensors["questions_ids"]:
+ if isinstance(press, KeyRerotationPress) or (isinstance(press, FinchPress) and press.rerotate_keys):
+ context_length = cache.get_seq_length()
+
+ cache_seq_lengths = [cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))]
+ answer = self.generate_answer(
+ question_ids=question_ids.to(self.model.device),
+ cache=cache,
+ context_length=context_length,
+ max_new_tokens=max_new_tokens,
+ )
+ self._remove_answer_from_cache(cache, cache_seq_lengths)
+
+ answers.append(answer)
+ return answers
- # Greedy decoding for each question
- answers = []
- for question_ids in input_tensors["questions_ids"]:
- if isinstance(press, KeyRerotationPress) or (isinstance(press, FinchPress) and press.rerotate_keys):
- context_length = cache.get_seq_length()
+ def _remove_answer_from_cache(self, cache: Cache, cache_seq_lengths: list[int]):
- answer = self.generate_answer(
- question_ids=question_ids.to(self.model.device),
- cache=cache,
- context_length=context_length,
- max_new_tokens=max_new_tokens,
- )
- answers.append(answer)
+ for layer_idx, sequence_length in enumerate(cache_seq_lengths):
+ cache.layers[layer_idx].keys = cache.layers[layer_idx].keys[:, :, :sequence_length]
+ cache.layers[layer_idx].values = cache.layers[layer_idx].values[:, :, :sequence_length]
- return answers
+ if isinstance(cache, QuantizedCache):
+ for layer_idx, sequence_length in enumerate(cache_seq_lengths):
+ cache.layers[layer_idx]._quantized_keys = cache.layers[layer_idx]._quantized_keys[
+ :, :, :sequence_length
+ ]
+ cache.layers[layer_idx]._quantized_values = cache.layers[layer_idx]._quantized_values[
+ :, :, :sequence_length
+ ]
def generate_answer(
self, question_ids: torch.Tensor, cache: Cache, context_length: int, max_new_tokens: int
@@ -245,7 +274,6 @@ def generate_answer(
str
The generated answer.
"""
- cache_seq_lengths = [cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))]
position_ids = torch.arange(
context_length, context_length + question_ids.shape[1], device=self.model.device
).unsqueeze(0)
@@ -276,20 +304,6 @@ def generate_answer(
if new_id.item() in should_stop_token_ids:
break
answer = self.tokenizer.decode(torch.stack(generated_ids), skip_special_tokens=True)
- # Remove the generated tokens from the cache
- for layer_idx, sequence_length in enumerate(cache_seq_lengths):
- cache.layers[layer_idx].keys = cache.layers[layer_idx].keys[:, :, :sequence_length]
- cache.layers[layer_idx].values = cache.layers[layer_idx].values[:, :, :sequence_length]
-
- if isinstance(cache, QuantizedCache):
- for layer_idx, sequence_length in enumerate(cache_seq_lengths):
- cache.layers[layer_idx]._quantized_keys = cache.layers[layer_idx]._quantized_keys[
- :, :, :sequence_length
- ]
- cache.layers[layer_idx]._quantized_values = cache.layers[layer_idx]._quantized_values[
- :, :, :sequence_length
- ]
-
return answer
def output_attentions(self, press: BasePress):
diff --git a/kvpress/presses/adakv_press.py b/kvpress/presses/adakv_press.py
index 22393678..815c777a 100644
--- a/kvpress/presses/adakv_press.py
+++ b/kvpress/presses/adakv_press.py
@@ -55,21 +55,21 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs):
# Compute scores
scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs)
- bsz, num_key_value_heads, q_len = scores.shape
+ bsz, num_key_value_heads, k_len = scores.shape
# Make sure to keep at least alpha * (1 - compression_ratio) KV pairs per head
- n_kept = int(q_len * (1 - self.compression_ratio)) # ScorerPress definition
+ n_kept = int(k_len * (1 - self.compression_ratio)) # ScorerPress definition
n_safe = int(n_kept * self.alpha_safeguard)
top_indices = torch.topk(scores, n_safe, dim=-1).indices
scores.scatter_(-1, top_indices, torch.finfo(scores.dtype).max)
# Compute bottom-k across heads
- n_pruned = num_key_value_heads * (q_len - n_kept)
+ n_pruned = num_key_value_heads * (k_len - n_kept)
indices = torch.topk(-scores.reshape(bsz, -1), n_pruned, dim=1).indices.flatten()
# Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details
batch_indices = torch.arange(bsz).repeat_interleave(n_pruned)
- head_indices = indices // q_len
- seq_indices = indices % q_len
+ head_indices = indices // k_len
+ seq_indices = indices % k_len
module.masked_key_indices = (batch_indices, head_indices, seq_indices)
return keys, values
diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py
index 7d40fcd7..8fb4acf2 100644
--- a/kvpress/presses/base_press.py
+++ b/kvpress/presses/base_press.py
@@ -20,6 +20,8 @@
Qwen3ForCausalLM,
)
+from kvpress.presses.utils import extract_keys_and_values
+
logger = logging.getLogger(__name__)
SUPPORTED_MODELS = (
@@ -124,24 +126,14 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
hidden_states = kwargs["hidden_states"]
cache = kwargs["past_key_values"]
+ cache_layer = cache.layers[module.layer_idx]
q_len = hidden_states.shape[1]
# Don't compress after pre-filling
if kwargs["cache_position"][-1] > q_len:
return output
- cache_layer = cache.layers[module.layer_idx]
- if isinstance(cache, QuantizedCache):
- keys = cache_layer._dequantize( # type: ignore[index]
- cache_layer._quantized_keys # type: ignore[index]
- )
- values = cache_layer._dequantize( # type: ignore[index]
- cache_layer._quantized_values # type: ignore[index]
- )
-
- else:
- keys = cache_layer.keys
- values = cache_layer.values
+ keys, values = extract_keys_and_values(cache, module.layer_idx)
keys, values = self.compress(module, hidden_states, keys, values, output[1], kwargs)
diff --git a/kvpress/presses/block_press.py b/kvpress/presses/block_press.py
index 6d8da788..b4edafdb 100644
--- a/kvpress/presses/block_press.py
+++ b/kvpress/presses/block_press.py
@@ -58,18 +58,18 @@ def compress(
assert attentions is None, "BlockPress does not support attentions."
- bsz, num_key_value_heads, q_len, head_dim = keys.shape
+ bsz, num_key_value_heads, k_len, head_dim = keys.shape
- block_size = self.block_size if self.block_size < q_len else q_len
- n_kept = int(q_len * (1 - self.compression_ratio))
+ block_size = self.block_size if self.block_size < k_len else k_len
+ n_kept = int(k_len * (1 - self.compression_ratio))
kept_indices = torch.arange(n_kept, device=keys.device).expand(bsz, num_key_value_heads, -1)
# Reshape hidden states to match the kept_indices
- states = hidden_states.view(bsz, q_len, num_key_value_heads, -1).transpose(1, 2)
+ states = hidden_states.view(bsz, k_len, num_key_value_heads, -1).transpose(1, 2)
- for i in range(n_kept, q_len, block_size):
- end = min(i + block_size, q_len)
+ for i in range(n_kept, k_len, block_size):
+ end = min(i + block_size, k_len)
current_indices = torch.arange(i, end, device=keys.device).expand(bsz, num_key_value_heads, -1)
current_indices = torch.cat([kept_indices, current_indices], dim=-1)
diff --git a/kvpress/presses/criticalkv_press.py b/kvpress/presses/criticalkv_press.py
index 5513bc67..abcb8b83 100644
--- a/kvpress/presses/criticalkv_press.py
+++ b/kvpress/presses/criticalkv_press.py
@@ -53,7 +53,7 @@ def compression_ratio(self, value):
@staticmethod
def vwl1norm(values, module):
- bsz, num_key_value_heads, q_len, _ = values.shape
+ bsz, num_key_value_heads, k_len, _ = values.shape
num_key_value_groups = module.config.num_attention_heads // num_key_value_heads
Wo = module.o_proj.weight.transpose(0, 1)
Wo = Wo.view(module.config.num_attention_heads, module.config.head_dim, module.config.hidden_size)
@@ -67,16 +67,16 @@ def vwl1norm(values, module):
head_WoV_norm = torch.norm(head_WoV, p=1, dim=-1)
head_WoV_norm_list.append(head_WoV_norm)
- # b_size, num_heads, q_len , k_len
+ # b_size, num_heads, k_len , k_len
WoV_norm = torch.stack(head_WoV_norm_list, dim=1)
- WoV_norm = WoV_norm.view(bsz, num_key_value_heads, module.num_key_value_groups, q_len).mean(dim=2)
+ WoV_norm = WoV_norm.view(bsz, num_key_value_heads, module.num_key_value_groups, k_len).mean(dim=2)
return WoV_norm
def score(self, module, hidden_states, keys, values, attentions, kwargs):
# Stage 1
scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs)
- q_len = keys.shape[2]
- selection_budget = int((1 - self.compression_ratio) * q_len * self.first_stage_ratio)
+ k_len = keys.shape[2]
+ selection_budget = int((1 - self.compression_ratio) * k_len * self.first_stage_ratio)
top_k_index = torch.topk(scores, selection_budget, sorted=True, dim=-1).indices
# Stage 2
@@ -140,10 +140,10 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs):
# Compute scores
scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs)
- bsz, num_key_value_heads, q_len = scores.shape
+ bsz, num_key_value_heads, k_len = scores.shape
# Make sure to keep at least alpha * (1 - compression_ratio) KV pairs per head
- n_kept = int(q_len * (1 - self.compression_ratio)) # ScorerPress definition
+ n_kept = int(k_len * (1 - self.compression_ratio)) # ScorerPress definition
n_safe = int(n_kept * self.alpha_safeguard)
top_indices = torch.topk(scores, n_safe, dim=-1).indices
scores.scatter_(-1, top_indices, torch.finfo(scores.dtype).max)
@@ -156,7 +156,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs):
budget_scores = scores.scatter(-1, top_indices, torch.finfo(scores.dtype).max)
budget_scores = budget_scores.reshape(bsz, -1)
top_indices = torch.topk(budget_scores, n_kept * num_key_value_heads, dim=-1).indices
- top_indices_head_idx = top_indices // q_len
+ top_indices_head_idx = top_indices // k_len
head_budgets = torch.zeros(num_key_value_heads, device=keys.device, dtype=torch.int64)
head_budgets.scatter_add_(0, top_indices_head_idx.flatten(), torch.ones_like(top_indices_head_idx.flatten()))
@@ -180,12 +180,12 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs):
##########################
# Compute bottom-k across heads
- n_pruned = num_key_value_heads * (q_len - n_kept)
+ n_pruned = num_key_value_heads * (k_len - n_kept)
indices = torch.topk(-scores.reshape(bsz, -1), n_pruned, dim=1).indices.flatten()
# Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details
batch_indices = torch.arange(bsz).repeat_interleave(n_pruned)
- head_indices = indices // q_len
- seq_indices = indices % q_len
+ head_indices = indices // k_len
+ seq_indices = indices % k_len
module.masked_key_indices = (batch_indices, head_indices, seq_indices)
return keys, values
diff --git a/kvpress/presses/decoding_press.py b/kvpress/presses/decoding_press.py
new file mode 100644
index 00000000..ec583ab6
--- /dev/null
+++ b/kvpress/presses/decoding_press.py
@@ -0,0 +1,220 @@
+# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+import logging
+from collections import defaultdict
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+from transformers.cache_utils import QuantizedCache
+
+from kvpress.presses.base_press import BasePress
+from kvpress.presses.scorer_press import ScorerPress
+from kvpress.presses.utils import extract_keys_and_values
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class DecodingPress(BasePress):
+ """
+ A press that only operates during decoding phase and maintains a running buffer of hidden states.
+
+ This press accumulates hidden states during decoding and applies compression every N steps
+ using a scorer press to determine which tokens to keep.
+
+
+ Parameters
+ ----------
+ base_press : ScorerPress
+ The scorer press used to compute importance scores for tokens.
+ compression_interval : int, default=10
+ Number of decoding steps between compression, i.e. compression will be applied every compression_interval steps.
+ target_size : int, default=1024
+ Target number of tokens to keep after compression.
+ hidden_states_buffer_size : int, default=128
+ Maximum number of hidden states to keep before compression. Larger values use more GPU memory.
+ Note: Some presses don't need buffered hidden states and can set this to 0 to use only the
+ current hidden state for compression scoring.
+ """
+
+ base_press: ScorerPress
+ compression_interval: int = 128
+ target_size: int = 1024
+ hidden_states_buffer_size: int = 128
+
+ def __post_init__(self):
+ # Buffer to store hidden states during decoding (per layer)
+ assert isinstance(self.base_press, ScorerPress), "DecodingPress requires a ScorerPress as input"
+ self.hidden_states_buffer = defaultdict(list) # Per-layer buffer
+ self.layer_step_counts = defaultdict(int) # Track step count per layer
+
+ assert self.compression_interval > 0, "compression_interval must be greater than 0"
+ assert self.target_size > 0, "target_size must be greater than 0"
+
+ if self.base_press.compression_ratio:
+ logger.warning(
+ f"compression_ratio is set for base press ({self.base_press.compression_ratio}). "
+ f"This will be overridden by the decoding press."
+ )
+
+ def compress(
+ self,
+ module: nn.Module,
+ hidden_states: torch.Tensor,
+ keys: torch.Tensor,
+ values: torch.Tensor,
+ attentions: torch.Tensor,
+ kwargs: dict,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Delegate compression to the base press during decoding phase.
+
+ Args:
+ module: The transformer module being compressed
+ hidden_states: Buffered hidden states from recent decoding steps (shape: [batch, buffer_len, hidden_dim])
+ keys: Key cache from all previous steps including current (shape: [batch, n_heads, seq_len, head_dim])
+ values: Value cache from all previous steps including current (shape: [batch, n_heads, seq_len, head_dim])
+ attentions: Attention weights (shape varies by implementation)
+ kwargs: Additional keyword arguments
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]: Compressed (keys, values) tensors
+
+ Note:
+ **Sequence length alignment**: During decoding compression, `hidden_states` contains the
+ buffered hidden states from recent decoding steps (buffer_len tokens), while `keys` and
+ `values` contain the full sequence history (seq_len tokens). The base press implementation
+ should use keys.shape[2] for full sequence length calculations. The buffered hidden_states
+ provide context for the most recent tokens when computing compression scores.
+
+ Performance Note:
+ It would be possible to speed up compression during decoding for certain scorer presses by
+ storing existing scores in a buffer (e.g. KNormPress) and reusing them in subsequent compressions.
+ """
+ k_len = keys.shape[2]
+ target_compression_ratio = self._find_target_compression_ratio(k_len, self.target_size)
+ logger.debug(f"Compressing {k_len} to {self.target_size} with ratio {target_compression_ratio}")
+
+ original_compression_ratio = self.base_press.compression_ratio
+ self.base_press.compression_ratio = target_compression_ratio
+ result = self.base_press.compress(module, hidden_states, keys, values, attentions, kwargs)
+ self.base_press.compression_ratio = original_compression_ratio
+ return result
+
+ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
+ """
+ Forward hook that manages decoding-specific compression logic.
+
+ This hook:
+ 1. Detects when we're in decoding phase (not prefilling)
+ 2. Accumulates hidden states in a buffer
+ 3. Applies compression every N steps
+ 4. Clears the buffer after compression
+ """
+ hidden_states = kwargs["hidden_states"]
+ cache = kwargs["past_key_values"]
+ q_len = hidden_states.shape[1]
+ layer_idx = module.layer_idx
+
+ # Only operate during decoding phase (after prefilling)
+ if kwargs["cache_position"][-1] <= q_len:
+ # We're still in prefilling phase, don't do anything
+ return output
+ # print(f"Adding hidden states to buffer: {hidden_states.shape}")
+ # Add current hidden states to buffer for this layer
+ self.hidden_states_buffer[layer_idx].append(hidden_states.detach().clone())
+
+ # print(f"Layer step counts: {self.layer_step_counts[layer_idx]}")
+ self.layer_step_counts[layer_idx] += 1
+
+ # Apply compression if we've reached the compression step threshold
+ if (self.layer_step_counts[layer_idx] >= self.compression_interval) or (q_len >= self.target_size):
+ logger.debug(
+ f"Applying decoding compression: layer_step_count ({self.layer_step_counts[layer_idx]}) >= compression_steps ({self.compression_interval})" # noqa: E501
+ )
+
+ cache_layer = cache.layers[module.layer_idx]
+ keys, values = extract_keys_and_values(cache, module.layer_idx)
+
+ # Get attention weights from output
+ attentions = output[1] if len(output) > 1 and output[1] is not None else None
+
+ # Apply compression using buffered hidden states for this layer
+ buffered_hidden_states = torch.cat(self.hidden_states_buffer[layer_idx], dim=1)
+ keys, values = self.compress(module, buffered_hidden_states, keys, values, attentions, kwargs)
+ logger.debug(f"Applied decoding compression: " f"keys.shape: {keys.shape}, values.shape: {values.shape}")
+
+ # Update cache with compressed keys and values
+ if isinstance(cache, QuantizedCache):
+ cache_layer._quantized_keys = cache_layer._quantize(keys, axis=cache_layer.axis_key)
+ cache_layer._quantized_values = cache_layer._quantize(values, axis=cache_layer.axis_value)
+ cache_layer.keys = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index]
+ cache_layer.values = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index]
+ cache_layer.cumulative_length = keys.shape[2]
+ else:
+ cache_layer.keys = keys
+ cache_layer.values = values
+
+ # Reset step count and clear buffer for this layer
+ self.layer_step_counts[layer_idx] = 0
+ # Always clear the buffer after compression - otherwise there's a mismatch between
+ # hidden states buffer and kv cache
+ self.hidden_states_buffer[layer_idx] = []
+
+ self.hidden_states_buffer[layer_idx] = (
+ self.hidden_states_buffer[layer_idx][-self.hidden_states_buffer_size :]
+ if self.hidden_states_buffer_size > 0
+ else []
+ )
+ return output
+
+ def reset(self):
+ """Reset the decoding press state."""
+ self.hidden_states_buffer = defaultdict(list)
+ self.layer_step_counts = defaultdict(int)
+
+ def _find_target_compression_ratio(self, q_len: int, target_tokens: int) -> float:
+ """
+ Find the compression ratio that results in exactly target_tokens after int() rounding.
+
+ Args:
+ q_len: Current sequence length
+ target_tokens: Desired number of tokens after compression
+
+ Returns:
+ Compression ratio that gives exactly target_tokens
+ """
+ if q_len <= target_tokens:
+ return 0.0
+
+ # Start with theoretical ratio
+ ratio = 1.0 - (target_tokens / q_len)
+
+ # Binary search to handle int() rounding
+ low, high = 0.0, 1.0
+ max_iterations = 20
+ iteration = 0
+
+ while iteration < max_iterations:
+ n_kept = int(q_len * (1 - ratio))
+ if n_kept == target_tokens:
+ break
+ elif n_kept > target_tokens:
+ # Need more compression
+ low = ratio
+ ratio = (ratio + high) / 2
+ else:
+ # Need less compression
+ high = ratio
+ ratio = (low + ratio) / 2
+ iteration += 1
+
+ final_n_kept = int(q_len * (1 - ratio))
+ if final_n_kept != target_tokens:
+ logger.warning(
+ f"Binary search failed: q_len={q_len}, target={target_tokens}, got={final_n_kept}, ratio={ratio}"
+ )
+
+ return ratio
diff --git a/kvpress/presses/duo_attention_press.py b/kvpress/presses/duo_attention_press.py
index b88d2f03..2843f8dd 100644
--- a/kvpress/presses/duo_attention_press.py
+++ b/kvpress/presses/duo_attention_press.py
@@ -103,9 +103,9 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs):
raise ValueError(
"Streaming mask not initialized. Make sure to call __post_init_from_model__ to initialize this press."
)
- q_len = hidden_states.shape[1]
+ k_len = keys.shape[2]
- if (self.head_compression_ratio > 0) or (q_len > (self.sink_size + self.recent_size)):
+ if (self.head_compression_ratio > 0) or (k_len > (self.sink_size + self.recent_size)):
# Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details
masked_keys = torch.zeros_like(keys[..., 0], dtype=torch.bool)
@@ -114,7 +114,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs):
# Compute the compression ratio
self.compression_ratio_ = self.streaming_mask.float().mean().item()
- self.compression_ratio_ *= 1 - (self.sink_size + self.recent_size) / q_len
+ self.compression_ratio_ *= 1 - (self.sink_size + self.recent_size) / k_len
return keys, values
diff --git a/kvpress/presses/finch_press.py b/kvpress/presses/finch_press.py
index e22952c3..29051b92 100644
--- a/kvpress/presses/finch_press.py
+++ b/kvpress/presses/finch_press.py
@@ -58,7 +58,7 @@ def score(self, module, hidden_states, keys, values, attentions, kwargs):
Similar to SnapKVPress except it adds a normalization step before averaging on the context window.
"""
- bsz, num_key_value_heads, q_len, _ = keys.shape
+ bsz, num_key_value_heads, k_len, _ = keys.shape
num_key_value_groups = module.config.num_attention_heads // num_key_value_heads
if attentions is not None:
@@ -69,13 +69,13 @@ def score(self, module, hidden_states, keys, values, attentions, kwargs):
)
if self.normalize_scores:
- non_zero_counts = torch.arange(q_len - self.window_size, q_len)[None, None, :, None]
+ non_zero_counts = torch.arange(k_len - self.window_size, k_len)[None, None, :, None]
non_zero_counts = non_zero_counts.to(attn_weights.device)
attn_weights = attn_weights * non_zero_counts
# Average per group
scores = attn_weights.mean(dim=-2)
- scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len - self.window_size)
+ scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, k_len - self.window_size)
scores = scores.mean(dim=2)
# Add back the observation window. Use max score to make sure the window is not pruned.
@@ -95,14 +95,14 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs):
scores = self.score(module, hidden_states, keys, values, attentions, kwargs)
# Compute indices to keep (optionally by chunks)
- q_len = hidden_states.shape[1]
+ k_len = keys.shape[2] # Use actual sequence length from keys instead of hidden_states
if self.chunk_length is None:
- n_kept = int(q_len * (1 - self.compression_ratio))
+ n_kept = int(k_len * (1 - self.compression_ratio))
indices = scores.topk(n_kept, dim=-1).indices
else:
assert self.chunk_length > self.window_size / (1 - self.compression_ratio)
indices = []
- for i in range(0, q_len, self.chunk_length):
+ for i in range(0, k_len, self.chunk_length):
chunk_scores = scores[:, :, i : i + self.chunk_length]
n_kept = max(1, int(chunk_scores.shape[2] * (1 - self.compression_ratio)))
chunk_indices = i + chunk_scores.topk(n_kept, dim=-1).indices
diff --git a/kvpress/presses/key_rerotation_press.py b/kvpress/presses/key_rerotation_press.py
index 129072f8..a264b3d4 100644
--- a/kvpress/presses/key_rerotation_press.py
+++ b/kvpress/presses/key_rerotation_press.py
@@ -137,7 +137,7 @@ def compress(
scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs)
# Get indices of KV pairs with the lowest scores
- q_len = hidden_states.shape[1]
+ q_len = keys.shape[2]
n_kept = int(q_len * (1 - self.press.compression_ratio))
indices = scores.topk(n_kept, dim=-1).indices
indices = torch.sort(indices, dim=2).values
diff --git a/kvpress/presses/kvzip_press.py b/kvpress/presses/kvzip_press.py
index 59b565e7..0f13e68a 100644
--- a/kvpress/presses/kvzip_press.py
+++ b/kvpress/presses/kvzip_press.py
@@ -14,7 +14,7 @@
from transformers.models.llama.modeling_llama import rotate_half
from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress
-from kvpress.presses.utils import get_query_states
+from kvpress.presses.utils import extract_keys_and_values, get_query_states
logger = logging.getLogger(__name__)
@@ -153,25 +153,16 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
hidden_states = kwargs["hidden_states"]
cache = kwargs.get("past_key_values", None) or kwargs.get("past_key_value", None)
-
cache_layer = cache.layers[module.layer_idx]
- if isinstance(cache, QuantizedCache):
- keys = cache_layer._dequantize( # type: ignore[index]
- cache_layer._quantized_keys # type: ignore[index]
- )
- values = cache_layer._dequantize( # type: ignore[index]
- cache_layer._quantized_values # type: ignore[index]
- )
- else:
- keys = cache_layer.keys
- values = cache_layer.values
+ keys, values = extract_keys_and_values(cache, module.layer_idx)
# Compute importance scores for KV pairs in the prefilled context,
# retaining only the originally prefilled KV pairs.
keys, values = self.score_kvzip(module, hidden_states, keys, values, output[1], kwargs)
if isinstance(cache, QuantizedCache):
+ # Update cache with compressed keys and values
cache_layer._quantized_keys = cache_layer._quantize(keys, axis=cache_layer.axis_key)
cache_layer._quantized_values = cache_layer._quantize(values, axis=cache_layer.axis_value)
cache_layer.keys = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index]
diff --git a/kvpress/presses/prefill_decoding_press.py b/kvpress/presses/prefill_decoding_press.py
new file mode 100644
index 00000000..85b5d465
--- /dev/null
+++ b/kvpress/presses/prefill_decoding_press.py
@@ -0,0 +1,86 @@
+# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+import logging
+from contextlib import contextmanager
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from transformers import PreTrainedModel
+
+from kvpress.presses.base_press import BasePress
+
+from .decoding_press import DecodingPress
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class PrefillDecodingPress(BasePress):
+ """
+ A wrapper press that combines separate prefilling and decoding compression strategies.
+
+ This press acts as a single press interface but internally delegates to different
+ presses based on the current phase (prefilling vs decoding). During prefilling,
+ it uses the prefilling_press. During decoding, it uses the decoding_press.
+
+ Parameters
+ ----------
+ prefilling_press : BasePress, optional
+ Press to use during the prefilling phase. If None, no compression is applied during prefilling.
+ decoding_press : DecodingPress, optional
+ Press to use during the decoding phase. If None, no compression is applied during decoding.
+ """
+
+ prefilling_press: Optional[BasePress] = None
+ decoding_press: Optional[DecodingPress] = None
+
+ def compress(
+ self,
+ module: nn.Module,
+ hidden_states: torch.Tensor,
+ keys: torch.Tensor,
+ values: torch.Tensor,
+ attentions: torch.Tensor,
+ kwargs: dict,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ q_len = hidden_states.shape[1]
+
+ # Determine if we're in prefilling or decoding phase
+ if kwargs["cache_position"][-1] <= q_len and self.prefilling_press is not None:
+ return self.prefilling_press.compress(module, hidden_states, keys, values, attentions, kwargs)
+ elif self.decoding_press is not None:
+ return self.decoding_press.compress(module, hidden_states, keys, values, attentions, kwargs)
+
+ # No compression applied
+ logger.warning("No compression applied during prefill or decoding phase")
+
+ return keys, values
+
+ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
+ """
+ Forward hook that delegates to the appropriate press based on current phase.
+ """
+ hidden_states = kwargs["hidden_states"]
+ q_len = hidden_states.shape[1]
+
+ # Determine if we're in prefilling or decoding phase
+ if kwargs["cache_position"][-1] <= q_len and self.prefilling_press is not None:
+ return self.prefilling_press.forward_hook(module, input, kwargs, output)
+ elif self.decoding_press is not None:
+ return self.decoding_press.forward_hook(module, input, kwargs, output)
+
+ # No hook applied
+ return output
+
+ @contextmanager
+ def __call__(self, model: PreTrainedModel):
+ try:
+ with super().__call__(model):
+ yield
+ finally:
+ # Reset decoding press if it exists
+ if self.decoding_press is not None:
+ self.decoding_press.reset()
diff --git a/kvpress/presses/pyramidkv_press.py b/kvpress/presses/pyramidkv_press.py
index 14b6c03b..006f83ce 100644
--- a/kvpress/presses/pyramidkv_press.py
+++ b/kvpress/presses/pyramidkv_press.py
@@ -100,8 +100,8 @@ def compress(
scores = self.score(module, hidden_states, keys, values, attentions, kwargs)
# Get indices of KV pairs with the lowest scores
- q_len = hidden_states.shape[1]
- n_kept = self.get_layer_budget(module, q_len)
+ k_len = keys.shape[2]
+ n_kept = self.get_layer_budget(module, k_len)
indices = scores.topk(n_kept, dim=-1).indices
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
diff --git a/kvpress/presses/scorer_press.py b/kvpress/presses/scorer_press.py
index 5c84b379..f7f33d16 100644
--- a/kvpress/presses/scorer_press.py
+++ b/kvpress/presses/scorer_press.py
@@ -90,8 +90,8 @@ def compress(
scores = self.score(module, hidden_states, keys, values, attentions, kwargs)
# Get indices of KV pairs with the lowest scores
- q_len = hidden_states.shape[1]
- n_kept = int(q_len * (1 - self.compression_ratio))
+ k_len = keys.shape[2]
+ n_kept = int(k_len * (1 - self.compression_ratio))
indices = scores.topk(n_kept, dim=-1).indices
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
diff --git a/kvpress/presses/simlayerkv_press.py b/kvpress/presses/simlayerkv_press.py
index 86086e7c..3c114814 100644
--- a/kvpress/presses/simlayerkv_press.py
+++ b/kvpress/presses/simlayerkv_press.py
@@ -94,13 +94,13 @@ def compress(
self.compression_ratios = []
# Check if compression is needed
- q_len = hidden_states.shape[1]
+ k_len = keys.shape[2]
min_length = self.n_initial + self.n_recent + self.n_last
- if q_len <= min_length:
+ if k_len <= min_length:
logger.warning(f"Sequence length is shorter than {min_length}: no compression applied")
- if (self.lazy_threshold == 1.0) or (q_len <= min_length):
+ if (self.lazy_threshold == 1.0) or (k_len <= min_length):
self.compression_ratios.append(0.0)
return keys, values
@@ -109,7 +109,7 @@ def compress(
# If layer is lazy, only keep the initial and recent KV pairs
keys = torch.cat([keys[:, :, : self.n_initial], keys[:, :, -self.n_recent + self.n_last :]], dim=2)
values = torch.cat([values[:, :, : self.n_initial], values[:, :, -self.n_recent + self.n_last :]], dim=2)
- self.compression_ratios.append((q_len - self.n_initial - self.n_recent + 1) / q_len)
+ self.compression_ratios.append((k_len - self.n_initial - self.n_recent + 1) / k_len)
else:
self.compression_ratios.append(0.0)
diff --git a/kvpress/presses/snapkv_press.py b/kvpress/presses/snapkv_press.py
index 68f92af9..1f3a46fd 100644
--- a/kvpress/presses/snapkv_press.py
+++ b/kvpress/presses/snapkv_press.py
@@ -44,7 +44,7 @@ def compute_window_attention(module, hidden_states, keys, window_size, position_
Compute the last window_size queries and associated attention weights for the first q_len - window_size keys.
"""
- bsz, q_len, _ = hidden_states.shape
+ bsz, _, k_len, _ = keys.shape
num_heads = module.config.num_attention_heads
head_dim = module.head_dim
num_key_value_groups = num_heads // module.config.num_key_value_heads
@@ -61,7 +61,7 @@ def compute_window_attention(module, hidden_states, keys, window_size, position_
key_states = repeat_kv(keys, num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
attention_mask = torch.ones_like(attn_weights) * float("-inf")
- attention_mask = torch.triu(attention_mask, diagonal=q_len - window_size + 1)
+ attention_mask = torch.triu(attention_mask, diagonal=k_len - window_size + 1)
attn_weights += attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = attn_weights[..., :-window_size]
@@ -78,10 +78,12 @@ def score(
kwargs,
) -> torch.Tensor:
- bsz, num_key_value_heads, q_len, _ = keys.shape
+ bsz, num_key_value_heads, k_len, _ = keys.shape
num_key_value_groups = module.config.num_attention_heads // num_key_value_heads
- assert q_len > self.window_size, "Query length should be greater than the window size"
+ assert (
+ hidden_states.shape[1] > self.window_size
+ ), f"Query length {hidden_states.shape[1]} should be greater than the window size {self.window_size}"
if attentions is not None:
attn_weights = attentions[..., -self.window_size :, : -self.window_size]
@@ -94,7 +96,7 @@ def score(
scores = F.avg_pool1d(scores, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1)
# Average per group (https://github.com/FasterDecoding/SnapKV/issues/22)
- scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len - self.window_size)
+ scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, k_len - self.window_size)
scores = scores.mean(2)
# Add back the observation window. Use max score to make sure the window is not pruned.
diff --git a/kvpress/presses/streaming_llm_press.py b/kvpress/presses/streaming_llm_press.py
index 4c0ac218..6d657389 100644
--- a/kvpress/presses/streaming_llm_press.py
+++ b/kvpress/presses/streaming_llm_press.py
@@ -44,9 +44,9 @@ def score(
kwargs,
) -> torch.Tensor:
- q_len = hidden_states.shape[1]
- assert q_len > self.n_sink, f"Input should contain more tokens than n_sink={self.n_sink}"
- n_pruned = q_len - int(q_len * (1 - self.compression_ratio))
+ k_len = keys.shape[2]
+ assert k_len > self.n_sink, f"Input should contain more tokens than n_sink={self.n_sink}"
+ n_pruned = k_len - int(k_len * (1 - self.compression_ratio))
scores = torch.ones_like(keys[..., 0])
scores[:, :, self.n_sink : self.n_sink + n_pruned] = 0
diff --git a/kvpress/presses/think_press.py b/kvpress/presses/think_press.py
index a19379c2..f277bb5c 100644
--- a/kvpress/presses/think_press.py
+++ b/kvpress/presses/think_press.py
@@ -72,7 +72,7 @@ def compress(
return keys, values
# Compute scores per dimension
- bsz, num_key_value_heads, q_len, head_dim = keys.shape
+ bsz, num_key_value_heads, k_len, head_dim = keys.shape
num_key_value_groups = module.config.num_attention_heads // num_key_value_heads
queries = self.compute_window_queries(module, kwargs["hidden_states"], kwargs["position_embeddings"])
@@ -84,7 +84,7 @@ def compress(
# Prune dimensions with the lowest scores by setting them to 0
n_pruned = int(head_dim * self.key_channel_compression_ratio)
indices = key_scores.topk(n_pruned, dim=-1, largest=False).indices
- indices = indices.unsqueeze(2).expand(-1, -1, q_len, -1)
+ indices = indices.unsqueeze(2).expand(-1, -1, k_len, -1)
keys = keys.scatter_(-1, indices, 0)
return keys, values
diff --git a/kvpress/presses/utils.py b/kvpress/presses/utils.py
index 938c15ba..3614fa5b 100644
--- a/kvpress/presses/utils.py
+++ b/kvpress/presses/utils.py
@@ -3,6 +3,7 @@
import torch
from torch import nn
+from transformers import Cache, QuantizedCache
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention
from transformers.models.phi3.modeling_phi3 import Phi3Attention
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention
@@ -50,3 +51,22 @@ def get_query_states(module: nn.Module, hidden_states: torch.Tensor) -> torch.Te
query_states = module.q_norm(query_states)
return query_states
+
+
+def dequantize_layer(cache_layer) -> tuple[torch.Tensor, torch.Tensor]:
+ keys = cache_layer._dequantize(cache_layer._quantized_keys)
+ values = cache_layer._dequantize(cache_layer._quantized_values)
+ return keys, values
+
+
+def extract_keys_and_values(cache: Cache, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Extracts the keys and values from a given cache layer,
+ handling both quantized and unquantized caches.
+ """
+ if isinstance(cache, QuantizedCache):
+ keys, values = dequantize_layer(cache.layers[layer_idx])
+ else:
+ keys = cache.layers[layer_idx].keys
+ values = cache.layers[layer_idx].values
+ return keys, values
diff --git a/notebooks/kvpress_decoding_aime25.ipynb b/notebooks/kvpress_decoding_aime25.ipynb
new file mode 100644
index 00000000..55ccd8bd
--- /dev/null
+++ b/notebooks/kvpress_decoding_aime25.ipynb
@@ -0,0 +1,671 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "d5b82a94ac3477fd",
+ "metadata": {},
+ "source": [
+ "# Using decoding compression on the AIME25 Math Dataset\n",
+ "\n",
+ "This notebook demonstrates how to compress during text generation.\n",
+ "We use `nvidia/OpenMath-Nemotron-7B` to solve math problems from the AIME25 dataset. For each problem, the model generates an answer in a boxed format (e.g., `\\boxed{42}`).\n",
+ "\n",
+ "To optimize memory usage during long-context generation, the notebook applies key-value cache compression during decoding.\n",
+ "Compression periodically reduces the cache size by keeping only the most relevant tokens, enabling efficient inference without sacrificing answer quality."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "4cd1c9e43c5ca1bf",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import requests\n",
+ "\n",
+ "import torch\n",
+ "from transformers import pipeline\n",
+ "\n",
+ "from kvpress import (\n",
+ " ExpectedAttentionPress,\n",
+ " KnormPress,\n",
+ " ObservedAttentionPress,\n",
+ " RandomPress,\n",
+ " SnapKVPress,\n",
+ " StreamingLLMPress,\n",
+ " TOVAPress,\n",
+ ")\n",
+ "\n",
+ "from kvpress.presses.decoding_press import DecodingPress"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "553f7d7f-7a2f-456d-a4b1-f38f3ead8767",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from datasets import load_dataset\n",
+ "dataset = load_dataset(\"math-ai/aime25\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "5e912ac7-e5d8-42ff-8f9e-855b8ae999a8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "sample = dataset[\"test\"][0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "f8c2689d-e461-4608-aa69-a4811f71d639",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'problem': 'Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$',\n",
+ " 'answer': '70',\n",
+ " 'id': '0'}"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sample"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "02b0af83-e71a-4129-9692-921d2bc16cb8",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "bd9b941a02394e7d936e018a25c4628b",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Device set to use cuda:0\n"
+ ]
+ }
+ ],
+ "source": [
+ "device = \"cuda:0\"\n",
+ "ckpt = \"nvidia/OpenMath-Nemotron-7B\"\n",
+ "attn_implementation = \"flash_attention_2\"\n",
+ "pipe = pipeline(\"kv-press-text-generation\", model=ckpt, device=device, dtype=\"auto\", model_kwargs={\"attn_implementation\":attn_implementation})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "ccb1ff5e-5e16-4779-bbae-5781e8255345",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from transformers import DynamicCache"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "96e1f430-4ece-49dd-a5fb-68ae3c2ee8b8",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Question: Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$\n",
+ "Answer: 70\n",
+ "Prediction: \n",
+ "Okay, so I need to find all integer bases b greater than 9 where the number 17 in base b divides the number 97 in base b. Then, sum all those bases. Hmm, let me think step by step.\n",
+ "\n",
+ "First, I should recall how numbers are represented in different bases. The number 17 in base b is equal to 1 times b plus 7 times 1, right? So that's 1*b + 7. Similarly, 97 in base b would be 9*b + 7. So, translating both numbers to base 10, we have:\n",
+ "\n",
+ "17_b = b + 7\n",
+ "\n",
+ "97_b = 9b + 7\n",
+ "\n",
+ "The problem states that 17_b must divide 97_b. So, in mathematical terms, this means that (9b + 7) divided by (b + 7) should result in an integer. So, (9b + 7) must be divisible by (b + 7). \n",
+ "\n",
+ "Let me write that as an equation:\n",
+ "\n",
+ "(9b + 7) ÷ (b + 7) = integer.\n",
+ "\n",
+ "To find when this division results in an integer, maybe I can perform the division and see what the remainder is. If the remainder is zero, then it's divisible. Let's try polynomial division or maybe manipulate the expression.\n",
+ "\n",
+ "Let me rewrite 9b + 7. Let's see, if I write 9b + 7 as 9*(b + 7) minus something. Let's compute:\n",
+ "\n",
+ "9*(b + 7) = 9b + 63\n",
+ "\n",
+ "But 9b + 7 is the original numerator. So, subtracting 9*(b + 7) from 9b +7 gives:\n",
+ "\n",
+ "9b +7 - (9b +63) = 7 - 63 = -56\n",
+ "\n",
+ "Therefore, 9b +7 = 9*(b +7) -56\n",
+ "\n",
+ "So, (9b +7)/(b +7) = 9 - 56/(b +7)\n",
+ "\n",
+ "For this to be an integer, 56/(b +7) must be an integer. Therefore, (b +7) must be a divisor of 56. \n",
+ "\n",
+ "But since b is a base greater than 9, the digits in the numbers 17_b and 97_b must be valid. In base b, the digits can go from 0 to b-1. So, in 17_b, the digits are 1 and 7. Since the base is greater than 9, 7 is a valid digit. Similarly, in 97_b, the digits are 9 and 7. So, 9 must be less than b. Therefore, the base b must be greater than 9, which is already given. So, the constraints are:\n",
+ "\n",
+ "1. b > 9\n",
+ "\n",
+ "2. (b +7) divides 56.\n",
+ "\n",
+ "So, first, let's find all divisors of 56. The positive divisors of 56 are:\n",
+ "\n",
+ "1, 2, 4, 7, 8, 14, 28, 56.\n",
+ "\n",
+ "But since b +7 must be one of these divisors, and b >9, then b +7 must be a divisor of 56 greater than 9 +7 =16. Wait, because b >9, so b >=10, so b +7 >=17. Therefore, the divisors of 56 that are greater than or equal to 17 are 28 and 56. Wait, let's check:\n",
+ "\n",
+ "Divisors of 56: 1, 2, 4, 7, 8, 14, 28, 56.\n",
+ "\n",
+ "So, divisors >=17 are 28 and 56. Therefore, b +7 can be 28 or 56. Therefore, solving for b:\n",
+ "\n",
+ "If b +7 =28, then b=21\n",
+ "\n",
+ "If b +7=56, then b=49\n",
+ "\n",
+ "So, the possible bases are 21 and 49. Then, the sum of these bases is 21 +49=70.\n",
+ "\n",
+ "Wait, but let me check if there are any other divisors. Wait, 14 is a divisor, but 14 is less than 17 (since b >=10, so b +7 >=17). So 14 is too small. Similarly, 8,7, etc. So only 28 and 56. Therefore, the possible bases are 21 and 49. So their sum is 70.\n",
+ "\n",
+ "But wait, let me verify this. Let's check for b=21:\n",
+ "\n",
+ "17 in base 21 is 1*21 +7=28\n",
+ "\n",
+ "97 in base21 is 9*21 +7=196\n",
+ "\n",
+ "Check if 28 divides 196. 196 divided by 28 is 7. Yes, that's an integer. So that works.\n",
+ "\n",
+ "For b=49:\n",
+ "\n",
+ "17 in base49 is 1*49 +7=56\n",
+ "\n",
+ "97 in base49 is 9*49 +7=448\n",
+ "\n",
+ "448 divided by 56 is 8. That's also an integer. So that works too.\n",
+ "\n",
+ "So, both bases are valid. Therefore, the sum is 21 +49=70.\n",
+ "\n",
+ "Wait, but hold on. Let me check if there are any negative divisors. But since b is a base, it must be a positive integer greater than 9, so b +7 is positive. Therefore, we don't need to consider negative divisors. So, yes, only 28 and 56. Therefore, the answer is 70.\n",
+ "\n",
+ "But let me just make sure I didn't miss any divisors. Let's list all divisors again:\n",
+ "\n",
+ "Positive divisors of 56: 1, 2, 4, 7, 8, 14, 28, 56.\n",
+ "\n",
+ "Negative divisors would be -1, -2, -4, -7, -8, -14, -28, -56. But since b +7 must be positive (as b>9), we can ignore the negative ones. So, only 28 and 56. Therefore, the answer is 70.\n",
+ "\n",
+ "But wait, another thought: when we say that (9b +7) is divisible by (b +7), we can also approach this by setting up the equation (9b +7) = k*(b +7), where k is an integer. Then, solving for b:\n",
+ "\n",
+ "9b +7 = k*b +7k\n",
+ "\n",
+ "Bring terms with b to one side:\n",
+ "\n",
+ "9b -k*b = 7k -7\n",
+ "\n",
+ "b*(9 -k) = 7(k -1)\n",
+ "\n",
+ "Therefore, b = [7(k -1)] / (9 -k)\n",
+ "\n",
+ "Since b must be an integer greater than 9, we can look for integer values of k such that 9 -k divides 7(k -1), and the result is b>9.\n",
+ "\n",
+ "Let me see. Let's solve for k:\n",
+ "\n",
+ "b = 7(k -1)/(9 -k)\n",
+ "\n",
+ "We need 9 -k to divide 7(k -1). Let's denote d = 9 -k, so k =9 -d. Then:\n",
+ "\n",
+ "b =7( (9 -d) -1 ) / d =7(8 -d)/d =7*(8/d -1)\n",
+ "\n",
+ "Since b must be a positive integer greater than 9, d must be a positive divisor of 56 (since 7*(8/d -1) must be integer). Wait, maybe this approach complicates things. Let's try substituting possible integer values of k.\n",
+ "\n",
+ "Alternatively, since k must be an integer such that 9 -k divides 7(k -1). Let's note that 9 -k and k -1 are related. Let's set m =9 -k, so k =9 -m. Then:\n",
+ "\n",
+ "b =7*( (9 -m) -1 ) / m =7*(8 -m)/m =7*(8/m -1)\n",
+ "\n",
+ "So, 8/m must be an integer because b must be integer. Therefore, m must be a positive divisor of 8. But m =9 -k, and since k is an integer, m can be positive or negative. However, since b must be positive, let's see:\n",
+ "\n",
+ "If m is positive, then 8/m -1 must be positive as well because 7*(8/m -1) =b>9. So:\n",
+ "\n",
+ "8/m -1 >0 =>8/m >1 =>m <8. But m is a positive divisor of 8. The positive divisors of 8 are 1,2,4,8. So m can be 1,2,4,8. But m must be less than 8 (from 8/m >1). So m=1,2,4.\n",
+ "\n",
+ "Wait, if m=8, then 8/m -1=1 -1=0, so b=0, which is invalid. So m=1,2,4.\n",
+ "\n",
+ "Let's check each:\n",
+ "\n",
+ "Case 1: m=1\n",
+ "\n",
+ "Then, b=7*(8/1 -1)=7*(7)=49. So b=49. Which is valid, as before.\n",
+ "\n",
+ "Case 2: m=2\n",
+ "\n",
+ "b=7*(8/2 -1)=7*(4 -1)=7*3=21. So b=21. Also valid.\n",
+ "\n",
+ "Case3: m=4\n",
+ "\n",
+ "b=7*(8/4 -1)=7*(2 -1)=7*1=7. But b=7 is not greater than 9, so invalid.\n",
+ "\n",
+ "So only m=1 and m=2 give valid bases, which are 49 and 21. So sum is 70. So same result as before.\n",
+ "\n",
+ "Alternatively, if m is negative, then m divides 8, but m negative. Let's see:\n",
+ "\n",
+ "If m is a negative divisor of 8, then m can be -1,-2,-4,-8.\n",
+ "\n",
+ "Then, let's check:\n",
+ "\n",
+ "Case m=-1:\n",
+ "\n",
+ "b=7*(8/(-1) -1)=7*(-8 -1)=7*(-9)=-63. Negative base is invalid.\n",
+ "\n",
+ "Similarly, m=-2:\n",
+ "\n",
+ "b=7*(8/(-2) -1)=7*(-4 -1)=7*(-5)=-35. Invalid.\n",
+ "\n",
+ "m=-4:\n",
+ "\n",
+ "b=7*(8/(-4) -1)=7*(-2 -1)=7*(-3)=-21. Invalid.\n",
+ "\n",
+ "m=-8:\n",
+ "\n",
+ "b=7*(8/(-8) -1)=7*(-1 -1)=7*(-2)=-14. Invalid.\n",
+ "\n",
+ "So all negative m give invalid bases. Therefore, only m=1 and m=2, leading to b=49 and 21. So sum is 70.\n",
+ "\n",
+ "Therefore, the answer is 70. So I think that's correct. Let me just check once more.\n",
+ "\n",
+ "Alternatively, maybe there's another approach. Let's consider that (9b +7)/(b +7) must be integer. Let's call this integer k. So:\n",
+ "\n",
+ "(9b +7) =k*(b +7)\n",
+ "\n",
+ "So, 9b +7 =k*b +7k\n",
+ "\n",
+ "Rearranged:\n",
+ "\n",
+ "(9 -k)*b =7k -7\n",
+ "\n",
+ "So, b=(7k -7)/(9 -k)\n",
+ "\n",
+ "We need b to be an integer greater than 9. So, (7k -7) must be divisible by (9 -k). Let's factor 7:\n",
+ "\n",
+ "7(k -1)/(9 -k) = -7(k -1)/(k -9)\n",
+ "\n",
+ "So, b= -7(k -1)/(k -9)\n",
+ "\n",
+ "We can write this as 7(1 -k)/(k -9). Hmm, not sure if that helps. Let's look for integer k such that (9 -k) divides 7(k -1). Let's denote d=9 -k, so k=9 -d. Then:\n",
+ "\n",
+ "b=7(k -1)/d=7(8 -d)/d=7*(8/d -1)\n",
+ "\n",
+ "So, 8/d must be integer, so d must divide 8. So d is a divisor of 8. Since d=9 -k, and k is an integer, d can be any integer (positive or negative) that divides 8. But since b must be positive and greater than 9, let's see:\n",
+ "\n",
+ "If d is positive, then 8/d must be integer, so d is a positive divisor of 8: 1,2,4,8.\n",
+ "\n",
+ "Then, for each d:\n",
+ "\n",
+ "d=1:\n",
+ "\n",
+ "b=7*(8/1 -1)=7*7=49. Valid.\n",
+ "\n",
+ "d=2:\n",
+ "\n",
+ "b=7*(8/2 -1)=7*(4 -1)=21. Valid.\n",
+ "\n",
+ "d=4:\n",
+ "\n",
+ "b=7*(8/4 -1)=7*(2 -1)=7. Not valid (b>9 required).\n",
+ "\n",
+ "d=8:\n",
+ "\n",
+ "b=7*(8/8 -1)=7*(1 -1)=0. Invalid.\n",
+ "\n",
+ "If d is negative, then d divides 8, so d=-1,-2,-4,-8.\n",
+ "\n",
+ "For d=-1:\n",
+ "\n",
+ "b=7*(8/(-1) -1)=7*(-8 -1)=7*(-9)=-63. Invalid.\n",
+ "\n",
+ "Similarly, other negative d's give negative b's. So only d=1 and d=2 give valid bases. So again, 49 and 21. Sum is 70.\n",
+ "\n",
+ "Therefore, I think the answer is 70. So I can be confident that the sum is 70.\n",
+ "\n",
+ "**Final Answer**\n",
+ "\\boxed{70}\n",
+ "To solve the problem, we need to find all integer bases \\( b > 9 \\) for which \\( 17_b \\) is a divisor of \\( 97_b \\).\n",
+ "\n",
+ "First, we convert the numbers from base \\( b \\) to base 10:\n",
+ "- \\( 17_b = 1 \\cdot b + 7 = b + 7 \\)\n",
+ "- \\( 97_b = 9 \\cdot b + 7 = 9b + 7 \\)\n",
+ "\n",
+ "We need \\( b + 7 \\) to divide \\( 9b + 7 \\). This can be expressed as:\n",
+ "\\[\n",
+ "\\frac{9b + 7}{b + 7} = k \\quad \\text{for some integer } k\n",
+ "\\]\n",
+ "\n",
+ "Rewriting the equation, we get:\n",
+ "\\[\n",
+ "9b + 7 = k(b + 7)\n",
+ "\\]\n",
+ "\\[\n",
+ "9b + 7 = kb + 7k\n",
+ "\\]\n",
+ "\\[\n",
+ "9b - kb = 7k - 7\n",
+ "\\]\n",
+ "\\[\n",
+ "b(9 - k) = 7(k - 1)\n",
+ "\\]\n",
+ "\\[\n",
+ "b = \\frac{7(k - 1)}{9 - k}\n",
+ "\\]\n",
+ "\n",
+ "For \\( b \\) to be an integer, \\( 9 - k \\) must be a divisor of \\( 7(k - 1) \\). We need to find the values of \\( k \\) such that \\( b > 9 \\).\n",
+ "\n",
+ "Let's consider the possible values of \\( k \\) by checking the divisors of 56 (since \\( 9b + 7 = 9(b + 7) - 56 \\)):\n",
+ "\n",
+ "The positive divisors of 56 are: 1, 2, 4, 7, 8, 14, 28, 56.\n",
+ "\n",
+ "Since \\( b > 9 \\), \\( b + 7 \\) must be greater than 16. Therefore, the valid divisors are 28 and 56.\n",
+ "\n",
+ "1. If \\( b + 7 = 28 \\):\n",
+ " \\[\n",
+ " b = 28 - 7 = 21\n",
+ " \\]\n",
+ "\n",
+ "2. If \\( b + 7 = 56 \\):\n",
+ " \\[\n",
+ " b = 56 - 7 = 49\n",
+ " \\]\n",
+ "\n",
+ "We verify these values:\n",
+ "- For \\( b = 21 \\):\n",
+ " \\[\n",
+ " 17_{21} = 21 + 7 = 28\n",
+ " \\]\n",
+ " \\[\n",
+ " 97_{21} = 9 \\cdot 21 + 7 = 196\n",
+ " \\]\n",
+ " \\[\n",
+ " 196 \\div 28 = 7 \\quad \\text{(an integer)}\n",
+ " \\]\n",
+ "\n",
+ "- For \\( b = 49 \\):\n",
+ " \\[\n",
+ " 17_{49} = 49 + 7 = 56\n",
+ " \\]\n",
+ " \\[\n",
+ " 97_{49} = 9 \\cdot 49 + 7 = 448\n",
+ " \\]\n",
+ " \\[\n",
+ " 448 \\div 56 = 8 \\quad \\text{(an integer)}\n",
+ " \\]\n",
+ "\n",
+ "Both values of \\( b \\) are valid. The sum of these bases is:\n",
+ "\\[\n",
+ "21 + 49 = 70\n",
+ "\\]\n",
+ "\n",
+ "Thus, the sum of all integer bases \\( b > 9 \\) for which \\( 17_b \\) is a divisor of \\( 97_b \\) is:\n",
+ "\\[\n",
+ "\\boxed{70}\n",
+ "\\]\n",
+ "Cache size: torch.Size([1, 4, 3781, 128])\n"
+ ]
+ }
+ ],
+ "source": [
+ "cache = DynamicCache()\n",
+ "question = sample[\"problem\"]\n",
+ "true_answer = sample[\"answer\"]\n",
+ "pred_answer = pipe(\" \", question=question, press=None, cache=cache, max_new_tokens=16_000)[\"answer\"]\n",
+ "\n",
+ "print(f\"Question: {question}\")\n",
+ "print(f\"Answer: {true_answer}\")\n",
+ "print(f\"Prediction: {pred_answer}\")\n",
+ "print(f\"Cache size: {cache.layers[0].keys.shape}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "7f4fa366-7e62-443a-b5c8-274128fe6237",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Question: Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$\n",
+ "Answer: 70\n",
+ "Prediction: \n",
+ "Okay, so I need to find all integer bases b greater than 9 where the number 17 in base b divides the number 97 in base b. Then, sum all those bases. Hmm, let me think step by step.\n",
+ "\n",
+ "First, I should recall how numbers are represented in different bases. The number 17 in base b is equal to 1 times b plus 7 times 1, right? So that's 1*b + 7. Similarly, 97 in base b would be 9*b + 7. So, translating both numbers to base 10, we have:\n",
+ "\n",
+ "17_b = b + 7\n",
+ "\n",
+ "97_b = 9b + 7\n",
+ "\n",
+ "The problem states that 17_b must divide 97_b. So, in mathematical terms, this means that (9b + 7) divided by (b + 7) should result in an integer. So, (9b + 7) must be divisible by (b + 7). \n",
+ "\n",
+ "Let me write that as an equation:\n",
+ "\n",
+ "(9b + 7) ÷ (b + 7) = integer.\n",
+ "\n",
+ "To find when this division results in an integer, maybe I can perform the division and see what the remainder is. If the remainder is zero, then it's divisible. Let's try polynomial division or maybe manipulate the expression.\n",
+ "\n",
+ "Let me rewrite 9b + 7. Let's see, if I write 9b + 7 as 9*(b + 7) minus something. Let's compute:\n",
+ "\n",
+ "9*(b + 7) = 9b + 63\n",
+ "\n",
+ "But 9b + 7 is the original numerator. So, subtracting 9*(b + 7) from 9b +7 gives:\n",
+ "\n",
+ "9b +7 - (9b +63) = 7 - 63 = -56\n",
+ "\n",
+ "Therefore, 9b +7 = 9*(b +7) -56\n",
+ "\n",
+ "So, (9b +7)/(b +7) = 9 - 56/(b +7)\n",
+ "\n",
+ "For this to be an integer, 56/(b +7) must be an integer. Therefore, (b +7) must be a divisor of 56. \n",
+ "\n",
+ "But since b is a base greater than 9, the digits in the numbers 17_b and 97_b must be valid. In base b, the digits can go from 0 to b-1. So, in 17_b, the digits are 1 and 7. Since the base is greater than 9, 7 is a valid digit. Similarly, in 97_b, the digits are 9 and 7. So, 9 must be less than b. Therefore, the base b must be greater than 9, which is already given. So, the constraints are:\n",
+ "\n",
+ "1. b > 9\n",
+ "\n",
+ "2. (b +7) divides 56.\n",
+ "\n",
+ "So, first, let's find all divisors of 56. The positive divisors of 56 are:\n",
+ "\n",
+ "1, 2, 4, 7, 8, 14, 28, 56.\n",
+ "\n",
+ "But since b +7 must be one of these divisors, and b >9, then b +7 must be a divisor of 56 greater than 9 +7 =16. Wait, because b >9, so b >=10, so b +7 >=17. Therefore, the divisors of 56 that are greater than or equal to 17 are 28 and 56. Wait, let's check:\n",
+ "\n",
+ "Divisors of 56: 1, 2, 4, 7, 8, 14, 28, 56.\n",
+ "\n",
+ "So, divisors >=17 are 28 and 56. Therefore, b +7 can be 28 or 56. Therefore, solving for b:\n",
+ "\n",
+ "If b +7 =28, then b=21\n",
+ "\n",
+ "If b +7=56, then b=49\n",
+ "\n",
+ "So, the possible bases are 21 and 49. Then, the sum of these bases is 21 +49=70.\n",
+ "\n",
+ "Wait, but let me check if there are any other divisors. Wait, 14 is a divisor, but 14 is less than 17 (since b >=10, so b +7 >=17). So 14 is too small. Similarly, 8,7, etc. So only 28 and 56. Therefore, the possible bases are 21 and 49. So their sum is 70.\n",
+ "\n",
+ "But wait, let me verify this. Let's check for b=21:\n",
+ "\n",
+ "17 in base 21 is 21 +7=28\n",
+ "\n",
+ "97 in base21 is 9*21 +7=189 +7=196\n",
+ "\n",
+ "196 divided by28 is 7, which is an integer. So that works.\n",
+ "\n",
+ "For b=49:\n",
+ "\n",
+ "17 in base49 is 49 +7=56\n",
+ "\n",
+ "97 in base49 is 9*49 +7=441 +7=448\n",
+ "\n",
+ "448 divided by56 is 8, which is an integer. So that works too.\n",
+ "\n",
+ "So both bases are valid. Therefore, the sum is 21 +49=70.\n",
+ "\n",
+ "Wait, but let me check if there are any other divisors. Wait, 56 is a divisor of 56, but 56 is 56. So, if b +7=56, then b=49. Which we already have. Similarly, 28 gives b=21. Are there any negative divisors? But since b is a base, it must be a positive integer greater than 9, so negative divisors don't make sense here. So, yes, only 28 and 56. Therefore, the answer is 70.\n",
+ "\n",
+ "But wait, let me check if there's a mistake here. Let me think again. The problem says \"all bases b>9\". So, the possible divisors of 56 are 1,2,4,7,8,14,28,56. But since b +7 must be one of these, and b>9, so b +7 must be at least 17. So, the possible divisors are 28 and 56. Therefore, b=21 and 49. So sum is 70. That seems correct.\n",
+ "\n",
+ "But let me check if there's another way. Let me think. Suppose we have (9b +7)/(b +7) must be integer. Let me compute this fraction:\n",
+ "\n",
+ "(9b +7)/(b +7) = 9 - 56/(b +7). Wait, how?\n",
+ "\n",
+ "Let me do the division:\n",
+ "\n",
+ "Divide 9b +7 by b +7.\n",
+ "\n",
+ "Dividing 9b by b gives 9. Multiply (b +7) by 9: 9b +63. Subtract that from 9b +7: (9b +7) - (9b +63) = -56. So, the division gives 9 with a remainder of -56. Therefore, (9b +7) = 9*(b +7) -56. Therefore, (9b +7)/(b +7) = 9 -56/(b +7). For this to be an integer, 56/(b +7) must be an integer. Therefore, (b +7) must be a divisor of 56. Which is the same conclusion as before.\n",
+ "\n",
+ "Therefore, the possible values of (b +7) are the positive divisors of 56 that are greater than or equal to 17 (since b>9 implies b +7>16). The divisors of 56 are 1,2,4,7,8,14,28,56. So, the ones greater than 16 are 28 and 56. Therefore, b +7=28 gives b=21, and b +7=56 gives b=49. So, sum is 21 +49=70.\n",
+ "\n",
+ "Therefore, the answer is 70. Let me check once more with another example. Suppose b=21:\n",
+ "\n",
+ "17 in base21 is 2*10 +7=27? Wait, no. Wait, in base b, 17_b is 1*b +7. So, 1*21 +7=28. 97_b is 9*21 +7=189 +7=196. 196 divided by28 is 7. Correct.\n",
+ "\n",
+ "For b=49:\n",
+ "\n",
+ "17_b is 1*49 +7=56. 97_b is 9*49 +7=441 +7=448. 448 divided by56 is 8. Correct.\n",
+ "\n",
+ "So, both bases work. Therefore, the sum is 21 +49=70. So, the answer is 70.\n",
+ "\n",
+ "**Final Answer**\n",
+ "\\boxed{70}\n",
+ "To solve the problem, we need to find all bases \\( b > 9 \\) such that \\( 17_b \\) divides \\( 97_b \\).\n",
+ "\n",
+ "First, we convert the numbers from base \\( b \\) to base 10:\n",
+ "- \\( 17_b \\) in base 10 is \\( 1 \\cdot b + 7 = b + 7 \\).\n",
+ "- \\( 97_b \\) in base 10 is \\( 9 \\cdot b + 7 = 9b + 7 \\).\n",
+ "\n",
+ "We need \\( 9b + 7 \\) to be divisible by \\( b + 7 \\). This can be expressed as:\n",
+ "\\[\n",
+ "\\frac{9b + 7}{b + 7} \\text{ must be an integer.}\n",
+ "\\]\n",
+ "\n",
+ "Rewriting the fraction:\n",
+ "\\[\n",
+ "\\frac{9b + 7}{b + 7} = \\frac{9(b + 7) - 56}{b + 7} = 9 - \\frac{56}{b + 7}.\n",
+ "\\]\n",
+ "\n",
+ "For this to be an integer, \\( \\frac{56}{b + 7} \\) must be an integer. Therefore, \\( b + 7 \\) must be a divisor of 56. The divisors of 56 are:\n",
+ "\\[\n",
+ "1, 2, 4, 7, 8, 14, 28, 56.\n",
+ "\\]\n",
+ "\n",
+ "Since \\( b > 9 \\), we need \\( b + 7 > 16 \\). The valid divisors of 56 that are greater than 16 are:\n",
+ "\\[\n",
+ "28 \\quad \\text{and} \\quad 56.\n",
+ "\\]\n",
+ "\n",
+ "Thus, we have:\n",
+ "\\[\n",
+ "b + 7 = 28 \\quad \\Rightarrow \\quad b = 21,\n",
+ "\\]\n",
+ "\\[\n",
+ "b + 7 = 56 \\quad \\Rightarrow \\quad b = 49.\n",
+ "\\]\n",
+ "\n",
+ "We verify these bases:\n",
+ "- For \\( b = 21 \\):\n",
+ " \\[\n",
+ " 17_{21} = 21 + 7 = 28,\n",
+ " \\]\n",
+ " \\[\n",
+ " 97_{21} = 9 \\cdot 21 + 7 = 196.\n",
+ " \\]\n",
+ " Since \\( 196 \\div 28 = 7 \\), \\( 17_{21} \\) divides \\( 97_{21} \\).\n",
+ "\n",
+ "- For \\( b = 49 \\):\n",
+ " \\[\n",
+ " 17_{49} = 49 + 7 = 56,\n",
+ " \\]\n",
+ " \\[\n",
+ " 97_{49} = 9 \\cdot 49 + 7 = 448.\n",
+ " \\]\n",
+ " Since \\( 448 \\div 56 = 8 \\), \\( 17_{49} \\) divides \\( 97_{49} \\).\n",
+ "\n",
+ "The valid bases are \\( b = 21 \\) and \\( b = 49 \\). Summing these bases:\n",
+ "\\[\n",
+ "21 + 49 = 70.\n",
+ "\\]\n",
+ "\n",
+ "Thus, the sum of all bases \\( b > 9 \\) where \\( 17_b \\) divides \\( 97_b \\) is:\n",
+ "\\[\n",
+ "\\boxed{70}.\n",
+ "\\]\n",
+ "Cache size: torch.Size([1, 4, 1182, 128])\n"
+ ]
+ }
+ ],
+ "source": [
+ "compression_interval = 1024 # compress every compression_steps\n",
+ "target_size = 512 # number of tokens to keep after compression. Note that actual cache size lies in [target_size, compression_interval]\n",
+ "\n",
+ "\n",
+ "press = DecodingPress(base_press=ExpectedAttentionPress(), compression_interval=compression_interval, target_size=target_size)\n",
+ "\n",
+ "cache = DynamicCache()\n",
+ "question = sample[\"problem\"]\n",
+ "true_answer = sample[\"answer\"]\n",
+ "pred_answer = pipe(\" \", question=question, press=press, cache=cache, max_new_tokens=16_000)[\"answer\"]\n",
+ "\n",
+ "print(f\"Question: {question}\")\n",
+ "print(f\"Answer: {true_answer}\")\n",
+ "print(f\"Prediction: {pred_answer}\")\n",
+ "print(f\"Cache size: {cache.layers[0].keys.shape}\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "kvpress",
+ "language": "python",
+ "name": "kvpress"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/tests/fixtures.py b/tests/fixtures.py
index 9ebceb19..3323ee60 100644
--- a/tests/fixtures.py
+++ b/tests/fixtures.py
@@ -94,7 +94,7 @@ def kv_press_llama3_2_flash_attn_pipeline():
"kv-press-text-generation",
model=ckpt,
device=device,
- model_kwargs={"attn_implementation": attn_implementation, "torch_dtype": torch.bfloat16},
+ model_kwargs={"attn_implementation": attn_implementation, "dtype": torch.bfloat16},
)
return pipe
@@ -108,6 +108,6 @@ def kv_press_qwen3_flash_attn_pipeline():
"kv-press-text-generation",
model=ckpt,
device=device,
- model_kwargs={"attn_implementation": attn_implementation, "torch_dtype": torch.bfloat16},
+ model_kwargs={"attn_implementation": attn_implementation, "dtype": torch.bfloat16},
)
return pipe
diff --git a/tests/integration/test_ruler.py b/tests/integration/test_ruler.py
index 43cfe290..155cf6dd 100644
--- a/tests/integration/test_ruler.py
+++ b/tests/integration/test_ruler.py
@@ -6,6 +6,7 @@
import torch
from transformers import DynamicCache, QuantoQuantizedCache
from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available
+
from kvpress import QFilterPress
from tests.default_presses import default_presses
from tests.fixtures import kv_press_llama3_2_flash_attn_pipeline, kv_press_qwen3_flash_attn_pipeline # noqa: F401
@@ -58,12 +59,9 @@ def test_ruler_is_correct(
# QFilterPress doesn't support Qwen3 4B. Will be tested in the next test class.
return
else:
- pred_answer = kv_press_qwen3_flash_attn_pipeline(
- context,
- question=question,
- press=press,
- cache=cache
- )["answer"]
+ pred_answer = kv_press_qwen3_flash_attn_pipeline(context, question=question, press=press, cache=cache)[
+ "answer"
+ ]
assert true_answer in pred_answer
@@ -102,10 +100,7 @@ def test_ruler_is_correct_for_qfilter(
question = df_ruler.iloc[idx]["question"]
true_answer = df_ruler.iloc[idx]["answer"][0]
- pred_answer = kv_press_llama3_2_flash_attn_pipeline(
- context,
- question=question,
- press=press,
- cache=cache
- )["answer"]
+ pred_answer = kv_press_llama3_2_flash_attn_pipeline(context, question=question, press=press, cache=cache)[
+ "answer"
+ ]
assert true_answer in pred_answer
diff --git a/tests/test_decoding_compression.py b/tests/test_decoding_compression.py
new file mode 100644
index 00000000..b36dae4c
--- /dev/null
+++ b/tests/test_decoding_compression.py
@@ -0,0 +1,301 @@
+# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Test script to verify that DecodingPress actually compresses during decoding.
+"""
+import logging
+
+import pytest
+import torch
+from transformers import DynamicCache, pipeline
+
+from kvpress import PyramidKVPress, ScorerPress
+from kvpress.presses.decoding_press import DecodingPress
+from kvpress.presses.knorm_press import KnormPress
+from kvpress.presses.prefill_decoding_press import PrefillDecodingPress
+from tests.default_presses import default_presses
+
+logger = logging.getLogger(__name__)
+
+
+@pytest.mark.parametrize("token_buffer_size", [32, 64, 128])
+def test_decoding_compression(token_buffer_size):
+ """Test that DecodingPress compresses the cache during decoding."""
+
+ # Initialize pipeline with a small model
+ pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
+
+ # Create a DecodingPress with KnormPress
+ press = DecodingPress(
+ base_press=KnormPress(compression_ratio=0.5), # Remove 50% of tokens
+ compression_interval=4, # Compress every 4 tokens
+ target_size=token_buffer_size,
+ )
+
+ # Create cache
+ cache = DynamicCache()
+
+ # Test context and question
+ context = "The quick brown fox jumps over the lazy dog. " * 10 # Repeat for longer context
+ question = "What animal jumps over the dog?"
+
+ # Run pipeline
+ pipe(context, question=question, press=press, cache=cache, max_new_tokens=20)
+
+ # Assert that all layers have the expected cache size
+ for layer_idx, cache_layer in enumerate(cache.layers):
+ layer_seq_len = cache_layer.keys.shape[2]
+ # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger
+ max_expected_size = token_buffer_size + press.compression_interval - 1
+ assert layer_seq_len <= max_expected_size, (
+ f"Layer {layer_idx}: Expected cache sequence length to be between {token_buffer_size} "
+ f"and {max_expected_size}, but got {layer_seq_len}"
+ )
+
+
+def test_prefill_decoding_press_calls_both_phases():
+ """Test that PrefillDecodingPress calls both prefilling and decoding presses."""
+
+ # Initialize pipeline
+ pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
+
+ # Create PrefillDecodingPress with both presses
+ combined_press = PrefillDecodingPress(
+ prefilling_press=KnormPress(compression_ratio=0.6), # Compress to 60% during prefill
+ decoding_press=DecodingPress(base_press=KnormPress(), compression_interval=3, target_size=48),
+ )
+
+ # Test context and question
+ context = "The quick brown fox jumps over the lazy dog. " * 12 # Longer context
+ question = "What animal jumps over the dog?"
+
+ # Run pipeline
+ cache = DynamicCache()
+ pipe(context, question=question, press=combined_press, cache=cache, max_new_tokens=15)
+
+ # Check that cache was compressed during both phases
+ # Final cache should be compressed to decoding press target size
+ for layer_idx, cache_layer in enumerate(cache.layers):
+ layer_seq_len = cache_layer.keys.shape[2]
+ # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger
+ target_size = 48 # token_buffer_size from decoding press
+ compression_steps = 3 # from the decoding press configuration
+ max_expected_size = target_size + compression_steps - 1
+ assert target_size <= layer_seq_len <= max_expected_size, (
+ f"Layer {layer_idx}: Expected final cache size to be between {target_size} "
+ f"and {max_expected_size} (decoding target), but got {layer_seq_len}"
+ )
+
+
+def test_decoding_press_without_prefill():
+ """Test that DecodingPress works correctly when used standalone (no prefill compression)."""
+
+ # Initialize pipeline
+ pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
+
+ # Create DecodingPress only
+ decoding_press = DecodingPress(base_press=KnormPress(compression_ratio=0.4), compression_interval=5, target_size=64)
+
+ # Test context and question
+ context = "The quick brown fox jumps over the lazy dog. " * 8
+ question = "What animal jumps over the dog?"
+
+ # Run pipeline
+ cache = DynamicCache()
+ pipe(context, question=question, press=decoding_press, cache=cache, max_new_tokens=25)
+
+ # Check that cache was compressed during decoding
+ for layer_idx, cache_layer in enumerate(cache.layers):
+ layer_seq_len = cache_layer.keys.shape[2]
+ # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger
+ target_size = 64
+ compression_steps = 5 # from the decoding press configuration
+ max_expected_size = target_size + compression_steps - 1
+ assert target_size <= layer_seq_len <= max_expected_size, (
+ f"Layer {layer_idx}: Expected cache size to be between {target_size} "
+ f"and {max_expected_size}, but got {layer_seq_len}"
+ )
+
+
+def test_prefill_decoding_press_decoding_only():
+ """Test PrefillDecodingPress with only decoding press (no prefill compression)."""
+
+ # Initialize pipeline
+ pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
+
+ # Create PrefillDecodingPress with only decoding press
+ combined_press = PrefillDecodingPress(
+ prefilling_press=None,
+ decoding_press=DecodingPress(
+ base_press=KnormPress(compression_ratio=0.6), compression_interval=4, target_size=56
+ ),
+ )
+
+ # Test context and question
+ context = "The quick brown fox jumps over the lazy dog. " * 9
+ question = "What animal jumps over the dog?"
+
+ # Run pipeline
+ cache = DynamicCache()
+ pipe(context, question=question, press=combined_press, cache=cache, max_new_tokens=12)
+
+ # Check that only decoding compression was applied
+ for layer_idx, cache_layer in enumerate(cache.layers):
+ layer_seq_len = cache_layer.keys.shape[2]
+ # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger
+ target_size = 56
+ compression_steps = 4 # from the decoding press configuration
+ max_expected_size = target_size + compression_steps - 1
+ assert target_size <= layer_seq_len <= max_expected_size, (
+ f"Layer {layer_idx}: Expected cache size to be between {target_size} "
+ f"and {max_expected_size}, but got {layer_seq_len}"
+ )
+
+
+def test_decoding_press_equivalence():
+ """Test that DecodingPress standalone yields same result as PrefillDecodingPress with decoding only."""
+
+ # Set random seed for reproducibility
+ torch.manual_seed(42)
+
+ # Initialize pipeline
+ pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
+
+ # Create standalone decoding press
+ decoding_press = DecodingPress(base_press=KnormPress(compression_ratio=0.5), compression_interval=3, target_size=52)
+
+ # Create PrefillDecodingPress with only decoding press
+ combined_press = PrefillDecodingPress(
+ prefilling_press=None,
+ decoding_press=DecodingPress(
+ base_press=KnormPress(compression_ratio=0.5), compression_interval=3, target_size=52
+ ),
+ )
+
+ # Test context and question
+ context = "The quick brown fox jumps over the lazy dog. " * 7
+ question = "What animal jumps over the dog?"
+
+ # Run with standalone decoding press
+ cache1 = DynamicCache()
+ result1 = pipe(context, question=question, press=decoding_press, cache=cache1, max_new_tokens=10)
+
+ # Run with combined press (decoding only)
+ cache2 = DynamicCache()
+ result2 = pipe(context, question=question, press=combined_press, cache=cache2, max_new_tokens=10)
+
+ # Compare cache sizes (should be identical)
+ for layer_idx in range(len(cache1.layers)):
+ cache1_size = cache1.layers[layer_idx].keys.shape[2]
+ cache2_size = cache2.layers[layer_idx].keys.shape[2]
+ assert cache1_size == cache2_size, (
+ f"Layer {layer_idx}: Standalone decoding cache size {cache1_size} != "
+ f"combined press cache size {cache2_size}"
+ )
+
+ # Compare generated text results (should be identical)
+ assert result1["answer"] == result2["answer"], (
+ f"Generated answers differ:\n"
+ f"Standalone decoding: '{result1['answer']}'\n"
+ f"Combined press: '{result2['answer']}'"
+ )
+
+
+"""
+E AttributeError: 'QFilterPress' object has no attribute 'q_filters'
+E Failed: DecodingPress failed with SnapKVPress: shape '[1, 2, 2, 6]' is invalid for input of size 12
+> query_states = query_states.view(bsz, window_size, num_heads, head_dim).transpose(1, 2)
+E RuntimeError: shape '[1, 2, 2, 6]' is invalid for input of size 12
+"""
+
+
+@pytest.mark.parametrize("press_config", default_presses)
+def test_all_presses_work_with_decoding_press(press_config):
+ """Test that all default presses work as base presses for DecodingPress."""
+
+ # Initialize pipeline
+ pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
+
+ # Get press class and use the first (easier) configuration
+ press_cls = press_config["cls"]
+ press_kwargs = press_config["kwargs"][0] # Use easier compression settings
+
+ base_press = press_cls(**press_kwargs)
+ if not isinstance(base_press, ScorerPress):
+ logger.info(f"Press {press_cls.__name__} is not a ScorerPress, skipping test")
+ return
+ if isinstance(base_press, (PyramidKVPress)):
+ # PyramidKVPress -> Pyramid shape, not compatible with token_buffer_size=48
+ logger.info(f"Press {press_cls.__name__} is not supported, skipping test")
+ return
+ if hasattr(base_press, "__post_init_from_model__"):
+ base_press.__post_init_from_model__(pipe.model)
+
+ # Create DecodingPress with this base press
+ decoding_press = DecodingPress(base_press=base_press, compression_interval=3, target_size=48)
+
+ # Test context and question
+ context = "The quick brown fox jumps over the lazy dog. " * 8
+ question = "What animal jumps over the dog?"
+
+ # Run pipeline
+ cache = DynamicCache()
+ result = pipe(context, question=question, press=decoding_press, cache=cache, max_new_tokens=15)
+
+ # Verify compression worked
+ assert len(result["answer"]) > 0, f"No answer generated with {press_cls.__name__}"
+
+ # Check that cache was compressed (allow some tolerance for rounding)
+ for layer_idx, cache_layer in enumerate(cache.layers):
+ layer_seq_len = cache_layer.keys.shape[2]
+ # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger
+ target_size = 48
+ compression_steps = 3 # from the decoding press configuration
+ max_expected_size = target_size + compression_steps - 1
+ assert (
+ target_size <= layer_seq_len <= max_expected_size
+ ), f"{press_cls.__name__}: Layer {layer_idx} cache size {layer_seq_len} not in expected range [{target_size}-{max_expected_size}]" # noqa: E501
+
+
+def test_compression_actually_reduces_memory():
+ """Test that compression actually reduces memory usage compared to no compression."""
+
+ pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
+
+ context = "The quick brown fox jumps over the lazy dog. " * 15 # Long context
+ question = "What animal jumps over the dog?"
+
+ # Run without compression
+ cache_uncompressed = DynamicCache()
+ result_uncompressed = pipe(context, question=question, cache=cache_uncompressed, max_new_tokens=25)
+
+ # Run with compression
+ press = DecodingPress(
+ base_press=KnormPress(compression_ratio=0.3), # Aggressive compression
+ compression_interval=3,
+ target_size=40,
+ )
+ cache_compressed = DynamicCache()
+ result_compressed = pipe(context, question=question, press=press, cache=cache_compressed, max_new_tokens=25)
+
+ # Calculate memory usage (approximate)
+ uncompressed_memory = sum(
+ (cache_layer.values.numel() + cache_layer.keys.numel()) * cache_layer.keys.element_size()
+ for cache_layer in cache_uncompressed.layers
+ )
+ compressed_memory = sum(
+ (cache_layer.values.numel() + cache_layer.keys.numel()) * cache_layer.keys.element_size()
+ for cache_layer in cache_compressed.layers
+ )
+
+ # Compression should significantly reduce memory usage
+ compression_ratio = compressed_memory / uncompressed_memory
+ assert compression_ratio < 0.6, (
+ f"Expected compression ratio < 0.6, but got {compression_ratio:.3f} "
+ f"(compressed: {compressed_memory} bytes, uncompressed: {uncompressed_memory} bytes)"
+ )
+
+ # Both should still generate reasonable answers
+ assert len(result_uncompressed["answer"]) > 0
+ assert len(result_compressed["answer"]) > 0
diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py
index c3e6f0b4..f0686cdd 100644
--- a/tests/test_pipeline.py
+++ b/tests/test_pipeline.py
@@ -155,7 +155,9 @@ def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811
keys = [layer.keys.clone() for layer in past_key_values.layers]
values = [layer.values.clone() for layer in past_key_values.layers]
+ cache_seq_lengths = [past_key_values.get_seq_length(layer_idx) for layer_idx in range(len(past_key_values))]
compression_pipeline.generate_answer(input_ids_question, past_key_values, context_length=22, max_new_tokens=10)
+ compression_pipeline._remove_answer_from_cache(past_key_values, cache_seq_lengths)
assert past_key_values.get_seq_length() == seq_len
assert all([torch.allclose(key, layer.keys) for key, layer in zip(keys, past_key_values.layers)])
assert all([torch.allclose(value, layer.values) for value, layer in zip(values, past_key_values.layers)])