From 5984ad2a767f30dffbccc2347a094fed6b98e405 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 19 Mar 2025 11:51:18 +0000 Subject: [PATCH 1/6] Add DuoAttention on the fly Signed-off-by: SimJeg --- evaluation/evaluate.py | 1 + kvpress/presses/duo_attention_press.py | 101 ++++++++++++++++++++++--- kvpress/presses/qfilter_press.py | 8 +- 3 files changed, 96 insertions(+), 14 deletions(-) diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index 7ee65d59..8087cb34 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -65,6 +65,7 @@ "think": ThinKPress(), "tova": TOVAPress(), "duo_attention": DuoAttentionPress(), + "duo_attention_on_the_fly": DuoAttentionPress(on_the_fly_scoring=True), "chunkkv": ChunkKVPress(press=SnapKVPress(), chunk_length=20), } diff --git a/kvpress/presses/duo_attention_press.py b/kvpress/presses/duo_attention_press.py index e8966c6c..02713ee7 100644 --- a/kvpress/presses/duo_attention_press.py +++ b/kvpress/presses/duo_attention_press.py @@ -1,13 +1,18 @@ # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from time import time from contextlib import contextmanager from dataclasses import dataclass, field from io import StringIO + import numpy as np import requests # type: ignore[import-untyped] import torch +from datasets import load_dataset +from transformers import AutoTokenizer +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from kvpress.presses.base_press import BasePress @@ -29,12 +34,15 @@ class DuoAttentionPress(BasePress): Splits attention heads into two types: - Retrieval heads: use the full KV cache - Streaming heads: use only sink and recent tokens. - - Head classification is based on scores loaded from https://github.com/mit-han-lab/duo-attention/ The higher the head_compression_ratio, the more streaming heads are used. + + Head classification is based on scores. + - If on_the_fly_scoring=False, scores are loaded from https://github.com/mit-han-lab/duo-attention/ + - (experimental) If on_the_fly_scoring=True, scores are computed using duo_attention_on_the_fly """ head_compression_ratio: float = 0.0 + on_the_fly_scoring: bool = False compression_ratio_: float = field(init=False, default=None) recent_size: int = field(init=False, default=None) sink_size: int = field(init=False, default=None) @@ -44,15 +52,19 @@ def __post_init_from_model__(self, model): """ Initialize sink_size, recent_size, and streaming_mask from a model """ - # Load attention pattern from the DuoAttention repo - self.sink_size, self.recent_size, head_scores = self.load_attention_pattern(model) - - # Define retrieval and streaming heads through a binary mask - n_pruned = round(head_scores.size * self.head_compression_ratio) - self.streaming_mask = torch.zeros(head_scores.shape, dtype=bool, device=model.device) - if n_pruned > 0: - indices = np.argsort(head_scores, axis=None)[:n_pruned] - self.streaming_mask[np.unravel_index(indices, head_scores.shape)] = True + if getattr(self, "_post_init_model_name", None) != model.config.name_or_path: + # Load attention pattern from the DuoAttention repo + self.sink_size, self.recent_size, head_scores = self.load_attention_pattern(model) + if self.on_the_fly_scoring: + head_scores = duo_attention_on_the_fly(model) + + # Define retrieval and streaming heads through a binary mask + n_pruned = round(head_scores.size * self.head_compression_ratio) + self.streaming_mask = torch.zeros(head_scores.shape, dtype=bool, device=model.device) + if n_pruned > 0: + indices = np.argsort(head_scores, axis=None)[:n_pruned] + self.streaming_mask[np.unravel_index(indices, head_scores.shape)] = True + self._post_init_model_name = model.config.name_or_path @property def compression_ratio(self) -> float: @@ -108,3 +120,70 @@ def __call__(self, model): self.__post_init_from_model__(model) with super().__call__(model): yield + + +def duo_attention_on_the_fly(model, num_samples=50, q_len=500): + """ + New experimental method to quickly compute DuoAttention scores: + - Compute the mean query and key on num_samples random samples from BookSum + - Repeat the mean query and key q_len times and apply RoPE to get (Q, K) + - Compute the attention weights for (Q[-1], K) and compute the "area under the cumulated attention curve" + These scores could also be saved to avoid recomputing them but this method is still experimental + """ + + start = time() + print(f"Starting computation of DuoAttention scores based on {num_samples} samples") + tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path) + num_heads = model.config.num_attention_heads + num_key_value_heads = model.config.num_key_value_heads + num_key_value_groups = num_heads // num_key_value_heads + + # Load data + dataset = load_dataset("kmfoda/booksum", split="train").to_pandas() + texts = dataset.sample(num_samples, random_state=42)["chapter"].tolist() + + # Initialize variables + position_ids = torch.arange(q_len).unsqueeze(0) + scores = torch.zeros((model.config.num_hidden_layers, num_key_value_heads)) + + # Compute scores + for text in texts: + with torch.no_grad(): + # Compute hidden states + inputs = tokenizer(text, return_tensors="pt").to(model.device) + hidden_states = list(model(**inputs, output_hidden_states=True).hidden_states[:-1]) + + for layer_idx, h in enumerate(hidden_states): + module = model.model.layers[layer_idx] + d = module.self_attn.head_dim + h = module.input_layernorm(h) + + # Mean query + q = module.self_attn.q_proj(h) + q = q.view(1, q.shape[1], -1, d) + q = q.mean(dim=1, keepdim=True) + q = q.repeat(1, q_len, 1, 1).transpose(1, 2) + + # Mean key + k = module.self_attn.k_proj(h) + k = k.view(1, k.shape[1], -1, d) + k = k.mean(dim=1, keepdim=True) + k = k.repeat(1, q_len, 1, 1).transpose(1, 2) + + # Apply RoPE + cos, sin = model.model.rotary_emb(h, position_ids.to(h.device)) + q, k = apply_rotary_pos_emb(q, k, cos, sin) + k = k.repeat_interleave(num_key_value_groups, dim=1) + + # Compute attention weights for the last token + attn_weights = torch.matmul(q[:, :, -1:, :], k.transpose(2, 3)) / (d**0.5) + attn_weights = attn_weights.softmax(dim=-1, dtype=torch.float32).squeeze() + + # Compute score: area under the cumulated attention curve + s = torch.cumsum(attn_weights, dim=1).mean(1).cpu() + s = s.view(-1, num_key_value_groups).mean(1) + + # Store the scores + scores[layer_idx] += s / num_samples + print(f"Finished computation of DuoAttention scores in {time() - start:.2f}s") + return scores.numpy() diff --git a/kvpress/presses/qfilter_press.py b/kvpress/presses/qfilter_press.py index e1cdd7eb..338314e6 100644 --- a/kvpress/presses/qfilter_press.py +++ b/kvpress/presses/qfilter_press.py @@ -27,9 +27,11 @@ class QFilterPress(ScorerPress): """ def __post_init_from_model__(self, model): - model_name = model.config.name_or_path.split("/")[-1] - self.q_filters = self.load_q_filters(model_name) - self.q_filters = self.q_filters.to(model.dtype) + if getattr(self, "_post_init_model_name", None) != model.config.name_or_path: + model_name = model.config.name_or_path.split("/")[-1] + self.q_filters = self.load_q_filters(model_name) + self.q_filters = self.q_filters.to(model.dtype) + self._post_init_model_name = model.config.name_or_path @staticmethod def load_q_filters(model_name): From c844335423157833d8b8ad16ce170ba71b095e78 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 19 Mar 2025 12:39:51 +0000 Subject: [PATCH 2/6] Fix flake8 Signed-off-by: SimJeg --- kvpress/presses/duo_attention_press.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kvpress/presses/duo_attention_press.py b/kvpress/presses/duo_attention_press.py index 02713ee7..aab8099d 100644 --- a/kvpress/presses/duo_attention_press.py +++ b/kvpress/presses/duo_attention_press.py @@ -38,7 +38,7 @@ class DuoAttentionPress(BasePress): Head classification is based on scores. - If on_the_fly_scoring=False, scores are loaded from https://github.com/mit-han-lab/duo-attention/ - - (experimental) If on_the_fly_scoring=True, scores are computed using duo_attention_on_the_fly + - (experimental) If on_the_fly_scoring=True, scores are computed using duo_attention_on_the_fly """ head_compression_ratio: float = 0.0 From f527ac749996edc8ac2485937306238eb9028a4f Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 19 Mar 2025 13:03:44 +0000 Subject: [PATCH 3/6] Use cachetools Signed-off-by: SimJeg --- evaluation/evaluate.py | 2 ++ kvpress/presses/duo_attention_press.py | 29 +++++++++++++------------- kvpress/presses/qfilter_press.py | 10 ++++----- pyproject.toml | 1 + 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index 8087cb34..82e84f00 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -30,6 +30,7 @@ StreamingLLMPress, ThinKPress, TOVAPress, + QFilterPress, ) logger = logging.getLogger(__name__) @@ -67,6 +68,7 @@ "duo_attention": DuoAttentionPress(), "duo_attention_on_the_fly": DuoAttentionPress(on_the_fly_scoring=True), "chunkkv": ChunkKVPress(press=SnapKVPress(), chunk_length=20), + "qfilter": QFilterPress(), } diff --git a/kvpress/presses/duo_attention_press.py b/kvpress/presses/duo_attention_press.py index aab8099d..e426b71f 100644 --- a/kvpress/presses/duo_attention_press.py +++ b/kvpress/presses/duo_attention_press.py @@ -1,11 +1,11 @@ # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from time import time +from cachetools import cached, LRUCache from contextlib import contextmanager from dataclasses import dataclass, field from io import StringIO - +from time import time import numpy as np import requests # type: ignore[import-untyped] @@ -48,23 +48,22 @@ class DuoAttentionPress(BasePress): sink_size: int = field(init=False, default=None) streaming_mask: torch.Tensor = field(init=False, default=None) + @cached(LRUCache(maxsize=128), key=lambda self, model: model.config.name_or_path) def __post_init_from_model__(self, model): """ Initialize sink_size, recent_size, and streaming_mask from a model """ - if getattr(self, "_post_init_model_name", None) != model.config.name_or_path: - # Load attention pattern from the DuoAttention repo - self.sink_size, self.recent_size, head_scores = self.load_attention_pattern(model) - if self.on_the_fly_scoring: - head_scores = duo_attention_on_the_fly(model) - - # Define retrieval and streaming heads through a binary mask - n_pruned = round(head_scores.size * self.head_compression_ratio) - self.streaming_mask = torch.zeros(head_scores.shape, dtype=bool, device=model.device) - if n_pruned > 0: - indices = np.argsort(head_scores, axis=None)[:n_pruned] - self.streaming_mask[np.unravel_index(indices, head_scores.shape)] = True - self._post_init_model_name = model.config.name_or_path + # Load attention pattern from the DuoAttention repo + self.sink_size, self.recent_size, head_scores = self.load_attention_pattern(model) + if self.on_the_fly_scoring: + head_scores = duo_attention_on_the_fly(model) + + # Define retrieval and streaming heads through a binary mask + n_pruned = round(head_scores.size * self.head_compression_ratio) + self.streaming_mask = torch.zeros(head_scores.shape, dtype=bool, device=model.device) + if n_pruned > 0: + indices = np.argsort(head_scores, axis=None)[:n_pruned] + self.streaming_mask[np.unravel_index(indices, head_scores.shape)] = True @property def compression_ratio(self) -> float: diff --git a/kvpress/presses/qfilter_press.py b/kvpress/presses/qfilter_press.py index 338314e6..038fea4a 100644 --- a/kvpress/presses/qfilter_press.py +++ b/kvpress/presses/qfilter_press.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from cachetools import cached, LRUCache from contextlib import contextmanager from dataclasses import dataclass @@ -26,12 +27,11 @@ class QFilterPress(ScorerPress): Prune KV pairs with Q-filters """ + @cached(LRUCache(maxsize=128), key=lambda self, model: model.config.name_or_path) def __post_init_from_model__(self, model): - if getattr(self, "_post_init_model_name", None) != model.config.name_or_path: - model_name = model.config.name_or_path.split("/")[-1] - self.q_filters = self.load_q_filters(model_name) - self.q_filters = self.q_filters.to(model.dtype) - self._post_init_model_name = model.config.name_or_path + model_name = model.config.name_or_path.split("/")[-1] + self.q_filters = self.load_q_filters(model_name) + self.q_filters = self.q_filters.to(model.dtype) @staticmethod def load_q_filters(model_name): diff --git a/pyproject.toml b/pyproject.toml index 56ee5532..3495242b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ rouge = "^1.0.1" bert-score = "^0.3.13" accelerate = "^1.0.0" requests = "^2.32.3" +cachetools = "^5.5.2" [tool.poetry.dev-dependencies] pytest = "^7.0.0" From d754356f001e2018763fedbd65484737534614ad Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 19 Mar 2025 13:26:10 +0000 Subject: [PATCH 4/6] Update LRU cache usage Signed-off-by: SimJeg --- kvpress/presses/duo_attention_press.py | 10 +++++++--- kvpress/presses/qfilter_press.py | 4 ++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/kvpress/presses/duo_attention_press.py b/kvpress/presses/duo_attention_press.py index e426b71f..ae4b2bae 100644 --- a/kvpress/presses/duo_attention_press.py +++ b/kvpress/presses/duo_attention_press.py @@ -25,6 +25,8 @@ "mistralai/Mistral-7B-Instruct-v0.3": "Mistral-7B-Instruct-v0.3/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10", # noqa: E501 } +cache = LRUCache(maxsize=128) + @dataclass class DuoAttentionPress(BasePress): @@ -48,15 +50,15 @@ class DuoAttentionPress(BasePress): sink_size: int = field(init=False, default=None) streaming_mask: torch.Tensor = field(init=False, default=None) - @cached(LRUCache(maxsize=128), key=lambda self, model: model.config.name_or_path) def __post_init_from_model__(self, model): """ Initialize sink_size, recent_size, and streaming_mask from a model """ # Load attention pattern from the DuoAttention repo - self.sink_size, self.recent_size, head_scores = self.load_attention_pattern(model) if self.on_the_fly_scoring: - head_scores = duo_attention_on_the_fly(model) + self.sink_size, self.recent_size, head_scores = 128, 256, duo_attention_on_the_fly(model) + else: + self.sink_size, self.recent_size, head_scores = self.load_attention_pattern(model) # Define retrieval and streaming heads through a binary mask n_pruned = round(head_scores.size * self.head_compression_ratio) @@ -93,6 +95,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): return keys, values @staticmethod + @cached(cache, key=lambda model: model.config.name_or_path) def load_attention_pattern(model): """ Load the attention pattern from the DuoAttention repo @@ -121,6 +124,7 @@ def __call__(self, model): yield +@cached(cache, key=lambda model, num_samples=50, q_len=500: (model.config.name_or_path, num_samples, q_len)) def duo_attention_on_the_fly(model, num_samples=50, q_len=500): """ New experimental method to quickly compute DuoAttention scores: diff --git a/kvpress/presses/qfilter_press.py b/kvpress/presses/qfilter_press.py index 038fea4a..552d3b4f 100644 --- a/kvpress/presses/qfilter_press.py +++ b/kvpress/presses/qfilter_press.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from cachetools import cached, LRUCache +from functools import cache from contextlib import contextmanager from dataclasses import dataclass @@ -27,13 +27,13 @@ class QFilterPress(ScorerPress): Prune KV pairs with Q-filters """ - @cached(LRUCache(maxsize=128), key=lambda self, model: model.config.name_or_path) def __post_init_from_model__(self, model): model_name = model.config.name_or_path.split("/")[-1] self.q_filters = self.load_q_filters(model_name) self.q_filters = self.q_filters.to(model.dtype) @staticmethod + @cache def load_q_filters(model_name): try: return QFilters.from_pretrained(f"nthngdy/{model_name}_qfilt").q_filters From d09960c1b97241b826a66401fff1926491c021af Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 19 Mar 2025 13:36:23 +0000 Subject: [PATCH 5/6] Fix mypy Signed-off-by: SimJeg --- kvpress/presses/duo_attention_press.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kvpress/presses/duo_attention_press.py b/kvpress/presses/duo_attention_press.py index ae4b2bae..cd39e8aa 100644 --- a/kvpress/presses/duo_attention_press.py +++ b/kvpress/presses/duo_attention_press.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from cachetools import cached, LRUCache +from cachetools import cached, LRUCache # type: ignore[import-untyped] from contextlib import contextmanager from dataclasses import dataclass, field from io import StringIO From 338f4f7803b87718a2033655de70e2a058246028 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 19 Mar 2025 15:54:59 +0000 Subject: [PATCH 6/6] Signed-off-by: Simon Jegou Signed-off-by: SimJeg --- evaluation/evaluate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index 76bba49f..7926ed2b 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -79,6 +79,7 @@ "duo_attention_on_the_fly": DuoAttentionPress(on_the_fly_scoring=True), "chunkkv": ChunkKVPress(press=SnapKVPress(), chunk_length=20), "qfilter": QFilterPress(), + "snap_think": ComposedPress([SnapKVPress(), ThinKPress()]), }