Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 121 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,47 @@ answer = pipe(context, question=question, press=press)["answer"]

In the snippet above, the compression is only applied on the context tokens so that you can evaluate the compression for different questions. Check the [Wikipedia notebook demo](notebooks/wikipedia_demo.ipynb) for a more detailed example (also available on Colab [here](https://colab.research.google.com/drive/1JNvaTKuuAHrl49dYB9-mdEH_y52Ib-NP)).

> [!IMPORTANT]
> We focus on compression during the pre-filling phase as the KV cache becomes a bottleneck for long-context sequence (100k - 1M tokens) which are essentially long context prompts. This would typically apply to improving prompt caching systems.
<details><summary>
Decoding Compression
</summary>
By default, KVPress applies compression during the pre-filling phase. As a new (experimental) feature, we now support decoding compression via the `DecodingPress` wrapper. `DecodingPress` compresses the KV cache periodically during token generation, optionally maintaining a buffer of recent hidden states. `DecodingPress` supports the following parameters:

> [!NOTE]
> Use `model_kwargs={"attn_implementation":"flash_attention_2"}` to enable flash attention. To use the press `ObservedAttentionPress`, you need to specify `model_kwargs={"attn_implementation":"eager"}` as this press requires to materialize the attention weights
- `base_press`: Any ScorerPress (e.g., `KNormPress`, `CriticalKVPress`)
- `compression_interval`: Steps between compressions (default: 10)
- `target_size`: Target cache size of the cache after compression (default: 1024)
- `hidden_states_buffer_size`: Number of hidden states to buffer before compression (default: 128). Some presses don't need buffered hidden states and can set this to 0.

## Contributing
Unlike a compression ratio, decoding press uses a `target_size` to compress the cache. This means that the cache is compressed every `compression_interval` steps, and the compression ratio is automatically computed such that the size of the cache after compression equals `target_size`.

We welcome contributions! To add a new press, simply open an issue or submit a pull request. Check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide.
An example for decoding compression:

```python
from transformers import pipeline
from kvpress import KnormPress
from kvpress import DecodingPress

# Initialize the pipeline
device = "cuda:0"
model = "meta-llama/Llama-3.1-8B-Instruct"
model_kwargs = {"attn_implementation": "flash_attention_2"}
pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)

# Create a decoding press that compresses every 10 steps to 512 tokens
decoding_press = DecodingPress(
base_press=KnormPress(),
compression_steps=10,
token_buffer_size=512
)

# Use with pipeline
context = "A very long text you want to compress during generation"
question = "Tell me a long story about this context"
response = pipe(context, question=question, press=decoding_press)["answer"]
```

> Not all existing presses are fully compatible with DecodingPress due to fundamental differences in how compression works during decoding versus prefilling. in particular, we only support ScorerPresses as base presses.

</details>

## Available presses

Expand Down Expand Up @@ -111,6 +143,8 @@ Finally we provide wrapper presses that can be combined with other presses:
- `ChunkPress` ([source](kvpress/presses/chunk_press.py), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): compress the KV cache on each sequence chunk separately. This can yield to more uniform compression across long sequences
- `CriticalKVPress` and `CriticalAdaKVPress` ([source](kvpress/presses/criticalkv_press.py), [paper](https://arxiv.org/abs/2502.03805)): refine the scores using the L1 norm of Wo @ values, coupled with a two-stage selection.
- `BlockPress` ([source](kvpress/presses/block_press.py), [paper](https://arxiv.org/abs/2504.15364)): segments input sequence into non-overlapping blocks and compresses iteratively.
- `DeocdingPress` ([source](kvpress/presses/decoding_press.py)): Allows for compression during decoding, see decoding section in this README.
- `PrefillDecodingPress` ([source](kvpress/presses/prefill_decoding_press.py)): Allows to compress both during prefilling and during decoding.

For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)

Expand All @@ -131,6 +165,7 @@ Below we report the average performance on the RULER dataset with 4k context len
<img src="evaluation/assets/leaderboard_plot_score.png" alt="Leaderboard">
</p>


## Quantization

We support KV cache quantization through the transformers `QuantizedCache` class (see [HF blog post](https://huggingface.co/blog/kv-cache-quantization#how-to-use-quantized-kv-cache-in-%F0%9F%A4%97-transformers)). To use it, simply pass a cache object to your pipeline:
Expand All @@ -149,6 +184,10 @@ By default, the `DynamicCache` is used (no quantization).
> [!IMPORTANT]
> To use the `QuantizedCache`, you need to install additional dependencies (_e.g._ `pip install optimum-quanto`).

## Contributing

We welcome contributions! To add a new press, simply open an issue or submit a pull request. Check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide.

## FAQ

<details><summary>
Expand Down Expand Up @@ -224,4 +263,80 @@ with press(model):

However, the `generate` method does not allow to exclude the question from the compression, which would artificially favors methods such as SnapKV. Ideally, we want a compression method that works whatever comes after the context (_e.g._ for use cases such as chat or document question answering). Finally the `generate` method does not allow to provide generation for multiple questions at once.

</details>



<details><summary>

### Can I combine compression during prefilling and decoding ?
</summary>


Combines separate presses for prefilling and decoding phases.

**Parameters:**
- `prefilling_press`: Press used during prefill phase
- `decoding_press`: Press used during decoding phase

## Usage Examples

### Basic Decoding Compression

```python
from transformers import pipeline
from kvpress import KnormPress
from kvpress import DecodingPress

# Initialize the pipeline
device = "cuda:0"
model = "meta-llama/Llama-3.1-8B-Instruct"
model_kwargs = {"attn_implementation": "flash_attention_2"}
pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)

# Create a decoding press that compresses every 10 steps to 512 tokens
decoding_press = DecodingPress(
base_press=KnormPress(),
compression_steps=10,
token_buffer_size=512
)

# Use with pipeline
context = "A very long text you want to compress during generation"
question = "Tell me a long story about this context"
response = pipe(context, question=question, press=decoding_press)["answer"]
```

### Combined Prefill + Decoding Compression

```python
from transformers import pipeline
from kvpress import CriticalKVPress, KnormPress
from kvpress import DecodingPress, PrefillDecodingPress

# Initialize the pipeline
device = "cuda:0"
model = "meta-llama/Llama-3.1-8B-Instruct"
model_kwargs = {"attn_implementation": "flash_attention_2"}
pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)

# Different strategies for prefill vs decoding
prefill_press = CriticalKVPress(KnormPress())
decoding_press = DecodingPress(
base_press=KnormPress(compression_ratio=0.2),
compression_steps=5,
token_buffer_size=256
)

# Combine them
combined_press = PrefillDecodingPress(
prefilling_press=prefill_press,
decoding_press=decoding_press
)

context = "A very long context that will be compressed during prefill"
question = "Generate a detailed analysis that will be compressed during decoding"
response = pipe(context, question=question, press=combined_press)["answer"]
```

</details>
142 changes: 74 additions & 68 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,74 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


from kvpress.attention_patch import patch_attention_functions
from kvpress.pipeline import KVPressTextGenerationPipeline
from kvpress.presses.adakv_press import AdaKVPress
from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress
from kvpress.presses.block_press import BlockPress
from kvpress.presses.chunk_press import ChunkPress
from kvpress.presses.chunkkv_press import ChunkKVPress
from kvpress.presses.composed_press import ComposedPress
from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress
from kvpress.presses.duo_attention_press import DuoAttentionPress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
from kvpress.presses.expected_attention_with_stats import ExpectedAttentionStatsPress
from kvpress.presses.finch_press import FinchPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.keydiff_press import KeyDiffPress
from kvpress.presses.knorm_press import KnormPress
from kvpress.presses.kvzip_press import KVzipPress
from kvpress.presses.lagkv_press import LagKVPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
from kvpress.presses.pyramidkv_press import PyramidKVPress
from kvpress.presses.qfilter_press import QFilterPress
from kvpress.presses.random_press import RandomPress
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.simlayerkv_press import SimLayerKVPress
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.tova_press import TOVAPress

# Patch the attention functions to support head-wise compression
patch_attention_functions()

__all__ = [
"CriticalAdaKVPress",
"CriticalKVPress",
"AdaKVPress",
"BasePress",
"ComposedPress",
"ScorerPress",
"ExpectedAttentionPress",
"KnormPress",
"ObservedAttentionPress",
"RandomPress",
"SimLayerKVPress",
"SnapKVPress",
"StreamingLLMPress",
"ThinKPress",
"TOVAPress",
"KVPressTextGenerationPipeline",
"PerLayerCompressionPress",
"KeyRerotationPress",
"ChunkPress",
"DuoAttentionPress",
"ChunkKVPress",
"QFilterPress",
"PyramidKVPress",
"FinchPress",
"LagKVPress",
"BlockPress",
"KeyDiffPress",
"KVzipPress",
"ExpectedAttentionStatsPress",
]
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


from kvpress.attention_patch import patch_attention_functions
from kvpress.pipeline import KVPressTextGenerationPipeline
from kvpress.presses.adakv_press import AdaKVPress
from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress
from kvpress.presses.block_press import BlockPress
from kvpress.presses.chunk_press import ChunkPress
from kvpress.presses.chunkkv_press import ChunkKVPress
from kvpress.presses.composed_press import ComposedPress
from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress
from kvpress.presses.decoding_press import DecodingPress
from kvpress.presses.duo_attention_press import DuoAttentionPress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
from kvpress.presses.expected_attention_with_stats import ExpectedAttentionStatsPress
from kvpress.presses.finch_press import FinchPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.keydiff_press import KeyDiffPress
from kvpress.presses.knorm_press import KnormPress
from kvpress.presses.kvzip_press import KVzipPress
from kvpress.presses.lagkv_press import LagKVPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
from kvpress.presses.prefill_decoding_press import PrefillDecodingPress
from kvpress.presses.pyramidkv_press import PyramidKVPress
from kvpress.presses.qfilter_press import QFilterPress
from kvpress.presses.random_press import RandomPress
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.simlayerkv_press import SimLayerKVPress
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.tova_press import TOVAPress

# Patch the attention functions to support head-wise compression
patch_attention_functions()

__all__ = [
"CriticalAdaKVPress",
"CriticalKVPress",
"AdaKVPress",
"BasePress",
"ComposedPress",
"ScorerPress",
"ExpectedAttentionPress",
"KnormPress",
"ObservedAttentionPress",
"RandomPress",
"SimLayerKVPress",
"SnapKVPress",
"StreamingLLMPress",
"ThinKPress",
"TOVAPress",
"KVPressTextGenerationPipeline",
"PerLayerCompressionPress",
"KeyRerotationPress",
"ChunkPress",
"DuoAttentionPress",
"ChunkKVPress",
"QFilterPress",
"PyramidKVPress",
"FinchPress",
"LagKVPress",
"BlockPress",
"KeyDiffPress",
"KVzipPress",
"DecodingPress",
"PrefillDecodingPress",
"ExpectedAttentionStatsPress",
"DecodingPress",
"PrefillDecodingPress",
]
Loading
Loading