Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions kvpress/presses/base_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from torch import nn
from transformers import (
Gemma3ForCausalLM,
Gemma3ForConditionalGeneration,
LlamaForCausalLM,
MistralForCausalLM,
Phi3ForCausalLM,
Expand All @@ -28,7 +28,7 @@
Phi3ForCausalLM,
Qwen2ForCausalLM,
Qwen3ForCausalLM,
Gemma3ForCausalLM,
Gemma3ForConditionalGeneration,
)


Expand Down Expand Up @@ -178,16 +178,17 @@ def __call__(self, model: PreTrainedModel) -> Generator:
if not isinstance(model, SUPPORTED_MODELS):
logger.warning(f"Model {type(model)} not tested, supported models: {SUPPORTED_MODELS}")

if isinstance(model, Gemma3ForCausalLM):
logger.warning("Compression in Gemma3 is only applied to layer without sliding window attention")
if isinstance(model, Gemma3ForConditionalGeneration):
logger.warning_once("Compression in Gemma3 is only applied to layer without sliding window attention")

hooks = []
try:
for layer in model.model.layers:
if isinstance(model, Gemma3ForCausalLM) and layer.is_sliding:
language_model = model.model.language_model if hasattr(model.model, "language_model") else model.model
for layer in language_model.layers:
if isinstance(model, Gemma3ForConditionalGeneration) and layer.self_attn.is_sliding:
# Skip layers with sliding window attention, only for Gemma3
continue
layer.self_attn.rotary_emb = model.model.rotary_emb
layer.self_attn.rotary_emb = language_model.rotary_emb
hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True))
yield
finally:
Expand Down
41 changes: 13 additions & 28 deletions kvpress/presses/expected_attention_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@
import torch
from torch import nn
from torch.nn import functional as F
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention
from transformers.models.llama.modeling_llama import repeat_kv
from transformers.models.phi3.modeling_phi3 import Phi3Attention
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention

from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.utils import get_query_states


