diff --git a/README.md b/README.md index d47c63a2..a16bfdc6 100644 --- a/README.md +++ b/README.md @@ -67,15 +67,47 @@ answer = pipe(context, question=question, press=press)["answer"] In the snippet above, the compression is only applied on the context tokens so that you can evaluate the compression for different questions. Check the [Wikipedia notebook demo](notebooks/wikipedia_demo.ipynb) for a more detailed example (also available on Colab [here](https://colab.research.google.com/drive/1JNvaTKuuAHrl49dYB9-mdEH_y52Ib-NP)). -> [!IMPORTANT] -> We focus on compression during the pre-filling phase as the KV cache becomes a bottleneck for long-context sequence (100k - 1M tokens) which are essentially long context prompts. This would typically apply to improving prompt caching systems. +
+Decoding Compression + +By default, KVPress applies compression during the pre-filling phase. As a new (experimental) feature, we now support decoding compression via the `DecodingPress` wrapper. `DecodingPress` compresses the KV cache periodically during token generation, optionally maintaining a buffer of recent hidden states. `DecodingPress` supports the following parameters: -> [!NOTE] -> Use `model_kwargs={"attn_implementation":"flash_attention_2"}` to enable flash attention. To use the press `ObservedAttentionPress`, you need to specify `model_kwargs={"attn_implementation":"eager"}` as this press requires to materialize the attention weights +- `base_press`: Any ScorerPress (e.g., `KNormPress`, `CriticalKVPress`) +- `compression_interval`: Steps between compressions (default: 10) +- `target_size`: Target cache size of the cache after compression (default: 1024) +- `hidden_states_buffer_size`: Number of hidden states to buffer before compression (default: 128). Some presses don't need buffered hidden states and can set this to 0. -## Contributing +Unlike a compression ratio, decoding press uses a `target_size` to compress the cache. This means that the cache is compressed every `compression_interval` steps, and the compression ratio is automatically computed such that the size of the cache after compression equals `target_size`. -We welcome contributions! To add a new press, simply open an issue or submit a pull request. Check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide. +An example for decoding compression: + +```python +from transformers import pipeline +from kvpress import KnormPress +from kvpress import DecodingPress + +# Initialize the pipeline +device = "cuda:0" +model = "meta-llama/Llama-3.1-8B-Instruct" +model_kwargs = {"attn_implementation": "flash_attention_2"} +pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs) + +# Create a decoding press that compresses every 10 steps to 512 tokens +decoding_press = DecodingPress( + base_press=KnormPress(), + compression_steps=10, + token_buffer_size=512 +) + +# Use with pipeline +context = "A very long text you want to compress during generation" +question = "Tell me a long story about this context" +response = pipe(context, question=question, press=decoding_press)["answer"] +``` + +> Not all existing presses are fully compatible with DecodingPress due to fundamental differences in how compression works during decoding versus prefilling. in particular, we only support ScorerPresses as base presses. + +
## Available presses @@ -111,6 +143,8 @@ Finally we provide wrapper presses that can be combined with other presses: - `ChunkPress` ([source](kvpress/presses/chunk_press.py), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): compress the KV cache on each sequence chunk separately. This can yield to more uniform compression across long sequences - `CriticalKVPress` and `CriticalAdaKVPress` ([source](kvpress/presses/criticalkv_press.py), [paper](https://arxiv.org/abs/2502.03805)): refine the scores using the L1 norm of Wo @ values, coupled with a two-stage selection. - `BlockPress` ([source](kvpress/presses/block_press.py), [paper](https://arxiv.org/abs/2504.15364)): segments input sequence into non-overlapping blocks and compresses iteratively. +- `DeocdingPress` ([source](kvpress/presses/decoding_press.py)): Allows for compression during decoding, see decoding section in this README. +- `PrefillDecodingPress` ([source](kvpress/presses/prefill_decoding_press.py)): Allows to compress both during prefilling and during decoding. For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression) @@ -131,6 +165,7 @@ Below we report the average performance on the RULER dataset with 4k context len Leaderboard

+ ## Quantization We support KV cache quantization through the transformers `QuantizedCache` class (see [HF blog post](https://huggingface.co/blog/kv-cache-quantization#how-to-use-quantized-kv-cache-in-%F0%9F%A4%97-transformers)). To use it, simply pass a cache object to your pipeline: @@ -149,6 +184,10 @@ By default, the `DynamicCache` is used (no quantization). > [!IMPORTANT] > To use the `QuantizedCache`, you need to install additional dependencies (_e.g._ `pip install optimum-quanto`). +## Contributing + +We welcome contributions! To add a new press, simply open an issue or submit a pull request. Check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide. + ## FAQ
@@ -224,4 +263,80 @@ with press(model): However, the `generate` method does not allow to exclude the question from the compression, which would artificially favors methods such as SnapKV. Ideally, we want a compression method that works whatever comes after the context (_e.g._ for use cases such as chat or document question answering). Finally the `generate` method does not allow to provide generation for multiple questions at once. +
+ + + +
+ +### Can I combine compression during prefilling and decoding ? + + + +Combines separate presses for prefilling and decoding phases. + +**Parameters:** +- `prefilling_press`: Press used during prefill phase +- `decoding_press`: Press used during decoding phase + +## Usage Examples + +### Basic Decoding Compression + +```python +from transformers import pipeline +from kvpress import KnormPress +from kvpress import DecodingPress + +# Initialize the pipeline +device = "cuda:0" +model = "meta-llama/Llama-3.1-8B-Instruct" +model_kwargs = {"attn_implementation": "flash_attention_2"} +pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs) + +# Create a decoding press that compresses every 10 steps to 512 tokens +decoding_press = DecodingPress( + base_press=KnormPress(), + compression_steps=10, + token_buffer_size=512 +) + +# Use with pipeline +context = "A very long text you want to compress during generation" +question = "Tell me a long story about this context" +response = pipe(context, question=question, press=decoding_press)["answer"] +``` + +### Combined Prefill + Decoding Compression + +```python +from transformers import pipeline +from kvpress import CriticalKVPress, KnormPress +from kvpress import DecodingPress, PrefillDecodingPress + +# Initialize the pipeline +device = "cuda:0" +model = "meta-llama/Llama-3.1-8B-Instruct" +model_kwargs = {"attn_implementation": "flash_attention_2"} +pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs) + +# Different strategies for prefill vs decoding +prefill_press = CriticalKVPress(KnormPress()) +decoding_press = DecodingPress( + base_press=KnormPress(compression_ratio=0.2), + compression_steps=5, + token_buffer_size=256 +) + +# Combine them +combined_press = PrefillDecodingPress( + prefilling_press=prefill_press, + decoding_press=decoding_press +) + +context = "A very long context that will be compressed during prefill" +question = "Generate a detailed analysis that will be compressed during decoding" +response = pipe(context, question=question, press=combined_press)["answer"] +``` +
\ No newline at end of file diff --git a/kvpress/__init__.py b/kvpress/__init__.py index cccd1f02..19165465 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -1,68 +1,74 @@ -# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - - -from kvpress.attention_patch import patch_attention_functions -from kvpress.pipeline import KVPressTextGenerationPipeline -from kvpress.presses.adakv_press import AdaKVPress -from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress -from kvpress.presses.block_press import BlockPress -from kvpress.presses.chunk_press import ChunkPress -from kvpress.presses.chunkkv_press import ChunkKVPress -from kvpress.presses.composed_press import ComposedPress -from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress -from kvpress.presses.duo_attention_press import DuoAttentionPress -from kvpress.presses.expected_attention_press import ExpectedAttentionPress -from kvpress.presses.expected_attention_with_stats import ExpectedAttentionStatsPress -from kvpress.presses.finch_press import FinchPress -from kvpress.presses.key_rerotation_press import KeyRerotationPress -from kvpress.presses.keydiff_press import KeyDiffPress -from kvpress.presses.knorm_press import KnormPress -from kvpress.presses.kvzip_press import KVzipPress -from kvpress.presses.lagkv_press import LagKVPress -from kvpress.presses.observed_attention_press import ObservedAttentionPress -from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress -from kvpress.presses.pyramidkv_press import PyramidKVPress -from kvpress.presses.qfilter_press import QFilterPress -from kvpress.presses.random_press import RandomPress -from kvpress.presses.scorer_press import ScorerPress -from kvpress.presses.simlayerkv_press import SimLayerKVPress -from kvpress.presses.snapkv_press import SnapKVPress -from kvpress.presses.streaming_llm_press import StreamingLLMPress -from kvpress.presses.think_press import ThinKPress -from kvpress.presses.tova_press import TOVAPress - -# Patch the attention functions to support head-wise compression -patch_attention_functions() - -__all__ = [ - "CriticalAdaKVPress", - "CriticalKVPress", - "AdaKVPress", - "BasePress", - "ComposedPress", - "ScorerPress", - "ExpectedAttentionPress", - "KnormPress", - "ObservedAttentionPress", - "RandomPress", - "SimLayerKVPress", - "SnapKVPress", - "StreamingLLMPress", - "ThinKPress", - "TOVAPress", - "KVPressTextGenerationPipeline", - "PerLayerCompressionPress", - "KeyRerotationPress", - "ChunkPress", - "DuoAttentionPress", - "ChunkKVPress", - "QFilterPress", - "PyramidKVPress", - "FinchPress", - "LagKVPress", - "BlockPress", - "KeyDiffPress", - "KVzipPress", - "ExpectedAttentionStatsPress", -] +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from kvpress.attention_patch import patch_attention_functions +from kvpress.pipeline import KVPressTextGenerationPipeline +from kvpress.presses.adakv_press import AdaKVPress +from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress +from kvpress.presses.block_press import BlockPress +from kvpress.presses.chunk_press import ChunkPress +from kvpress.presses.chunkkv_press import ChunkKVPress +from kvpress.presses.composed_press import ComposedPress +from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress +from kvpress.presses.decoding_press import DecodingPress +from kvpress.presses.duo_attention_press import DuoAttentionPress +from kvpress.presses.expected_attention_press import ExpectedAttentionPress +from kvpress.presses.expected_attention_with_stats import ExpectedAttentionStatsPress +from kvpress.presses.finch_press import FinchPress +from kvpress.presses.key_rerotation_press import KeyRerotationPress +from kvpress.presses.keydiff_press import KeyDiffPress +from kvpress.presses.knorm_press import KnormPress +from kvpress.presses.kvzip_press import KVzipPress +from kvpress.presses.lagkv_press import LagKVPress +from kvpress.presses.observed_attention_press import ObservedAttentionPress +from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress +from kvpress.presses.prefill_decoding_press import PrefillDecodingPress +from kvpress.presses.pyramidkv_press import PyramidKVPress +from kvpress.presses.qfilter_press import QFilterPress +from kvpress.presses.random_press import RandomPress +from kvpress.presses.scorer_press import ScorerPress +from kvpress.presses.simlayerkv_press import SimLayerKVPress +from kvpress.presses.snapkv_press import SnapKVPress +from kvpress.presses.streaming_llm_press import StreamingLLMPress +from kvpress.presses.think_press import ThinKPress +from kvpress.presses.tova_press import TOVAPress + +# Patch the attention functions to support head-wise compression +patch_attention_functions() + +__all__ = [ + "CriticalAdaKVPress", + "CriticalKVPress", + "AdaKVPress", + "BasePress", + "ComposedPress", + "ScorerPress", + "ExpectedAttentionPress", + "KnormPress", + "ObservedAttentionPress", + "RandomPress", + "SimLayerKVPress", + "SnapKVPress", + "StreamingLLMPress", + "ThinKPress", + "TOVAPress", + "KVPressTextGenerationPipeline", + "PerLayerCompressionPress", + "KeyRerotationPress", + "ChunkPress", + "DuoAttentionPress", + "ChunkKVPress", + "QFilterPress", + "PyramidKVPress", + "FinchPress", + "LagKVPress", + "BlockPress", + "KeyDiffPress", + "KVzipPress", + "DecodingPress", + "PrefillDecodingPress", + "ExpectedAttentionStatsPress", + "DecodingPress", + "PrefillDecodingPress", +] diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index 55eb506a..a8765cab 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -12,9 +12,11 @@ from transformers.pipelines.base import GenericTensor from kvpress.presses.base_press import BasePress +from kvpress.presses.decoding_press import DecodingPress from kvpress.presses.finch_press import FinchPress from kvpress.presses.key_rerotation_press import KeyRerotationPress from kvpress.presses.observed_attention_press import ObservedAttentionPress +from kvpress.presses.prefill_decoding_press import PrefillDecodingPress logger = logging.getLogger(__name__) @@ -189,6 +191,11 @@ def _forward( list[str] Generated answers for each input question. """ + if isinstance(press, (DecodingPress, PrefillDecodingPress)) and len(input_tensors["questions_ids"]) > 1: + raise ValueError( + "DecodingPress is not compatible with multiple questions. Please specify a single question." + ) + context_ids = input_tensors["context_ids"].to(self.model.device) context_length = context_ids.shape[1] @@ -196,7 +203,9 @@ def _forward( if cache is None: cache = DynamicCache() - with press(self.model) if press is not None else contextlib.nullcontext(): + # We only perform prefill compression if the press is a prefill press + perform_prefill_compression = press is not None and not isinstance(press, DecodingPress) + with press(self.model) if perform_prefill_compression else contextlib.nullcontext(): # We run the model without the lm head for pre-filling. self.model.model( input_ids=context_ids, @@ -204,24 +213,44 @@ def _forward( output_attentions=self.output_attentions(press), ) - logger.debug(f"Context Length: {context_length}") - logger.debug(f"Compressed Context Length: {cache.get_seq_length()}") + logger.debug(f"Context Length: {context_length}") + logger.debug(f"Compressed Context Length: {cache.get_seq_length()}") + + # We only perform decoding compression if the press is a decoding or prefill decoding press + perform_decoding_compression = press is not None and isinstance(press, (DecodingPress, PrefillDecodingPress)) + with press(self.model) if perform_decoding_compression else contextlib.nullcontext(): + # Greedy decoding for each question + answers = [] + for question_ids in input_tensors["questions_ids"]: + if isinstance(press, KeyRerotationPress) or (isinstance(press, FinchPress) and press.rerotate_keys): + context_length = cache.get_seq_length() + + cache_seq_lengths = [cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))] + answer = self.generate_answer( + question_ids=question_ids.to(self.model.device), + cache=cache, + context_length=context_length, + max_new_tokens=max_new_tokens, + ) + self._remove_answer_from_cache(cache, cache_seq_lengths) + + answers.append(answer) + return answers - # Greedy decoding for each question - answers = [] - for question_ids in input_tensors["questions_ids"]: - if isinstance(press, KeyRerotationPress) or (isinstance(press, FinchPress) and press.rerotate_keys): - context_length = cache.get_seq_length() + def _remove_answer_from_cache(self, cache: Cache, cache_seq_lengths: list[int]): - answer = self.generate_answer( - question_ids=question_ids.to(self.model.device), - cache=cache, - context_length=context_length, - max_new_tokens=max_new_tokens, - ) - answers.append(answer) + for layer_idx, sequence_length in enumerate(cache_seq_lengths): + cache.layers[layer_idx].keys = cache.layers[layer_idx].keys[:, :, :sequence_length] + cache.layers[layer_idx].values = cache.layers[layer_idx].values[:, :, :sequence_length] - return answers + if isinstance(cache, QuantizedCache): + for layer_idx, sequence_length in enumerate(cache_seq_lengths): + cache.layers[layer_idx]._quantized_keys = cache.layers[layer_idx]._quantized_keys[ + :, :, :sequence_length + ] + cache.layers[layer_idx]._quantized_values = cache.layers[layer_idx]._quantized_values[ + :, :, :sequence_length + ] def generate_answer( self, question_ids: torch.Tensor, cache: Cache, context_length: int, max_new_tokens: int @@ -245,7 +274,6 @@ def generate_answer( str The generated answer. """ - cache_seq_lengths = [cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))] position_ids = torch.arange( context_length, context_length + question_ids.shape[1], device=self.model.device ).unsqueeze(0) @@ -276,20 +304,6 @@ def generate_answer( if new_id.item() in should_stop_token_ids: break answer = self.tokenizer.decode(torch.stack(generated_ids), skip_special_tokens=True) - # Remove the generated tokens from the cache - for layer_idx, sequence_length in enumerate(cache_seq_lengths): - cache.layers[layer_idx].keys = cache.layers[layer_idx].keys[:, :, :sequence_length] - cache.layers[layer_idx].values = cache.layers[layer_idx].values[:, :, :sequence_length] - - if isinstance(cache, QuantizedCache): - for layer_idx, sequence_length in enumerate(cache_seq_lengths): - cache.layers[layer_idx]._quantized_keys = cache.layers[layer_idx]._quantized_keys[ - :, :, :sequence_length - ] - cache.layers[layer_idx]._quantized_values = cache.layers[layer_idx]._quantized_values[ - :, :, :sequence_length - ] - return answer def output_attentions(self, press: BasePress): diff --git a/kvpress/presses/adakv_press.py b/kvpress/presses/adakv_press.py index 22393678..815c777a 100644 --- a/kvpress/presses/adakv_press.py +++ b/kvpress/presses/adakv_press.py @@ -55,21 +55,21 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): # Compute scores scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs) - bsz, num_key_value_heads, q_len = scores.shape + bsz, num_key_value_heads, k_len = scores.shape # Make sure to keep at least alpha * (1 - compression_ratio) KV pairs per head - n_kept = int(q_len * (1 - self.compression_ratio)) # ScorerPress definition + n_kept = int(k_len * (1 - self.compression_ratio)) # ScorerPress definition n_safe = int(n_kept * self.alpha_safeguard) top_indices = torch.topk(scores, n_safe, dim=-1).indices scores.scatter_(-1, top_indices, torch.finfo(scores.dtype).max) # Compute bottom-k across heads - n_pruned = num_key_value_heads * (q_len - n_kept) + n_pruned = num_key_value_heads * (k_len - n_kept) indices = torch.topk(-scores.reshape(bsz, -1), n_pruned, dim=1).indices.flatten() # Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details batch_indices = torch.arange(bsz).repeat_interleave(n_pruned) - head_indices = indices // q_len - seq_indices = indices % q_len + head_indices = indices // k_len + seq_indices = indices % k_len module.masked_key_indices = (batch_indices, head_indices, seq_indices) return keys, values diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index 7d40fcd7..8fb4acf2 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -20,6 +20,8 @@ Qwen3ForCausalLM, ) +from kvpress.presses.utils import extract_keys_and_values + logger = logging.getLogger(__name__) SUPPORTED_MODELS = ( @@ -124,24 +126,14 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic hidden_states = kwargs["hidden_states"] cache = kwargs["past_key_values"] + cache_layer = cache.layers[module.layer_idx] q_len = hidden_states.shape[1] # Don't compress after pre-filling if kwargs["cache_position"][-1] > q_len: return output - cache_layer = cache.layers[module.layer_idx] - if isinstance(cache, QuantizedCache): - keys = cache_layer._dequantize( # type: ignore[index] - cache_layer._quantized_keys # type: ignore[index] - ) - values = cache_layer._dequantize( # type: ignore[index] - cache_layer._quantized_values # type: ignore[index] - ) - - else: - keys = cache_layer.keys - values = cache_layer.values + keys, values = extract_keys_and_values(cache, module.layer_idx) keys, values = self.compress(module, hidden_states, keys, values, output[1], kwargs) diff --git a/kvpress/presses/block_press.py b/kvpress/presses/block_press.py index 6d8da788..b4edafdb 100644 --- a/kvpress/presses/block_press.py +++ b/kvpress/presses/block_press.py @@ -58,18 +58,18 @@ def compress( assert attentions is None, "BlockPress does not support attentions." - bsz, num_key_value_heads, q_len, head_dim = keys.shape + bsz, num_key_value_heads, k_len, head_dim = keys.shape - block_size = self.block_size if self.block_size < q_len else q_len - n_kept = int(q_len * (1 - self.compression_ratio)) + block_size = self.block_size if self.block_size < k_len else k_len + n_kept = int(k_len * (1 - self.compression_ratio)) kept_indices = torch.arange(n_kept, device=keys.device).expand(bsz, num_key_value_heads, -1) # Reshape hidden states to match the kept_indices - states = hidden_states.view(bsz, q_len, num_key_value_heads, -1).transpose(1, 2) + states = hidden_states.view(bsz, k_len, num_key_value_heads, -1).transpose(1, 2) - for i in range(n_kept, q_len, block_size): - end = min(i + block_size, q_len) + for i in range(n_kept, k_len, block_size): + end = min(i + block_size, k_len) current_indices = torch.arange(i, end, device=keys.device).expand(bsz, num_key_value_heads, -1) current_indices = torch.cat([kept_indices, current_indices], dim=-1) diff --git a/kvpress/presses/criticalkv_press.py b/kvpress/presses/criticalkv_press.py index 5513bc67..abcb8b83 100644 --- a/kvpress/presses/criticalkv_press.py +++ b/kvpress/presses/criticalkv_press.py @@ -53,7 +53,7 @@ def compression_ratio(self, value): @staticmethod def vwl1norm(values, module): - bsz, num_key_value_heads, q_len, _ = values.shape + bsz, num_key_value_heads, k_len, _ = values.shape num_key_value_groups = module.config.num_attention_heads // num_key_value_heads Wo = module.o_proj.weight.transpose(0, 1) Wo = Wo.view(module.config.num_attention_heads, module.config.head_dim, module.config.hidden_size) @@ -67,16 +67,16 @@ def vwl1norm(values, module): head_WoV_norm = torch.norm(head_WoV, p=1, dim=-1) head_WoV_norm_list.append(head_WoV_norm) - # b_size, num_heads, q_len , k_len + # b_size, num_heads, k_len , k_len WoV_norm = torch.stack(head_WoV_norm_list, dim=1) - WoV_norm = WoV_norm.view(bsz, num_key_value_heads, module.num_key_value_groups, q_len).mean(dim=2) + WoV_norm = WoV_norm.view(bsz, num_key_value_heads, module.num_key_value_groups, k_len).mean(dim=2) return WoV_norm def score(self, module, hidden_states, keys, values, attentions, kwargs): # Stage 1 scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs) - q_len = keys.shape[2] - selection_budget = int((1 - self.compression_ratio) * q_len * self.first_stage_ratio) + k_len = keys.shape[2] + selection_budget = int((1 - self.compression_ratio) * k_len * self.first_stage_ratio) top_k_index = torch.topk(scores, selection_budget, sorted=True, dim=-1).indices # Stage 2 @@ -140,10 +140,10 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): # Compute scores scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs) - bsz, num_key_value_heads, q_len = scores.shape + bsz, num_key_value_heads, k_len = scores.shape # Make sure to keep at least alpha * (1 - compression_ratio) KV pairs per head - n_kept = int(q_len * (1 - self.compression_ratio)) # ScorerPress definition + n_kept = int(k_len * (1 - self.compression_ratio)) # ScorerPress definition n_safe = int(n_kept * self.alpha_safeguard) top_indices = torch.topk(scores, n_safe, dim=-1).indices scores.scatter_(-1, top_indices, torch.finfo(scores.dtype).max) @@ -156,7 +156,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): budget_scores = scores.scatter(-1, top_indices, torch.finfo(scores.dtype).max) budget_scores = budget_scores.reshape(bsz, -1) top_indices = torch.topk(budget_scores, n_kept * num_key_value_heads, dim=-1).indices - top_indices_head_idx = top_indices // q_len + top_indices_head_idx = top_indices // k_len head_budgets = torch.zeros(num_key_value_heads, device=keys.device, dtype=torch.int64) head_budgets.scatter_add_(0, top_indices_head_idx.flatten(), torch.ones_like(top_indices_head_idx.flatten())) @@ -180,12 +180,12 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): ########################## # Compute bottom-k across heads - n_pruned = num_key_value_heads * (q_len - n_kept) + n_pruned = num_key_value_heads * (k_len - n_kept) indices = torch.topk(-scores.reshape(bsz, -1), n_pruned, dim=1).indices.flatten() # Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details batch_indices = torch.arange(bsz).repeat_interleave(n_pruned) - head_indices = indices // q_len - seq_indices = indices % q_len + head_indices = indices // k_len + seq_indices = indices % k_len module.masked_key_indices = (batch_indices, head_indices, seq_indices) return keys, values diff --git a/kvpress/presses/decoding_press.py b/kvpress/presses/decoding_press.py new file mode 100644 index 00000000..ec583ab6 --- /dev/null +++ b/kvpress/presses/decoding_press.py @@ -0,0 +1,220 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections import defaultdict +from dataclasses import dataclass + +import torch +import torch.nn as nn +from transformers.cache_utils import QuantizedCache + +from kvpress.presses.base_press import BasePress +from kvpress.presses.scorer_press import ScorerPress +from kvpress.presses.utils import extract_keys_and_values + +logger = logging.getLogger(__name__) + + +@dataclass +class DecodingPress(BasePress): + """ + A press that only operates during decoding phase and maintains a running buffer of hidden states. + + This press accumulates hidden states during decoding and applies compression every N steps + using a scorer press to determine which tokens to keep. + + + Parameters + ---------- + base_press : ScorerPress + The scorer press used to compute importance scores for tokens. + compression_interval : int, default=10 + Number of decoding steps between compression, i.e. compression will be applied every compression_interval steps. + target_size : int, default=1024 + Target number of tokens to keep after compression. + hidden_states_buffer_size : int, default=128 + Maximum number of hidden states to keep before compression. Larger values use more GPU memory. + Note: Some presses don't need buffered hidden states and can set this to 0 to use only the + current hidden state for compression scoring. + """ + + base_press: ScorerPress + compression_interval: int = 128 + target_size: int = 1024 + hidden_states_buffer_size: int = 128 + + def __post_init__(self): + # Buffer to store hidden states during decoding (per layer) + assert isinstance(self.base_press, ScorerPress), "DecodingPress requires a ScorerPress as input" + self.hidden_states_buffer = defaultdict(list) # Per-layer buffer + self.layer_step_counts = defaultdict(int) # Track step count per layer + + assert self.compression_interval > 0, "compression_interval must be greater than 0" + assert self.target_size > 0, "target_size must be greater than 0" + + if self.base_press.compression_ratio: + logger.warning( + f"compression_ratio is set for base press ({self.base_press.compression_ratio}). " + f"This will be overridden by the decoding press." + ) + + def compress( + self, + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attentions: torch.Tensor, + kwargs: dict, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Delegate compression to the base press during decoding phase. + + Args: + module: The transformer module being compressed + hidden_states: Buffered hidden states from recent decoding steps (shape: [batch, buffer_len, hidden_dim]) + keys: Key cache from all previous steps including current (shape: [batch, n_heads, seq_len, head_dim]) + values: Value cache from all previous steps including current (shape: [batch, n_heads, seq_len, head_dim]) + attentions: Attention weights (shape varies by implementation) + kwargs: Additional keyword arguments + + Returns: + tuple[torch.Tensor, torch.Tensor]: Compressed (keys, values) tensors + + Note: + **Sequence length alignment**: During decoding compression, `hidden_states` contains the + buffered hidden states from recent decoding steps (buffer_len tokens), while `keys` and + `values` contain the full sequence history (seq_len tokens). The base press implementation + should use keys.shape[2] for full sequence length calculations. The buffered hidden_states + provide context for the most recent tokens when computing compression scores. + + Performance Note: + It would be possible to speed up compression during decoding for certain scorer presses by + storing existing scores in a buffer (e.g. KNormPress) and reusing them in subsequent compressions. + """ + k_len = keys.shape[2] + target_compression_ratio = self._find_target_compression_ratio(k_len, self.target_size) + logger.debug(f"Compressing {k_len} to {self.target_size} with ratio {target_compression_ratio}") + + original_compression_ratio = self.base_press.compression_ratio + self.base_press.compression_ratio = target_compression_ratio + result = self.base_press.compress(module, hidden_states, keys, values, attentions, kwargs) + self.base_press.compression_ratio = original_compression_ratio + return result + + def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): + """ + Forward hook that manages decoding-specific compression logic. + + This hook: + 1. Detects when we're in decoding phase (not prefilling) + 2. Accumulates hidden states in a buffer + 3. Applies compression every N steps + 4. Clears the buffer after compression + """ + hidden_states = kwargs["hidden_states"] + cache = kwargs["past_key_values"] + q_len = hidden_states.shape[1] + layer_idx = module.layer_idx + + # Only operate during decoding phase (after prefilling) + if kwargs["cache_position"][-1] <= q_len: + # We're still in prefilling phase, don't do anything + return output + # print(f"Adding hidden states to buffer: {hidden_states.shape}") + # Add current hidden states to buffer for this layer + self.hidden_states_buffer[layer_idx].append(hidden_states.detach().clone()) + + # print(f"Layer step counts: {self.layer_step_counts[layer_idx]}") + self.layer_step_counts[layer_idx] += 1 + + # Apply compression if we've reached the compression step threshold + if (self.layer_step_counts[layer_idx] >= self.compression_interval) or (q_len >= self.target_size): + logger.debug( + f"Applying decoding compression: layer_step_count ({self.layer_step_counts[layer_idx]}) >= compression_steps ({self.compression_interval})" # noqa: E501 + ) + + cache_layer = cache.layers[module.layer_idx] + keys, values = extract_keys_and_values(cache, module.layer_idx) + + # Get attention weights from output + attentions = output[1] if len(output) > 1 and output[1] is not None else None + + # Apply compression using buffered hidden states for this layer + buffered_hidden_states = torch.cat(self.hidden_states_buffer[layer_idx], dim=1) + keys, values = self.compress(module, buffered_hidden_states, keys, values, attentions, kwargs) + logger.debug(f"Applied decoding compression: " f"keys.shape: {keys.shape}, values.shape: {values.shape}") + + # Update cache with compressed keys and values + if isinstance(cache, QuantizedCache): + cache_layer._quantized_keys = cache_layer._quantize(keys, axis=cache_layer.axis_key) + cache_layer._quantized_values = cache_layer._quantize(values, axis=cache_layer.axis_value) + cache_layer.keys = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index] + cache_layer.values = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index] + cache_layer.cumulative_length = keys.shape[2] + else: + cache_layer.keys = keys + cache_layer.values = values + + # Reset step count and clear buffer for this layer + self.layer_step_counts[layer_idx] = 0 + # Always clear the buffer after compression - otherwise there's a mismatch between + # hidden states buffer and kv cache + self.hidden_states_buffer[layer_idx] = [] + + self.hidden_states_buffer[layer_idx] = ( + self.hidden_states_buffer[layer_idx][-self.hidden_states_buffer_size :] + if self.hidden_states_buffer_size > 0 + else [] + ) + return output + + def reset(self): + """Reset the decoding press state.""" + self.hidden_states_buffer = defaultdict(list) + self.layer_step_counts = defaultdict(int) + + def _find_target_compression_ratio(self, q_len: int, target_tokens: int) -> float: + """ + Find the compression ratio that results in exactly target_tokens after int() rounding. + + Args: + q_len: Current sequence length + target_tokens: Desired number of tokens after compression + + Returns: + Compression ratio that gives exactly target_tokens + """ + if q_len <= target_tokens: + return 0.0 + + # Start with theoretical ratio + ratio = 1.0 - (target_tokens / q_len) + + # Binary search to handle int() rounding + low, high = 0.0, 1.0 + max_iterations = 20 + iteration = 0 + + while iteration < max_iterations: + n_kept = int(q_len * (1 - ratio)) + if n_kept == target_tokens: + break + elif n_kept > target_tokens: + # Need more compression + low = ratio + ratio = (ratio + high) / 2 + else: + # Need less compression + high = ratio + ratio = (low + ratio) / 2 + iteration += 1 + + final_n_kept = int(q_len * (1 - ratio)) + if final_n_kept != target_tokens: + logger.warning( + f"Binary search failed: q_len={q_len}, target={target_tokens}, got={final_n_kept}, ratio={ratio}" + ) + + return ratio diff --git a/kvpress/presses/duo_attention_press.py b/kvpress/presses/duo_attention_press.py index b88d2f03..2843f8dd 100644 --- a/kvpress/presses/duo_attention_press.py +++ b/kvpress/presses/duo_attention_press.py @@ -103,9 +103,9 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): raise ValueError( "Streaming mask not initialized. Make sure to call __post_init_from_model__ to initialize this press." ) - q_len = hidden_states.shape[1] + k_len = keys.shape[2] - if (self.head_compression_ratio > 0) or (q_len > (self.sink_size + self.recent_size)): + if (self.head_compression_ratio > 0) or (k_len > (self.sink_size + self.recent_size)): # Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details masked_keys = torch.zeros_like(keys[..., 0], dtype=torch.bool) @@ -114,7 +114,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): # Compute the compression ratio self.compression_ratio_ = self.streaming_mask.float().mean().item() - self.compression_ratio_ *= 1 - (self.sink_size + self.recent_size) / q_len + self.compression_ratio_ *= 1 - (self.sink_size + self.recent_size) / k_len return keys, values diff --git a/kvpress/presses/finch_press.py b/kvpress/presses/finch_press.py index e22952c3..29051b92 100644 --- a/kvpress/presses/finch_press.py +++ b/kvpress/presses/finch_press.py @@ -58,7 +58,7 @@ def score(self, module, hidden_states, keys, values, attentions, kwargs): Similar to SnapKVPress except it adds a normalization step before averaging on the context window. """ - bsz, num_key_value_heads, q_len, _ = keys.shape + bsz, num_key_value_heads, k_len, _ = keys.shape num_key_value_groups = module.config.num_attention_heads // num_key_value_heads if attentions is not None: @@ -69,13 +69,13 @@ def score(self, module, hidden_states, keys, values, attentions, kwargs): ) if self.normalize_scores: - non_zero_counts = torch.arange(q_len - self.window_size, q_len)[None, None, :, None] + non_zero_counts = torch.arange(k_len - self.window_size, k_len)[None, None, :, None] non_zero_counts = non_zero_counts.to(attn_weights.device) attn_weights = attn_weights * non_zero_counts # Average per group scores = attn_weights.mean(dim=-2) - scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len - self.window_size) + scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, k_len - self.window_size) scores = scores.mean(dim=2) # Add back the observation window. Use max score to make sure the window is not pruned. @@ -95,14 +95,14 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): scores = self.score(module, hidden_states, keys, values, attentions, kwargs) # Compute indices to keep (optionally by chunks) - q_len = hidden_states.shape[1] + k_len = keys.shape[2] # Use actual sequence length from keys instead of hidden_states if self.chunk_length is None: - n_kept = int(q_len * (1 - self.compression_ratio)) + n_kept = int(k_len * (1 - self.compression_ratio)) indices = scores.topk(n_kept, dim=-1).indices else: assert self.chunk_length > self.window_size / (1 - self.compression_ratio) indices = [] - for i in range(0, q_len, self.chunk_length): + for i in range(0, k_len, self.chunk_length): chunk_scores = scores[:, :, i : i + self.chunk_length] n_kept = max(1, int(chunk_scores.shape[2] * (1 - self.compression_ratio))) chunk_indices = i + chunk_scores.topk(n_kept, dim=-1).indices diff --git a/kvpress/presses/key_rerotation_press.py b/kvpress/presses/key_rerotation_press.py index 129072f8..a264b3d4 100644 --- a/kvpress/presses/key_rerotation_press.py +++ b/kvpress/presses/key_rerotation_press.py @@ -137,7 +137,7 @@ def compress( scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs) # Get indices of KV pairs with the lowest scores - q_len = hidden_states.shape[1] + q_len = keys.shape[2] n_kept = int(q_len * (1 - self.press.compression_ratio)) indices = scores.topk(n_kept, dim=-1).indices indices = torch.sort(indices, dim=2).values diff --git a/kvpress/presses/kvzip_press.py b/kvpress/presses/kvzip_press.py index 59b565e7..0f13e68a 100644 --- a/kvpress/presses/kvzip_press.py +++ b/kvpress/presses/kvzip_press.py @@ -14,7 +14,7 @@ from transformers.models.llama.modeling_llama import rotate_half from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress -from kvpress.presses.utils import get_query_states +from kvpress.presses.utils import extract_keys_and_values, get_query_states logger = logging.getLogger(__name__) @@ -153,25 +153,16 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic hidden_states = kwargs["hidden_states"] cache = kwargs.get("past_key_values", None) or kwargs.get("past_key_value", None) - cache_layer = cache.layers[module.layer_idx] - if isinstance(cache, QuantizedCache): - keys = cache_layer._dequantize( # type: ignore[index] - cache_layer._quantized_keys # type: ignore[index] - ) - values = cache_layer._dequantize( # type: ignore[index] - cache_layer._quantized_values # type: ignore[index] - ) - else: - keys = cache_layer.keys - values = cache_layer.values + keys, values = extract_keys_and_values(cache, module.layer_idx) # Compute importance scores for KV pairs in the prefilled context, # retaining only the originally prefilled KV pairs. keys, values = self.score_kvzip(module, hidden_states, keys, values, output[1], kwargs) if isinstance(cache, QuantizedCache): + # Update cache with compressed keys and values cache_layer._quantized_keys = cache_layer._quantize(keys, axis=cache_layer.axis_key) cache_layer._quantized_values = cache_layer._quantize(values, axis=cache_layer.axis_value) cache_layer.keys = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index] diff --git a/kvpress/presses/prefill_decoding_press.py b/kvpress/presses/prefill_decoding_press.py new file mode 100644 index 00000000..85b5d465 --- /dev/null +++ b/kvpress/presses/prefill_decoding_press.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PreTrainedModel + +from kvpress.presses.base_press import BasePress + +from .decoding_press import DecodingPress + +logger = logging.getLogger(__name__) + + +@dataclass +class PrefillDecodingPress(BasePress): + """ + A wrapper press that combines separate prefilling and decoding compression strategies. + + This press acts as a single press interface but internally delegates to different + presses based on the current phase (prefilling vs decoding). During prefilling, + it uses the prefilling_press. During decoding, it uses the decoding_press. + + Parameters + ---------- + prefilling_press : BasePress, optional + Press to use during the prefilling phase. If None, no compression is applied during prefilling. + decoding_press : DecodingPress, optional + Press to use during the decoding phase. If None, no compression is applied during decoding. + """ + + prefilling_press: Optional[BasePress] = None + decoding_press: Optional[DecodingPress] = None + + def compress( + self, + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attentions: torch.Tensor, + kwargs: dict, + ) -> tuple[torch.Tensor, torch.Tensor]: + q_len = hidden_states.shape[1] + + # Determine if we're in prefilling or decoding phase + if kwargs["cache_position"][-1] <= q_len and self.prefilling_press is not None: + return self.prefilling_press.compress(module, hidden_states, keys, values, attentions, kwargs) + elif self.decoding_press is not None: + return self.decoding_press.compress(module, hidden_states, keys, values, attentions, kwargs) + + # No compression applied + logger.warning("No compression applied during prefill or decoding phase") + + return keys, values + + def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): + """ + Forward hook that delegates to the appropriate press based on current phase. + """ + hidden_states = kwargs["hidden_states"] + q_len = hidden_states.shape[1] + + # Determine if we're in prefilling or decoding phase + if kwargs["cache_position"][-1] <= q_len and self.prefilling_press is not None: + return self.prefilling_press.forward_hook(module, input, kwargs, output) + elif self.decoding_press is not None: + return self.decoding_press.forward_hook(module, input, kwargs, output) + + # No hook applied + return output + + @contextmanager + def __call__(self, model: PreTrainedModel): + try: + with super().__call__(model): + yield + finally: + # Reset decoding press if it exists + if self.decoding_press is not None: + self.decoding_press.reset() diff --git a/kvpress/presses/pyramidkv_press.py b/kvpress/presses/pyramidkv_press.py index 14b6c03b..006f83ce 100644 --- a/kvpress/presses/pyramidkv_press.py +++ b/kvpress/presses/pyramidkv_press.py @@ -100,8 +100,8 @@ def compress( scores = self.score(module, hidden_states, keys, values, attentions, kwargs) # Get indices of KV pairs with the lowest scores - q_len = hidden_states.shape[1] - n_kept = self.get_layer_budget(module, q_len) + k_len = keys.shape[2] + n_kept = self.get_layer_budget(module, k_len) indices = scores.topk(n_kept, dim=-1).indices indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) diff --git a/kvpress/presses/scorer_press.py b/kvpress/presses/scorer_press.py index 5c84b379..f7f33d16 100644 --- a/kvpress/presses/scorer_press.py +++ b/kvpress/presses/scorer_press.py @@ -90,8 +90,8 @@ def compress( scores = self.score(module, hidden_states, keys, values, attentions, kwargs) # Get indices of KV pairs with the lowest scores - q_len = hidden_states.shape[1] - n_kept = int(q_len * (1 - self.compression_ratio)) + k_len = keys.shape[2] + n_kept = int(k_len * (1 - self.compression_ratio)) indices = scores.topk(n_kept, dim=-1).indices indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) diff --git a/kvpress/presses/simlayerkv_press.py b/kvpress/presses/simlayerkv_press.py index 86086e7c..3c114814 100644 --- a/kvpress/presses/simlayerkv_press.py +++ b/kvpress/presses/simlayerkv_press.py @@ -94,13 +94,13 @@ def compress( self.compression_ratios = [] # Check if compression is needed - q_len = hidden_states.shape[1] + k_len = keys.shape[2] min_length = self.n_initial + self.n_recent + self.n_last - if q_len <= min_length: + if k_len <= min_length: logger.warning(f"Sequence length is shorter than {min_length}: no compression applied") - if (self.lazy_threshold == 1.0) or (q_len <= min_length): + if (self.lazy_threshold == 1.0) or (k_len <= min_length): self.compression_ratios.append(0.0) return keys, values @@ -109,7 +109,7 @@ def compress( # If layer is lazy, only keep the initial and recent KV pairs keys = torch.cat([keys[:, :, : self.n_initial], keys[:, :, -self.n_recent + self.n_last :]], dim=2) values = torch.cat([values[:, :, : self.n_initial], values[:, :, -self.n_recent + self.n_last :]], dim=2) - self.compression_ratios.append((q_len - self.n_initial - self.n_recent + 1) / q_len) + self.compression_ratios.append((k_len - self.n_initial - self.n_recent + 1) / k_len) else: self.compression_ratios.append(0.0) diff --git a/kvpress/presses/snapkv_press.py b/kvpress/presses/snapkv_press.py index 68f92af9..1f3a46fd 100644 --- a/kvpress/presses/snapkv_press.py +++ b/kvpress/presses/snapkv_press.py @@ -44,7 +44,7 @@ def compute_window_attention(module, hidden_states, keys, window_size, position_ Compute the last window_size queries and associated attention weights for the first q_len - window_size keys. """ - bsz, q_len, _ = hidden_states.shape + bsz, _, k_len, _ = keys.shape num_heads = module.config.num_attention_heads head_dim = module.head_dim num_key_value_groups = num_heads // module.config.num_key_value_heads @@ -61,7 +61,7 @@ def compute_window_attention(module, hidden_states, keys, window_size, position_ key_states = repeat_kv(keys, num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim) attention_mask = torch.ones_like(attn_weights) * float("-inf") - attention_mask = torch.triu(attention_mask, diagonal=q_len - window_size + 1) + attention_mask = torch.triu(attention_mask, diagonal=k_len - window_size + 1) attn_weights += attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = attn_weights[..., :-window_size] @@ -78,10 +78,12 @@ def score( kwargs, ) -> torch.Tensor: - bsz, num_key_value_heads, q_len, _ = keys.shape + bsz, num_key_value_heads, k_len, _ = keys.shape num_key_value_groups = module.config.num_attention_heads // num_key_value_heads - assert q_len > self.window_size, "Query length should be greater than the window size" + assert ( + hidden_states.shape[1] > self.window_size + ), f"Query length {hidden_states.shape[1]} should be greater than the window size {self.window_size}" if attentions is not None: attn_weights = attentions[..., -self.window_size :, : -self.window_size] @@ -94,7 +96,7 @@ def score( scores = F.avg_pool1d(scores, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1) # Average per group (https://github.com/FasterDecoding/SnapKV/issues/22) - scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len - self.window_size) + scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, k_len - self.window_size) scores = scores.mean(2) # Add back the observation window. Use max score to make sure the window is not pruned. diff --git a/kvpress/presses/streaming_llm_press.py b/kvpress/presses/streaming_llm_press.py index 4c0ac218..6d657389 100644 --- a/kvpress/presses/streaming_llm_press.py +++ b/kvpress/presses/streaming_llm_press.py @@ -44,9 +44,9 @@ def score( kwargs, ) -> torch.Tensor: - q_len = hidden_states.shape[1] - assert q_len > self.n_sink, f"Input should contain more tokens than n_sink={self.n_sink}" - n_pruned = q_len - int(q_len * (1 - self.compression_ratio)) + k_len = keys.shape[2] + assert k_len > self.n_sink, f"Input should contain more tokens than n_sink={self.n_sink}" + n_pruned = k_len - int(k_len * (1 - self.compression_ratio)) scores = torch.ones_like(keys[..., 0]) scores[:, :, self.n_sink : self.n_sink + n_pruned] = 0 diff --git a/kvpress/presses/think_press.py b/kvpress/presses/think_press.py index a19379c2..f277bb5c 100644 --- a/kvpress/presses/think_press.py +++ b/kvpress/presses/think_press.py @@ -72,7 +72,7 @@ def compress( return keys, values # Compute scores per dimension - bsz, num_key_value_heads, q_len, head_dim = keys.shape + bsz, num_key_value_heads, k_len, head_dim = keys.shape num_key_value_groups = module.config.num_attention_heads // num_key_value_heads queries = self.compute_window_queries(module, kwargs["hidden_states"], kwargs["position_embeddings"]) @@ -84,7 +84,7 @@ def compress( # Prune dimensions with the lowest scores by setting them to 0 n_pruned = int(head_dim * self.key_channel_compression_ratio) indices = key_scores.topk(n_pruned, dim=-1, largest=False).indices - indices = indices.unsqueeze(2).expand(-1, -1, q_len, -1) + indices = indices.unsqueeze(2).expand(-1, -1, k_len, -1) keys = keys.scatter_(-1, indices, 0) return keys, values diff --git a/kvpress/presses/utils.py b/kvpress/presses/utils.py index 938c15ba..3614fa5b 100644 --- a/kvpress/presses/utils.py +++ b/kvpress/presses/utils.py @@ -3,6 +3,7 @@ import torch from torch import nn +from transformers import Cache, QuantizedCache from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention from transformers.models.phi3.modeling_phi3 import Phi3Attention from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention @@ -50,3 +51,22 @@ def get_query_states(module: nn.Module, hidden_states: torch.Tensor) -> torch.Te query_states = module.q_norm(query_states) return query_states + + +def dequantize_layer(cache_layer) -> tuple[torch.Tensor, torch.Tensor]: + keys = cache_layer._dequantize(cache_layer._quantized_keys) + values = cache_layer._dequantize(cache_layer._quantized_values) + return keys, values + + +def extract_keys_and_values(cache: Cache, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Extracts the keys and values from a given cache layer, + handling both quantized and unquantized caches. + """ + if isinstance(cache, QuantizedCache): + keys, values = dequantize_layer(cache.layers[layer_idx]) + else: + keys = cache.layers[layer_idx].keys + values = cache.layers[layer_idx].values + return keys, values diff --git a/notebooks/kvpress_decoding_aime25.ipynb b/notebooks/kvpress_decoding_aime25.ipynb new file mode 100644 index 00000000..55ccd8bd --- /dev/null +++ b/notebooks/kvpress_decoding_aime25.ipynb @@ -0,0 +1,671 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d5b82a94ac3477fd", + "metadata": {}, + "source": [ + "# Using decoding compression on the AIME25 Math Dataset\n", + "\n", + "This notebook demonstrates how to compress during text generation.\n", + "We use `nvidia/OpenMath-Nemotron-7B` to solve math problems from the AIME25 dataset. For each problem, the model generates an answer in a boxed format (e.g., `\\boxed{42}`).\n", + "\n", + "To optimize memory usage during long-context generation, the notebook applies key-value cache compression during decoding.\n", + "Compression periodically reduces the cache size by keeping only the most relevant tokens, enabling efficient inference without sacrificing answer quality." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "4cd1c9e43c5ca1bf", + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "\n", + "import torch\n", + "from transformers import pipeline\n", + "\n", + "from kvpress import (\n", + " ExpectedAttentionPress,\n", + " KnormPress,\n", + " ObservedAttentionPress,\n", + " RandomPress,\n", + " SnapKVPress,\n", + " StreamingLLMPress,\n", + " TOVAPress,\n", + ")\n", + "\n", + "from kvpress.presses.decoding_press import DecodingPress" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "553f7d7f-7a2f-456d-a4b1-f38f3ead8767", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "dataset = load_dataset(\"math-ai/aime25\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5e912ac7-e5d8-42ff-8f9e-855b8ae999a8", + "metadata": {}, + "outputs": [], + "source": [ + "sample = dataset[\"test\"][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "f8c2689d-e461-4608-aa69-a4811f71d639", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'problem': 'Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$',\n", + " 'answer': '70',\n", + " 'id': '0'}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sample" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "02b0af83-e71a-4129-9692-921d2bc16cb8", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bd9b941a02394e7d936e018a25c4628b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:009$ for which $17_b$ is a divisor of $97_b.$\n", + "Answer: 70\n", + "Prediction: \n", + "Okay, so I need to find all integer bases b greater than 9 where the number 17 in base b divides the number 97 in base b. Then, sum all those bases. Hmm, let me think step by step.\n", + "\n", + "First, I should recall how numbers are represented in different bases. The number 17 in base b is equal to 1 times b plus 7 times 1, right? So that's 1*b + 7. Similarly, 97 in base b would be 9*b + 7. So, translating both numbers to base 10, we have:\n", + "\n", + "17_b = b + 7\n", + "\n", + "97_b = 9b + 7\n", + "\n", + "The problem states that 17_b must divide 97_b. So, in mathematical terms, this means that (9b + 7) divided by (b + 7) should result in an integer. So, (9b + 7) must be divisible by (b + 7). \n", + "\n", + "Let me write that as an equation:\n", + "\n", + "(9b + 7) ÷ (b + 7) = integer.\n", + "\n", + "To find when this division results in an integer, maybe I can perform the division and see what the remainder is. If the remainder is zero, then it's divisible. Let's try polynomial division or maybe manipulate the expression.\n", + "\n", + "Let me rewrite 9b + 7. Let's see, if I write 9b + 7 as 9*(b + 7) minus something. Let's compute:\n", + "\n", + "9*(b + 7) = 9b + 63\n", + "\n", + "But 9b + 7 is the original numerator. So, subtracting 9*(b + 7) from 9b +7 gives:\n", + "\n", + "9b +7 - (9b +63) = 7 - 63 = -56\n", + "\n", + "Therefore, 9b +7 = 9*(b +7) -56\n", + "\n", + "So, (9b +7)/(b +7) = 9 - 56/(b +7)\n", + "\n", + "For this to be an integer, 56/(b +7) must be an integer. Therefore, (b +7) must be a divisor of 56. \n", + "\n", + "But since b is a base greater than 9, the digits in the numbers 17_b and 97_b must be valid. In base b, the digits can go from 0 to b-1. So, in 17_b, the digits are 1 and 7. Since the base is greater than 9, 7 is a valid digit. Similarly, in 97_b, the digits are 9 and 7. So, 9 must be less than b. Therefore, the base b must be greater than 9, which is already given. So, the constraints are:\n", + "\n", + "1. b > 9\n", + "\n", + "2. (b +7) divides 56.\n", + "\n", + "So, first, let's find all divisors of 56. The positive divisors of 56 are:\n", + "\n", + "1, 2, 4, 7, 8, 14, 28, 56.\n", + "\n", + "But since b +7 must be one of these divisors, and b >9, then b +7 must be a divisor of 56 greater than 9 +7 =16. Wait, because b >9, so b >=10, so b +7 >=17. Therefore, the divisors of 56 that are greater than or equal to 17 are 28 and 56. Wait, let's check:\n", + "\n", + "Divisors of 56: 1, 2, 4, 7, 8, 14, 28, 56.\n", + "\n", + "So, divisors >=17 are 28 and 56. Therefore, b +7 can be 28 or 56. Therefore, solving for b:\n", + "\n", + "If b +7 =28, then b=21\n", + "\n", + "If b +7=56, then b=49\n", + "\n", + "So, the possible bases are 21 and 49. Then, the sum of these bases is 21 +49=70.\n", + "\n", + "Wait, but let me check if there are any other divisors. Wait, 14 is a divisor, but 14 is less than 17 (since b >=10, so b +7 >=17). So 14 is too small. Similarly, 8,7, etc. So only 28 and 56. Therefore, the possible bases are 21 and 49. So their sum is 70.\n", + "\n", + "But wait, let me verify this. Let's check for b=21:\n", + "\n", + "17 in base 21 is 1*21 +7=28\n", + "\n", + "97 in base21 is 9*21 +7=196\n", + "\n", + "Check if 28 divides 196. 196 divided by 28 is 7. Yes, that's an integer. So that works.\n", + "\n", + "For b=49:\n", + "\n", + "17 in base49 is 1*49 +7=56\n", + "\n", + "97 in base49 is 9*49 +7=448\n", + "\n", + "448 divided by 56 is 8. That's also an integer. So that works too.\n", + "\n", + "So, both bases are valid. Therefore, the sum is 21 +49=70.\n", + "\n", + "Wait, but hold on. Let me check if there are any negative divisors. But since b is a base, it must be a positive integer greater than 9, so b +7 is positive. Therefore, we don't need to consider negative divisors. So, yes, only 28 and 56. Therefore, the answer is 70.\n", + "\n", + "But let me just make sure I didn't miss any divisors. Let's list all divisors again:\n", + "\n", + "Positive divisors of 56: 1, 2, 4, 7, 8, 14, 28, 56.\n", + "\n", + "Negative divisors would be -1, -2, -4, -7, -8, -14, -28, -56. But since b +7 must be positive (as b>9), we can ignore the negative ones. So, only 28 and 56. Therefore, the answer is 70.\n", + "\n", + "But wait, another thought: when we say that (9b +7) is divisible by (b +7), we can also approach this by setting up the equation (9b +7) = k*(b +7), where k is an integer. Then, solving for b:\n", + "\n", + "9b +7 = k*b +7k\n", + "\n", + "Bring terms with b to one side:\n", + "\n", + "9b -k*b = 7k -7\n", + "\n", + "b*(9 -k) = 7(k -1)\n", + "\n", + "Therefore, b = [7(k -1)] / (9 -k)\n", + "\n", + "Since b must be an integer greater than 9, we can look for integer values of k such that 9 -k divides 7(k -1), and the result is b>9.\n", + "\n", + "Let me see. Let's solve for k:\n", + "\n", + "b = 7(k -1)/(9 -k)\n", + "\n", + "We need 9 -k to divide 7(k -1). Let's denote d = 9 -k, so k =9 -d. Then:\n", + "\n", + "b =7( (9 -d) -1 ) / d =7(8 -d)/d =7*(8/d -1)\n", + "\n", + "Since b must be a positive integer greater than 9, d must be a positive divisor of 56 (since 7*(8/d -1) must be integer). Wait, maybe this approach complicates things. Let's try substituting possible integer values of k.\n", + "\n", + "Alternatively, since k must be an integer such that 9 -k divides 7(k -1). Let's note that 9 -k and k -1 are related. Let's set m =9 -k, so k =9 -m. Then:\n", + "\n", + "b =7*( (9 -m) -1 ) / m =7*(8 -m)/m =7*(8/m -1)\n", + "\n", + "So, 8/m must be an integer because b must be integer. Therefore, m must be a positive divisor of 8. But m =9 -k, and since k is an integer, m can be positive or negative. However, since b must be positive, let's see:\n", + "\n", + "If m is positive, then 8/m -1 must be positive as well because 7*(8/m -1) =b>9. So:\n", + "\n", + "8/m -1 >0 =>8/m >1 =>m <8. But m is a positive divisor of 8. The positive divisors of 8 are 1,2,4,8. So m can be 1,2,4,8. But m must be less than 8 (from 8/m >1). So m=1,2,4.\n", + "\n", + "Wait, if m=8, then 8/m -1=1 -1=0, so b=0, which is invalid. So m=1,2,4.\n", + "\n", + "Let's check each:\n", + "\n", + "Case 1: m=1\n", + "\n", + "Then, b=7*(8/1 -1)=7*(7)=49. So b=49. Which is valid, as before.\n", + "\n", + "Case 2: m=2\n", + "\n", + "b=7*(8/2 -1)=7*(4 -1)=7*3=21. So b=21. Also valid.\n", + "\n", + "Case3: m=4\n", + "\n", + "b=7*(8/4 -1)=7*(2 -1)=7*1=7. But b=7 is not greater than 9, so invalid.\n", + "\n", + "So only m=1 and m=2 give valid bases, which are 49 and 21. So sum is 70. So same result as before.\n", + "\n", + "Alternatively, if m is negative, then m divides 8, but m negative. Let's see:\n", + "\n", + "If m is a negative divisor of 8, then m can be -1,-2,-4,-8.\n", + "\n", + "Then, let's check:\n", + "\n", + "Case m=-1:\n", + "\n", + "b=7*(8/(-1) -1)=7*(-8 -1)=7*(-9)=-63. Negative base is invalid.\n", + "\n", + "Similarly, m=-2:\n", + "\n", + "b=7*(8/(-2) -1)=7*(-4 -1)=7*(-5)=-35. Invalid.\n", + "\n", + "m=-4:\n", + "\n", + "b=7*(8/(-4) -1)=7*(-2 -1)=7*(-3)=-21. Invalid.\n", + "\n", + "m=-8:\n", + "\n", + "b=7*(8/(-8) -1)=7*(-1 -1)=7*(-2)=-14. Invalid.\n", + "\n", + "So all negative m give invalid bases. Therefore, only m=1 and m=2, leading to b=49 and 21. So sum is 70.\n", + "\n", + "Therefore, the answer is 70. So I think that's correct. Let me just check once more.\n", + "\n", + "Alternatively, maybe there's another approach. Let's consider that (9b +7)/(b +7) must be integer. Let's call this integer k. So:\n", + "\n", + "(9b +7) =k*(b +7)\n", + "\n", + "So, 9b +7 =k*b +7k\n", + "\n", + "Rearranged:\n", + "\n", + "(9 -k)*b =7k -7\n", + "\n", + "So, b=(7k -7)/(9 -k)\n", + "\n", + "We need b to be an integer greater than 9. So, (7k -7) must be divisible by (9 -k). Let's factor 7:\n", + "\n", + "7(k -1)/(9 -k) = -7(k -1)/(k -9)\n", + "\n", + "So, b= -7(k -1)/(k -9)\n", + "\n", + "We can write this as 7(1 -k)/(k -9). Hmm, not sure if that helps. Let's look for integer k such that (9 -k) divides 7(k -1). Let's denote d=9 -k, so k=9 -d. Then:\n", + "\n", + "b=7(k -1)/d=7(8 -d)/d=7*(8/d -1)\n", + "\n", + "So, 8/d must be integer, so d must divide 8. So d is a divisor of 8. Since d=9 -k, and k is an integer, d can be any integer (positive or negative) that divides 8. But since b must be positive and greater than 9, let's see:\n", + "\n", + "If d is positive, then 8/d must be integer, so d is a positive divisor of 8: 1,2,4,8.\n", + "\n", + "Then, for each d:\n", + "\n", + "d=1:\n", + "\n", + "b=7*(8/1 -1)=7*7=49. Valid.\n", + "\n", + "d=2:\n", + "\n", + "b=7*(8/2 -1)=7*(4 -1)=21. Valid.\n", + "\n", + "d=4:\n", + "\n", + "b=7*(8/4 -1)=7*(2 -1)=7. Not valid (b>9 required).\n", + "\n", + "d=8:\n", + "\n", + "b=7*(8/8 -1)=7*(1 -1)=0. Invalid.\n", + "\n", + "If d is negative, then d divides 8, so d=-1,-2,-4,-8.\n", + "\n", + "For d=-1:\n", + "\n", + "b=7*(8/(-1) -1)=7*(-8 -1)=7*(-9)=-63. Invalid.\n", + "\n", + "Similarly, other negative d's give negative b's. So only d=1 and d=2 give valid bases. So again, 49 and 21. Sum is 70.\n", + "\n", + "Therefore, I think the answer is 70. So I can be confident that the sum is 70.\n", + "\n", + "**Final Answer**\n", + "\\boxed{70}\n", + "To solve the problem, we need to find all integer bases \\( b > 9 \\) for which \\( 17_b \\) is a divisor of \\( 97_b \\).\n", + "\n", + "First, we convert the numbers from base \\( b \\) to base 10:\n", + "- \\( 17_b = 1 \\cdot b + 7 = b + 7 \\)\n", + "- \\( 97_b = 9 \\cdot b + 7 = 9b + 7 \\)\n", + "\n", + "We need \\( b + 7 \\) to divide \\( 9b + 7 \\). This can be expressed as:\n", + "\\[\n", + "\\frac{9b + 7}{b + 7} = k \\quad \\text{for some integer } k\n", + "\\]\n", + "\n", + "Rewriting the equation, we get:\n", + "\\[\n", + "9b + 7 = k(b + 7)\n", + "\\]\n", + "\\[\n", + "9b + 7 = kb + 7k\n", + "\\]\n", + "\\[\n", + "9b - kb = 7k - 7\n", + "\\]\n", + "\\[\n", + "b(9 - k) = 7(k - 1)\n", + "\\]\n", + "\\[\n", + "b = \\frac{7(k - 1)}{9 - k}\n", + "\\]\n", + "\n", + "For \\( b \\) to be an integer, \\( 9 - k \\) must be a divisor of \\( 7(k - 1) \\). We need to find the values of \\( k \\) such that \\( b > 9 \\).\n", + "\n", + "Let's consider the possible values of \\( k \\) by checking the divisors of 56 (since \\( 9b + 7 = 9(b + 7) - 56 \\)):\n", + "\n", + "The positive divisors of 56 are: 1, 2, 4, 7, 8, 14, 28, 56.\n", + "\n", + "Since \\( b > 9 \\), \\( b + 7 \\) must be greater than 16. Therefore, the valid divisors are 28 and 56.\n", + "\n", + "1. If \\( b + 7 = 28 \\):\n", + " \\[\n", + " b = 28 - 7 = 21\n", + " \\]\n", + "\n", + "2. If \\( b + 7 = 56 \\):\n", + " \\[\n", + " b = 56 - 7 = 49\n", + " \\]\n", + "\n", + "We verify these values:\n", + "- For \\( b = 21 \\):\n", + " \\[\n", + " 17_{21} = 21 + 7 = 28\n", + " \\]\n", + " \\[\n", + " 97_{21} = 9 \\cdot 21 + 7 = 196\n", + " \\]\n", + " \\[\n", + " 196 \\div 28 = 7 \\quad \\text{(an integer)}\n", + " \\]\n", + "\n", + "- For \\( b = 49 \\):\n", + " \\[\n", + " 17_{49} = 49 + 7 = 56\n", + " \\]\n", + " \\[\n", + " 97_{49} = 9 \\cdot 49 + 7 = 448\n", + " \\]\n", + " \\[\n", + " 448 \\div 56 = 8 \\quad \\text{(an integer)}\n", + " \\]\n", + "\n", + "Both values of \\( b \\) are valid. The sum of these bases is:\n", + "\\[\n", + "21 + 49 = 70\n", + "\\]\n", + "\n", + "Thus, the sum of all integer bases \\( b > 9 \\) for which \\( 17_b \\) is a divisor of \\( 97_b \\) is:\n", + "\\[\n", + "\\boxed{70}\n", + "\\]\n", + "Cache size: torch.Size([1, 4, 3781, 128])\n" + ] + } + ], + "source": [ + "cache = DynamicCache()\n", + "question = sample[\"problem\"]\n", + "true_answer = sample[\"answer\"]\n", + "pred_answer = pipe(\" \", question=question, press=None, cache=cache, max_new_tokens=16_000)[\"answer\"]\n", + "\n", + "print(f\"Question: {question}\")\n", + "print(f\"Answer: {true_answer}\")\n", + "print(f\"Prediction: {pred_answer}\")\n", + "print(f\"Cache size: {cache.layers[0].keys.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "7f4fa366-7e62-443a-b5c8-274128fe6237", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Question: Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$\n", + "Answer: 70\n", + "Prediction: \n", + "Okay, so I need to find all integer bases b greater than 9 where the number 17 in base b divides the number 97 in base b. Then, sum all those bases. Hmm, let me think step by step.\n", + "\n", + "First, I should recall how numbers are represented in different bases. The number 17 in base b is equal to 1 times b plus 7 times 1, right? So that's 1*b + 7. Similarly, 97 in base b would be 9*b + 7. So, translating both numbers to base 10, we have:\n", + "\n", + "17_b = b + 7\n", + "\n", + "97_b = 9b + 7\n", + "\n", + "The problem states that 17_b must divide 97_b. So, in mathematical terms, this means that (9b + 7) divided by (b + 7) should result in an integer. So, (9b + 7) must be divisible by (b + 7). \n", + "\n", + "Let me write that as an equation:\n", + "\n", + "(9b + 7) ÷ (b + 7) = integer.\n", + "\n", + "To find when this division results in an integer, maybe I can perform the division and see what the remainder is. If the remainder is zero, then it's divisible. Let's try polynomial division or maybe manipulate the expression.\n", + "\n", + "Let me rewrite 9b + 7. Let's see, if I write 9b + 7 as 9*(b + 7) minus something. Let's compute:\n", + "\n", + "9*(b + 7) = 9b + 63\n", + "\n", + "But 9b + 7 is the original numerator. So, subtracting 9*(b + 7) from 9b +7 gives:\n", + "\n", + "9b +7 - (9b +63) = 7 - 63 = -56\n", + "\n", + "Therefore, 9b +7 = 9*(b +7) -56\n", + "\n", + "So, (9b +7)/(b +7) = 9 - 56/(b +7)\n", + "\n", + "For this to be an integer, 56/(b +7) must be an integer. Therefore, (b +7) must be a divisor of 56. \n", + "\n", + "But since b is a base greater than 9, the digits in the numbers 17_b and 97_b must be valid. In base b, the digits can go from 0 to b-1. So, in 17_b, the digits are 1 and 7. Since the base is greater than 9, 7 is a valid digit. Similarly, in 97_b, the digits are 9 and 7. So, 9 must be less than b. Therefore, the base b must be greater than 9, which is already given. So, the constraints are:\n", + "\n", + "1. b > 9\n", + "\n", + "2. (b +7) divides 56.\n", + "\n", + "So, first, let's find all divisors of 56. The positive divisors of 56 are:\n", + "\n", + "1, 2, 4, 7, 8, 14, 28, 56.\n", + "\n", + "But since b +7 must be one of these divisors, and b >9, then b +7 must be a divisor of 56 greater than 9 +7 =16. Wait, because b >9, so b >=10, so b +7 >=17. Therefore, the divisors of 56 that are greater than or equal to 17 are 28 and 56. Wait, let's check:\n", + "\n", + "Divisors of 56: 1, 2, 4, 7, 8, 14, 28, 56.\n", + "\n", + "So, divisors >=17 are 28 and 56. Therefore, b +7 can be 28 or 56. Therefore, solving for b:\n", + "\n", + "If b +7 =28, then b=21\n", + "\n", + "If b +7=56, then b=49\n", + "\n", + "So, the possible bases are 21 and 49. Then, the sum of these bases is 21 +49=70.\n", + "\n", + "Wait, but let me check if there are any other divisors. Wait, 14 is a divisor, but 14 is less than 17 (since b >=10, so b +7 >=17). So 14 is too small. Similarly, 8,7, etc. So only 28 and 56. Therefore, the possible bases are 21 and 49. So their sum is 70.\n", + "\n", + "But wait, let me verify this. Let's check for b=21:\n", + "\n", + "17 in base 21 is 21 +7=28\n", + "\n", + "97 in base21 is 9*21 +7=189 +7=196\n", + "\n", + "196 divided by28 is 7, which is an integer. So that works.\n", + "\n", + "For b=49:\n", + "\n", + "17 in base49 is 49 +7=56\n", + "\n", + "97 in base49 is 9*49 +7=441 +7=448\n", + "\n", + "448 divided by56 is 8, which is an integer. So that works too.\n", + "\n", + "So both bases are valid. Therefore, the sum is 21 +49=70.\n", + "\n", + "Wait, but let me check if there are any other divisors. Wait, 56 is a divisor of 56, but 56 is 56. So, if b +7=56, then b=49. Which we already have. Similarly, 28 gives b=21. Are there any negative divisors? But since b is a base, it must be a positive integer greater than 9, so negative divisors don't make sense here. So, yes, only 28 and 56. Therefore, the answer is 70.\n", + "\n", + "But wait, let me check if there's a mistake here. Let me think again. The problem says \"all bases b>9\". So, the possible divisors of 56 are 1,2,4,7,8,14,28,56. But since b +7 must be one of these, and b>9, so b +7 must be at least 17. So, the possible divisors are 28 and 56. Therefore, b=21 and 49. So sum is 70. That seems correct.\n", + "\n", + "But let me check if there's another way. Let me think. Suppose we have (9b +7)/(b +7) must be integer. Let me compute this fraction:\n", + "\n", + "(9b +7)/(b +7) = 9 - 56/(b +7). Wait, how?\n", + "\n", + "Let me do the division:\n", + "\n", + "Divide 9b +7 by b +7.\n", + "\n", + "Dividing 9b by b gives 9. Multiply (b +7) by 9: 9b +63. Subtract that from 9b +7: (9b +7) - (9b +63) = -56. So, the division gives 9 with a remainder of -56. Therefore, (9b +7) = 9*(b +7) -56. Therefore, (9b +7)/(b +7) = 9 -56/(b +7). For this to be an integer, 56/(b +7) must be an integer. Therefore, (b +7) must be a divisor of 56. Which is the same conclusion as before.\n", + "\n", + "Therefore, the possible values of (b +7) are the positive divisors of 56 that are greater than or equal to 17 (since b>9 implies b +7>16). The divisors of 56 are 1,2,4,7,8,14,28,56. So, the ones greater than 16 are 28 and 56. Therefore, b +7=28 gives b=21, and b +7=56 gives b=49. So, sum is 21 +49=70.\n", + "\n", + "Therefore, the answer is 70. Let me check once more with another example. Suppose b=21:\n", + "\n", + "17 in base21 is 2*10 +7=27? Wait, no. Wait, in base b, 17_b is 1*b +7. So, 1*21 +7=28. 97_b is 9*21 +7=189 +7=196. 196 divided by28 is 7. Correct.\n", + "\n", + "For b=49:\n", + "\n", + "17_b is 1*49 +7=56. 97_b is 9*49 +7=441 +7=448. 448 divided by56 is 8. Correct.\n", + "\n", + "So, both bases work. Therefore, the sum is 21 +49=70. So, the answer is 70.\n", + "\n", + "**Final Answer**\n", + "\\boxed{70}\n", + "To solve the problem, we need to find all bases \\( b > 9 \\) such that \\( 17_b \\) divides \\( 97_b \\).\n", + "\n", + "First, we convert the numbers from base \\( b \\) to base 10:\n", + "- \\( 17_b \\) in base 10 is \\( 1 \\cdot b + 7 = b + 7 \\).\n", + "- \\( 97_b \\) in base 10 is \\( 9 \\cdot b + 7 = 9b + 7 \\).\n", + "\n", + "We need \\( 9b + 7 \\) to be divisible by \\( b + 7 \\). This can be expressed as:\n", + "\\[\n", + "\\frac{9b + 7}{b + 7} \\text{ must be an integer.}\n", + "\\]\n", + "\n", + "Rewriting the fraction:\n", + "\\[\n", + "\\frac{9b + 7}{b + 7} = \\frac{9(b + 7) - 56}{b + 7} = 9 - \\frac{56}{b + 7}.\n", + "\\]\n", + "\n", + "For this to be an integer, \\( \\frac{56}{b + 7} \\) must be an integer. Therefore, \\( b + 7 \\) must be a divisor of 56. The divisors of 56 are:\n", + "\\[\n", + "1, 2, 4, 7, 8, 14, 28, 56.\n", + "\\]\n", + "\n", + "Since \\( b > 9 \\), we need \\( b + 7 > 16 \\). The valid divisors of 56 that are greater than 16 are:\n", + "\\[\n", + "28 \\quad \\text{and} \\quad 56.\n", + "\\]\n", + "\n", + "Thus, we have:\n", + "\\[\n", + "b + 7 = 28 \\quad \\Rightarrow \\quad b = 21,\n", + "\\]\n", + "\\[\n", + "b + 7 = 56 \\quad \\Rightarrow \\quad b = 49.\n", + "\\]\n", + "\n", + "We verify these bases:\n", + "- For \\( b = 21 \\):\n", + " \\[\n", + " 17_{21} = 21 + 7 = 28,\n", + " \\]\n", + " \\[\n", + " 97_{21} = 9 \\cdot 21 + 7 = 196.\n", + " \\]\n", + " Since \\( 196 \\div 28 = 7 \\), \\( 17_{21} \\) divides \\( 97_{21} \\).\n", + "\n", + "- For \\( b = 49 \\):\n", + " \\[\n", + " 17_{49} = 49 + 7 = 56,\n", + " \\]\n", + " \\[\n", + " 97_{49} = 9 \\cdot 49 + 7 = 448.\n", + " \\]\n", + " Since \\( 448 \\div 56 = 8 \\), \\( 17_{49} \\) divides \\( 97_{49} \\).\n", + "\n", + "The valid bases are \\( b = 21 \\) and \\( b = 49 \\). Summing these bases:\n", + "\\[\n", + "21 + 49 = 70.\n", + "\\]\n", + "\n", + "Thus, the sum of all bases \\( b > 9 \\) where \\( 17_b \\) divides \\( 97_b \\) is:\n", + "\\[\n", + "\\boxed{70}.\n", + "\\]\n", + "Cache size: torch.Size([1, 4, 1182, 128])\n" + ] + } + ], + "source": [ + "compression_interval = 1024 # compress every compression_steps\n", + "target_size = 512 # number of tokens to keep after compression. Note that actual cache size lies in [target_size, compression_interval]\n", + "\n", + "\n", + "press = DecodingPress(base_press=ExpectedAttentionPress(), compression_interval=compression_interval, target_size=target_size)\n", + "\n", + "cache = DynamicCache()\n", + "question = sample[\"problem\"]\n", + "true_answer = sample[\"answer\"]\n", + "pred_answer = pipe(\" \", question=question, press=press, cache=cache, max_new_tokens=16_000)[\"answer\"]\n", + "\n", + "print(f\"Question: {question}\")\n", + "print(f\"Answer: {true_answer}\")\n", + "print(f\"Prediction: {pred_answer}\")\n", + "print(f\"Cache size: {cache.layers[0].keys.shape}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "kvpress", + "language": "python", + "name": "kvpress" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/fixtures.py b/tests/fixtures.py index 9ebceb19..3323ee60 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -94,7 +94,7 @@ def kv_press_llama3_2_flash_attn_pipeline(): "kv-press-text-generation", model=ckpt, device=device, - model_kwargs={"attn_implementation": attn_implementation, "torch_dtype": torch.bfloat16}, + model_kwargs={"attn_implementation": attn_implementation, "dtype": torch.bfloat16}, ) return pipe @@ -108,6 +108,6 @@ def kv_press_qwen3_flash_attn_pipeline(): "kv-press-text-generation", model=ckpt, device=device, - model_kwargs={"attn_implementation": attn_implementation, "torch_dtype": torch.bfloat16}, + model_kwargs={"attn_implementation": attn_implementation, "dtype": torch.bfloat16}, ) return pipe diff --git a/tests/integration/test_ruler.py b/tests/integration/test_ruler.py index 43cfe290..155cf6dd 100644 --- a/tests/integration/test_ruler.py +++ b/tests/integration/test_ruler.py @@ -6,6 +6,7 @@ import torch from transformers import DynamicCache, QuantoQuantizedCache from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available + from kvpress import QFilterPress from tests.default_presses import default_presses from tests.fixtures import kv_press_llama3_2_flash_attn_pipeline, kv_press_qwen3_flash_attn_pipeline # noqa: F401 @@ -58,12 +59,9 @@ def test_ruler_is_correct( # QFilterPress doesn't support Qwen3 4B. Will be tested in the next test class. return else: - pred_answer = kv_press_qwen3_flash_attn_pipeline( - context, - question=question, - press=press, - cache=cache - )["answer"] + pred_answer = kv_press_qwen3_flash_attn_pipeline(context, question=question, press=press, cache=cache)[ + "answer" + ] assert true_answer in pred_answer @@ -102,10 +100,7 @@ def test_ruler_is_correct_for_qfilter( question = df_ruler.iloc[idx]["question"] true_answer = df_ruler.iloc[idx]["answer"][0] - pred_answer = kv_press_llama3_2_flash_attn_pipeline( - context, - question=question, - press=press, - cache=cache - )["answer"] + pred_answer = kv_press_llama3_2_flash_attn_pipeline(context, question=question, press=press, cache=cache)[ + "answer" + ] assert true_answer in pred_answer diff --git a/tests/test_decoding_compression.py b/tests/test_decoding_compression.py new file mode 100644 index 00000000..b36dae4c --- /dev/null +++ b/tests/test_decoding_compression.py @@ -0,0 +1,301 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Test script to verify that DecodingPress actually compresses during decoding. +""" +import logging + +import pytest +import torch +from transformers import DynamicCache, pipeline + +from kvpress import PyramidKVPress, ScorerPress +from kvpress.presses.decoding_press import DecodingPress +from kvpress.presses.knorm_press import KnormPress +from kvpress.presses.prefill_decoding_press import PrefillDecodingPress +from tests.default_presses import default_presses + +logger = logging.getLogger(__name__) + + +@pytest.mark.parametrize("token_buffer_size", [32, 64, 128]) +def test_decoding_compression(token_buffer_size): + """Test that DecodingPress compresses the cache during decoding.""" + + # Initialize pipeline with a small model + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + # Create a DecodingPress with KnormPress + press = DecodingPress( + base_press=KnormPress(compression_ratio=0.5), # Remove 50% of tokens + compression_interval=4, # Compress every 4 tokens + target_size=token_buffer_size, + ) + + # Create cache + cache = DynamicCache() + + # Test context and question + context = "The quick brown fox jumps over the lazy dog. " * 10 # Repeat for longer context + question = "What animal jumps over the dog?" + + # Run pipeline + pipe(context, question=question, press=press, cache=cache, max_new_tokens=20) + + # Assert that all layers have the expected cache size + for layer_idx, cache_layer in enumerate(cache.layers): + layer_seq_len = cache_layer.keys.shape[2] + # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger + max_expected_size = token_buffer_size + press.compression_interval - 1 + assert layer_seq_len <= max_expected_size, ( + f"Layer {layer_idx}: Expected cache sequence length to be between {token_buffer_size} " + f"and {max_expected_size}, but got {layer_seq_len}" + ) + + +def test_prefill_decoding_press_calls_both_phases(): + """Test that PrefillDecodingPress calls both prefilling and decoding presses.""" + + # Initialize pipeline + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + # Create PrefillDecodingPress with both presses + combined_press = PrefillDecodingPress( + prefilling_press=KnormPress(compression_ratio=0.6), # Compress to 60% during prefill + decoding_press=DecodingPress(base_press=KnormPress(), compression_interval=3, target_size=48), + ) + + # Test context and question + context = "The quick brown fox jumps over the lazy dog. " * 12 # Longer context + question = "What animal jumps over the dog?" + + # Run pipeline + cache = DynamicCache() + pipe(context, question=question, press=combined_press, cache=cache, max_new_tokens=15) + + # Check that cache was compressed during both phases + # Final cache should be compressed to decoding press target size + for layer_idx, cache_layer in enumerate(cache.layers): + layer_seq_len = cache_layer.keys.shape[2] + # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger + target_size = 48 # token_buffer_size from decoding press + compression_steps = 3 # from the decoding press configuration + max_expected_size = target_size + compression_steps - 1 + assert target_size <= layer_seq_len <= max_expected_size, ( + f"Layer {layer_idx}: Expected final cache size to be between {target_size} " + f"and {max_expected_size} (decoding target), but got {layer_seq_len}" + ) + + +def test_decoding_press_without_prefill(): + """Test that DecodingPress works correctly when used standalone (no prefill compression).""" + + # Initialize pipeline + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + # Create DecodingPress only + decoding_press = DecodingPress(base_press=KnormPress(compression_ratio=0.4), compression_interval=5, target_size=64) + + # Test context and question + context = "The quick brown fox jumps over the lazy dog. " * 8 + question = "What animal jumps over the dog?" + + # Run pipeline + cache = DynamicCache() + pipe(context, question=question, press=decoding_press, cache=cache, max_new_tokens=25) + + # Check that cache was compressed during decoding + for layer_idx, cache_layer in enumerate(cache.layers): + layer_seq_len = cache_layer.keys.shape[2] + # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger + target_size = 64 + compression_steps = 5 # from the decoding press configuration + max_expected_size = target_size + compression_steps - 1 + assert target_size <= layer_seq_len <= max_expected_size, ( + f"Layer {layer_idx}: Expected cache size to be between {target_size} " + f"and {max_expected_size}, but got {layer_seq_len}" + ) + + +def test_prefill_decoding_press_decoding_only(): + """Test PrefillDecodingPress with only decoding press (no prefill compression).""" + + # Initialize pipeline + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + # Create PrefillDecodingPress with only decoding press + combined_press = PrefillDecodingPress( + prefilling_press=None, + decoding_press=DecodingPress( + base_press=KnormPress(compression_ratio=0.6), compression_interval=4, target_size=56 + ), + ) + + # Test context and question + context = "The quick brown fox jumps over the lazy dog. " * 9 + question = "What animal jumps over the dog?" + + # Run pipeline + cache = DynamicCache() + pipe(context, question=question, press=combined_press, cache=cache, max_new_tokens=12) + + # Check that only decoding compression was applied + for layer_idx, cache_layer in enumerate(cache.layers): + layer_seq_len = cache_layer.keys.shape[2] + # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger + target_size = 56 + compression_steps = 4 # from the decoding press configuration + max_expected_size = target_size + compression_steps - 1 + assert target_size <= layer_seq_len <= max_expected_size, ( + f"Layer {layer_idx}: Expected cache size to be between {target_size} " + f"and {max_expected_size}, but got {layer_seq_len}" + ) + + +def test_decoding_press_equivalence(): + """Test that DecodingPress standalone yields same result as PrefillDecodingPress with decoding only.""" + + # Set random seed for reproducibility + torch.manual_seed(42) + + # Initialize pipeline + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + # Create standalone decoding press + decoding_press = DecodingPress(base_press=KnormPress(compression_ratio=0.5), compression_interval=3, target_size=52) + + # Create PrefillDecodingPress with only decoding press + combined_press = PrefillDecodingPress( + prefilling_press=None, + decoding_press=DecodingPress( + base_press=KnormPress(compression_ratio=0.5), compression_interval=3, target_size=52 + ), + ) + + # Test context and question + context = "The quick brown fox jumps over the lazy dog. " * 7 + question = "What animal jumps over the dog?" + + # Run with standalone decoding press + cache1 = DynamicCache() + result1 = pipe(context, question=question, press=decoding_press, cache=cache1, max_new_tokens=10) + + # Run with combined press (decoding only) + cache2 = DynamicCache() + result2 = pipe(context, question=question, press=combined_press, cache=cache2, max_new_tokens=10) + + # Compare cache sizes (should be identical) + for layer_idx in range(len(cache1.layers)): + cache1_size = cache1.layers[layer_idx].keys.shape[2] + cache2_size = cache2.layers[layer_idx].keys.shape[2] + assert cache1_size == cache2_size, ( + f"Layer {layer_idx}: Standalone decoding cache size {cache1_size} != " + f"combined press cache size {cache2_size}" + ) + + # Compare generated text results (should be identical) + assert result1["answer"] == result2["answer"], ( + f"Generated answers differ:\n" + f"Standalone decoding: '{result1['answer']}'\n" + f"Combined press: '{result2['answer']}'" + ) + + +""" +E AttributeError: 'QFilterPress' object has no attribute 'q_filters' +E Failed: DecodingPress failed with SnapKVPress: shape '[1, 2, 2, 6]' is invalid for input of size 12 +> query_states = query_states.view(bsz, window_size, num_heads, head_dim).transpose(1, 2) +E RuntimeError: shape '[1, 2, 2, 6]' is invalid for input of size 12 +""" + + +@pytest.mark.parametrize("press_config", default_presses) +def test_all_presses_work_with_decoding_press(press_config): + """Test that all default presses work as base presses for DecodingPress.""" + + # Initialize pipeline + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + # Get press class and use the first (easier) configuration + press_cls = press_config["cls"] + press_kwargs = press_config["kwargs"][0] # Use easier compression settings + + base_press = press_cls(**press_kwargs) + if not isinstance(base_press, ScorerPress): + logger.info(f"Press {press_cls.__name__} is not a ScorerPress, skipping test") + return + if isinstance(base_press, (PyramidKVPress)): + # PyramidKVPress -> Pyramid shape, not compatible with token_buffer_size=48 + logger.info(f"Press {press_cls.__name__} is not supported, skipping test") + return + if hasattr(base_press, "__post_init_from_model__"): + base_press.__post_init_from_model__(pipe.model) + + # Create DecodingPress with this base press + decoding_press = DecodingPress(base_press=base_press, compression_interval=3, target_size=48) + + # Test context and question + context = "The quick brown fox jumps over the lazy dog. " * 8 + question = "What animal jumps over the dog?" + + # Run pipeline + cache = DynamicCache() + result = pipe(context, question=question, press=decoding_press, cache=cache, max_new_tokens=15) + + # Verify compression worked + assert len(result["answer"]) > 0, f"No answer generated with {press_cls.__name__}" + + # Check that cache was compressed (allow some tolerance for rounding) + for layer_idx, cache_layer in enumerate(cache.layers): + layer_seq_len = cache_layer.keys.shape[2] + # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger + target_size = 48 + compression_steps = 3 # from the decoding press configuration + max_expected_size = target_size + compression_steps - 1 + assert ( + target_size <= layer_seq_len <= max_expected_size + ), f"{press_cls.__name__}: Layer {layer_idx} cache size {layer_seq_len} not in expected range [{target_size}-{max_expected_size}]" # noqa: E501 + + +def test_compression_actually_reduces_memory(): + """Test that compression actually reduces memory usage compared to no compression.""" + + pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") + + context = "The quick brown fox jumps over the lazy dog. " * 15 # Long context + question = "What animal jumps over the dog?" + + # Run without compression + cache_uncompressed = DynamicCache() + result_uncompressed = pipe(context, question=question, cache=cache_uncompressed, max_new_tokens=25) + + # Run with compression + press = DecodingPress( + base_press=KnormPress(compression_ratio=0.3), # Aggressive compression + compression_interval=3, + target_size=40, + ) + cache_compressed = DynamicCache() + result_compressed = pipe(context, question=question, press=press, cache=cache_compressed, max_new_tokens=25) + + # Calculate memory usage (approximate) + uncompressed_memory = sum( + (cache_layer.values.numel() + cache_layer.keys.numel()) * cache_layer.keys.element_size() + for cache_layer in cache_uncompressed.layers + ) + compressed_memory = sum( + (cache_layer.values.numel() + cache_layer.keys.numel()) * cache_layer.keys.element_size() + for cache_layer in cache_compressed.layers + ) + + # Compression should significantly reduce memory usage + compression_ratio = compressed_memory / uncompressed_memory + assert compression_ratio < 0.6, ( + f"Expected compression ratio < 0.6, but got {compression_ratio:.3f} " + f"(compressed: {compressed_memory} bytes, uncompressed: {uncompressed_memory} bytes)" + ) + + # Both should still generate reasonable answers + assert len(result_uncompressed["answer"]) > 0 + assert len(result_compressed["answer"]) > 0 diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index c3e6f0b4..f0686cdd 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -155,7 +155,9 @@ def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811 keys = [layer.keys.clone() for layer in past_key_values.layers] values = [layer.values.clone() for layer in past_key_values.layers] + cache_seq_lengths = [past_key_values.get_seq_length(layer_idx) for layer_idx in range(len(past_key_values))] compression_pipeline.generate_answer(input_ids_question, past_key_values, context_length=22, max_new_tokens=10) + compression_pipeline._remove_answer_from_cache(past_key_values, cache_seq_lengths) assert past_key_values.get_seq_length() == seq_len assert all([torch.allclose(key, layer.keys) for key, layer in zip(keys, past_key_values.layers)]) assert all([torch.allclose(value, layer.values) for value, layer in zip(values, past_key_values.layers)])