From 15e6ed5bc366778ead2f9645db8a6458783c4f2d Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 8 Apr 2026 14:09:49 +0000 Subject: [PATCH 1/8] Fix KVComposePress compression ratio accuracy and cap transformers<5.3 Replace quantile-based threshold with topk ranking in compute_important_per_layer to match the pattern used by all other presses and avoid overshooting the target compression ratio due to quantile interpolation. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: SimJeg --- kvpress/presses/kvcompose_press.py | 13 ++++++++----- pyproject.toml | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/kvpress/presses/kvcompose_press.py b/kvpress/presses/kvcompose_press.py index 642c90fe..0c3232c1 100644 --- a/kvpress/presses/kvcompose_press.py +++ b/kvpress/presses/kvcompose_press.py @@ -273,11 +273,14 @@ def compute_important_per_layer(self): """ self.compute_composite_scores() - threshold_head = self.composite_scores_per_head.quantile(self.compression_ratio) - self.important_per_head = (self.composite_scores_per_head >= threshold_head).sum(dim=-1).cpu().numpy() - - threshold_layer = self.composite_scores_per_layer.quantile(self.compression_ratio) - self.important_per_layer = (self.composite_scores_per_layer >= threshold_layer).sum(dim=-1).cpu().numpy() + n_kept = int(self.composite_scores_per_head.numel() * (1 - self.compression_ratio)) + kept = self.composite_scores_per_head.reshape(-1).topk(n_kept).indices // self.context_len + bins = self.num_layers * self.num_kv_heads + self.important_per_head = torch.bincount(kept, minlength=bins).reshape(self.num_layers, self.num_kv_heads).cpu().numpy() + + n_kept = int(self.composite_scores_per_layer.numel() * (1 - self.compression_ratio)) + kept = self.composite_scores_per_layer.reshape(-1).topk(n_kept).indices // self.context_len + self.important_per_layer = torch.bincount(kept, minlength=self.num_layers).cpu().numpy() def prepare_important_masks(self): """ diff --git a/pyproject.toml b/pyproject.toml index de1fe1bc..1a555d23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ readme = "README.md" dependencies = [ "numpy>=2.0.0,<3", "torch>=2.3.1,<3", - "transformers>=4.56.0", + "transformers>=4.56.0,<5.3", "datasets>=2.21.0", "pandas>=2.2.2,<3", "accelerate>=1.0.0,<2", From efc294b6a07a715fcdeb10b373078a5bd91bf977 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 8 Apr 2026 14:21:15 +0000 Subject: [PATCH 2/8] Fix style Signed-off-by: SimJeg --- kvpress/presses/kvcompose_press.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/kvpress/presses/kvcompose_press.py b/kvpress/presses/kvcompose_press.py index 0c3232c1..601d763a 100644 --- a/kvpress/presses/kvcompose_press.py +++ b/kvpress/presses/kvcompose_press.py @@ -276,7 +276,9 @@ def compute_important_per_layer(self): n_kept = int(self.composite_scores_per_head.numel() * (1 - self.compression_ratio)) kept = self.composite_scores_per_head.reshape(-1).topk(n_kept).indices // self.context_len bins = self.num_layers * self.num_kv_heads - self.important_per_head = torch.bincount(kept, minlength=bins).reshape(self.num_layers, self.num_kv_heads).cpu().numpy() + self.important_per_head = ( + torch.bincount(kept, minlength=bins).reshape(self.num_layers, self.num_kv_heads).cpu().numpy() + ) n_kept = int(self.composite_scores_per_layer.numel() * (1 - self.compression_ratio)) kept = self.composite_scores_per_layer.reshape(-1).topk(n_kept).indices // self.context_len From 31b4a0ef281e409d6b80d40435b04c71c80948b1 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 8 Apr 2026 15:02:39 +0000 Subject: [PATCH 3/8] Update Makefil --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index c69b3eee..6cf0dd7d 100644 --- a/Makefile +++ b/Makefile @@ -43,7 +43,7 @@ reports: test: reports $(UV) pip install flash-attn --no-build-isolation --find-links https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/expanded_assets/v0.7.12 PYTHONPATH=. \ - $(UV) run pytest \ + $(UV) run --no-sync pytest \ --cov-report xml:reports/coverage.xml \ --cov=kvpress/ \ --junitxml=./reports/junit.xml \ From 8b5325ba5375c2e8ef6288580f39f3e477ac5296 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 8 Apr 2026 15:29:54 +0000 Subject: [PATCH 4/8] Update FA wheels --- .github/workflows/test.yml | 1 - Makefile | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 33470022..ad1e849c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -23,7 +23,6 @@ jobs: - name: Install dependencies run: | uv sync --all-groups - uv pip install torch==2.10 - run: make test env: HF_TOKEN: ${{ secrets.HF_TOKEN }} diff --git a/Makefile b/Makefile index 6cf0dd7d..89939a97 100644 --- a/Makefile +++ b/Makefile @@ -41,7 +41,7 @@ reports: .PHONY: test test: reports - $(UV) pip install flash-attn --no-build-isolation --find-links https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/expanded_assets/v0.7.12 + $(UV) pip install flash-attn --no-build-isolation --find-links https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/expanded_assets/v0.9.4 PYTHONPATH=. \ $(UV) run --no-sync pytest \ --cov-report xml:reports/coverage.xml \ From 298240167643099007e2fa9556827a45b0e77e15 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 8 Apr 2026 16:08:29 +0000 Subject: [PATCH 5/8] Skip memory-hungry presses in RULER integration test on L4 GPU CompactorPress, LeverageScorePress, and NonCausalAttnPress OOM when run with Qwen3-4B on the L4 GPU (23GB), corrupting CUDA state and cascading into 31 test failures. Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/integration/test_ruler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/integration/test_ruler.py b/tests/integration/test_ruler.py index f7bc62bb..ffd76c76 100644 --- a/tests/integration/test_ruler.py +++ b/tests/integration/test_ruler.py @@ -7,7 +7,7 @@ from transformers import DynamicCache, QuantizedCache from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available -from kvpress import QFilterPress +from kvpress import CompactorPress, LeverageScorePress, NonCausalAttnPress, QFilterPress from tests.default_presses import default_presses from tests.fixtures import kv_press_llama3_2_flash_attn_pipeline, kv_press_qwen3_flash_attn_pipeline # noqa: F401 @@ -58,6 +58,8 @@ def test_ruler_is_correct( if isinstance(press, QFilterPress): # QFilterPress doesn't support Qwen3 4B. Will be tested in the next test class. return + if isinstance(press, (CompactorPress, LeverageScorePress, NonCausalAttnPress)): + pytest.skip("Skipped: these presses OOM on L4 GPU with Qwen3-4B") else: pred_answer = kv_press_qwen3_flash_attn_pipeline(context, question=question, press=press, cache=cache)[ "answer" From 339014b7803641b45ac62e690d76d59a3e114c04 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Thu, 9 Apr 2026 08:54:04 +0000 Subject: [PATCH 6/8] Revert "Skip memory-hungry presses in RULER integration test on L4 GPU" This reverts commit 298240167643099007e2fa9556827a45b0e77e15. --- tests/integration/test_ruler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/integration/test_ruler.py b/tests/integration/test_ruler.py index ffd76c76..f7bc62bb 100644 --- a/tests/integration/test_ruler.py +++ b/tests/integration/test_ruler.py @@ -7,7 +7,7 @@ from transformers import DynamicCache, QuantizedCache from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available -from kvpress import CompactorPress, LeverageScorePress, NonCausalAttnPress, QFilterPress +from kvpress import QFilterPress from tests.default_presses import default_presses from tests.fixtures import kv_press_llama3_2_flash_attn_pipeline, kv_press_qwen3_flash_attn_pipeline # noqa: F401 @@ -58,8 +58,6 @@ def test_ruler_is_correct( if isinstance(press, QFilterPress): # QFilterPress doesn't support Qwen3 4B. Will be tested in the next test class. return - if isinstance(press, (CompactorPress, LeverageScorePress, NonCausalAttnPress)): - pytest.skip("Skipped: these presses OOM on L4 GPU with Qwen3-4B") else: pred_answer = kv_press_qwen3_flash_attn_pipeline(context, question=question, press=press, cache=cache)[ "answer" From c27fe1c2e0db702af9bdf7eb62b56e3f59649eb5 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Thu, 9 Apr 2026 09:13:37 +0000 Subject: [PATCH 7/8] Clean memory Signed-off-by: SimJeg --- tests/fixtures.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/tests/fixtures.py b/tests/fixtures.py index 4462e46a..b58e2624 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 +import gc + import pytest import torch from transformers import AutoModelForCausalLM, pipeline @@ -80,7 +82,10 @@ def kv_press_llama3_1_flash_attn_pipeline(): device=device, model_kwargs={"attn_implementation": attn_implementation, "dtype": torch.bfloat16}, ) - return pipe + yield pipe + del pipe + gc.collect() + torch.cuda.empty_cache() @pytest.fixture(scope="class") @@ -94,7 +99,10 @@ def kv_press_llama3_2_flash_attn_pipeline(): device=device, model_kwargs={"attn_implementation": attn_implementation, "dtype": torch.bfloat16}, ) - return pipe + yield pipe + del pipe + gc.collect() + torch.cuda.empty_cache() @pytest.fixture(scope="class") @@ -108,4 +116,7 @@ def kv_press_qwen3_flash_attn_pipeline(): device=device, model_kwargs={"attn_implementation": attn_implementation, "dtype": torch.bfloat16}, ) - return pipe + yield pipe + del pipe + gc.collect() + torch.cuda.empty_cache() From f04313cc2edde3162478de74e0c21a57eeea9b0d Mon Sep 17 00:00:00 2001 From: SimJeg Date: Thu, 9 Apr 2026 09:24:40 +0000 Subject: [PATCH 8/8] Remove gc collect Signed-off-by: SimJeg --- pyproject.toml | 2 +- tests/fixtures.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1a555d23..679fc57b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "kvpress" -version = "0.5.2" +version = "0.5.3" description = "Efficiently compress the KV cache of any pretrained transformer" authors = [ { name = "Simon Jegou" }, diff --git a/tests/fixtures.py b/tests/fixtures.py index b58e2624..35364327 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -2,8 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 -import gc - import pytest import torch from transformers import AutoModelForCausalLM, pipeline @@ -84,7 +82,6 @@ def kv_press_llama3_1_flash_attn_pipeline(): ) yield pipe del pipe - gc.collect() torch.cuda.empty_cache() @@ -101,7 +98,6 @@ def kv_press_llama3_2_flash_attn_pipeline(): ) yield pipe del pipe - gc.collect() torch.cuda.empty_cache() @@ -118,5 +114,4 @@ def kv_press_qwen3_flash_attn_pipeline(): ) yield pipe del pipe - gc.collect() torch.cuda.empty_cache()