diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index bfffa308c..358e255f7 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -110,6 +110,25 @@ def check_support_param_buffer_assignment(*args, **kwargs): return False +def apply_module_tree_override(module_tree, override): + """ + Recursively find the corresponding key of override in module_tree and override it. + """ + if isinstance(module_tree, dict) and isinstance(override, dict): + for k, v in override.items(): + if k in module_tree and isinstance(module_tree[k], (dict, list)) and isinstance(v, (dict, list)): + module_tree[k] = apply_module_tree_override(module_tree[k], v) + else: + module_tree[k] = v + elif isinstance(module_tree, list) and isinstance(override, list): + for o in override: + if isinstance(o, dict): + for b in module_tree: + if isinstance(b, dict): + apply_module_tree_override(b, o) + return module_tree + + NOT_QUANTIZE_FLAG = ":!" @@ -125,6 +144,8 @@ class BaseQModel(nn.Module): # a tree node of all the roots that contain quantizable modules module_tree: List[str] = None + # Override module_tree according to different QUANT_METHOD + module_tree_overrides: dict[METHOD, List[str]] = None # Strict=True -> all layer_modules must exists in model # Some models (deepseek2-lite) dynamically create lora modules based on config.rank @@ -198,6 +219,13 @@ def __init__( ): super().__init__() + quant_method = quantize_config.quant_method + # override module_tree if need + if self.module_tree_overrides is not None and self.module_tree_overrides.get(quant_method) is not None: + log.info(f'Module Tree: overridden by METHOD.{quant_method.upper()}') + # setting cls.module_tree + type(self).module_tree = apply_module_tree_override(self.module_tree, self.module_tree_overrides[quant_method]) + # record configuration early so model lifecycle hooks can rely on them self.compiled = False # set to True while compile() is triggered successfully self.quantized = quantized @@ -794,7 +822,7 @@ def quantize( ) if not self.support_batch_quantize: - log.warn("Quantize: batch_size overriden by model class definition to `disabled`") + log.warn("Quantize: batch_size overridden by model class definition to `disabled`") batch_size = 1 # but actually disabled if self.quantize_config.format == FORMAT.MARLIN: diff --git a/gptqmodel/models/definitions/qwen3_moe.py b/gptqmodel/models/definitions/qwen3_moe.py index 31e788838..e0f56602a 100644 --- a/gptqmodel/models/definitions/qwen3_moe.py +++ b/gptqmodel/models/definitions/qwen3_moe.py @@ -4,6 +4,7 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from ..base import BaseQModel +from ...quantization import METHOD class Qwen3MoeQModel(BaseQModel): @@ -33,3 +34,13 @@ class Qwen3MoeQModel(BaseQModel): }, } ] + + module_tree_overrides = { + METHOD.AWQ: [ + { + "mlp": { + "gate": ("gate",), + } + } + ] + }