diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..1aaedc83 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,43 @@ +# AGENTS.md + +## Project Overview + +- `kvpress` is a Python library for KV cache compression using 🤗 transformers. Read `README.md` for full project context. +- Philosophy: keep one place to compare many KV cache compression methods, make evaluation easy, and favor readability over raw speed. +- Core package code lives in `kvpress/`. +- Compression methods are implemented as "presses" in `kvpress/presses/`. +- Evaluation tooling and benchmark datasets live in `evaluation/`. +- Tests live in `tests/`. + +## Environment Setup + +- Package manager: `uv`. Install: `uv sync`. Activate: `source .venv/bin/activate`. + +## Key Entry Points + +- `KVPressTextGenerationPipeline` in `kvpress/pipeline.py` is the primary user-facing API for applying a press during generation. +- `kvpress/__init__.py`: lists all available presses. +- All presses are `@dataclass` classes inheriting from `BasePress` (`kvpress/presses/base_press.py`), and many presses inherit from `ScorerPress` (`kvpress/presses/scorer_press.py`) for score-based pruning. +- Read `BasePress` and `ScorerPress` implementations to understand the press architecture and hook mechanism. + +## Style + +- `make format` (isort + black), `make style` (flake8, mypy, SPDX header check). +- All Python files **must** have SPDX headers: +```python +# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +``` + +## Adding or Modifying a Press + +1. Create `kvpress/presses/my_press.py` as a `@dataclass` inheriting from `BasePress` (or `ScorerPress` if the press is score-based). +2. Export it in `kvpress/__init__.py` (add both the import and the `__all__` entry). +3. Add tests in `tests/default_presses.py` (shared parametrized matrix) and/or `tests/presses/` (press-specific tests). Check existing examples to decide. +4. If evaluation support is needed, add a pre-configured instance to `PRESS_REGISTRY` in `evaluation/evaluate_registry.py`. +5. Update `README.md` with press description, link to paper, and source link. +6. Run `make style` and test only new/modified tests. + +## Commits + +- Sign commits with DCO (`git commit -s`) as required by `CONTRIBUTING.md`. diff --git a/README.md b/README.md index 6dce0964..1e950e98 100644 --- a/README.md +++ b/README.md @@ -18,35 +18,22 @@ Deploying long-context LLMs is costly due to the linear growth of the key-value pip install kvpress ``` -For a local installation with all dev dependencies, use uv: +For a local installation, use [uv](https://docs.astral.sh/uv/): ```bash git clone https://github.com/NVIDIA/kvpress.git cd kvpress -uv sync --all-groups +uv sync ``` -
-Advanced installation settings - -To install optional packages, you can use [uv](https://docs.astral.sh/uv/). -To install with flash attention, just run: +To install with all optional dependencies, run: ```bash git clone https://github.com/NVIDIA/kvpress.git cd kvpress -uv sync --extra flash-attn +uv sync --extra eval --extra flash-attn ``` -To install with dependencies for evaluation, run - -```bash -git clone https://github.com/NVIDIA/kvpress.git -cd kvpress -uv sync --extra eval -``` -
- ## Usage KVPress provides a set of "presses" that compress the KV cache during the prefilling-phase. Each press is associated with a `compression_ratio` attribute that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline`. It is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported and handles chat templates and tokenization for you: diff --git a/evaluation/benchmarks/longbench/calculate_metrics.py b/evaluation/benchmarks/longbench/calculate_metrics.py index 0d41a051..d92a027a 100644 --- a/evaluation/benchmarks/longbench/calculate_metrics.py +++ b/evaluation/benchmarks/longbench/calculate_metrics.py @@ -5,19 +5,11 @@ import string from collections import Counter +import jieba import numpy as np +from fuzzywuzzy import fuzz from rouge import Rouge -try: - import jieba - from fuzzywuzzy import fuzz -except ImportError as e: - missing_module = str(e).split()[-1].strip("'") # Extract missing module name - print( - f"Module '{missing_module}' not found. \ - If test Longbench, please install it using 'pip install {missing_module}'" - ) - def calculate_metrics(df): predictions = df["predicted_answer"].tolist() diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index 2267432f..b7b59bac 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -23,12 +23,12 @@ from kvpress import ( ComposedPress, DecodingPress, + DMSPress, DuoAttentionPress, FinchPress, ObservedAttentionPress, ScorerPress, ThinKPress, - DMSPress, ) logger = logging.getLogger(__name__) diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index 64d1f16d..bec1c0e9 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -13,10 +13,10 @@ from kvpress.presses.base_press import BasePress from kvpress.presses.decoding_press import DecodingPress +from kvpress.presses.dms_press import DMSPress from kvpress.presses.finch_press import FinchPress from kvpress.presses.key_rerotation_press import KeyRerotationPress from kvpress.presses.prefill_decoding_press import PrefillDecodingPress -from kvpress.presses.dms_press import DMSPress logger = logging.getLogger(__name__) diff --git a/kvpress/presses/fastkvzip_press.py b/kvpress/presses/fastkvzip_press.py index eb929eaf..1c57bbeb 100644 --- a/kvpress/presses/fastkvzip_press.py +++ b/kvpress/presses/fastkvzip_press.py @@ -12,7 +12,7 @@ import torch from huggingface_hub import hf_hub_download from torch import nn -from transformers import AutoConfig, PreTrainedModel, Gemma3ForConditionalGeneration +from transformers import AutoConfig, Gemma3ForConditionalGeneration, PreTrainedModel from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress @@ -24,6 +24,7 @@ class FastKVzipGate(nn.Module): """ Fast KVzip gate architecture (https://arxiv.org/abs/2601.17668). """ + def __init__( self, index: int, @@ -79,7 +80,7 @@ def extra_repr(self): def load_fastkvzip(model_name: str = "Qwen/Qwen3-8B", device: str = "cuda"): - """ Load trained gate weights """ + """Load trained gate weights""" if not model_name: raise AssertionError("Model_name is empty. Please check load_gate.") state_dict, gate_id = get_gate_weight(model_name) @@ -105,7 +106,7 @@ def load_fastkvzip(model_name: str = "Qwen/Qwen3-8B", device: str = "cuda"): def get_gate_id(model_name: str): - """ Get the gate id from model names """ + """Get the gate id from model names""" config = AutoConfig.from_pretrained(model_name) if hasattr(config, "text_config"): config = config.text_config @@ -118,7 +119,7 @@ def get_gate_id(model_name: str): def get_gate_weight(model_name: str): - """ Load trained gate weights from HuggingFace """ + """Load trained gate weights from HuggingFace""" gate_id = get_gate_id(model_name) file_path = hf_hub_download(repo_id="Jang-Hyun/Fast-KVzip", filename=gate_id, repo_type="model") diff --git a/kvpress/presses/kvzap_press.py b/kvpress/presses/kvzap_press.py index f64bb4b5..d2c937ad 100644 --- a/kvpress/presses/kvzap_press.py +++ b/kvpress/presses/kvzap_press.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass, field -from typing import Optional, Literal +from typing import Literal, Optional import torch import torch.nn as nn diff --git a/kvzap/data.py b/kvzap/data.py index 99373ff8..beab3a81 100644 --- a/kvzap/data.py +++ b/kvzap/data.py @@ -15,8 +15,8 @@ from datasets import load_dataset from tqdm.auto import tqdm from transformers import PreTrainedModel, PreTrainedTokenizerBase -from transformers.models.llama.modeling_llama import repeat_kv from transformers.integrations.finegrained_fp8 import FP8Linear +from transformers.models.llama.modeling_llama import repeat_kv def load_nemotron_dataset( diff --git a/kvzap/evaluate_aime.py b/kvzap/evaluate_aime.py index 2c224440..8625b87b 100644 --- a/kvzap/evaluate_aime.py +++ b/kvzap/evaluate_aime.py @@ -3,14 +3,14 @@ import json import uuid -from tqdm import tqdm -from pathlib import Path from contextlib import nullcontext +from pathlib import Path -from transformers import AutoTokenizer, AutoModelForCausalLM from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer -from kvpress import KVzapPress, DMSPress +from kvpress import DMSPress, KVzapPress def calculate_metrics(df): diff --git a/kvzap/train.py b/kvzap/train.py index dc6831c4..52e13d26 100644 --- a/kvzap/train.py +++ b/kvzap/train.py @@ -9,23 +9,20 @@ KVzapPress to compress the KV cache during inference. """ -import numpy as np from pathlib import Path -from tqdm.auto import tqdm - +import numpy as np import torch -from torch import nn - +from sklearn.linear_model import Ridge from skorch import NeuralNetRegressor -from skorch.callbacks import LRScheduler, GradientNormClipping +from skorch.callbacks import GradientNormClipping, LRScheduler from skorch.dataset import ValidSplit -from sklearn.linear_model import Ridge - +from torch import nn +from tqdm.auto import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config -from kvpress.presses.kvzap_press import KVzapModel, KVzapConfig -from kvzap.data import load_nemotron_dataset, KVzapDataCollector +from kvpress.presses.kvzap_press import KVzapConfig, KVzapModel +from kvzap.data import KVzapDataCollector, load_nemotron_dataset def train_mlp( @@ -188,11 +185,12 @@ def train( print(f"Loading model {model_name} and tokenizer") quantization_config = FineGrainedFP8Config() if fp8 else None model = AutoModelForCausalLM.from_pretrained( - model_name, dtype="auto", - device_map="auto", - attn_implementation="eager", - quantization_config=quantization_config, - ) + model_name, + dtype="auto", + device_map="auto", + attn_implementation="eager", + quantization_config=quantization_config, + ) tokenizer = AutoTokenizer.from_pretrained(model_name) # Load dataset diff --git a/pyproject.toml b/pyproject.toml index 3f68ad88..625271b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "kvpress" -version = "0.5.0" +version = "0.5.1" description = "Efficiently compress the KV cache of any pretrained transformer" authors = [ { name = "Simon Jegou" }, @@ -14,8 +14,6 @@ dependencies = [ "numpy>=2.0.0,<3", "torch>=2.3.1,<3", "transformers>=5.0.0", - "sentencepiece>=0.2.0,<0.3", - "protobuf>=5.27.2,<6", "datasets>=2.21.0,<3", "pandas>=2.2.2,<3", "accelerate>=1.0.0,<2", @@ -31,6 +29,8 @@ eval = [ "tqdm>=4.66.4,<5", "scipy>=1.13.1,<2", "bert-score>=0.3.13,<0.4", + "jieba>=0.42.1", + "fuzzywuzzy>=0.18.0", ] flash-attn = [ "flash-attn" @@ -51,6 +51,8 @@ dev = [ "bs4>=0.0.2,<0.0.3", "nvitop>=1.3.2,<2", "matplotlib>=3.9.0,<4", + "sentencepiece>=0.2.0,<0.3", + "protobuf>=5.27.2,<6", ] diff --git a/tests/default_presses.py b/tests/default_presses.py index 5b1278ed..413f1ea6 100644 --- a/tests/default_presses.py +++ b/tests/default_presses.py @@ -26,8 +26,8 @@ ThinKPress, TOVAPress, ) -from kvpress.presses.kvzap_press import KVzapConfig, KVzapModel from kvpress.presses.fastkvzip_press import FastKVzipGate +from kvpress.presses.kvzap_press import KVzapConfig, KVzapModel class TestDuoAttentionPress(DuoAttentionPress): diff --git a/tests/fixtures.py b/tests/fixtures.py index a95578ea..4462e46a 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -20,9 +20,7 @@ def unit_test_model(): @pytest.fixture(scope="session") def unit_test_model_output_attention(): - model = AutoModelForCausalLM.from_pretrained( - "MaxJeblick/llama2-0b-unit-test", attn_implementation="eager" - ).eval() + model = AutoModelForCausalLM.from_pretrained("MaxJeblick/llama2-0b-unit-test", attn_implementation="eager").eval() return model.to(get_device()) diff --git a/tests/presses/test_head_compression.py b/tests/presses/test_head_compression.py index f9138f86..ec145db0 100644 --- a/tests/presses/test_head_compression.py +++ b/tests/presses/test_head_compression.py @@ -4,8 +4,8 @@ import torch from transformers import DynamicCache -from kvpress import AdaKVPress, CriticalAdaKVPress, KnormPress, KVzipPress, RandomPress, DMSPress -from tests.fixtures import unit_test_model, kv_press_unit_test_pipeline # noqa: F401 +from kvpress import AdaKVPress, CriticalAdaKVPress, DMSPress, KnormPress, KVzipPress, RandomPress +from tests.fixtures import kv_press_unit_test_pipeline, unit_test_model # noqa: F401 def compute_masked_percentage(module, batch_size, num_key_value_heads, seq_len): diff --git a/tests/test_decoding_compression.py b/tests/test_decoding_compression.py index 137c6b4e..dd30c7f4 100644 --- a/tests/test_decoding_compression.py +++ b/tests/test_decoding_compression.py @@ -12,14 +12,14 @@ from kvpress import ( CompactorPress, + DecodingPress, + KnormPress, + KVzapPress, LeverageScorePress, NonCausalAttnPress, + PrefillDecodingPress, PyramidKVPress, ScorerPress, - DecodingPress, - KnormPress, - PrefillDecodingPress, - KVzapPress, ) from tests.default_presses import default_presses