diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index da850dd45..6ac798a6b 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -204,13 +204,13 @@ def _layer_input_features(self, state: _AWQLayerState) -> Dict[str, torch.Tensor # tuple(features[name].shape), # ) - for root, tensors in root_buckets.items(): - if not tensors or root in features: - continue - try: - features[root] = torch.cat(tensors, dim=0) - except RuntimeError: - features[root] = tensors[0] + # for root, tensors in root_buckets.items(): + # if not tensors or root in features: + # continue + # try: + # features[root] = torch.cat(tensors, dim=0) + # except RuntimeError: + # features[root] = tensors[0] return features def _refresh_forward_kwargs_from_cache(self) -> None: @@ -1233,7 +1233,7 @@ def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tenso def hook(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): if not inp: return - feature = inp[0] + feature = inp if isinstance(feature, (tuple, list)) and feature: feature = feature[0] self._record_input_feature(name, feature) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 69f061b97..1a8cab28d 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -33,6 +33,7 @@ from ..looper.named_module import NamedModule from ..models import BaseQModel from ..models._const import SUPPORTS_MODULE_TYPES +from ..models.base import CAPTURE_ONLY_FLAG from ..nn_modules.hooked_linear import (STOP_FORWARD_EXCEPTION, HookedLinear, StopForward, replace_module_with_hooked_legacy) from ..quantization.config import VRAMStrategy @@ -1251,11 +1252,15 @@ def _loop_impl(self, fail_safe: bool = False, **kwargs): return total_log - def crate_named_modules(self, full, is_lm_head_module, layer_index, layers_prefix, names, processor, fail_safe, layer_module=None) -> Dict[str, NamedModule]: + def crate_named_modules(self, module, full, is_lm_head_module, layer_index, layers_prefix, names, processor, fail_safe, layer_module=None) -> Dict[str, NamedModule]: subset = {} for n in names: if n in full: subset[n] = full[n] + elif n.endswith(CAPTURE_ONLY_FLAG): + # Obtain the CAPTURE_ONLY_FLAG Module separately + n = n.split(CAPTURE_ONLY_FLAG, 1)[0] + subset[n], _ = get_module_by_name_prefix(module, module_name=n) # some modules have layer_modules that are dynamic based on config # ref: deepseek v2/v3/r1 elif self.gptq_model.layer_modules_strict: diff --git a/gptqmodel/looper/named_module.py b/gptqmodel/looper/named_module.py index db078c3ab..42a74fd40 100644 --- a/gptqmodel/looper/named_module.py +++ b/gptqmodel/looper/named_module.py @@ -54,12 +54,14 @@ def __init__(self, module: torch.nn.Module, name: str, full_name:str, layer_inde in_features = module.weight.shape[0] out_features = module.weight.shape[1] else: - raise NotImplementedError(f"Unsupported module.module type: `{type(module)}`") + in_features = None + out_features = None - self.state.update({ - "in_features": in_features, - "out_features": out_features, - }) + if in_features and out_features: + self.state.update({ + "in_features": in_features, + "out_features": out_features, + }) def parameters(self, recurse: bool = True): return self.module.parameters(recurse=recurse) @@ -73,6 +75,12 @@ def buffers(self, recurse: bool = True): def named_buffers(self, prefix: str = "", recurse: bool = True): return self.module.named_buffers(prefix=prefix, recurse=recurse) + def register_forward_hook( + self, *args, **kwargs + ): + with self._parent_lock: + return self.module.register_forward_hook(*args, **kwargs) + def register_buffer( self, name: str, tensor: Optional[Tensor], persistent: bool = True ) -> None: diff --git a/gptqmodel/looper/stage_subset.py b/gptqmodel/looper/stage_subset.py index fe6bbe83b..2cf65345b 100644 --- a/gptqmodel/looper/stage_subset.py +++ b/gptqmodel/looper/stage_subset.py @@ -79,6 +79,7 @@ def run_subset_stage( is_awq_processor = processor_name_lower.startswith("awq") subset = looper.crate_named_modules( + module=module, full=full, is_lm_head_module=is_lm_head_module, layer_index=layer_index, diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 0fda18350..f7daa8e72 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -12,6 +12,7 @@ import time from collections import defaultdict from contextlib import nullcontext +from itertools import count from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, Union import torch @@ -326,10 +327,15 @@ def build_moe_modules_if_need(cls, model_config, layer_modules, is_awq_quantize: num_experts = cls.get_num_experts(model_config) moe_simple = [] + capture_only_modules = None for names in layer_modules: moe_simple.append([]) - has_expert = all(EXPERT_INDEX_PLACEHOLDER in n for n in names) + has_expert = any(EXPERT_INDEX_PLACEHOLDER in n for n in names) + has_capture_only = all(CAPTURE_ONLY_FLAG in n for n in names) + if has_capture_only: + capture_only_modules = list(names) + continue if not has_expert: moe_simple[-1].extend(names) @@ -340,7 +346,18 @@ def build_moe_modules_if_need(cls, model_config, layer_modules, is_awq_quantize: # result like: ['mlp.experts.0.gate_proj', 'mlp.experts.0.up_proj', 'mlp.experts.1.gate_proj', 'mlp.experts.1.up_proj', ...] for index in range(num_experts): for n in names: - moe_simple[-1].append(n.replace(EXPERT_INDEX_PLACEHOLDER, str(index))) + if EXPERT_INDEX_PLACEHOLDER in n: + moe_simple[-1].append(n.replace(EXPERT_INDEX_PLACEHOLDER, str(index))) + # added 'mlp.shared_expert.gate_proj', 'mlp.shared_expert.up_proj' + for n in names: + if EXPERT_INDEX_PLACEHOLDER not in n: + moe_simple[-1].append(n) + # Currently, only need to add `capture_only_modules` to `['mlp.experts.#.gate_proj', 'mlp.experts.#.up_proj']` + # or ['mlp.shared_expert.gate_proj', 'mlp.shared_expert.up_proj', 'mlp.experts.#.gate_proj', 'mlp.experts.#.up_proj'] + add_capture_only_module = len(names) == (4 if any("shared_expert" in n for n in names) else 2) + if add_capture_only_module and capture_only_modules: + # Extend all elements in capture_only_modules + moe_simple[-1].extend(capture_only_modules) else: # result like: ['mlp.experts.0.gate_proj', 'mlp.experts.1.gate_proj', 'mlp.experts.0.up_proj', 'mlp.experts.1.up_proj', ...] for n in names: @@ -364,11 +381,7 @@ def get_num_experts(cls, model_config): @classmethod def filter_not_quantize_module(cls, layer_modules, quantize_config): layer_modules = [ - [ - name - for name in block - if all(flag not in name for flag in NON_QUANTIZE_FLAGS) - ] + [name for name in block if NOT_QUANTIZE_FLAG not in name] for block in layer_modules ] layer_modules = [block for block in layer_modules if block] # 去掉空 block @@ -1017,7 +1030,12 @@ def _try_update_last_module(candidate_name: str) -> bool: _try_update_last_module(candidate_name) continue - if num_experts is not None and len(block) == num_experts and last_module is not None and last_module_name is not None: + has_shared_expert = any("shared_expert" in n for n in block) + is_down_proj_block = (num_experts is not None and + len(block) == (num_experts + 1 if has_shared_expert else num_experts)) + is_gate_up_proj_block = (num_experts is not None and + len(block) == (2 * num_experts + 2 if has_shared_expert else 2 * num_experts) + 1) + if is_down_proj_block and last_module is not None and last_module_name is not None: # mlp.experts.0.down_proj target_suffix = last_module_name.split(".")[-1] for name in block: @@ -1079,7 +1097,8 @@ def _try_update_last_module(candidate_name: str) -> bool: last_module_root = root module2inspect, _ = get_module_by_name_prefix(module, root) - if num_experts is not None and len(block) == 2 * num_experts and module2inspect is not None: + # process ['mlp.experts.#.gate_proj', 'mlp.experts.#.gup_proj'] + if is_gate_up_proj_block and module2inspect is not None: if last_module_root not in input_feat: log.debug( "awq_get_modules_for_scaling: missing input feature for `%s` while processing experts block (layer block size=%s)", @@ -1101,7 +1120,13 @@ def _try_update_last_module(candidate_name: str) -> bool: nodes.append(n) # Update tracker to the LAST item of this block - candidate_name = strip_non_quantize_flags(block[-1]) + if is_gate_up_proj_block: + # The block content is [..., mlp.experts.{last_index}.up_proj, shared_expert.gate_proj, shared_expert.up_proj, mlp] + # mlp.experts.{last_index}.up_proj should be selected as last_module + offset_from_end = 4 if has_shared_expert else 2 + else: + offset_from_end = 1 + candidate_name = strip_non_quantize_flags(block[-offset_from_end]) _try_update_last_module(candidate_name) import torch @@ -1339,6 +1364,10 @@ def build_layer_modules(cls, tree, include_capture_only: bool = False): raise ValueError("Mapping configuration not found in the tree.") out_blocks = [] + alias_groups: Dict[tuple[str | None, int], List[tuple[str, bool, bool]]] = {} + alias_meta: Dict[tuple[str | None, int], Dict[str, int]] = {} + alias_seq = count() + group_seq = count() def _parse_token(token: str) -> tuple[str, List[str]]: parts = token.split(":") @@ -1352,19 +1381,46 @@ def _group_from_flags(flags: List[str]) -> int: return int(flag) return 0 - def process_entries(parent_token: str, entries, parent_group_offset: int = 0): + def _has_numeric_flag(flags: List[str]) -> bool: + return any(flag.isdigit() for flag in flags) + + def _get_scope(parent_name: str) -> str | None: + if not parent_name: + return None + return parent_name.split(".", 1)[0] + + def process_entries(parent_token: str, entries, parent_group_offset: int = 0, scope_key: str | None = None): """Process entries recursively to handle nested dict structures for MoE""" - groups: defaultdict[int, List[tuple[str, bool, bool]]] = defaultdict(list) + groups: defaultdict[int, List[tuple]] = defaultdict(list) parent_name, parent_flags = _parse_token(parent_token) - parent_group = parent_group_offset + _group_from_flags(parent_flags) + parent_rel_group = _group_from_flags(parent_flags) + parent_group = parent_group_offset + parent_rel_group parent_has_bang = "!" in parent_flags parent_capture_only = "?" in parent_flags + parent_has_numeric = _has_numeric_flag(parent_flags) + + scope = scope_key if scope_key is not None else _get_scope(parent_name) + parent_alias_scope = scope if parent_has_numeric else parent_name + + def _make_entry(full_path: str, has_bang: bool, capture_only: bool, *, alias_base: int, alias_rel: int, alias_scope: str | None) -> tuple: + return (full_path, has_bang, capture_only, alias_scope, (alias_base, alias_rel)) child_group_offset = parent_group_offset add_parent = parent_has_bang or (parent_capture_only and include_capture_only) if add_parent: - groups[parent_group].append((parent_name, parent_has_bang, parent_capture_only)) + alias_base = parent_rel_group if parent_has_numeric else parent_group + parent_entry_scope = f"{parent_alias_scope}.__parent__" if parent_alias_scope is not None else None + groups[parent_group].append( + _make_entry( + parent_name, + parent_has_bang, + parent_capture_only, + alias_base=alias_base, + alias_rel=0, + alias_scope=parent_entry_scope, + ) + ) child_group_offset = max(child_group_offset, parent_group + 1) # Handle tuple/list of strings (traditional format) @@ -1375,7 +1431,8 @@ def process_entries(parent_token: str, entries, parent_group_offset: int = 0): has_bang = "!" in child_flags capture_only = "?" in child_flags # first numeric tag is the group id; default 0 - grp = child_group_offset + _group_from_flags(child_flags) + child_rel_group = _group_from_flags(child_flags) + grp = child_group_offset + child_rel_group # Apply parent group offset to avoid conflicts between different nesting levels # Store the full path including parent for later use if parent_name.endswith(f".{child_name}") or parent_name == child_name: @@ -1387,7 +1444,19 @@ def process_entries(parent_token: str, entries, parent_group_offset: int = 0): if capture_only and not include_capture_only: continue - groups[grp].append((full_path, has_bang, capture_only)) + alias_scope = scope if parent_has_numeric else parent_name + alias_base = parent_rel_group if parent_has_numeric else grp + alias_rel = child_rel_group if parent_has_numeric else 0 + groups[grp].append( + _make_entry( + full_path, + has_bang, + capture_only, + alias_base=alias_base, + alias_rel=alias_rel, + alias_scope=alias_scope, + ) + ) elif isinstance(entries, dict): # Calculate max group number used at current level to avoid conflicts @@ -1408,15 +1477,31 @@ def process_entries(parent_token: str, entries, parent_group_offset: int = 0): f"{parent_name}.{EXPERT_INDEX_PLACEHOLDER}" if parent_name else EXPERT_INDEX_PLACEHOLDER ) + template_parent_token = ( + f"{template_parent}:{parent_rel_group}" + if parent_has_numeric + else template_parent + ) # Use a higher offset for expert modules to avoid conflicts with parent level expert_offset = current_offset + max_current_group + 100 # Large offset to avoid conflicts # Handle special case where sub_entries is ("#",) or "#" - this means use the parent path directly if (isinstance(sub_entries, (tuple, list)) and len(sub_entries) == 1 and sub_entries[0] == "#") or sub_entries == "#": # For ("#",) or "#" format, use the template_parent directly with default group 0 - groups[expert_offset].append((template_parent, False, False)) + alias_scope = scope if parent_has_numeric else template_parent + alias_base = parent_rel_group if parent_has_numeric else expert_offset + groups[expert_offset].append( + _make_entry( + template_parent, + False, + False, + alias_base=alias_base, + alias_rel=0, + alias_scope=alias_scope, + ) + ) else: - sub_groups = process_entries(template_parent, sub_entries, expert_offset) + sub_groups = process_entries(template_parent_token, sub_entries, expert_offset, scope) for grp, items in sub_groups.items(): groups[grp].extend(items) else: @@ -1429,7 +1514,7 @@ def process_entries(parent_token: str, entries, parent_group_offset: int = 0): f"{parent_name}.{sub_parent}" if parent_name else sub_parent ) - sub_groups = process_entries(full_sub_parent, sub_entries, current_offset) + sub_groups = process_entries(full_sub_parent, sub_entries, current_offset, scope) for grp, items in sub_groups.items(): groups[grp].extend(items) # Update offset for next sibling to avoid conflicts @@ -1438,28 +1523,47 @@ def process_entries(parent_token: str, entries, parent_group_offset: int = 0): return groups + def _register_alias(order_idx: int, item: tuple[str, bool, bool, str | None, tuple[int, int]]): + full_path, has_bang, capture_only, scope, alias_parts = item + if capture_only and not include_capture_only: + return + alias_scope = scope + alias_base, alias_rel = alias_parts + alias_index = alias_base + alias_rel + key = (alias_scope, alias_index) + meta = alias_meta.get(key) + if meta is None: + alias_meta[key] = {"order": order_idx, "seq": next(alias_seq)} + alias_groups[key] = [(full_path, has_bang, capture_only)] + else: + meta["order"] = min(meta["order"], order_idx) + alias_groups[key].append((full_path, has_bang, capture_only)) + for parent, entries in mapping.items(): groups = process_entries(parent, entries) - # Emit per-group, skipping pure-:! blocks (norm-only), but - # preserving :! markers on mixed blocks if they ever occur. for g in sorted(groups): + order_idx = next(group_seq) items = groups[g] - # if every entry is :!, skip this block (matches your expected output) - # if all(has_bang for _, has_bang in items): - # continue - - block = [] - for full_path, has_bang, capture_only in items: - # The full path is already constructed in process_entries - name = full_path - if has_bang: - name += NOT_QUANTIZE_FLAG - if capture_only and include_capture_only: - name += CAPTURE_ONLY_FLAG - block.append(name) - - out_blocks.append(block) + for item in items: + if len(item) == 3: + full_path, has_bang, capture_only = item + scope = full_path + alias_parts = (g, 0) + _register_alias(order_idx, (full_path, has_bang, capture_only, scope, alias_parts)) + else: + _register_alias(order_idx, item) + + for key in sorted(alias_groups.keys(), key=lambda k: (alias_meta[k]["order"], alias_meta[k]["seq"])): + block = [] + for full_path, has_bang, capture_only in alias_groups[key]: + name = full_path + if has_bang: + name += NOT_QUANTIZE_FLAG + if capture_only and include_capture_only: + name += CAPTURE_ONLY_FLAG + block.append(name) + out_blocks.append(block) return out_blocks diff --git a/gptqmodel/models/definitions/qwen2_moe.py b/gptqmodel/models/definitions/qwen2_moe.py index 89446c5b3..f923d6b69 100644 --- a/gptqmodel/models/definitions/qwen2_moe.py +++ b/gptqmodel/models/definitions/qwen2_moe.py @@ -21,12 +21,26 @@ class Qwen2MoeQModel(BaseQModel): "input_layernorm": ("input_layernorm:!",), "self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"), "post_attention_layernorm": ("post_attention_layernorm:!",), - "mlp": { - "gate": ("gate:0",), - "shared_expert": ("gate_proj:0", "up_proj:0", "down_proj:1"), - "experts": { + "mlp:?": { + "gate": ("gate:!",), + "shared_expert:0": ("gate_proj:0", "up_proj:0", "down_proj:1"), + "experts:0": { "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), }, }, } ] + + # module_tree_overrides = { + # METHOD.AWQ: [ + # { + # "mlp:?": { + # "gate": ("gate:!",), + # "shared_expert": None, + # "experts": { + # "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), + # }, + # }, + # } + # ] + # } diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index fe4a00301..44bcab9a9 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -571,8 +571,8 @@ def materialize_global_hessian(self, target_device: Optional[torch.device] = Non for partial_device, partial in self._device_hessian_partials.items(): if partial.device != result_accum.device or partial.dtype != torch.float32: # TODO FIXME multi-3090 using P2P is revaling an issue where result_accum and/or partial is not ready for consolidation on the main thread - # when parials are calculated on the individual - try: + # when parials are calculated on the individual + try: result_accum.add_(partial.to(device=result_accum.device, dtype=torch.float32)) except: log.warn(f"Quantization: Module `{self.name}` -> Retry partial.to in 0.25s") diff --git a/tests/models/model_test.py b/tests/models/model_test.py index 2048cc4ac..501585335 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -58,6 +58,7 @@ def is_flash_attn_2_available(): # type: ignore return False from gptqmodel import BACKEND, DEBUG_ON, GPTQModel # noqa: E402 +from gptqmodel.looper.module_looper import StopMainLoop # noqa: E402 from gptqmodel.models.base import BaseQModel # noqa: E402 from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402 from gptqmodel.quantization import FORMAT, METHOD # noqa: E402 @@ -65,7 +66,6 @@ def is_flash_attn_2_available(): # type: ignore from gptqmodel.utils.eval import EVAL # noqa: E402 from gptqmodel.utils.model import MODALITY # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 -from gptqmodel.looper.module_looper import StopMainLoop # noqa: E402 RAND_SEED = 898 diff --git a/tests/models/test_glm4_moe._awq.py b/tests/models/test_glm4_moe._awq.py index 57991cf46..ad7730bdb 100644 --- a/tests/models/test_glm4_moe._awq.py +++ b/tests/models/test_glm4_moe._awq.py @@ -4,10 +4,8 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from model_test import ModelTest -from gptqmodel.quantization.config import VRAMStrategy -from gptqmodel.utils.eval import EVAL from gptqmodel.quantization import FORMAT, METHOD - +from gptqmodel.utils.eval import EVAL # | Metric | MARLIN | diff --git a/tests/models/test_qwen2_moe_quant.py b/tests/models/test_qwen2_moe_quant.py index 316594e6d..f614b93fd 100644 --- a/tests/models/test_qwen2_moe_quant.py +++ b/tests/models/test_qwen2_moe_quant.py @@ -17,8 +17,6 @@ class TestQwen2_5_Moe(ModelTest): "acc_norm": {"value": 0.3055, "floor_pct": 0.2}, }, } - TRUST_REMOTE_CODE = False - EVAL_BATCH_SIZE = 6 def test_qwen2_5(self): self.quant_lm_eval() diff --git a/tests/test_awq.py b/tests/test_awq.py index 6230a3de5..e1ec41e65 100644 --- a/tests/test_awq.py +++ b/tests/test_awq.py @@ -47,15 +47,14 @@ def setUpClass(cls): if requested_samples is not None: sample_count = max(1, int(requested_samples)) else: - total_mem_gb = 0 if torch.cuda.is_available(): try: - total_mem_gb = ( + ( torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory / (1024 ** 3) ) except Exception: - total_mem_gb = 0 + pass # if total_mem_gb >= 80: # sample_count = 1024 diff --git a/tests/test_awq_weight_mean.py b/tests/test_awq_weight_mean.py index c8b541750..10dbd8257 100644 --- a/tests/test_awq_weight_mean.py +++ b/tests/test_awq_weight_mean.py @@ -1,12 +1,13 @@ import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7" #"expandable_segments:True" import time -import torch -import pytest +import pytest +import torch from parameterized import parameterized from pytest import MonkeyPatch from torch import nn diff --git a/tests/test_subset_parsing.py b/tests/test_subset_parsing.py index 614547e62..aec349436 100644 --- a/tests/test_subset_parsing.py +++ b/tests/test_subset_parsing.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium import os +import sys +from pathlib import Path from types import SimpleNamespace from typing import Callable, Dict, List, Optional @@ -10,16 +12,23 @@ from transformers import Qwen3MoeConfig, Qwen3MoeForCausalLM from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock + +repo_root = Path(__file__).resolve().parents[1] +repo_str = str(repo_root) +if repo_str not in sys.path: + sys.path.insert(0, repo_str) + from gptqmodel.looper.awq_processor import AWQProcessor from gptqmodel.looper.loop_processor import LoopProcessor from gptqmodel.looper.module_looper import ModuleLooper -from gptqmodel.looper.stage_subset import run_subset_stage from gptqmodel.looper.named_module import NamedModule +from gptqmodel.looper.stage_subset import run_subset_stage +from gptqmodel.models.definitions.qwen2_moe import Qwen2MoeQModel +from gptqmodel.models.definitions.qwen3_moe import Qwen3MoeQModel +from gptqmodel.nn_modules.hooked_linear import replace_module_with_hooked_legacy from gptqmodel.quantization import FORMAT, METHOD from gptqmodel.quantization.config import QuantizeConfig, VRAMStrategy -from gptqmodel.nn_modules.hooked_linear import replace_module_with_hooked_legacy from gptqmodel.utils.model import find_modules, get_module_by_name_prefix -from gptqmodel.models.definitions.qwen3_moe import Qwen3MoeQModel # honour the request to bind the test harness to GPU index 5 when CUDA is available @@ -71,7 +80,7 @@ def test_mlp_capture_flag_propagates_to_layer_modules(): include_capture_only=True, ) capture_blocks = [block for block in full if any(":?" in name for name in block)] - assert capture_blocks and capture_blocks[0] == ["mlp:?"] + assert capture_blocks and "mlp:?" in capture_blocks[0] simple = Qwen3MoeQModel.simple_layer_modules( model_config=model_config, @@ -85,6 +94,20 @@ def test_mlp_capture_flag_propagates_to_layer_modules(): assert isinstance(mlp_module, Qwen3MoeSparseMoeBlock) +def test_qwen2_moe_shared_expert_merges_with_experts(): + blocks = Qwen2MoeQModel.build_layer_modules(Qwen2MoeQModel.module_tree) + + gate_block = next(block for block in blocks if "mlp.shared_expert.gate_proj" in block) + assert "mlp.experts.{expert_index}.gate_proj" in gate_block + assert "mlp.experts.{expert_index}.up_proj" in gate_block + + down_block = next(block for block in blocks if "mlp.shared_expert.down_proj" in block) + assert "mlp.experts.{expert_index}.down_proj" in down_block + + expert_gate_blocks = [block for block in blocks if "mlp.experts.{expert_index}.gate_proj" in block] + assert len(expert_gate_blocks) == 1 + + def test_awq_processor_enables_subset_early_stop(): calibration = [{"input_ids": torch.tensor([1, 2, 3])}] qcfg = _make_quant_config() @@ -165,14 +188,14 @@ def __call__( processor: str, ): self.events.append( - dict( - stage=stage, - layer_idx=layer_idx, - subset_index=subset_index, - subset_total=subset_total, - module_names=module_names, - processor=processor, - ) + { + "stage": stage, + "layer_idx": layer_idx, + "subset_index": subset_index, + "subset_total": subset_total, + "module_names": module_names, + "processor": processor, + } )