From 8737b56fa50a2102c0f95ea3c0403c57272fc2e4 Mon Sep 17 00:00:00 2001 From: alessiodevoto Date: Mon, 25 Aug 2025 19:00:35 +0200 Subject: [PATCH 1/4] improve ea Signed-off-by: alessiodevoto --- kvpress/presses/expected_attention_press.py | 41 +++++++++++---------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/kvpress/presses/expected_attention_press.py b/kvpress/presses/expected_attention_press.py index 8b504625..e659ae69 100644 --- a/kvpress/presses/expected_attention_press.py +++ b/kvpress/presses/expected_attention_press.py @@ -67,45 +67,46 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): """ bsz, q_len, _ = hidden_states.shape - n, d = module.config.num_attention_heads, module.head_dim + num_heads, head_dim = module.config.num_attention_heads, module.head_dim # Remove first hidden_states that likely contain outliers h = hidden_states[:, self.n_sink :] - if isinstance(module, (Qwen3Attention, Gemma3Attention)): - # Qwen and Gemma use QK norm, which is not compatible with ExpectedAttentionPress (for now) - raise NotImplementedError(f"ExpectedAttentionPress not yet implemented for {module.__class__}.") - elif isinstance(module, Phi3Attention): - Wq = module.qkv_proj.weight[: n * d] + if isinstance(module, Phi3Attention): + qkv = module.qkv_proj(h) + query_states = qkv[..., : num_heads * head_dim] elif hasattr(module, "q_proj"): # Assume Llama-like attention layer - Wq = module.q_proj.weight # type: ignore[assignment] + query_states = module.q_proj(h) else: raise NotImplementedError(f"ExpectedAttentionPress not yet implemented for {module.__class__}.") - # Query mean - mean_h = torch.mean(h, dim=1, keepdim=True) - mu = torch.matmul(mean_h, Wq.T).squeeze(1) - mu = mu.view(bsz, n, d) + query_states = query_states.view(bsz, h.shape[1], num_heads, head_dim).transpose(1, 2) + + # Support for Qwen3 and Gemma3 QK norm + if isinstance(module, (Qwen3Attention, Gemma3Attention)): + query_states = module.q_norm(query_states) + + mu = query_states.mean(dim=2, keepdim=True) # Query covariance cov = None if self.use_covariance: - h = h - mean_h - q = torch.matmul(h, Wq.T).view(bsz, h.shape[1], n, d) - # Compute per-head query covariance directly in the projected space. - # This avoids forming an intermediate O((n * d)^2) covariance matrix - # for the full hidden states, reducing both memory and compute cost. - cov = torch.einsum("bsni,bsnj->bnij", q, q) / h.shape[1] + centered_states = query_states - mu + centered_states = centered_states.transpose(1, 2) + cov = torch.einsum("bsni,bsnj->bnij", centered_states, centered_states) / h.shape[1] + mu = mu.squeeze(2) # RoPE rotation matrix on next n_future_positions position_ids = torch.arange(q_len, q_len + self.n_future_positions).unsqueeze(0).to(mu.device) cos, sin = module.rotary_emb(mu, position_ids) cos, sin = cos[0], sin[0] - Id = torch.eye(d, device=cos.device, dtype=cos.dtype) - P = torch.zeros((d, d), device=cos.device, dtype=cos.dtype) - P[d // 2 :, : d // 2], P[: d // 2, d // 2 :] = torch.eye(d // 2), -torch.eye(d // 2) + Id = torch.eye(head_dim, device=cos.device, dtype=cos.dtype) + P = torch.zeros((head_dim, head_dim), device=cos.device, dtype=cos.dtype) + P[head_dim // 2 :, : head_dim // 2], P[: head_dim // 2, head_dim // 2 :] = torch.eye(head_dim // 2), -torch.eye( + head_dim // 2 + ) R = cos.unsqueeze(1) * Id + sin.unsqueeze(1) * P # Apply average rotation to the mean and covariance From 1e207e550b5be0f1696b13d705d6444627170065 Mon Sep 17 00:00:00 2001 From: alessiodevoto Date: Tue, 26 Aug 2025 09:17:18 +0000 Subject: [PATCH 2/4] support Gemma3 Signed-off-by: alessiodevoto --- kvpress/presses/base_press.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index c5c8511b..69ddb1a8 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -10,7 +10,7 @@ import torch from torch import nn from transformers import ( - Gemma3ForCausalLM, + Gemma3ForConditionalGeneration, LlamaForCausalLM, MistralForCausalLM, Phi3ForCausalLM, @@ -28,7 +28,7 @@ Phi3ForCausalLM, Qwen2ForCausalLM, Qwen3ForCausalLM, - Gemma3ForCausalLM, + Gemma3ForConditionalGeneration, ) @@ -178,16 +178,17 @@ def __call__(self, model: PreTrainedModel) -> Generator: if not isinstance(model, SUPPORTED_MODELS): logger.warning(f"Model {type(model)} not tested, supported models: {SUPPORTED_MODELS}") - if isinstance(model, Gemma3ForCausalLM): - logger.warning("Compression in Gemma3 is only applied to layer without sliding window attention") + if isinstance(model, Gemma3ForConditionalGeneration): + logger.warning_once("Compression in Gemma3 is only applied to layer without sliding window attention") hooks = [] try: - for layer in model.model.layers: - if isinstance(model, Gemma3ForCausalLM) and layer.is_sliding: + lm = model.model.language_model if isinstance(model, Gemma3ForConditionalGeneration) else model.model + for layer in lm.layers: + if isinstance(model, Gemma3ForConditionalGeneration) and layer.self_attn.is_sliding: # Skip layers with sliding window attention, only for Gemma3 continue - layer.self_attn.rotary_emb = model.model.rotary_emb + layer.self_attn.rotary_emb = lm.rotary_emb hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True)) yield finally: From 244163b526aca671e71f33c70036c7aa73bf76d2 Mon Sep 17 00:00:00 2001 From: alessiodevoto Date: Tue, 26 Aug 2025 11:26:35 +0000 Subject: [PATCH 3/4] fixes Signed-off-by: alessiodevoto --- kvpress/presses/base_press.py | 6 +++--- kvpress/presses/expected_attention_press.py | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index 69ddb1a8..1528f55d 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -183,12 +183,12 @@ def __call__(self, model: PreTrainedModel) -> Generator: hooks = [] try: - lm = model.model.language_model if isinstance(model, Gemma3ForConditionalGeneration) else model.model - for layer in lm.layers: + language_model = model.model.language_model if hasattr(model.model, "language_model") else model.model + for layer in language_model.layers: if isinstance(model, Gemma3ForConditionalGeneration) and layer.self_attn.is_sliding: # Skip layers with sliding window attention, only for Gemma3 continue - layer.self_attn.rotary_emb = lm.rotary_emb + layer.self_attn.rotary_emb = language_model.rotary_emb hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True)) yield finally: diff --git a/kvpress/presses/expected_attention_press.py b/kvpress/presses/expected_attention_press.py index e659ae69..861cd396 100644 --- a/kvpress/presses/expected_attention_press.py +++ b/kvpress/presses/expected_attention_press.py @@ -93,8 +93,7 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): cov = None if self.use_covariance: centered_states = query_states - mu - centered_states = centered_states.transpose(1, 2) - cov = torch.einsum("bsni,bsnj->bnij", centered_states, centered_states) / h.shape[1] + cov = torch.einsum("bnsi,bnsj->bnij", centered_states, centered_states) / h.shape[1] mu = mu.squeeze(2) # RoPE rotation matrix on next n_future_positions From 241f088781f04a44833fd147147f8097f1ec74c0 Mon Sep 17 00:00:00 2001 From: alessiodevoto Date: Tue, 26 Aug 2025 11:52:21 +0000 Subject: [PATCH 4/4] refactor query extaction Signed-off-by: alessiodevoto --- kvpress/presses/expected_attention_press.py | 25 ++-------- kvpress/presses/kvzip_press.py | 19 +------- kvpress/presses/snapkv_press.py | 19 +------- kvpress/presses/think_press.py | 23 +-------- kvpress/presses/utils.py | 52 +++++++++++++++++++++ 5 files changed, 63 insertions(+), 75 deletions(-) create mode 100644 kvpress/presses/utils.py diff --git a/kvpress/presses/expected_attention_press.py b/kvpress/presses/expected_attention_press.py index 861cd396..e22e2b02 100644 --- a/kvpress/presses/expected_attention_press.py +++ b/kvpress/presses/expected_attention_press.py @@ -8,12 +8,10 @@ import torch from torch import nn from torch.nn import functional as F -from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention from transformers.models.llama.modeling_llama import repeat_kv -from transformers.models.phi3.modeling_phi3 import Phi3Attention -from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention from kvpress.presses.scorer_press import ScorerPress +from kvpress.presses.utils import get_query_states @dataclass @@ -66,27 +64,14 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): Compute the mean and covariance matrix of the queries """ - bsz, q_len, _ = hidden_states.shape - num_heads, head_dim = module.config.num_attention_heads, module.head_dim + q_len = hidden_states.shape[1] + head_dim = module.head_dim # Remove first hidden_states that likely contain outliers h = hidden_states[:, self.n_sink :] + query_states = get_query_states(module, h) - if isinstance(module, Phi3Attention): - qkv = module.qkv_proj(h) - query_states = qkv[..., : num_heads * head_dim] - elif hasattr(module, "q_proj"): - # Assume Llama-like attention layer - query_states = module.q_proj(h) - else: - raise NotImplementedError(f"ExpectedAttentionPress not yet implemented for {module.__class__}.") - - query_states = query_states.view(bsz, h.shape[1], num_heads, head_dim).transpose(1, 2) - - # Support for Qwen3 and Gemma3 QK norm - if isinstance(module, (Qwen3Attention, Gemma3Attention)): - query_states = module.q_norm(query_states) - + # Query mean mu = query_states.mean(dim=2, keepdim=True) # Query covariance diff --git a/kvpress/presses/kvzip_press.py b/kvpress/presses/kvzip_press.py index 0c68fd79..d49fd802 100644 --- a/kvpress/presses/kvzip_press.py +++ b/kvpress/presses/kvzip_press.py @@ -11,12 +11,10 @@ import torch from torch import nn from transformers import AutoTokenizer, Gemma3ForCausalLM, PreTrainedModel, PreTrainedTokenizer, QuantizedCache -from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention from transformers.models.llama.modeling_llama import rotate_half -from transformers.models.phi3.modeling_phi3 import Phi3Attention -from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress +from kvpress.presses.utils import get_query_states logger = logging.getLogger(__name__) @@ -306,20 +304,7 @@ def score_kvzip( head_dim = module.head_dim num_key_value_groups = num_heads // num_heads_kv - if isinstance(module, Phi3Attention): - qkv = module.qkv_proj(hidden_states) - queries = qkv[..., : num_heads * head_dim] - elif hasattr(module, "q_proj"): - # Assume Llama-like attention layer - queries = module.q_proj(hidden_states) - else: - raise NotImplementedError(f"KVzip not yet implemented for {module.__class__}.") - - queries = queries.view(bsz, q_len, num_heads, head_dim).transpose(1, 2) - - # Support for Qwen3 and Gemma3 QK norm - if isinstance(module, (Qwen3Attention, Gemma3Attention)): - queries = module.q_norm(queries) + queries = get_query_states(module, hidden_states) # Apply RoPE cos, sin = kwargs["position_embeddings"] diff --git a/kvpress/presses/snapkv_press.py b/kvpress/presses/snapkv_press.py index 0a962f46..68f92af9 100644 --- a/kvpress/presses/snapkv_press.py +++ b/kvpress/presses/snapkv_press.py @@ -8,12 +8,10 @@ import torch from torch import nn from torch.nn import functional as F -from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention from transformers.models.llama.modeling_llama import repeat_kv, rotate_half -from transformers.models.phi3.modeling_phi3 import Phi3Attention -from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention from kvpress.presses.scorer_press import ScorerPress +from kvpress.presses.utils import get_query_states @dataclass @@ -52,20 +50,7 @@ def compute_window_attention(module, hidden_states, keys, window_size, position_ num_key_value_groups = num_heads // module.config.num_key_value_heads # Get last window_size queries - if isinstance(module, Phi3Attention): - qkv = module.qkv_proj(hidden_states[:, -window_size:]) - query_states = qkv[..., : num_heads * head_dim] - elif hasattr(module, "q_proj"): - # Assume Llama-like attention layer - query_states = module.q_proj(hidden_states[:, -window_size:]) - else: - raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.") - - query_states = query_states.view(bsz, window_size, num_heads, head_dim).transpose(1, 2) - - # Support for Qwen3 and Gemma3 QK norm - if isinstance(module, (Qwen3Attention, Gemma3Attention)): - query_states = module.q_norm(query_states) + query_states = get_query_states(module, hidden_states[:, -window_size:]) # Apply RoPE cos, sin = position_embeddings diff --git a/kvpress/presses/think_press.py b/kvpress/presses/think_press.py index 710ff4e6..a19379c2 100644 --- a/kvpress/presses/think_press.py +++ b/kvpress/presses/think_press.py @@ -6,12 +6,10 @@ import torch from torch import nn -from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention from transformers.models.llama.modeling_llama import rotate_half -from transformers.models.phi3.modeling_phi3 import Phi3Attention -from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention from kvpress.presses.base_press import BasePress +from kvpress.presses.utils import get_query_states @dataclass @@ -46,25 +44,8 @@ def compute_window_queries(self, module, hidden_states, position_embeddings): """ Re-compute the last window_size query states """ - bsz, q_len, _ = hidden_states.shape - num_heads = module.config.num_attention_heads - head_dim = module.head_dim - # Get last self.window_size queries - if isinstance(module, Phi3Attention): - qkv = module.qkv_proj(hidden_states[:, -self.window_size :]) - query_states = qkv[..., : num_heads * head_dim] - elif hasattr(module, "q_proj"): - # Assume Llama-like attention layer - query_states = module.q_proj(hidden_states[:, -self.window_size :]) - else: - raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.") - - query_states = query_states.view(bsz, self.window_size, num_heads, head_dim).transpose(1, 2) - - # Support for Qwen3 and Gemma3 QK norm - if isinstance(module, (Qwen3Attention, Gemma3Attention)): - query_states = module.q_norm(query_states) + query_states = get_query_states(module, hidden_states[:, -self.window_size :]) # Apply RoPE cos, sin = position_embeddings diff --git a/kvpress/presses/utils.py b/kvpress/presses/utils.py new file mode 100644 index 00000000..938c15ba --- /dev/null +++ b/kvpress/presses/utils.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch import nn +from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention +from transformers.models.phi3.modeling_phi3 import Phi3Attention +from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention + + +def get_query_states(module: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Extracts the query states from a given attention module and hidden states tensor. + + This function supports multiple attention module types: Phi3Attention, Qwen3Attention, Gemma3Attention, + and Llama-like modules. It handles the appropriate projection and reshaping to obtain the query states + in the expected format. + + Parameters + ---------- + module : nn.Module + The attention module from which to extract query states. Must be one of + Phi3Attention, Qwen3Attention, Gemma3Attention, or a Llama-like attention module + with a 'q_proj' attribute. + hidden_states : torch.Tensor + The input hidden states of shape (batch_size, seq_len, hidden_dim). + + Returns + ------- + query_states : torch.Tensor + The extracted query states of shape (batch_size, num_heads, seq_len, head_dim). + """ + bsz, q_len, _ = hidden_states.shape + num_heads = module.config.num_attention_heads + head_dim = module.head_dim + + if isinstance(module, Phi3Attention): + qkv = module.qkv_proj(hidden_states) + query_states = qkv[..., : num_heads * head_dim] + elif hasattr(module, "q_proj"): + # Assume Llama-like attention layer + query_states = module.q_proj(hidden_states) + else: + raise NotImplementedError(f"Press not yet implemented for {module.__class__}.") + + query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2) + + # Support for Qwen3 and Gemma3 QK norm + if isinstance(module, (Qwen3Attention, Gemma3Attention)): + query_states = module.q_norm(query_states) + + return query_states