Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ jobs:
- name: Install dependencies
run: |
uv sync --all-groups
uv pip install torch==2.10
- run: make test
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ reports:

.PHONY: test
test: reports
$(UV) pip install flash-attn --no-build-isolation --find-links https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/expanded_assets/v0.7.12
$(UV) pip install flash-attn --no-build-isolation --find-links https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/expanded_assets/v0.9.4
PYTHONPATH=. \
$(UV) run pytest \
$(UV) run --no-sync pytest \
--cov-report xml:reports/coverage.xml \
--cov=kvpress/ \
--junitxml=./reports/junit.xml \
Expand Down
13 changes: 9 additions & 4 deletions kvpress/presses/kvcompose_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,16 @@ def compute_important_per_layer(self):
"""
self.compute_composite_scores()

threshold_head = self.composite_scores_per_head.quantile(self.compression_ratio)
self.important_per_head = (self.composite_scores_per_head >= threshold_head).sum(dim=-1).cpu().numpy()
n_kept = int(self.composite_scores_per_head.numel() * (1 - self.compression_ratio))
kept = self.composite_scores_per_head.reshape(-1).topk(n_kept).indices // self.context_len
bins = self.num_layers * self.num_kv_heads
self.important_per_head = (
torch.bincount(kept, minlength=bins).reshape(self.num_layers, self.num_kv_heads).cpu().numpy()
)

threshold_layer = self.composite_scores_per_layer.quantile(self.compression_ratio)
self.important_per_layer = (self.composite_scores_per_layer >= threshold_layer).sum(dim=-1).cpu().numpy()
n_kept = int(self.composite_scores_per_layer.numel() * (1 - self.compression_ratio))
kept = self.composite_scores_per_layer.reshape(-1).topk(n_kept).indices // self.context_len
self.important_per_layer = torch.bincount(kept, minlength=self.num_layers).cpu().numpy()

def prepare_important_masks(self):
"""
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "kvpress"
version = "0.5.2"
version = "0.5.3"
description = "Efficiently compress the KV cache of any pretrained transformer"
authors = [
{ name = "Simon Jegou" },
Expand All @@ -13,7 +13,7 @@ readme = "README.md"
dependencies = [
"numpy>=2.0.0,<3",
"torch>=2.3.1,<3",
"transformers>=4.56.0",
"transformers>=4.56.0,<5.3",
"datasets>=2.21.0",
"pandas>=2.2.2,<3",
"accelerate>=1.0.0,<2",
Expand Down
12 changes: 9 additions & 3 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def kv_press_llama3_1_flash_attn_pipeline():
device=device,
model_kwargs={"attn_implementation": attn_implementation, "dtype": torch.bfloat16},
)
return pipe
yield pipe
del pipe
torch.cuda.empty_cache()


@pytest.fixture(scope="class")
Expand All @@ -94,7 +96,9 @@ def kv_press_llama3_2_flash_attn_pipeline():
device=device,
model_kwargs={"attn_implementation": attn_implementation, "dtype": torch.bfloat16},
)
return pipe
yield pipe
del pipe
torch.cuda.empty_cache()


@pytest.fixture(scope="class")
Expand All @@ -108,4 +112,6 @@ def kv_press_qwen3_flash_attn_pipeline():
device=device,
model_kwargs={"attn_implementation": attn_implementation, "dtype": torch.bfloat16},
)
return pipe
yield pipe
del pipe
torch.cuda.empty_cache()
Loading