From 6d3e6e6a8a508d742ef1dc27dbb330f2670ccd70 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 3 Nov 2025 13:54:07 +0000 Subject: [PATCH 1/4] ad awq moe test --- tests/models/test_glm4_moe._awq.py | 45 ++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 tests/models/test_glm4_moe._awq.py diff --git a/tests/models/test_glm4_moe._awq.py b/tests/models/test_glm4_moe._awq.py new file mode 100644 index 000000000..57991cf46 --- /dev/null +++ b/tests/models/test_glm4_moe._awq.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium +from model_test import ModelTest + +from gptqmodel.quantization.config import VRAMStrategy +from gptqmodel.utils.eval import EVAL +from gptqmodel.quantization import FORMAT, METHOD + + + +# | Metric | MARLIN | +# |--------------------------------|----------| +# | arc_challenge :: acc,none | 0.5094 | +# | arc_challenge :: acc_norm,none | 0.5486 | +class TestQwen3Moe(ModelTest): + FORMAT = FORMAT.GEMM + METHOD = METHOD.AWQ + + # HESSIAN_CHUNK_SIZE = 256 * 1024 * 1024 + NATIVE_MODEL_ID = "/monster/data/model/Qwen3-30B-A3B" + EVAL_TASKS = { + EVAL.LM_EVAL.ARC_CHALLENGE: { + "acc": {"value": 0.5094, "floor_pct": 0.04}, + "acc_norm": {"value": 0.5486, "floor_pct": 0.04}, + }, + } + + # VRAM_STRATEGY = VRAMStrategy.BALANCED + # TRUST_REMOTE_CODE = False + # APPLY_CHAT_TEMPLATE = True + # EVAL_BATCH_SIZE = 6 + # V2 = False + # DEBUG = True + # ACT_GROUP_AWARE = True + # DESC_ACT = False + # DATASET_SIZE = 512 + # DATASET_SORT = "desc" + # QUANT_BATCH_SIZE = 4 + # CALIB_NOISE_MODE = "unseen" + # CALIB_NOISE_PERCENT = 0.025 + + def test_mimo(self): + self.quant_lm_eval() From a5a0cec487d26d0907e287be9a60b7075515bdb0 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 3 Nov 2025 14:14:23 +0000 Subject: [PATCH 2/4] udpate scores --- tests/models/test_llama3_2_awq.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/models/test_llama3_2_awq.py b/tests/models/test_llama3_2_awq.py index 0556673aa..451124113 100644 --- a/tests/models/test_llama3_2_awq.py +++ b/tests/models/test_llama3_2_awq.py @@ -11,10 +11,10 @@ # | Metric | MARLIN | # |--------------------------------|----------| -# | arc_challenge :: acc,none | 0.3131 | -# | arc_challenge :: acc_norm,none | 0.3379 | -# | mmlu_stem :: acc,none | 0.3527 | -# | gsm8k_plat :: exact,flexible | 0.2754 | +# | arc_challenge :: acc,none | 0.3166 | +# | arc_challenge :: acc_norm,none | 0.3464 | +# | mmlu_stem :: acc,none | 0.3692 | +# | gsm8k_plat :: exact,flexible | 0.2994 | class TestLlama3_2_awq(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" EVAL_BATCH_SIZE = 64 @@ -24,25 +24,25 @@ class TestLlama3_2_awq(ModelTest): EVAL.LM_EVAL.GSM8K_PLATINUM_COT: { "chat_template": True, "exact_match,flexible-extract": { - "value": 0.2754, + "value": 0.2994, "floor_pct": 0.04, }, }, EVAL.LM_EVAL.ARC_CHALLENGE: { "chat_template": True, "acc": { - "value": 0.3131, + "value": 0.3166, "floor_pct": 0.04, }, "acc_norm": { - "value": 0.3379, + "value": 0.3464, "floor_pct": 0.04, }, }, EVAL.LM_EVAL.MMLU_STEM: { "chat_template": False, "acc": { - "value": 0.3527, + "value": 0.3692, "floor_pct": 0.04, }, }, From 38353df17167d9d3fb987e84e91698d3e8cc0b3b Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 3 Nov 2025 14:41:33 +0000 Subject: [PATCH 3/4] fix quant speed regression --- gptqmodel/looper/awq_processor.py | 3 +- gptqmodel/looper/gptq_processor.py | 2 + gptqmodel/looper/loop_processor.py | 5 +- gptqmodel/looper/module_looper.py | 96 +++++++++++++++++++++++++++++- gptqmodel/looper/stage_layer.py | 3 +- gptqmodel/looper/stage_subset.py | 25 +++++++- 6 files changed, 127 insertions(+), 7 deletions(-) diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index 2b858ca7a..ea34b286d 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -74,7 +74,8 @@ def __init__( prepare_dataset_func=prepare_dataset_func, batch_size=batch_size, require_fwd=require_fwd, - fwd_after_process=False, + fwd_after_process=True, + subset_forward_early_stop=True, ) self.calculate_w_wq_diff = calculate_w_wq_diff diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index db0100133..487efff98 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -55,6 +55,8 @@ def __init__( prepare_dataset_func=prepare_dataset_func, batch_size=batch_size, require_fwd=require_fwd, + fwd_after_process=True, + subset_forward_early_stop=True, ) self.calculate_w_wq_diff = calculate_w_wq_diff diff --git a/gptqmodel/looper/loop_processor.py b/gptqmodel/looper/loop_processor.py index fe52bb053..7c2f536da 100644 --- a/gptqmodel/looper/loop_processor.py +++ b/gptqmodel/looper/loop_processor.py @@ -62,6 +62,7 @@ def __init__( require_fwd: bool = True, fwd_after_process: bool = True, fwd_all_modules_in_single_pass: bool = False, + subset_forward_early_stop: bool = False, ): # process level lock self.lock = threading.Lock() @@ -79,7 +80,7 @@ def __init__( # looper should bypass generate + hooks if this is false self.require_fwd = require_fwd # default True - # after process(), do we need to forward again? paried with require_fwd == True + # after process(), do we need to forward again? paired with require_fwd == True # if true, forward output is captured post process() and saved for next loop as input # if false, forward output before process() call is saved for next loop as input self.fwd_after_process = fwd_after_process # default True @@ -88,6 +89,8 @@ def __init__( # if true, fwd is repeated based on module dep sub-groups # if false, sub-module groups are merged as one and fwd happens in one pass self.fwd_all_modules_in_single_pass = fwd_all_modules_in_single_pass # default False + # when True, stop the layer forward immediately after the final module in a subset fires + self.subset_forward_early_stop = subset_forward_early_stop self.inputs_cache: InputCache = InputCache(None, None, None, None) self.tasks = {} diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 88c92cf7f..c9129086a 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -132,6 +132,7 @@ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]): vram_strategy = VRAMStrategy.EXCLUSIVE self._vram_strategy = vram_strategy self._moe_subset_threshold = 16 + self._subset_callback = getattr(self.gptq_model, "subset_callback", None) for processor in self.processors: self._processor_mask_tls(processor) @@ -140,6 +141,10 @@ def register_layer_callback(self, callback) -> None: """Register or replace the layer-complete callback target.""" self._layer_callback = callback + def register_subset_callback(self, callback) -> None: + """Register or replace the subset event callback target.""" + self._subset_callback = callback + def _resolve_layer_callback(self): for candidate in ( getattr(self, "_layer_callback", None), @@ -152,6 +157,16 @@ def _resolve_layer_callback(self): return candidate return None + def _resolve_subset_callback(self): + for candidate in ( + getattr(self, "_subset_callback", None), + getattr(self, "subset_callback", None), + getattr(self.gptq_model, "subset_callback", None), + ): + if candidate is not None: + return candidate + return None + def callbackup(self, layer_idx: int, submodule_finalized: bool): callback = self._resolve_layer_callback() if callback is None: @@ -173,6 +188,26 @@ def callbackup(self, layer_idx: int, submodule_finalized: bool): raise result return result + def _subset_event_dispatch( + self, + *, + stage: str, + layer_idx: int, + subset_index: int, + subset_total: int, + module_names: List[str], + processor: str, + ) -> None: + self._emit_subset_event( + stage=stage, + layer_idx=layer_idx, + subset_index=subset_index, + subset_total=subset_total, + module_names=module_names, + processor=processor, + raise_in_place=True, + ) + def _request_loop_stop(self, exc: Optional[BaseException]) -> None: with self.lock: if self._loop_stop_exc is None and exc is not None: @@ -189,6 +224,60 @@ def _check_loop_stop(self) -> bool: raise self._loop_stop_exc return True + def _emit_subset_event( + self, + *, + stage: str, + layer_idx: int, + subset_index: int, + subset_total: int, + module_names: List[str], + processor: str, + raise_in_place: bool, + ) -> None: + callback = self._resolve_subset_callback() + if callback is None: + return + + handler = getattr(callback, "subset_event", None) + if handler is None and callable(callback): + handler = callback + if handler is None: + return + + try: + result = handler( + stage=stage, + layer_idx=layer_idx, + subset_index=subset_index, + subset_total=subset_total, + module_names=module_names, + processor=processor, + ) + except StopMainLoop as exc: + self._request_loop_stop(exc) + if raise_in_place: + raise + return + except BaseException as exc: + self._request_loop_stop(exc) + if raise_in_place: + raise + return + + if result is StopMainLoop: + exc = StopMainLoop(f"Subset callback requested stop at layer {layer_idx} subset {subset_index}") + self._request_loop_stop(exc) + if raise_in_place: + raise exc + return + + if isinstance(result, StopMainLoop): + self._request_loop_stop(result) + if raise_in_place: + raise result + return + def _emit_layer_complete( self, layer_idx: int, @@ -266,11 +355,14 @@ def _coerce_to_int(self, value) -> Optional[int]: def _resolve_batch_total(self, raw_count, fallback_sequence) -> int: count = self._coerce_to_int(raw_count) + fallback_len = self._safe_len(fallback_sequence) + fallback = self._coerce_to_int(fallback_len) + if count is not None and count > 0: + if fallback is not None and fallback >= 0: + return min(count, fallback) return count - fallback_len = self._safe_len(fallback_sequence) - fallback = self._coerce_to_int(fallback_len) if fallback is not None: return max(fallback, 0) diff --git a/gptqmodel/looper/stage_layer.py b/gptqmodel/looper/stage_layer.py index e1c41ccfc..1d2cfa759 100644 --- a/gptqmodel/looper/stage_layer.py +++ b/gptqmodel/looper/stage_layer.py @@ -128,7 +128,7 @@ def run_layer_stage( names[:8], ) subset_result = run_subset_stage( - looper, + looper=looper, processor=processor, module=module, layer_inputs=layer_inputs, @@ -151,6 +151,7 @@ def run_layer_stage( log=log, region_timer=region_timer, previous_processed_subset=previous_subset_processed, + subset_event_cb=looper._subset_event_dispatch, ) layer_inputs = subset_result.layer_inputs diff --git a/gptqmodel/looper/stage_subset.py b/gptqmodel/looper/stage_subset.py index 5738481fe..fe6bbe83b 100644 --- a/gptqmodel/looper/stage_subset.py +++ b/gptqmodel/looper/stage_subset.py @@ -11,7 +11,7 @@ import math import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Callable, Dict, List, Optional import torch @@ -69,6 +69,7 @@ def run_subset_stage( log=None, region_timer=None, previous_processed_subset: Optional[Dict[str, NamedModule]] = None, + subset_event_cb: Optional[Callable[..., None]] = None, ) -> SubsetStageResult: """Process a single subset of modules within the layer quantization loop.""" logger = log or setup_logger() @@ -88,6 +89,18 @@ def run_subset_stage( layer_module=module, ) + def emit_subset_event(stage: str) -> None: + if subset_event_cb is None: + return + subset_event_cb( + stage=stage, + layer_idx=layer_index, + subset_index=subset_index, + subset_total=subset_total, + module_names=list(subset.keys()), + processor=processor_name, + ) + # TODO FIXME: If a full layer has no module to quantize a simple forward() is enough and output is captured # to be used as next layer's input. So one pass forward (entire layer simple forward wihout need of dealing # with subset loops and micro forward loops, just full layer, usally XXXDecodeLayer.forward(). @@ -235,7 +248,8 @@ def run_subset_stage( if hasattr(subset[name], 'forward_hook'): original_hook = processor.pre_process_fwd_hook(name) subset[name].forward_hook = looper._masked_hook_wrapper(processor, original_hook, hook_source) - if is_last and processor.fwd_after_process: + enable_stop = processor.fwd_after_process or getattr(processor, "subset_forward_early_stop", False) + if is_last and enable_stop: subset[name].forward_hook_last = True else: original_hook = processor.pre_process_fwd_hook(name) @@ -262,6 +276,8 @@ def run_subset_stage( len(subset), ) + emit_subset_event("forward_start") + fwd_start = time.perf_counter() forward_source = f"{layer_descriptor}:subset{subset_index + 1}/{subset_total}" @@ -325,6 +341,7 @@ def run_subset_stage( processor.receive_layer_inputs(forward_outputs) layer_inputs = processor.inputs_cache.layer_inputs del forward_outputs + emit_subset_event("forward_end") fwd_time = time.perf_counter() - fwd_start processor.set_fwd_time(fwd_time) @@ -387,6 +404,8 @@ def run_subset_stage( processed_subset: Dict[str, NamedModule] = {} futures = [] + emit_subset_event("quant_start") + @torch.inference_mode() def _process_on_worker( proc: LoopProcessor, @@ -474,6 +493,8 @@ def _process_on_worker( processed_subset[name] = named_module torch_sync() + emit_subset_event("quant_complete") + context = SubsetForwardContext( subset=subset, forward_device_map=forward_device_map, From b2dba50e053e534c0cd4cb0e3322c213d9f5924f Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 3 Nov 2025 14:51:40 +0000 Subject: [PATCH 4/4] comments --- gptqmodel/looper/stage_inputs_capture.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gptqmodel/looper/stage_inputs_capture.py b/gptqmodel/looper/stage_inputs_capture.py index f1d6f954f..dbf2609d1 100644 --- a/gptqmodel/looper/stage_inputs_capture.py +++ b/gptqmodel/looper/stage_inputs_capture.py @@ -117,6 +117,9 @@ def store_input_hook(module, args, kwargs): one_kwargs[k] = nested_move_to(v, device=data_device) layer_input_kwargs.append(one_kwargs) + # In normal repeating layer/sbuset early stop happens on the last module forward + # but the first model input embedding call we use a simple model register forwar hook + # and wait for the first instance this callback is called raise STOP_FORWARD_EXCEPTION if cur_layer_device == META: