diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 33470022..ad1e849c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -23,7 +23,6 @@ jobs: - name: Install dependencies run: | uv sync --all-groups - uv pip install torch==2.10 - run: make test env: HF_TOKEN: ${{ secrets.HF_TOKEN }} diff --git a/Makefile b/Makefile index c69b3eee..89939a97 100644 --- a/Makefile +++ b/Makefile @@ -41,9 +41,9 @@ reports: .PHONY: test test: reports - $(UV) pip install flash-attn --no-build-isolation --find-links https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/expanded_assets/v0.7.12 + $(UV) pip install flash-attn --no-build-isolation --find-links https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/expanded_assets/v0.9.4 PYTHONPATH=. \ - $(UV) run pytest \ + $(UV) run --no-sync pytest \ --cov-report xml:reports/coverage.xml \ --cov=kvpress/ \ --junitxml=./reports/junit.xml \ diff --git a/kvpress/presses/kvcompose_press.py b/kvpress/presses/kvcompose_press.py index 642c90fe..601d763a 100644 --- a/kvpress/presses/kvcompose_press.py +++ b/kvpress/presses/kvcompose_press.py @@ -273,11 +273,16 @@ def compute_important_per_layer(self): """ self.compute_composite_scores() - threshold_head = self.composite_scores_per_head.quantile(self.compression_ratio) - self.important_per_head = (self.composite_scores_per_head >= threshold_head).sum(dim=-1).cpu().numpy() + n_kept = int(self.composite_scores_per_head.numel() * (1 - self.compression_ratio)) + kept = self.composite_scores_per_head.reshape(-1).topk(n_kept).indices // self.context_len + bins = self.num_layers * self.num_kv_heads + self.important_per_head = ( + torch.bincount(kept, minlength=bins).reshape(self.num_layers, self.num_kv_heads).cpu().numpy() + ) - threshold_layer = self.composite_scores_per_layer.quantile(self.compression_ratio) - self.important_per_layer = (self.composite_scores_per_layer >= threshold_layer).sum(dim=-1).cpu().numpy() + n_kept = int(self.composite_scores_per_layer.numel() * (1 - self.compression_ratio)) + kept = self.composite_scores_per_layer.reshape(-1).topk(n_kept).indices // self.context_len + self.important_per_layer = torch.bincount(kept, minlength=self.num_layers).cpu().numpy() def prepare_important_masks(self): """ diff --git a/pyproject.toml b/pyproject.toml index de1fe1bc..679fc57b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "kvpress" -version = "0.5.2" +version = "0.5.3" description = "Efficiently compress the KV cache of any pretrained transformer" authors = [ { name = "Simon Jegou" }, @@ -13,7 +13,7 @@ readme = "README.md" dependencies = [ "numpy>=2.0.0,<3", "torch>=2.3.1,<3", - "transformers>=4.56.0", + "transformers>=4.56.0,<5.3", "datasets>=2.21.0", "pandas>=2.2.2,<3", "accelerate>=1.0.0,<2", diff --git a/tests/fixtures.py b/tests/fixtures.py index 4462e46a..35364327 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -80,7 +80,9 @@ def kv_press_llama3_1_flash_attn_pipeline(): device=device, model_kwargs={"attn_implementation": attn_implementation, "dtype": torch.bfloat16}, ) - return pipe + yield pipe + del pipe + torch.cuda.empty_cache() @pytest.fixture(scope="class") @@ -94,7 +96,9 @@ def kv_press_llama3_2_flash_attn_pipeline(): device=device, model_kwargs={"attn_implementation": attn_implementation, "dtype": torch.bfloat16}, ) - return pipe + yield pipe + del pipe + torch.cuda.empty_cache() @pytest.fixture(scope="class") @@ -108,4 +112,6 @@ def kv_press_qwen3_flash_attn_pipeline(): device=device, model_kwargs={"attn_implementation": attn_implementation, "dtype": torch.bfloat16}, ) - return pipe + yield pipe + del pipe + torch.cuda.empty_cache()