From 4960d019d121564ee80f057c45b4e9b65889e1bf Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Thu, 6 Nov 2025 14:12:07 +0800 Subject: [PATCH 1/3] cleanup is_moe_down_block / is_moe_gate_up_block Signed-off-by: ZX-ModelCloud --- gptqmodel/models/base.py | 39 +++++++++------------------------------ 1 file changed, 9 insertions(+), 30 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 9fe7cf482..f8d7a417d 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -203,10 +203,11 @@ class BaseQModel(nn.Module): server = None - support_batch_quantize = True - support_offload_to_disk = True + moe_expert_module_name_prefixes = [".expert"] + moe_shared_expert_module_name_prefixes = [".shared_expert"] + ATTENTION_MASKS_DTYPE = torch.bool # default to bool ATTENTION_MASKS_REQUIRED_FOR_INPUT: bool = False @@ -1030,32 +1031,10 @@ def _try_update_last_module(candidate_name: str) -> bool: _try_update_last_module(candidate_name) continue - has_shared_expert = any("shared_expert" in n for n in block) - - # Determine if this block is a down_proj block: - # - If a shared_expert exists, the block will include an additional shared_expert.down_proj, - # so its length becomes num_experts + 1. - # - Otherwise, the length is num_experts. - # - Additionally, the block must contain at least one item whose name includes "down". - is_down_proj_block = ( - num_experts is not None - and len(block) == (num_experts + 1 if has_shared_expert else num_experts) - and any("down" in name for name in block) - ) - - # Determine if this block is a gate_up_proj block: - # - If a shared_expert exists, the block will include shared_expert.gate_proj and shared_expert.up_proj, - # so its length becomes 2 * num_experts + 2. - # - Otherwise, the length is 2 * num_experts. - # - The additional +1 accounts for an extra MLP layer appended to this block. - # - The block must contain at least one item with "gate" in its name and one with "up" in its name. - is_gate_up_proj_block = ( - num_experts is not None - and len(block) == (2 * num_experts + 2 if has_shared_expert else 2 * num_experts) + 1 - and any("gate" in name for name in block) - and any("up" in name for name in block) - ) - if is_down_proj_block and last_module is not None and last_module_name is not None: + is_moe_block = any(any(k in name for k in self.moe_expert_module_name_prefixes) for name in block) + is_moe_down_block = is_moe_block and any("down" in name for name in block) + is_moe_gate_up_block = is_moe_block and any("gate" in name for name in block) and any("up" in name for name in block) + if is_moe_down_block and last_module is not None and last_module_name is not None: # mlp.experts.0.down_proj target_suffix = last_module_name.split(".")[-1] for name in block: @@ -1118,7 +1097,7 @@ def _try_update_last_module(candidate_name: str) -> bool: module2inspect, _ = get_module_by_name_prefix(module, root) # process ['mlp.experts.#.gate_proj', 'mlp.experts.#.gup_proj'] - if is_gate_up_proj_block and module2inspect is not None: + if is_moe_gate_up_block and module2inspect is not None: if last_module_root not in input_feat: log.debug( "awq_get_modules_for_scaling: missing input feature for `%s` while processing experts block (layer block size=%s)", @@ -1140,7 +1119,7 @@ def _try_update_last_module(candidate_name: str) -> bool: nodes.append(n) # Update tracker to the LAST item of this block - if is_gate_up_proj_block: + if is_moe_gate_up_block: # The block content is [..., mlp.experts.{last_index}.up_proj, shared_expert.gate_proj, shared_expert.up_proj, mlp] # mlp.experts.{last_index}.up_proj should be selected as last_module last_up_proj_index = 2 * num_experts - 1 From 5eb86ef185ba77564e13d633763de8759e025575 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Thu, 6 Nov 2025 15:34:44 +0800 Subject: [PATCH 2/3] cleanup last_up_proj_index Signed-off-by: ZX-ModelCloud --- gptqmodel/models/base.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index f8d7a417d..dfa16b42b 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -206,7 +206,6 @@ class BaseQModel(nn.Module): support_offload_to_disk = True moe_expert_module_name_prefixes = [".expert"] - moe_shared_expert_module_name_prefixes = [".shared_expert"] ATTENTION_MASKS_DTYPE = torch.bool # default to bool @@ -1122,9 +1121,18 @@ def _try_update_last_module(candidate_name: str) -> bool: if is_moe_gate_up_block: # The block content is [..., mlp.experts.{last_index}.up_proj, shared_expert.gate_proj, shared_expert.up_proj, mlp] # mlp.experts.{last_index}.up_proj should be selected as last_module - last_up_proj_index = 2 * num_experts - 1 + # Find all indices that contain both ".experts" and "gate_proj"/"up_proj" + gate_up_proj_indices = [ + i for i, name in enumerate(block) + if any(k in name for k in self.moe_expert_module_name_prefixes) and ("gate" in name or "up" in name) + ] + + # Use the last one if any exist + assert len(gate_up_proj_indices) > 0, "No expert gate_proj/up_proj found in block." + last_up_proj_index = gate_up_proj_indices[-1] + candidate_name = strip_non_quantize_flags(block[last_up_proj_index]) - assert "up" in candidate_name + assert "gate" in candidate_name or "up" in candidate_name else: candidate_name = strip_non_quantize_flags(block[-1]) _try_update_last_module(candidate_name) From 74ce86f17323463629ef9af4e05f4a2d71c02a18 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Thu, 6 Nov 2025 16:55:33 +0800 Subject: [PATCH 3/3] Filtering MLP modules like Qwen3MoeSparseMoeBlock Signed-off-by: ZX-ModelCloud --- gptqmodel/looper/awq_processor.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index 6ac798a6b..f365dfdac 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -17,6 +17,7 @@ from ..looper.loop_processor import DTYPE_SIZE_COLUMN, MODULE_FEATURE_COLUMN, LoopProcessor from ..looper.named_module import NamedModule from ..models import BaseQModel +from ..models._const import SUPPORTS_MODULE_TYPES from ..models.writer import (PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, PROCESS_LOG_NAME, PROCESS_LOG_TIME, PROCESS_USED_MEMORY, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES) from ..nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear @@ -332,7 +333,12 @@ def _quantize_layer(self, layer_index: int, state: _AWQLayerState) -> None: return with state.lock: - named_childs = dict(state.modules) + # Filtering MLP modules like Qwen3MoeSparseMoeBlock + named_childs = { + name: module + for name, module in state.modules.items() + if isinstance(module, tuple(SUPPORTS_MODULE_TYPES)) + } module_kwargs_global = dict(self._module_forward_kwargs)