Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# AGENTS.md

## Project Overview

- `kvpress` is a Python library for KV cache compression using 🤗 transformers. Read `README.md` for full project context.
- Philosophy: keep one place to compare many KV cache compression methods, make evaluation easy, and favor readability over raw speed.
- Core package code lives in `kvpress/`.
- Compression methods are implemented as "presses" in `kvpress/presses/`.
- Evaluation tooling and benchmark datasets live in `evaluation/`.
- Tests live in `tests/`.

## Environment Setup

- Package manager: `uv`. Install: `uv sync`. Activate: `source .venv/bin/activate`.

## Key Entry Points

- `KVPressTextGenerationPipeline` in `kvpress/pipeline.py` is the primary user-facing API for applying a press during generation.
- `kvpress/__init__.py`: lists all available presses.
- All presses are `@dataclass` classes inheriting from `BasePress` (`kvpress/presses/base_press.py`), and many presses inherit from `ScorerPress` (`kvpress/presses/scorer_press.py`) for score-based pruning.
- Read `BasePress` and `ScorerPress` implementations to understand the press architecture and hook mechanism.

## Style

- `make format` (isort + black), `make style` (flake8, mypy, SPDX header check).
- All Python files **must** have SPDX headers:
```python
# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
```

## Adding or Modifying a Press

1. Create `kvpress/presses/my_press.py` as a `@dataclass` inheriting from `BasePress` (or `ScorerPress` if the press is score-based).
2. Export it in `kvpress/__init__.py` (add both the import and the `__all__` entry).
3. Add tests in `tests/default_presses.py` (shared parametrized matrix) and/or `tests/presses/` (press-specific tests). Check existing examples to decide.
4. If evaluation support is needed, add a pre-configured instance to `PRESS_REGISTRY` in `evaluation/evaluate_registry.py`.
5. Update `README.md` with press description, link to paper, and source link.
6. Run `make style` and test only new/modified tests.

## Commits

- Sign commits with DCO (`git commit -s`) as required by `CONTRIBUTING.md`.
21 changes: 4 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,35 +18,22 @@ Deploying long-context LLMs is costly due to the linear growth of the key-value
pip install kvpress
```

For a local installation with all dev dependencies, use uv:
For a local installation, use [uv](https://docs.astral.sh/uv/):

```bash
git clone https://github.com/NVIDIA/kvpress.git
cd kvpress
uv sync --all-groups
uv sync
```
<details><summary>
Advanced installation settings
</summary>

To install optional packages, you can use [uv](https://docs.astral.sh/uv/).
To install with flash attention, just run:
To install with all optional dependencies, run:

```bash
git clone https://github.com/NVIDIA/kvpress.git
cd kvpress
uv sync --extra flash-attn
uv sync --extra eval --extra flash-attn
```

To install with dependencies for evaluation, run

```bash
git clone https://github.com/NVIDIA/kvpress.git
cd kvpress
uv sync --extra eval
```
</details>

## Usage

KVPress provides a set of "presses" that compress the KV cache during the prefilling-phase. Each press is associated with a `compression_ratio` attribute that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline`. It is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported and handles chat templates and tokenization for you:
Expand Down
12 changes: 2 additions & 10 deletions evaluation/benchmarks/longbench/calculate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,11 @@
import string
from collections import Counter

import jieba
import numpy as np
from fuzzywuzzy import fuzz
from rouge import Rouge

try:
import jieba
from fuzzywuzzy import fuzz
except ImportError as e:
missing_module = str(e).split()[-1].strip("'") # Extract missing module name
print(
f"Module '{missing_module}' not found. \
If test Longbench, please install it using 'pip install {missing_module}'"
)


def calculate_metrics(df):
predictions = df["predicted_answer"].tolist()
Expand Down
2 changes: 1 addition & 1 deletion evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from kvpress import (
ComposedPress,
DecodingPress,
DMSPress,
DuoAttentionPress,
FinchPress,
ObservedAttentionPress,
ScorerPress,
ThinKPress,
DMSPress,
)

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion kvpress/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@

from kvpress.presses.base_press import BasePress
from kvpress.presses.decoding_press import DecodingPress
from kvpress.presses.dms_press import DMSPress
from kvpress.presses.finch_press import FinchPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.prefill_decoding_press import PrefillDecodingPress
from kvpress.presses.dms_press import DMSPress

logger = logging.getLogger(__name__)

Expand Down
9 changes: 5 additions & 4 deletions kvpress/presses/fastkvzip_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
from huggingface_hub import hf_hub_download
from torch import nn
from transformers import AutoConfig, PreTrainedModel, Gemma3ForConditionalGeneration
from transformers import AutoConfig, Gemma3ForConditionalGeneration, PreTrainedModel
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm

from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress
Expand All @@ -24,6 +24,7 @@ class FastKVzipGate(nn.Module):
"""
Fast KVzip gate architecture (https://arxiv.org/abs/2601.17668).
"""

