From 971ef7cb119405d9fb2844f19f2fa3ddcb6e81f8 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Mon, 3 Nov 2025 21:40:23 +0800 Subject: [PATCH 01/16] Fixed an issue in AWQ quantization that used the wrong input_feature["mlp"] tensor Signed-off-by: ZX-ModelCloud --- gptqmodel/looper/awq_processor.py | 16 ++++++++-------- gptqmodel/looper/named_module.py | 14 ++++++++------ gptqmodel/looper/stage_subset.py | 15 +++++++++++++++ 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index 2b858ca7a..958bbff48 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -202,13 +202,13 @@ def _layer_input_features(self, state: _AWQLayerState) -> Dict[str, torch.Tensor # tuple(features[name].shape), # ) - for root, tensors in root_buckets.items(): - if not tensors or root in features: - continue - try: - features[root] = torch.cat(tensors, dim=0) - except RuntimeError: - features[root] = tensors[0] + # for root, tensors in root_buckets.items(): + # if not tensors or root in features: + # continue + # try: + # features[root] = torch.cat(tensors, dim=0) + # except RuntimeError: + # features[root] = tensors[0] return features def _refresh_forward_kwargs_from_cache(self) -> None: @@ -1231,7 +1231,7 @@ def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tenso def hook(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): if not inp: return - feature = inp[0] + feature = inp if isinstance(feature, (tuple, list)) and feature: feature = feature[0] self._record_input_feature(name, feature) diff --git a/gptqmodel/looper/named_module.py b/gptqmodel/looper/named_module.py index db078c3ab..32d98b769 100644 --- a/gptqmodel/looper/named_module.py +++ b/gptqmodel/looper/named_module.py @@ -54,12 +54,14 @@ def __init__(self, module: torch.nn.Module, name: str, full_name:str, layer_inde in_features = module.weight.shape[0] out_features = module.weight.shape[1] else: - raise NotImplementedError(f"Unsupported module.module type: `{type(module)}`") - - self.state.update({ - "in_features": in_features, - "out_features": out_features, - }) + in_features = None + out_features = None + + if in_features and out_features: + self.state.update({ + "in_features": in_features, + "out_features": out_features, + }) def parameters(self, recurse: bool = True): return self.module.parameters(recurse=recurse) diff --git a/gptqmodel/looper/stage_subset.py b/gptqmodel/looper/stage_subset.py index 5738481fe..da11965a4 100644 --- a/gptqmodel/looper/stage_subset.py +++ b/gptqmodel/looper/stage_subset.py @@ -220,6 +220,21 @@ def run_subset_stage( if len(forward_row_counts) < batch_count: forward_row_counts.extend([1] * (batch_count - len(forward_row_counts))) + if is_awq_processor: + model_type = looper.gptq_model.model.config.model_type + if model_type == "mixtral": + subset['block_sparse_moe'] = NamedModule(module.mlp, name="block_sparse_moe", + full_name=f"model.layers.{layer_index}.block_sparse_moe", + layer_index=layer_index) + + if model_type == "deepseek_v2" or model_type == "deepseek_v3": + subset['mlp'] = NamedModule(module.mlp, name="mlp", full_name=f"model.layers.{layer_index}.mlp", + layer_index=layer_index) + + if model_type == "qwen2_moe" or model_type == "qwen3_moe": + subset['mlp'] = NamedModule(module.mlp, name="mlp", full_name=f"model.layers.{layer_index}.mlp", + layer_index=layer_index) + subset_size = len(subset) for idx, (name, m) in enumerate(subset.items()): # Register the forward hook that captures activations for quantization. From 800b74c285071e03563a897909d47d5790a35a2b Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Tue, 4 Nov 2025 12:01:28 +0800 Subject: [PATCH 02/16] process moe block Signed-off-by: ZX-ModelCloud --- gptqmodel/looper/stage_subset.py | 15 --------------- gptqmodel/models/base.py | 11 +++++++++-- gptqmodel/models/definitions/qwen2_moe.py | 16 ++++++++++++++++ 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/gptqmodel/looper/stage_subset.py b/gptqmodel/looper/stage_subset.py index da11965a4..5738481fe 100644 --- a/gptqmodel/looper/stage_subset.py +++ b/gptqmodel/looper/stage_subset.py @@ -220,21 +220,6 @@ def run_subset_stage( if len(forward_row_counts) < batch_count: forward_row_counts.extend([1] * (batch_count - len(forward_row_counts))) - if is_awq_processor: - model_type = looper.gptq_model.model.config.model_type - if model_type == "mixtral": - subset['block_sparse_moe'] = NamedModule(module.mlp, name="block_sparse_moe", - full_name=f"model.layers.{layer_index}.block_sparse_moe", - layer_index=layer_index) - - if model_type == "deepseek_v2" or model_type == "deepseek_v3": - subset['mlp'] = NamedModule(module.mlp, name="mlp", full_name=f"model.layers.{layer_index}.mlp", - layer_index=layer_index) - - if model_type == "qwen2_moe" or model_type == "qwen3_moe": - subset['mlp'] = NamedModule(module.mlp, name="mlp", full_name=f"model.layers.{layer_index}.mlp", - layer_index=layer_index) - subset_size = len(subset) for idx, (name, m) in enumerate(subset.items()): # Register the forward hook that captures activations for quantization. diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index b02d8f9d6..677e18379 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -1305,7 +1305,7 @@ def shell_module_materialize( # else: # log.info(f"{self.__class__.__name__}: `MODEL switching to eval mode.") @classmethod - def build_layer_modules(cls, tree): + def build_layer_modules(cls, tree, is_awq_quantize): """ tree format: [, , "#", { parent_module: ( "child[:!][:grp]", ... ), ... }] @@ -1329,10 +1329,17 @@ def build_layer_modules(cls, tree): out_blocks = [] - def process_entries(parent, entries, parent_group_offset=0): + def process_entries(parent_name, entries, parent_group_offset=0): """Process entries recursively to handle nested dict structures for MoE""" groups = defaultdict(list) + parent = parent_name.split(":", 1)[0] + if is_awq_quantize: + has_question = ('?' in parent_name) + # process moe block + if has_question: + out_blocks.append([parent]) + # Handle tuple/list of strings (traditional format) if isinstance(entries, (tuple, list)): for ent in entries: diff --git a/gptqmodel/models/definitions/qwen2_moe.py b/gptqmodel/models/definitions/qwen2_moe.py index 89446c5b3..0ecea7c18 100644 --- a/gptqmodel/models/definitions/qwen2_moe.py +++ b/gptqmodel/models/definitions/qwen2_moe.py @@ -4,6 +4,7 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from ..base import BaseQModel +from ...quantization import METHOD class Qwen2MoeQModel(BaseQModel): @@ -30,3 +31,18 @@ class Qwen2MoeQModel(BaseQModel): }, } ] + + module_tree_overrides = { + METHOD.AWQ: [ + { + "mlp:?": { + "gate": ("gate:!",), + "shared_expert": None, + "experts": { + "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), + }, + }, + } + ] + } + From 479f5a44807f4499dbaae540b026449d6a830c3f Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Tue, 4 Nov 2025 13:46:37 +0800 Subject: [PATCH 03/16] fix merge error Signed-off-by: ZX-ModelCloud --- gptqmodel/models/base.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 417939d4f..6f38d8b1a 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -392,12 +392,12 @@ def filter_not_quantize_module(cls, layer_modules, quantize_config): @classmethod def simple_layer_modules(cls, model_config, quantize_config, is_awq_quantize: bool = False, include_capture_only: bool = False): layer_modules = cls.build_layer_modules(cls.module_tree, include_capture_only=include_capture_only) - + print(f"simple_layer_modules build_layer_modules: {layer_modules}") layer_modules = cls.build_moe_modules_if_need(model_config, layer_modules, is_awq_quantize) - + print(f"simple_layer_modules build_moe_modules_if_need: {layer_modules}") layer_modules = cls.filter_not_quantize_module(layer_modules, quantize_config) - # print(f"simple_layer_modules layer_modules: {layer_modules}") + print(f"simple_layer_modules layer_modules: {layer_modules}") return layer_modules @classmethod @@ -1367,13 +1367,6 @@ def process_entries(parent_token: str, entries, parent_group_offset: int = 0): groups[parent_group].append((parent_name, parent_has_bang, parent_capture_only)) child_group_offset = max(child_group_offset, parent_group + 1) - parent = parent_name.split(":", 1)[0] - if is_awq_quantize: - has_question = ('?' in parent_name) - # process moe block - if has_question: - out_blocks.append([parent]) - # Handle tuple/list of strings (traditional format) if isinstance(entries, (tuple, list)): for ent in entries: From 29333e8085c6e156ce6461a155e7483dfeff9efb Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Tue, 4 Nov 2025 15:26:30 +0800 Subject: [PATCH 04/16] Obtain the CAPTURE_ONLY_FLAG Module Signed-off-by: ZX-ModelCloud --- gptqmodel/looper/module_looper.py | 7 ++++++- gptqmodel/looper/stage_subset.py | 1 + gptqmodel/models/base.py | 6 +----- gptqmodel/models/definitions/qwen2_moe.py | 2 +- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 69f061b97..1a8cab28d 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -33,6 +33,7 @@ from ..looper.named_module import NamedModule from ..models import BaseQModel from ..models._const import SUPPORTS_MODULE_TYPES +from ..models.base import CAPTURE_ONLY_FLAG from ..nn_modules.hooked_linear import (STOP_FORWARD_EXCEPTION, HookedLinear, StopForward, replace_module_with_hooked_legacy) from ..quantization.config import VRAMStrategy @@ -1251,11 +1252,15 @@ def _loop_impl(self, fail_safe: bool = False, **kwargs): return total_log - def crate_named_modules(self, full, is_lm_head_module, layer_index, layers_prefix, names, processor, fail_safe, layer_module=None) -> Dict[str, NamedModule]: + def crate_named_modules(self, module, full, is_lm_head_module, layer_index, layers_prefix, names, processor, fail_safe, layer_module=None) -> Dict[str, NamedModule]: subset = {} for n in names: if n in full: subset[n] = full[n] + elif n.endswith(CAPTURE_ONLY_FLAG): + # Obtain the CAPTURE_ONLY_FLAG Module separately + n = n.split(CAPTURE_ONLY_FLAG, 1)[0] + subset[n], _ = get_module_by_name_prefix(module, module_name=n) # some modules have layer_modules that are dynamic based on config # ref: deepseek v2/v3/r1 elif self.gptq_model.layer_modules_strict: diff --git a/gptqmodel/looper/stage_subset.py b/gptqmodel/looper/stage_subset.py index fe6bbe83b..2cf65345b 100644 --- a/gptqmodel/looper/stage_subset.py +++ b/gptqmodel/looper/stage_subset.py @@ -79,6 +79,7 @@ def run_subset_stage( is_awq_processor = processor_name_lower.startswith("awq") subset = looper.crate_named_modules( + module=module, full=full, is_lm_head_module=is_lm_head_module, layer_index=layer_index, diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 6f38d8b1a..876e88466 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -364,11 +364,7 @@ def get_num_experts(cls, model_config): @classmethod def filter_not_quantize_module(cls, layer_modules, quantize_config): layer_modules = [ - [ - name - for name in block - if all(flag not in name for flag in NON_QUANTIZE_FLAGS) - ] + [name for name in block if NOT_QUANTIZE_FLAG not in name] for block in layer_modules ] layer_modules = [block for block in layer_modules if block] # 去掉空 block diff --git a/gptqmodel/models/definitions/qwen2_moe.py b/gptqmodel/models/definitions/qwen2_moe.py index 0ecea7c18..cdcce4a1c 100644 --- a/gptqmodel/models/definitions/qwen2_moe.py +++ b/gptqmodel/models/definitions/qwen2_moe.py @@ -22,7 +22,7 @@ class Qwen2MoeQModel(BaseQModel): "input_layernorm": ("input_layernorm:!",), "self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"), "post_attention_layernorm": ("post_attention_layernorm:!",), - "mlp": { + "mlp:?": { "gate": ("gate:0",), "shared_expert": ("gate_proj:0", "up_proj:0", "down_proj:1"), "experts": { From 1f6eaad22972464df44e435d42e0e5260c84e99d Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Tue, 4 Nov 2025 15:39:33 +0800 Subject: [PATCH 05/16] Add "mlp" to the subset ["mlp.experts.#.gate_proj", "mlp.experts.#.up_proj"] Signed-off-by: ZX-ModelCloud --- gptqmodel/models/base.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 876e88466..8b7399d15 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -326,10 +326,15 @@ def build_moe_modules_if_need(cls, model_config, layer_modules, is_awq_quantize: num_experts = cls.get_num_experts(model_config) moe_simple = [] + capture_only_modules = [] for names in layer_modules: moe_simple.append([]) has_expert = all(EXPERT_INDEX_PLACEHOLDER in n for n in names) + has_capture_only = all(CAPTURE_ONLY_FLAG in n for n in names) + if has_capture_only: + capture_only_modules.append(names) + continue if not has_expert: moe_simple[-1].extend(names) @@ -341,6 +346,9 @@ def build_moe_modules_if_need(cls, model_config, layer_modules, is_awq_quantize: for index in range(num_experts): for n in names: moe_simple[-1].append(n.replace(EXPERT_INDEX_PLACEHOLDER, str(index))) + if capture_only_modules: + # Extend all elements in capture_only_modules + moe_simple[-1].extend(sum(capture_only_modules, [])) else: # result like: ['mlp.experts.0.gate_proj', 'mlp.experts.1.gate_proj', 'mlp.experts.0.up_proj', 'mlp.experts.1.up_proj', ...] for n in names: From c2074a2a77ab9d86e256a5e3d12863bcdc7bc3a1 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Tue, 4 Nov 2025 16:54:22 +0800 Subject: [PATCH 06/16] NamedModule override register_forward_hook() Signed-off-by: ZX-ModelCloud --- gptqmodel/looper/named_module.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/gptqmodel/looper/named_module.py b/gptqmodel/looper/named_module.py index 32d98b769..42a74fd40 100644 --- a/gptqmodel/looper/named_module.py +++ b/gptqmodel/looper/named_module.py @@ -75,6 +75,12 @@ def buffers(self, recurse: bool = True): def named_buffers(self, prefix: str = "", recurse: bool = True): return self.module.named_buffers(prefix=prefix, recurse=recurse) + def register_forward_hook( + self, *args, **kwargs + ): + with self._parent_lock: + return self.module.register_forward_hook(*args, **kwargs) + def register_buffer( self, name: str, tensor: Optional[Tensor], persistent: bool = True ) -> None: From 5d5ccebf89d2f7003627ef9601b72caf19b17b19 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Tue, 4 Nov 2025 17:33:20 +0800 Subject: [PATCH 07/16] cleanup Signed-off-by: ZX-ModelCloud --- gptqmodel/models/base.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 8b7399d15..e63e37406 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -326,14 +326,14 @@ def build_moe_modules_if_need(cls, model_config, layer_modules, is_awq_quantize: num_experts = cls.get_num_experts(model_config) moe_simple = [] - capture_only_modules = [] + capture_only_modules = None for names in layer_modules: moe_simple.append([]) has_expert = all(EXPERT_INDEX_PLACEHOLDER in n for n in names) has_capture_only = all(CAPTURE_ONLY_FLAG in n for n in names) if has_capture_only: - capture_only_modules.append(names) + capture_only_modules = names continue if not has_expert: @@ -346,9 +346,10 @@ def build_moe_modules_if_need(cls, model_config, layer_modules, is_awq_quantize: for index in range(num_experts): for n in names: moe_simple[-1].append(n.replace(EXPERT_INDEX_PLACEHOLDER, str(index))) - if capture_only_modules: + # Currently, only need to add `capture_only_modules` to `[mlp.experts.0.gate_proj', 'mlp.experts.0.up_proj'...]`. + if len(names) == 2 and capture_only_modules: # Extend all elements in capture_only_modules - moe_simple[-1].extend(sum(capture_only_modules, [])) + moe_simple[-1].extend(capture_only_modules) else: # result like: ['mlp.experts.0.gate_proj', 'mlp.experts.1.gate_proj', 'mlp.experts.0.up_proj', 'mlp.experts.1.up_proj', ...] for n in names: @@ -1083,7 +1084,8 @@ def _try_update_last_module(candidate_name: str) -> bool: last_module_root = root module2inspect, _ = get_module_by_name_prefix(module, root) - if num_experts is not None and len(block) == 2 * num_experts and module2inspect is not None: + # process ['mlp.experts.#.gate_proj', 'mlp.experts.#.gup_proj'] + if num_experts is not None and len(block) == 2 * num_experts + 1 and module2inspect is not None: if last_module_root not in input_feat: log.debug( "awq_get_modules_for_scaling: missing input feature for `%s` while processing experts block (layer block size=%s)", From a42dfdadad4abbe807181b1ca310a316053e4197 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Tue, 4 Nov 2025 17:34:08 +0800 Subject: [PATCH 08/16] cleanup Signed-off-by: ZX-ModelCloud --- gptqmodel/models/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index e63e37406..a7ea718fc 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -397,12 +397,12 @@ def filter_not_quantize_module(cls, layer_modules, quantize_config): @classmethod def simple_layer_modules(cls, model_config, quantize_config, is_awq_quantize: bool = False, include_capture_only: bool = False): layer_modules = cls.build_layer_modules(cls.module_tree, include_capture_only=include_capture_only) - print(f"simple_layer_modules build_layer_modules: {layer_modules}") + layer_modules = cls.build_moe_modules_if_need(model_config, layer_modules, is_awq_quantize) - print(f"simple_layer_modules build_moe_modules_if_need: {layer_modules}") + layer_modules = cls.filter_not_quantize_module(layer_modules, quantize_config) - print(f"simple_layer_modules layer_modules: {layer_modules}") + # print(f"simple_layer_modules layer_modules: {layer_modules}") return layer_modules @classmethod From fe0997be0190243f412653f1dc6e880bdc0d564b Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 4 Nov 2025 12:16:58 +0000 Subject: [PATCH 09/16] remove custom override --- gptqmodel/models/definitions/qwen2_moe.py | 26 +++++++++++------------ 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/gptqmodel/models/definitions/qwen2_moe.py b/gptqmodel/models/definitions/qwen2_moe.py index cdcce4a1c..96814b032 100644 --- a/gptqmodel/models/definitions/qwen2_moe.py +++ b/gptqmodel/models/definitions/qwen2_moe.py @@ -32,17 +32,17 @@ class Qwen2MoeQModel(BaseQModel): } ] - module_tree_overrides = { - METHOD.AWQ: [ - { - "mlp:?": { - "gate": ("gate:!",), - "shared_expert": None, - "experts": { - "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), - }, - }, - } - ] - } + # module_tree_overrides = { + # METHOD.AWQ: [ + # { + # "mlp:?": { + # "gate": ("gate:!",), + # "shared_expert": None, + # "experts": { + # "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), + # }, + # }, + # } + # ] + # } From 98d0176c26ea5f62a1228b9b252c6d5e457ab902 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 4 Nov 2025 14:22:14 +0000 Subject: [PATCH 10/16] new parent node subset merging --- gptqmodel/models/base.py | 135 +++++++++++++++++----- gptqmodel/models/definitions/qwen2_moe.py | 5 +- tests/test_subset_parsing.py | 24 +++- 3 files changed, 133 insertions(+), 31 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index a7ea718fc..d0d60fa11 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -11,6 +11,7 @@ import threading import time from collections import defaultdict +from itertools import count from contextlib import nullcontext from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, Union @@ -333,7 +334,8 @@ def build_moe_modules_if_need(cls, model_config, layer_modules, is_awq_quantize: has_expert = all(EXPERT_INDEX_PLACEHOLDER in n for n in names) has_capture_only = all(CAPTURE_ONLY_FLAG in n for n in names) if has_capture_only: - capture_only_modules = names + capture_only_modules = list(names) + moe_simple[-1].extend(capture_only_modules) continue if not has_expert: @@ -1345,6 +1347,10 @@ def build_layer_modules(cls, tree, include_capture_only: bool = False): raise ValueError("Mapping configuration not found in the tree.") out_blocks = [] + alias_groups: Dict[tuple[str | None, int], List[tuple[str, bool, bool]]] = {} + alias_meta: Dict[tuple[str | None, int], Dict[str, int]] = {} + alias_seq = count() + group_seq = count() def _parse_token(token: str) -> tuple[str, List[str]]: parts = token.split(":") @@ -1358,19 +1364,46 @@ def _group_from_flags(flags: List[str]) -> int: return int(flag) return 0 - def process_entries(parent_token: str, entries, parent_group_offset: int = 0): + def _has_numeric_flag(flags: List[str]) -> bool: + return any(flag.isdigit() for flag in flags) + + def _get_scope(parent_name: str) -> str | None: + if not parent_name: + return None + return parent_name.split(".", 1)[0] + + def process_entries(parent_token: str, entries, parent_group_offset: int = 0, scope_key: str | None = None): """Process entries recursively to handle nested dict structures for MoE""" - groups: defaultdict[int, List[tuple[str, bool, bool]]] = defaultdict(list) + groups: defaultdict[int, List[tuple]] = defaultdict(list) parent_name, parent_flags = _parse_token(parent_token) - parent_group = parent_group_offset + _group_from_flags(parent_flags) + parent_rel_group = _group_from_flags(parent_flags) + parent_group = parent_group_offset + parent_rel_group parent_has_bang = "!" in parent_flags parent_capture_only = "?" in parent_flags + parent_has_numeric = _has_numeric_flag(parent_flags) + + scope = scope_key if scope_key is not None else _get_scope(parent_name) + parent_alias_scope = scope if parent_has_numeric else parent_name + + def _make_entry(full_path: str, has_bang: bool, capture_only: bool, *, alias_base: int, alias_rel: int, alias_scope: str | None) -> tuple: + return (full_path, has_bang, capture_only, alias_scope, (alias_base, alias_rel)) child_group_offset = parent_group_offset add_parent = parent_has_bang or (parent_capture_only and include_capture_only) if add_parent: - groups[parent_group].append((parent_name, parent_has_bang, parent_capture_only)) + alias_base = parent_rel_group if parent_has_numeric else parent_group + parent_entry_scope = f"{parent_alias_scope}.__parent__" if parent_alias_scope is not None else None + groups[parent_group].append( + _make_entry( + parent_name, + parent_has_bang, + parent_capture_only, + alias_base=alias_base, + alias_rel=0, + alias_scope=parent_entry_scope, + ) + ) child_group_offset = max(child_group_offset, parent_group + 1) # Handle tuple/list of strings (traditional format) @@ -1381,7 +1414,8 @@ def process_entries(parent_token: str, entries, parent_group_offset: int = 0): has_bang = "!" in child_flags capture_only = "?" in child_flags # first numeric tag is the group id; default 0 - grp = child_group_offset + _group_from_flags(child_flags) + child_rel_group = _group_from_flags(child_flags) + grp = child_group_offset + child_rel_group # Apply parent group offset to avoid conflicts between different nesting levels # Store the full path including parent for later use if parent_name.endswith(f".{child_name}") or parent_name == child_name: @@ -1393,7 +1427,19 @@ def process_entries(parent_token: str, entries, parent_group_offset: int = 0): if capture_only and not include_capture_only: continue - groups[grp].append((full_path, has_bang, capture_only)) + alias_scope = scope if parent_has_numeric else parent_name + alias_base = parent_rel_group if parent_has_numeric else grp + alias_rel = child_rel_group if parent_has_numeric else 0 + groups[grp].append( + _make_entry( + full_path, + has_bang, + capture_only, + alias_base=alias_base, + alias_rel=alias_rel, + alias_scope=alias_scope, + ) + ) elif isinstance(entries, dict): # Calculate max group number used at current level to avoid conflicts @@ -1414,15 +1460,31 @@ def process_entries(parent_token: str, entries, parent_group_offset: int = 0): f"{parent_name}.{EXPERT_INDEX_PLACEHOLDER}" if parent_name else EXPERT_INDEX_PLACEHOLDER ) + template_parent_token = ( + f"{template_parent}:{parent_rel_group}" + if parent_has_numeric + else template_parent + ) # Use a higher offset for expert modules to avoid conflicts with parent level expert_offset = current_offset + max_current_group + 100 # Large offset to avoid conflicts # Handle special case where sub_entries is ("#",) or "#" - this means use the parent path directly if (isinstance(sub_entries, (tuple, list)) and len(sub_entries) == 1 and sub_entries[0] == "#") or sub_entries == "#": # For ("#",) or "#" format, use the template_parent directly with default group 0 - groups[expert_offset].append((template_parent, False, False)) + alias_scope = scope if parent_has_numeric else template_parent + alias_base = parent_rel_group if parent_has_numeric else expert_offset + groups[expert_offset].append( + _make_entry( + template_parent, + False, + False, + alias_base=alias_base, + alias_rel=0, + alias_scope=alias_scope, + ) + ) else: - sub_groups = process_entries(template_parent, sub_entries, expert_offset) + sub_groups = process_entries(template_parent_token, sub_entries, expert_offset, scope) for grp, items in sub_groups.items(): groups[grp].extend(items) else: @@ -1435,7 +1497,7 @@ def process_entries(parent_token: str, entries, parent_group_offset: int = 0): f"{parent_name}.{sub_parent}" if parent_name else sub_parent ) - sub_groups = process_entries(full_sub_parent, sub_entries, current_offset) + sub_groups = process_entries(full_sub_parent, sub_entries, current_offset, scope) for grp, items in sub_groups.items(): groups[grp].extend(items) # Update offset for next sibling to avoid conflicts @@ -1444,28 +1506,47 @@ def process_entries(parent_token: str, entries, parent_group_offset: int = 0): return groups + def _register_alias(order_idx: int, item: tuple[str, bool, bool, str | None, tuple[int, int]]): + full_path, has_bang, capture_only, scope, alias_parts = item + if capture_only and not include_capture_only: + return + alias_scope = scope + alias_base, alias_rel = alias_parts + alias_index = alias_base + alias_rel + key = (alias_scope, alias_index) + meta = alias_meta.get(key) + if meta is None: + alias_meta[key] = {"order": order_idx, "seq": next(alias_seq)} + alias_groups[key] = [(full_path, has_bang, capture_only)] + else: + meta["order"] = min(meta["order"], order_idx) + alias_groups[key].append((full_path, has_bang, capture_only)) + for parent, entries in mapping.items(): groups = process_entries(parent, entries) - # Emit per-group, skipping pure-:! blocks (norm-only), but - # preserving :! markers on mixed blocks if they ever occur. for g in sorted(groups): + order_idx = next(group_seq) items = groups[g] - # if every entry is :!, skip this block (matches your expected output) - # if all(has_bang for _, has_bang in items): - # continue - - block = [] - for full_path, has_bang, capture_only in items: - # The full path is already constructed in process_entries - name = full_path - if has_bang: - name += NOT_QUANTIZE_FLAG - if capture_only and include_capture_only: - name += CAPTURE_ONLY_FLAG - block.append(name) - - out_blocks.append(block) + for item in items: + if len(item) == 3: + full_path, has_bang, capture_only = item + scope = full_path + alias_parts = (g, 0) + _register_alias(order_idx, (full_path, has_bang, capture_only, scope, alias_parts)) + else: + _register_alias(order_idx, item) + + for key in sorted(alias_groups.keys(), key=lambda k: (alias_meta[k]["order"], alias_meta[k]["seq"])): + block = [] + for full_path, has_bang, capture_only in alias_groups[key]: + name = full_path + if has_bang: + name += NOT_QUANTIZE_FLAG + if capture_only and include_capture_only: + name += CAPTURE_ONLY_FLAG + block.append(name) + out_blocks.append(block) return out_blocks diff --git a/gptqmodel/models/definitions/qwen2_moe.py b/gptqmodel/models/definitions/qwen2_moe.py index 96814b032..b65a2fc79 100644 --- a/gptqmodel/models/definitions/qwen2_moe.py +++ b/gptqmodel/models/definitions/qwen2_moe.py @@ -24,8 +24,8 @@ class Qwen2MoeQModel(BaseQModel): "post_attention_layernorm": ("post_attention_layernorm:!",), "mlp:?": { "gate": ("gate:0",), - "shared_expert": ("gate_proj:0", "up_proj:0", "down_proj:1"), - "experts": { + "shared_expert:0": ("gate_proj:0", "up_proj:0", "down_proj:1"), + "experts:0": { "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), }, }, @@ -45,4 +45,3 @@ class Qwen2MoeQModel(BaseQModel): # } # ] # } - diff --git a/tests/test_subset_parsing.py b/tests/test_subset_parsing.py index 614547e62..e9818e343 100644 --- a/tests/test_subset_parsing.py +++ b/tests/test_subset_parsing.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium import os +import sys +from pathlib import Path from types import SimpleNamespace from typing import Callable, Dict, List, Optional @@ -10,6 +12,11 @@ from transformers import Qwen3MoeConfig, Qwen3MoeForCausalLM from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock +repo_root = Path(__file__).resolve().parents[1] +repo_str = str(repo_root) +if repo_str not in sys.path: + sys.path.insert(0, repo_str) + from gptqmodel.looper.awq_processor import AWQProcessor from gptqmodel.looper.loop_processor import LoopProcessor from gptqmodel.looper.module_looper import ModuleLooper @@ -19,6 +26,7 @@ from gptqmodel.quantization.config import QuantizeConfig, VRAMStrategy from gptqmodel.nn_modules.hooked_linear import replace_module_with_hooked_legacy from gptqmodel.utils.model import find_modules, get_module_by_name_prefix +from gptqmodel.models.definitions.qwen2_moe import Qwen2MoeQModel from gptqmodel.models.definitions.qwen3_moe import Qwen3MoeQModel @@ -71,7 +79,7 @@ def test_mlp_capture_flag_propagates_to_layer_modules(): include_capture_only=True, ) capture_blocks = [block for block in full if any(":?" in name for name in block)] - assert capture_blocks and capture_blocks[0] == ["mlp:?"] + assert capture_blocks and "mlp:?" in capture_blocks[0] simple = Qwen3MoeQModel.simple_layer_modules( model_config=model_config, @@ -85,6 +93,20 @@ def test_mlp_capture_flag_propagates_to_layer_modules(): assert isinstance(mlp_module, Qwen3MoeSparseMoeBlock) +def test_qwen2_moe_shared_expert_merges_with_experts(): + blocks = Qwen2MoeQModel.build_layer_modules(Qwen2MoeQModel.module_tree) + + gate_block = next(block for block in blocks if "mlp.shared_expert.gate_proj" in block) + assert "mlp.experts.{expert_index}.gate_proj" in gate_block + assert "mlp.experts.{expert_index}.up_proj" in gate_block + + down_block = next(block for block in blocks if "mlp.shared_expert.down_proj" in block) + assert "mlp.experts.{expert_index}.down_proj" in down_block + + expert_gate_blocks = [block for block in blocks if "mlp.experts.{expert_index}.gate_proj" in block] + assert len(expert_gate_blocks) == 1 + + def test_awq_processor_enables_subset_early_stop(): calibration = [{"input_ids": torch.tensor([1, 2, 3])}] qcfg = _make_quant_config() From ba6b80d993866ae25fa7c6d2d1cf7b1d441d26a7 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 4 Nov 2025 15:49:33 +0000 Subject: [PATCH 11/16] do not quantize gate --- gptqmodel/models/definitions/qwen2_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptqmodel/models/definitions/qwen2_moe.py b/gptqmodel/models/definitions/qwen2_moe.py index b65a2fc79..2097f2712 100644 --- a/gptqmodel/models/definitions/qwen2_moe.py +++ b/gptqmodel/models/definitions/qwen2_moe.py @@ -23,7 +23,7 @@ class Qwen2MoeQModel(BaseQModel): "self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"), "post_attention_layernorm": ("post_attention_layernorm:!",), "mlp:?": { - "gate": ("gate:0",), + "gate": ("gate:!",), "shared_expert:0": ("gate_proj:0", "up_proj:0", "down_proj:1"), "experts:0": { "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), From 440e29e36dbb1704a24ffdb803f7f22e5608e967 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 4 Nov 2025 15:51:36 +0000 Subject: [PATCH 12/16] cleanup --- tests/models/test_qwen2_moe_quant.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/models/test_qwen2_moe_quant.py b/tests/models/test_qwen2_moe_quant.py index 316594e6d..f614b93fd 100644 --- a/tests/models/test_qwen2_moe_quant.py +++ b/tests/models/test_qwen2_moe_quant.py @@ -17,8 +17,6 @@ class TestQwen2_5_Moe(ModelTest): "acc_norm": {"value": 0.3055, "floor_pct": 0.2}, }, } - TRUST_REMOTE_CODE = False - EVAL_BATCH_SIZE = 6 def test_qwen2_5(self): self.quant_lm_eval() From f2d1652ba750f3a71bd3223d826c4f9ce3517b92 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 4 Nov 2025 15:51:50 +0000 Subject: [PATCH 13/16] ruff --- gptqmodel/models/base.py | 2 +- gptqmodel/models/definitions/qwen2_moe.py | 1 - gptqmodel/quantization/gptq.py | 4 ++-- tests/models/model_test.py | 2 +- tests/models/test_glm4_moe._awq.py | 4 +--- tests/test_awq.py | 5 ++--- tests/test_awq_weight_mean.py | 5 +++-- tests/test_subset_parsing.py | 25 ++++++++++++----------- 8 files changed, 23 insertions(+), 25 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index d0d60fa11..f69ced010 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -11,8 +11,8 @@ import threading import time from collections import defaultdict -from itertools import count from contextlib import nullcontext +from itertools import count from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, Union import torch diff --git a/gptqmodel/models/definitions/qwen2_moe.py b/gptqmodel/models/definitions/qwen2_moe.py index 2097f2712..f923d6b69 100644 --- a/gptqmodel/models/definitions/qwen2_moe.py +++ b/gptqmodel/models/definitions/qwen2_moe.py @@ -4,7 +4,6 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from ..base import BaseQModel -from ...quantization import METHOD class Qwen2MoeQModel(BaseQModel): diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index fe4a00301..44bcab9a9 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -571,8 +571,8 @@ def materialize_global_hessian(self, target_device: Optional[torch.device] = Non for partial_device, partial in self._device_hessian_partials.items(): if partial.device != result_accum.device or partial.dtype != torch.float32: # TODO FIXME multi-3090 using P2P is revaling an issue where result_accum and/or partial is not ready for consolidation on the main thread - # when parials are calculated on the individual - try: + # when parials are calculated on the individual + try: result_accum.add_(partial.to(device=result_accum.device, dtype=torch.float32)) except: log.warn(f"Quantization: Module `{self.name}` -> Retry partial.to in 0.25s") diff --git a/tests/models/model_test.py b/tests/models/model_test.py index 2048cc4ac..501585335 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -58,6 +58,7 @@ def is_flash_attn_2_available(): # type: ignore return False from gptqmodel import BACKEND, DEBUG_ON, GPTQModel # noqa: E402 +from gptqmodel.looper.module_looper import StopMainLoop # noqa: E402 from gptqmodel.models.base import BaseQModel # noqa: E402 from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402 from gptqmodel.quantization import FORMAT, METHOD # noqa: E402 @@ -65,7 +66,6 @@ def is_flash_attn_2_available(): # type: ignore from gptqmodel.utils.eval import EVAL # noqa: E402 from gptqmodel.utils.model import MODALITY # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 -from gptqmodel.looper.module_looper import StopMainLoop # noqa: E402 RAND_SEED = 898 diff --git a/tests/models/test_glm4_moe._awq.py b/tests/models/test_glm4_moe._awq.py index 57991cf46..ad7730bdb 100644 --- a/tests/models/test_glm4_moe._awq.py +++ b/tests/models/test_glm4_moe._awq.py @@ -4,10 +4,8 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from model_test import ModelTest -from gptqmodel.quantization.config import VRAMStrategy -from gptqmodel.utils.eval import EVAL from gptqmodel.quantization import FORMAT, METHOD - +from gptqmodel.utils.eval import EVAL # | Metric | MARLIN | diff --git a/tests/test_awq.py b/tests/test_awq.py index 6230a3de5..e1ec41e65 100644 --- a/tests/test_awq.py +++ b/tests/test_awq.py @@ -47,15 +47,14 @@ def setUpClass(cls): if requested_samples is not None: sample_count = max(1, int(requested_samples)) else: - total_mem_gb = 0 if torch.cuda.is_available(): try: - total_mem_gb = ( + ( torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory / (1024 ** 3) ) except Exception: - total_mem_gb = 0 + pass # if total_mem_gb >= 80: # sample_count = 1024 diff --git a/tests/test_awq_weight_mean.py b/tests/test_awq_weight_mean.py index c8b541750..10dbd8257 100644 --- a/tests/test_awq_weight_mean.py +++ b/tests/test_awq_weight_mean.py @@ -1,12 +1,13 @@ import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7" #"expandable_segments:True" import time -import torch -import pytest +import pytest +import torch from parameterized import parameterized from pytest import MonkeyPatch from torch import nn diff --git a/tests/test_subset_parsing.py b/tests/test_subset_parsing.py index e9818e343..aec349436 100644 --- a/tests/test_subset_parsing.py +++ b/tests/test_subset_parsing.py @@ -12,6 +12,7 @@ from transformers import Qwen3MoeConfig, Qwen3MoeForCausalLM from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock + repo_root = Path(__file__).resolve().parents[1] repo_str = str(repo_root) if repo_str not in sys.path: @@ -20,14 +21,14 @@ from gptqmodel.looper.awq_processor import AWQProcessor from gptqmodel.looper.loop_processor import LoopProcessor from gptqmodel.looper.module_looper import ModuleLooper -from gptqmodel.looper.stage_subset import run_subset_stage from gptqmodel.looper.named_module import NamedModule +from gptqmodel.looper.stage_subset import run_subset_stage +from gptqmodel.models.definitions.qwen2_moe import Qwen2MoeQModel +from gptqmodel.models.definitions.qwen3_moe import Qwen3MoeQModel +from gptqmodel.nn_modules.hooked_linear import replace_module_with_hooked_legacy from gptqmodel.quantization import FORMAT, METHOD from gptqmodel.quantization.config import QuantizeConfig, VRAMStrategy -from gptqmodel.nn_modules.hooked_linear import replace_module_with_hooked_legacy from gptqmodel.utils.model import find_modules, get_module_by_name_prefix -from gptqmodel.models.definitions.qwen2_moe import Qwen2MoeQModel -from gptqmodel.models.definitions.qwen3_moe import Qwen3MoeQModel # honour the request to bind the test harness to GPU index 5 when CUDA is available @@ -187,14 +188,14 @@ def __call__( processor: str, ): self.events.append( - dict( - stage=stage, - layer_idx=layer_idx, - subset_index=subset_index, - subset_total=subset_total, - module_names=module_names, - processor=processor, - ) + { + "stage": stage, + "layer_idx": layer_idx, + "subset_index": subset_index, + "subset_total": subset_total, + "module_names": module_names, + "processor": processor, + } ) From 1af314040254df4ed942febf953066570b4059f3 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Wed, 5 Nov 2025 10:35:33 +0800 Subject: [PATCH 14/16] fix build_moe_modules_if_need() Signed-off-by: ZX-ModelCloud --- gptqmodel/models/base.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index f69ced010..f4c87dff0 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -331,11 +331,10 @@ def build_moe_modules_if_need(cls, model_config, layer_modules, is_awq_quantize: for names in layer_modules: moe_simple.append([]) - has_expert = all(EXPERT_INDEX_PLACEHOLDER in n for n in names) + has_expert = any(EXPERT_INDEX_PLACEHOLDER in n for n in names) has_capture_only = all(CAPTURE_ONLY_FLAG in n for n in names) if has_capture_only: capture_only_modules = list(names) - moe_simple[-1].extend(capture_only_modules) continue if not has_expert: @@ -347,9 +346,16 @@ def build_moe_modules_if_need(cls, model_config, layer_modules, is_awq_quantize: # result like: ['mlp.experts.0.gate_proj', 'mlp.experts.0.up_proj', 'mlp.experts.1.gate_proj', 'mlp.experts.1.up_proj', ...] for index in range(num_experts): for n in names: - moe_simple[-1].append(n.replace(EXPERT_INDEX_PLACEHOLDER, str(index))) - # Currently, only need to add `capture_only_modules` to `[mlp.experts.0.gate_proj', 'mlp.experts.0.up_proj'...]`. - if len(names) == 2 and capture_only_modules: + if EXPERT_INDEX_PLACEHOLDER in n: + moe_simple[-1].append(n.replace(EXPERT_INDEX_PLACEHOLDER, str(index))) + # added 'mlp.shared_expert.gate_proj', 'mlp.shared_expert.up_proj' + for n in names: + if EXPERT_INDEX_PLACEHOLDER not in n: + moe_simple[-1].append(n) + # Currently, only need to add `capture_only_modules` to `['mlp.experts.#.gate_proj', 'mlp.experts.#.up_proj']` + # or ['mlp.shared_expert.gate_proj', 'mlp.shared_expert.up_proj', 'mlp.experts.#.gate_proj', 'mlp.experts.#.up_proj'] + add_capture_only_module = len(names) == (4 if any("shared_expert" in n for n in names) else 2) + if add_capture_only_module and capture_only_modules: # Extend all elements in capture_only_modules moe_simple[-1].extend(capture_only_modules) else: @@ -399,12 +405,13 @@ def filter_not_quantize_module(cls, layer_modules, quantize_config): @classmethod def simple_layer_modules(cls, model_config, quantize_config, is_awq_quantize: bool = False, include_capture_only: bool = False): layer_modules = cls.build_layer_modules(cls.module_tree, include_capture_only=include_capture_only) - + print(f"simple_layer_modules build_layer_modules: {layer_modules}") layer_modules = cls.build_moe_modules_if_need(model_config, layer_modules, is_awq_quantize) - + print(f"simple_layer_modules build_moe_modules_if_need: {layer_modules}") layer_modules = cls.filter_not_quantize_module(layer_modules, quantize_config) - # print(f"simple_layer_modules layer_modules: {layer_modules}") + print(f"simple_layer_modules layer_modules: {layer_modules}") + return layer_modules @classmethod From 5b09816dc6e57a21b1e3b6f1c91d28b93581a83c Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Wed, 5 Nov 2025 13:15:13 +0800 Subject: [PATCH 15/16] fix wrong inp tensor Signed-off-by: ZX-ModelCloud --- gptqmodel/models/base.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index f4c87dff0..e9590d187 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -1031,7 +1031,9 @@ def _try_update_last_module(candidate_name: str) -> bool: _try_update_last_module(candidate_name) continue - if num_experts is not None and len(block) == num_experts and last_module is not None and last_module_name is not None: + is_down_proj_block = (num_experts is not None and + len(block) == (num_experts + 1 if any("shared_expert" in n for n in block) else num_experts)) + if is_down_proj_block and last_module is not None and last_module_name is not None: # mlp.experts.0.down_proj target_suffix = last_module_name.split(".")[-1] for name in block: @@ -1094,7 +1096,9 @@ def _try_update_last_module(candidate_name: str) -> bool: module2inspect, _ = get_module_by_name_prefix(module, root) # process ['mlp.experts.#.gate_proj', 'mlp.experts.#.gup_proj'] - if num_experts is not None and len(block) == 2 * num_experts + 1 and module2inspect is not None: + is_gate_up_proj_block = (num_experts is not None and + len(block) == (2 * num_experts + 2 if any("shared_expert" in n for n in block) else 2 * num_experts) + 1) + if is_gate_up_proj_block and module2inspect is not None: if last_module_root not in input_feat: log.debug( "awq_get_modules_for_scaling: missing input feature for `%s` while processing experts block (layer block size=%s)", From 47caadce71ce2dc40546ed545447fe208860b9e7 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Wed, 5 Nov 2025 15:48:57 +0800 Subject: [PATCH 16/16] fix expert `mlp.experts.0.down_proj` due to missing prev_op Signed-off-by: ZX-ModelCloud --- gptqmodel/models/base.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index e9590d187..f7daa8e72 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -405,13 +405,12 @@ def filter_not_quantize_module(cls, layer_modules, quantize_config): @classmethod def simple_layer_modules(cls, model_config, quantize_config, is_awq_quantize: bool = False, include_capture_only: bool = False): layer_modules = cls.build_layer_modules(cls.module_tree, include_capture_only=include_capture_only) - print(f"simple_layer_modules build_layer_modules: {layer_modules}") + layer_modules = cls.build_moe_modules_if_need(model_config, layer_modules, is_awq_quantize) - print(f"simple_layer_modules build_moe_modules_if_need: {layer_modules}") - layer_modules = cls.filter_not_quantize_module(layer_modules, quantize_config) - print(f"simple_layer_modules layer_modules: {layer_modules}") + layer_modules = cls.filter_not_quantize_module(layer_modules, quantize_config) + # print(f"simple_layer_modules layer_modules: {layer_modules}") return layer_modules @classmethod @@ -1031,8 +1030,11 @@ def _try_update_last_module(candidate_name: str) -> bool: _try_update_last_module(candidate_name) continue + has_shared_expert = any("shared_expert" in n for n in block) is_down_proj_block = (num_experts is not None and - len(block) == (num_experts + 1 if any("shared_expert" in n for n in block) else num_experts)) + len(block) == (num_experts + 1 if has_shared_expert else num_experts)) + is_gate_up_proj_block = (num_experts is not None and + len(block) == (2 * num_experts + 2 if has_shared_expert else 2 * num_experts) + 1) if is_down_proj_block and last_module is not None and last_module_name is not None: # mlp.experts.0.down_proj target_suffix = last_module_name.split(".")[-1] @@ -1096,8 +1098,6 @@ def _try_update_last_module(candidate_name: str) -> bool: module2inspect, _ = get_module_by_name_prefix(module, root) # process ['mlp.experts.#.gate_proj', 'mlp.experts.#.gup_proj'] - is_gate_up_proj_block = (num_experts is not None and - len(block) == (2 * num_experts + 2 if any("shared_expert" in n for n in block) else 2 * num_experts) + 1) if is_gate_up_proj_block and module2inspect is not None: if last_module_root not in input_feat: log.debug( @@ -1120,7 +1120,13 @@ def _try_update_last_module(candidate_name: str) -> bool: nodes.append(n) # Update tracker to the LAST item of this block - candidate_name = strip_non_quantize_flags(block[-1]) + if is_gate_up_proj_block: + # The block content is [..., mlp.experts.{last_index}.up_proj, shared_expert.gate_proj, shared_expert.up_proj, mlp] + # mlp.experts.{last_index}.up_proj should be selected as last_module + offset_from_end = 4 if has_shared_expert else 2 + else: + offset_from_end = 1 + candidate_name = strip_non_quantize_flags(block[-offset_from_end]) _try_update_last_module(candidate_name) import torch