From c4757167c5dba78c6d8413d9ed5dc78bd5b8a6a2 Mon Sep 17 00:00:00 2001 From: LRL2-ModelCloud Date: Wed, 15 Oct 2025 18:38:30 +0800 Subject: [PATCH 1/2] fix attn mask --- gptqmodel/utils/attn_mask.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/gptqmodel/utils/attn_mask.py b/gptqmodel/utils/attn_mask.py index 5da0572f5..308cf829a 100644 --- a/gptqmodel/utils/attn_mask.py +++ b/gptqmodel/utils/attn_mask.py @@ -22,9 +22,29 @@ def normalize_seq_mask(mask: torch.Tensor | None, seq_len: int | None = None) -> return None m = mask - # Convert numeric to bool 'keep' (HF tends to use >0 for keep; extended masks use big negatives for masked) - if m.dtype != torch.bool: - m = (m > 0) + if m.dtype == torch.bool: + keep = m + else: + if m.is_floating_point(): + z = torch.isclose(m, torch.zeros((), device=m.device, dtype=m.dtype)) + o = torch.isclose(m, torch.ones((), device=m.device, dtype=m.dtype)) + is_binary = torch.all(z | o) + else: + is_binary = torch.all((m == 0) | (m == 1)) + + if not is_binary: + nonneg = torch.all(m >= 0) + has_zero = torch.any(m == 0) + if nonneg and has_zero: + maxv = torch.amax(m) + is_scaled_binary = torch.all((m == 0) | (m == maxv)) + else: + is_scaled_binary = False + else: + is_scaled_binary = False + + keep = (m > 0) if (is_binary or is_scaled_binary) else (m >= 0) + m = keep # Squeeze broadcast dims to reach [B, S] if m.dim() == 4 and m.size(1) == 1 and m.size(2) == 1: From 11775052d1c640aedffa7d6048867c62a53126c9 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 15 Oct 2025 12:04:43 +0000 Subject: [PATCH 2/2] skip expensive checks --- gptqmodel/utils/attn_mask.py | 37 ++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/gptqmodel/utils/attn_mask.py b/gptqmodel/utils/attn_mask.py index 308cf829a..009795c70 100644 --- a/gptqmodel/utils/attn_mask.py +++ b/gptqmodel/utils/attn_mask.py @@ -25,25 +25,34 @@ def normalize_seq_mask(mask: torch.Tensor | None, seq_len: int | None = None) -> if m.dtype == torch.bool: keep = m else: - if m.is_floating_point(): - z = torch.isclose(m, torch.zeros((), device=m.device, dtype=m.dtype)) - o = torch.isclose(m, torch.ones((), device=m.device, dtype=m.dtype)) - is_binary = torch.all(z | o) + has_negative = torch.any(m < 0) + if has_negative: + keep = m >= 0 else: - is_binary = torch.all((m == 0) | (m == 1)) + maxv = torch.amax(m) + if m.is_floating_point(): + minv = torch.amin(m) + approx_zero = torch.isclose(minv, torch.zeros((), device=m.device, dtype=m.dtype)) + approx_one = torch.isclose(maxv, torch.ones((), device=m.device, dtype=m.dtype)) + same_extreme = torch.isclose(maxv, minv) + has_mid = torch.any((m > minv) & (m < maxv)) + is_binary = approx_zero and (approx_one or same_extreme) and not has_mid + else: + minv = torch.amin(m) + outside = torch.any((m != 0) & (m != 1)) + is_binary = minv >= 0 and maxv <= 1 and not outside - if not is_binary: - nonneg = torch.all(m >= 0) - has_zero = torch.any(m == 0) - if nonneg and has_zero: - maxv = torch.amax(m) - is_scaled_binary = torch.all((m == 0) | (m == maxv)) + has_positive = torch.any(m > 0) + if has_positive and not is_binary: + scaled_mismatch = torch.any((m > 0) & (m != maxv)) + is_scaled_binary = not scaled_mismatch else: is_scaled_binary = False - else: - is_scaled_binary = False - keep = (m > 0) if (is_binary or is_scaled_binary) else (m >= 0) + if not has_positive: + keep = torch.zeros_like(m, dtype=torch.bool) + else: + keep = (m > 0) if (is_binary or is_scaled_binary) else (m >= 0) m = keep # Squeeze broadcast dims to reach [B, S]