def __init__(
self,
index: int,
Expand Down Expand Up @@ -79,7 +80,7 @@ def extra_repr(self):


def load_fastkvzip(model_name: str = "Qwen/Qwen3-8B", device: str = "cuda"):
""" Load trained gate weights """
"""Load trained gate weights"""
if not model_name:
raise AssertionError("Model_name is empty. Please check load_gate.")
state_dict, gate_id = get_gate_weight(model_name)
Expand All @@ -105,7 +106,7 @@ def load_fastkvzip(model_name: str = "Qwen/Qwen3-8B", device: str = "cuda"):


def get_gate_id(model_name: str):
""" Get the gate id from model names """
"""Get the gate id from model names"""
config = AutoConfig.from_pretrained(model_name)
if hasattr(config, "text_config"):
config = config.text_config
Expand All @@ -118,7 +119,7 @@ def get_gate_id(model_name: str):


def get_gate_weight(model_name: str):
""" Load trained gate weights from HuggingFace """
"""Load trained gate weights from HuggingFace"""
gate_id = get_gate_id(model_name)
file_path = hf_hub_download(repo_id="Jang-Hyun/Fast-KVzip", filename=gate_id, repo_type="model")

Expand Down
2 changes: 1 addition & 1 deletion kvpress/presses/kvzap_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass, field
from typing import Optional, Literal
from typing import Literal, Optional

import torch
import torch.nn as nn
Expand Down
2 changes: 1 addition & 1 deletion kvzap/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from datasets import load_dataset
from tqdm.auto import tqdm
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.models.llama.modeling_llama import repeat_kv
from transformers.integrations.finegrained_fp8 import FP8Linear
from transformers.models.llama.modeling_llama import repeat_kv


def load_nemotron_dataset(
Expand Down
8 changes: 4 additions & 4 deletions kvzap/evaluate_aime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

import json
import uuid
from tqdm import tqdm
from pathlib import Path
from contextlib import nullcontext
from pathlib import Path

from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from kvpress import KVzapPress, DMSPress
from kvpress import DMSPress, KVzapPress


def calculate_metrics(df):
Expand Down
28 changes: 13 additions & 15 deletions kvzap/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,20 @@
KVzapPress to compress the KV cache during inference.
"""

import numpy as np
from pathlib import Path

from tqdm.auto import tqdm

import numpy as np
import torch
from torch import nn

from sklearn.linear_model import Ridge
from skorch import NeuralNetRegressor
from skorch.callbacks import LRScheduler, GradientNormClipping
from skorch.callbacks import GradientNormClipping, LRScheduler
from skorch.dataset import ValidSplit
from sklearn.linear_model import Ridge

from torch import nn
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config

from kvpress.presses.kvzap_press import KVzapModel, KVzapConfig
from kvzap.data import load_nemotron_dataset, KVzapDataCollector
from kvpress.presses.kvzap_press import KVzapConfig, KVzapModel
from kvzap.data import KVzapDataCollector, load_nemotron_dataset


def train_mlp(
Expand Down Expand Up @@ -188,11 +185,12 @@ def train(
print(f"Loading model {model_name} and tokenizer")
quantization_config = FineGrainedFP8Config() if fp8 else None
model = AutoModelForCausalLM.from_pretrained(
model_name, dtype="auto",
device_map="auto",
attn_implementation="eager",
quantization_config=quantization_config,
)
model_name,
dtype="auto",
device_map="auto",
attn_implementation="eager",
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load dataset
Expand Down
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "kvpress"
version = "0.5.0"
version = "0.5.1"
description = "Efficiently compress the KV cache of any pretrained transformer"
authors = [
{ name = "Simon Jegou" },
Expand All @@ -14,8 +14,6 @@ dependencies = [
"numpy>=2.0.0,<3",
"torch>=2.3.1,<3",
"transformers>=5.0.0",
"sentencepiece>=0.2.0,<0.3",
"protobuf>=5.27.2,<6",
"datasets>=2.21.0,<3",
"pandas>=2.2.2,<3",
"accelerate>=1.0.0,<2",
Expand All @@ -31,6 +29,8 @@ eval = [
"tqdm>=4.66.4,<5",
"scipy>=1.13.1,<2",
"bert-score>=0.3.13,<0.4",
"jieba>=0.42.1",
"fuzzywuzzy>=0.18.0",
]
flash-attn = [
"flash-attn"
Expand All @@ -51,6 +51,8 @@ dev = [
"bs4>=0.0.2,<0.0.3",
"nvitop>=1.3.2,<2",
"matplotlib>=3.9.0,<4",
"sentencepiece>=0.2.0,<0.3",
"protobuf>=5.27.2,<6",
]


Expand Down
2 changes: 1 addition & 1 deletion tests/default_presses.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
ThinKPress,
TOVAPress,
)
from kvpress.presses.kvzap_press import KVzapConfig, KVzapModel
from kvpress.presses.fastkvzip_press import FastKVzipGate
from kvpress.presses.kvzap_press import KVzapConfig, KVzapModel


class TestDuoAttentionPress(DuoAttentionPress):
Expand Down
4 changes: 1 addition & 3 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ def unit_test_model():

@pytest.fixture(scope="session")
def unit_test_model_output_attention():
model = AutoModelForCausalLM.from_pretrained(
"MaxJeblick/llama2-0b-unit-test", attn_implementation="eager"
).eval()
model = AutoModelForCausalLM.from_pretrained("MaxJeblick/llama2-0b-unit-test", attn_implementation="eager").eval()
return model.to(get_device())


Expand Down
4 changes: 2 additions & 2 deletions tests/presses/test_head_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch
from transformers import DynamicCache

from kvpress import AdaKVPress, CriticalAdaKVPress, KnormPress, KVzipPress, RandomPress, DMSPress
from tests.fixtures import unit_test_model, kv_press_unit_test_pipeline # noqa: F401
from kvpress import AdaKVPress, CriticalAdaKVPress, DMSPress, KnormPress, KVzipPress, RandomPress
from tests.fixtures import kv_press_unit_test_pipeline, unit_test_model # noqa: F401


def compute_masked_percentage(module, batch_size, num_key_value_heads, seq_len):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_decoding_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@

from kvpress import (
CompactorPress,
DecodingPress,
KnormPress,
KVzapPress,
LeverageScorePress,
NonCausalAttnPress,
PrefillDecodingPress,
PyramidKVPress,
ScorerPress,
DecodingPress,
KnormPress,
PrefillDecodingPress,
KVzapPress,
)
from tests.default_presses import default_presses

Expand Down