diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index c8f56d4e..7926ed2b 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -34,6 +34,7 @@ StreamingLLMPress, ThinKPress, TOVAPress, + QFilterPress, ) logger = logging.getLogger(__name__) @@ -75,9 +76,10 @@ "think": ThinKPress(), "tova": TOVAPress(), "duo_attention": DuoAttentionPress(), + "duo_attention_on_the_fly": DuoAttentionPress(on_the_fly_scoring=True), "chunkkv": ChunkKVPress(press=SnapKVPress(), chunk_length=20), + "qfilter": QFilterPress(), "snap_think": ComposedPress([SnapKVPress(), ThinKPress()]), - "full_kv": ExpectedAttentionPress(0.0), } diff --git a/kvpress/presses/duo_attention_press.py b/kvpress/presses/duo_attention_press.py index e8966c6c..40abb3d1 100644 --- a/kvpress/presses/duo_attention_press.py +++ b/kvpress/presses/duo_attention_press.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from cachetools import cached, LRUCache # type: ignore[import-untyped] from contextlib import contextmanager from dataclasses import dataclass, field from io import StringIO @@ -8,6 +9,9 @@ import numpy as np import requests # type: ignore[import-untyped] import torch +from datasets import load_dataset +from transformers import AutoTokenizer +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from kvpress.presses.base_press import BasePress @@ -20,6 +24,8 @@ "mistralai/Mistral-7B-Instruct-v0.3": "Mistral-7B-Instruct-v0.3/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10", # noqa: E501 } +cache = LRUCache(maxsize=128) + @dataclass class DuoAttentionPress(BasePress): @@ -29,12 +35,15 @@ class DuoAttentionPress(BasePress): Splits attention heads into two types: - Retrieval heads: use the full KV cache - Streaming heads: use only sink and recent tokens. - - Head classification is based on scores loaded from https://github.com/mit-han-lab/duo-attention/ The higher the head_compression_ratio, the more streaming heads are used. + + Head classification is based on scores. + - If on_the_fly_scoring=False, scores are loaded from https://github.com/mit-han-lab/duo-attention/ + - (experimental) If on_the_fly_scoring=True, scores are computed using duo_attention_on_the_fly """ head_compression_ratio: float = 0.0 + on_the_fly_scoring: bool = False compression_ratio_: float = field(init=False, default=None) recent_size: int = field(init=False, default=None) sink_size: int = field(init=False, default=None) @@ -45,7 +54,10 @@ def __post_init_from_model__(self, model): Initialize sink_size, recent_size, and streaming_mask from a model """ # Load attention pattern from the DuoAttention repo - self.sink_size, self.recent_size, head_scores = self.load_attention_pattern(model) + if self.on_the_fly_scoring: + self.sink_size, self.recent_size, head_scores = 128, 256, duo_attention_on_the_fly(model) + else: + self.sink_size, self.recent_size, head_scores = self.load_attention_pattern(model) # Define retrieval and streaming heads through a binary mask n_pruned = round(head_scores.size * self.head_compression_ratio) @@ -82,6 +94,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): return keys, values @staticmethod + @cached(cache, key=lambda model: model.config.name_or_path) def load_attention_pattern(model): """ Load the attention pattern from the DuoAttention repo @@ -108,3 +121,68 @@ def __call__(self, model): self.__post_init_from_model__(model) with super().__call__(model): yield + + +@cached(cache, key=lambda model, num_samples=50, q_len=500: (model.config.name_or_path, num_samples, q_len)) +def duo_attention_on_the_fly(model, num_samples=50, q_len=500): + """ + New experimental method to quickly compute DuoAttention scores: + - Compute the mean query and key on num_samples random samples from BookSum + - Repeat the mean query and key q_len times and apply RoPE to get (Q, K) + - Compute the attention weights for (Q[-1], K) and compute the "area under the cumulated attention curve" + These scores could also be saved to avoid recomputing them but this method is still experimental + """ + + tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path) + num_heads = model.config.num_attention_heads + num_key_value_heads = model.config.num_key_value_heads + num_key_value_groups = num_heads // num_key_value_heads + + # Load data + dataset = load_dataset("kmfoda/booksum", split="train").to_pandas() + texts = dataset.sample(num_samples, random_state=42)["chapter"].tolist() + + # Initialize variables + position_ids = torch.arange(q_len).unsqueeze(0) + scores = torch.zeros((model.config.num_hidden_layers, num_key_value_heads)) + + # Compute scores + for text in texts: + with torch.no_grad(): + # Compute hidden states + inputs = tokenizer(text, return_tensors="pt").to(model.device) + hidden_states = list(model(**inputs, output_hidden_states=True).hidden_states[:-1]) + + for layer_idx, h in enumerate(hidden_states): + module = model.model.layers[layer_idx] + d = module.self_attn.head_dim + h = module.input_layernorm(h) + + # Mean query + q = module.self_attn.q_proj(h) + q = q.view(1, q.shape[1], -1, d) + q = q.mean(dim=1, keepdim=True) + q = q.repeat(1, q_len, 1, 1).transpose(1, 2) + + # Mean key + k = module.self_attn.k_proj(h) + k = k.view(1, k.shape[1], -1, d) + k = k.mean(dim=1, keepdim=True) + k = k.repeat(1, q_len, 1, 1).transpose(1, 2) + + # Apply RoPE + cos, sin = model.model.rotary_emb(h, position_ids.to(h.device)) + q, k = apply_rotary_pos_emb(q, k, cos, sin) + k = k.repeat_interleave(num_key_value_groups, dim=1) + + # Compute attention weights for the last token + attn_weights = torch.matmul(q[:, :, -1:, :], k.transpose(2, 3)) / (d**0.5) + attn_weights = attn_weights.softmax(dim=-1, dtype=torch.float32).squeeze() + + # Compute score: area under the cumulated attention curve + s = torch.cumsum(attn_weights, dim=1).mean(1) + s = s.view(-1, num_key_value_groups).mean(1) + + # Store the scores + scores[layer_idx] += s.cpu() / num_samples + return scores.numpy() diff --git a/kvpress/presses/qfilter_press.py b/kvpress/presses/qfilter_press.py index e1cdd7eb..552d3b4f 100644 --- a/kvpress/presses/qfilter_press.py +++ b/kvpress/presses/qfilter_press.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from functools import cache from contextlib import contextmanager from dataclasses import dataclass @@ -32,6 +33,7 @@ def __post_init_from_model__(self, model): self.q_filters = self.q_filters.to(model.dtype) @staticmethod + @cache def load_q_filters(model_name): try: return QFilters.from_pretrained(f"nthngdy/{model_name}_qfilt").q_filters diff --git a/pyproject.toml b/pyproject.toml index 56ee5532..3495242b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ rouge = "^1.0.1" bert-score = "^0.3.13" accelerate = "^1.0.0" requests = "^2.32.3" +cachetools = "^5.5.2" [tool.poetry.dev-dependencies] pytest = "^7.0.0"