diff --git a/tico/quantization/wrapq/wrappers/llama/quant_attn.py b/tico/quantization/wrapq/wrappers/llama/quant_attn.py index 7dd87fd3..cbab811f 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_attn.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_attn.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy from typing import Optional, Tuple import torch @@ -52,12 +53,6 @@ def __init__( ) self.kv_rep = cfg.num_attention_heads // cfg.num_key_value_heads - # Constant scale (1/√d) - self.scale_t = torch.tensor( - float(getattr(fp_attn, "scaling", self.head_dim**-0.5)) - ) - self.obs_scale = self._make_obs("scale") - # ---- Wrap q k v o projections via PTQWrapper --------------- q_cfg = qcfg.child("q_proj") if qcfg else None k_cfg = qcfg.child("k_proj") if qcfg else None @@ -81,7 +76,7 @@ def __init__( fp_attn.q_proj, qcfg=q_cfg, fp_name=f"{fp_name}.q_proj" ) self.k_proj = PTQWrapper( - fp_attn.k_proj, qcfg=k_cfg, fp_name=f"{fp_name}.k_proj" + copy.deepcopy(fp_attn.k_proj), qcfg=k_cfg, fp_name=f"{fp_name}.k_proj" ) self.v_proj = PTQWrapper( fp_attn.v_proj, qcfg=v_cfg, fp_name=f"{fp_name}.v_proj" @@ -90,6 +85,17 @@ def __init__( fp_attn.o_proj, qcfg=o_cfg, fp_name=f"{fp_name}.o_proj" ) + # Constant scale (1/√d) + scale_t = torch.tensor( + float(getattr(fp_attn, "scaling", self.head_dim**-0.5)) + ) + # merge scale_t to k_proj, (otherwise merge it to q_proj) + with torch.no_grad(): + lin = self.k_proj.wrapped.module + lin.weight.mul_(scale_t) + if lin.bias is not None: + lin.bias.mul_(scale_t) + mk = self._make_obs self.obs_hidden = mk("hidden") @@ -119,7 +125,6 @@ def __init__( # Masking & attention math self.obs_causal_mask = mk("causal_mask") - self.obs_logits_raw = mk("logits_raw") self.obs_logits = mk("logits") self.obs_mask_add = mk("mask_add") self.obs_softmax = mk("softmax") @@ -226,9 +231,7 @@ def forward( v_rep = v_for_attn.repeat_interleave(self.kv_rep, dim=1) # (B, n_h, K, H) # Attention logits: q @ k^T - logits_raw = self._fq(q_rot @ k_rep.transpose(-2, -1), self.obs_logits_raw) - scale = self._fq(self.scale_t, self.obs_scale) - logits = self._fq(logits_raw * scale, self.obs_logits) + logits = self._fq(q_rot @ k_rep.transpose(-2, -1), self.obs_logits) # Build causal mask if needed if attention_mask is None or attention_mask.dtype == torch.bool: @@ -265,7 +268,6 @@ def _all_observers(self): # local first yield from ( self.obs_hidden, - self.obs_scale, self.obs_cos, self.obs_sin, self.obs_causal_mask, @@ -283,7 +285,6 @@ def _all_observers(self): self.obs_k_cos, self.obs_k_sin, self.obs_k_rot, - self.obs_logits_raw, self.obs_logits, self.obs_mask_add, self.obs_softmax,