@dataclass
Expand Down Expand Up @@ -66,46 +64,33 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor):
Compute the mean and covariance matrix of the queries
"""

bsz, q_len, _ = hidden_states.shape
n, d = module.config.num_attention_heads, module.head_dim
q_len = hidden_states.shape[1]
head_dim = module.head_dim

# Remove first hidden_states that likely contain outliers
h = hidden_states[:, self.n_sink :]

if isinstance(module, (Qwen3Attention, Gemma3Attention)):
# Qwen and Gemma use QK norm, which is not compatible with ExpectedAttentionPress (for now)
raise NotImplementedError(f"ExpectedAttentionPress not yet implemented for {module.__class__}.")
elif isinstance(module, Phi3Attention):
Wq = module.qkv_proj.weight[: n * d]
elif hasattr(module, "q_proj"):
# Assume Llama-like attention layer
Wq = module.q_proj.weight # type: ignore[assignment]
else:
raise NotImplementedError(f"ExpectedAttentionPress not yet implemented for {module.__class__}.")
query_states = get_query_states(module, h)

# Query mean
mean_h = torch.mean(h, dim=1, keepdim=True)
mu = torch.matmul(mean_h, Wq.T).squeeze(1)
mu = mu.view(bsz, n, d)
mu = query_states.mean(dim=2, keepdim=True)

# Query covariance
cov = None
if self.use_covariance:
h = h - mean_h
q = torch.matmul(h, Wq.T).view(bsz, h.shape[1], n, d)
# Compute per-head query covariance directly in the projected space.
# This avoids forming an intermediate O((n * d)^2) covariance matrix
# for the full hidden states, reducing both memory and compute cost.
cov = torch.einsum("bsni,bsnj->bnij", q, q) / h.shape[1]
centered_states = query_states - mu
cov = torch.einsum("bnsi,bnsj->bnij", centered_states, centered_states) / h.shape[1]
mu = mu.squeeze(2)

# RoPE rotation matrix on next n_future_positions
position_ids = torch.arange(q_len, q_len + self.n_future_positions).unsqueeze(0).to(mu.device)
cos, sin = module.rotary_emb(mu, position_ids)
cos, sin = cos[0], sin[0]

Id = torch.eye(d, device=cos.device, dtype=cos.dtype)
P = torch.zeros((d, d), device=cos.device, dtype=cos.dtype)
P[d // 2 :, : d // 2], P[: d // 2, d // 2 :] = torch.eye(d // 2), -torch.eye(d // 2)
Id = torch.eye(head_dim, device=cos.device, dtype=cos.dtype)
P = torch.zeros((head_dim, head_dim), device=cos.device, dtype=cos.dtype)
P[head_dim // 2 :, : head_dim // 2], P[: head_dim // 2, head_dim // 2 :] = torch.eye(head_dim // 2), -torch.eye(
head_dim // 2
)
R = cos.unsqueeze(1) * Id + sin.unsqueeze(1) * P

# Apply average rotation to the mean and covariance
Expand Down
19 changes: 2 additions & 17 deletions kvpress/presses/kvzip_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@
import torch
from torch import nn
from transformers import AutoTokenizer, Gemma3ForCausalLM, PreTrainedModel, PreTrainedTokenizer, QuantizedCache
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention
from transformers.models.llama.modeling_llama import rotate_half
from transformers.models.phi3.modeling_phi3 import Phi3Attention
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention

from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress
from kvpress.presses.utils import get_query_states

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -306,20 +304,7 @@ def score_kvzip(
head_dim = module.head_dim
num_key_value_groups = num_heads // num_heads_kv

if isinstance(module, Phi3Attention):
qkv = module.qkv_proj(hidden_states)
queries = qkv[..., : num_heads * head_dim]
elif hasattr(module, "q_proj"):
# Assume Llama-like attention layer
queries = module.q_proj(hidden_states)
else:
raise NotImplementedError(f"KVzip not yet implemented for {module.__class__}.")

queries = queries.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

# Support for Qwen3 and Gemma3 QK norm
if isinstance(module, (Qwen3Attention, Gemma3Attention)):
queries = module.q_norm(queries)
queries = get_query_states(module, hidden_states)

# Apply RoPE
cos, sin = kwargs["position_embeddings"]
Expand Down
19 changes: 2 additions & 17 deletions kvpress/presses/snapkv_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@
import torch
from torch import nn
from torch.nn import functional as F
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention
from transformers.models.llama.modeling_llama import repeat_kv, rotate_half
from transformers.models.phi3.modeling_phi3 import Phi3Attention
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention

from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.utils import get_query_states


@dataclass
Expand Down Expand Up @@ -52,20 +50,7 @@ def compute_window_attention(module, hidden_states, keys, window_size, position_
num_key_value_groups = num_heads // module.config.num_key_value_heads

# Get last window_size queries
if isinstance(module, Phi3Attention):
qkv = module.qkv_proj(hidden_states[:, -window_size:])
query_states = qkv[..., : num_heads * head_dim]
elif hasattr(module, "q_proj"):
# Assume Llama-like attention layer
query_states = module.q_proj(hidden_states[:, -window_size:])
else:
raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.")

query_states = query_states.view(bsz, window_size, num_heads, head_dim).transpose(1, 2)

# Support for Qwen3 and Gemma3 QK norm
if isinstance(module, (Qwen3Attention, Gemma3Attention)):
query_states = module.q_norm(query_states)
query_states = get_query_states(module, hidden_states[:, -window_size:])

# Apply RoPE
cos, sin = position_embeddings
Expand Down
23 changes: 2 additions & 21 deletions kvpress/presses/think_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@

import torch
from torch import nn
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention
from transformers.models.llama.modeling_llama import rotate_half
from transformers.models.phi3.modeling_phi3 import Phi3Attention
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention

from kvpress.presses.base_press import BasePress
from kvpress.presses.utils import get_query_states


@dataclass
Expand Down Expand Up @@ -46,25 +44,8 @@ def compute_window_queries(self, module, hidden_states, position_embeddings):
"""
Re-compute the last window_size query states
"""
bsz, q_len, _ = hidden_states.shape
num_heads = module.config.num_attention_heads
head_dim = module.head_dim

# Get last self.window_size queries
if isinstance(module, Phi3Attention):
qkv = module.qkv_proj(hidden_states[:, -self.window_size :])
query_states = qkv[..., : num_heads * head_dim]
elif hasattr(module, "q_proj"):
# Assume Llama-like attention layer
query_states = module.q_proj(hidden_states[:, -self.window_size :])
else:
raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.")

query_states = query_states.view(bsz, self.window_size, num_heads, head_dim).transpose(1, 2)

# Support for Qwen3 and Gemma3 QK norm
if isinstance(module, (Qwen3Attention, Gemma3Attention)):
query_states = module.q_norm(query_states)
query_states = get_query_states(module, hidden_states[:, -self.window_size :])

# Apply RoPE
cos, sin = position_embeddings
Expand Down
52 changes: 52 additions & 0 deletions kvpress/presses/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch import nn
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention
from transformers.models.phi3.modeling_phi3 import Phi3Attention
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention


def get_query_states(module: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Extracts the query states from a given attention module and hidden states tensor.

This function supports multiple attention module types: Phi3Attention, Qwen3Attention, Gemma3Attention,
and Llama-like modules. It handles the appropriate projection and reshaping to obtain the query states
in the expected format.

Parameters
----------
module : nn.Module
The attention module from which to extract query states. Must be one of
Phi3Attention, Qwen3Attention, Gemma3Attention, or a Llama-like attention module
with a 'q_proj' attribute.
hidden_states : torch.Tensor
The input hidden states of shape (batch_size, seq_len, hidden_dim).

Returns
-------
query_states : torch.Tensor
The extracted query states of shape (batch_size, num_heads, seq_len, head_dim).
"""
bsz, q_len, _ = hidden_states.shape
num_heads = module.config.num_attention_heads
head_dim = module.head_dim

if isinstance(module, Phi3Attention):
qkv = module.qkv_proj(hidden_states)
query_states = qkv[..., : num_heads * head_dim]
elif hasattr(module, "q_proj"):
# Assume Llama-like attention layer
query_states = module.q_proj(hidden_states)
else:
raise NotImplementedError(f"Press not yet implemented for {module.__class__}.")

query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

# Support for Qwen3 and Gemma3 QK norm
if isinstance(module, (Qwen3Attention, Gemma3Attention)):
query_states = module.q_norm(query_states)

return query_states