Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions notebooks/speed_and_memory.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
"outputs": [],
"source": [
"# Run these below to be able to use quantized cache \n",
"# See also: https://huggingface.co/blog/kv-cache-quantization#how-to-use-quantized-kv-cache-in-%F0%9F%A4%97-transformers\n",
"#!pip install git+https://github.com/huggingface/transformers.git --upgrade\n",
"#!pip install -U optimum-quanto"
]
},
Expand Down
17 changes: 16 additions & 1 deletion tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,25 @@ def kv_press_danube_pipeline():
)


@pytest.fixture(scope="session")
def kv_press_llama3_2_flash_attn_pipeline():
device = "cuda:0"
ckpt = "meta-llama/Llama-3.2-1B-Instruct"
attn_implementation = "flash_attention_2"
pipe = pipeline(
"kv-press-text-generation",
model=ckpt,
device=device,
torch_dtype="auto",
model_kwargs={"attn_implementation": attn_implementation},
)
return pipe


@pytest.fixture(scope="session")
def kv_press_llama3_1_flash_attn_pipeline():
device = "cuda:0"
ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct"
ckpt = "meta-llama/Llama-3.1-8B-Instruct"
attn_implementation = "flash_attention_2"
pipe = pipeline(
"kv-press-text-generation",
Expand Down
9 changes: 8 additions & 1 deletion tests/integration/test_ruler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import datasets
import pytest
import torch
from transformers import DynamicCache, QuantizedCacheConfig, QuantoQuantizedCache
from transformers import DynamicCache, QuantoQuantizedCache, QuantizedCacheConfig
from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available

from tests.default_presses import default_presses
Expand All @@ -26,6 +26,13 @@ def test_ruler_is_correct(kv_press_llama3_1_flash_attn_pipeline, df_ruler, press
cls = press_dict["cls"]
kwargs = press_dict["kwargs"][0]
press = cls(**kwargs)
if not hasattr(cls, "compression_ratio"):
pytest.skip(reason="Press does not support compression_ratio")
# set compression ratio to a small value for testing
try:
press.compression_ratio = 0.1
except AttributeError:
pytest.skip(reason="Press does not support setting compression_ratio")

if cache == "dynamic":
cache = DynamicCache()
Expand Down