Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion gptqmodel/looper/awq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def __init__(
prepare_dataset_func=prepare_dataset_func,
batch_size=batch_size,
require_fwd=require_fwd,
fwd_after_process=False,
fwd_after_process=True,
subset_forward_early_stop=True,
)

self.calculate_w_wq_diff = calculate_w_wq_diff
Expand Down
2 changes: 2 additions & 0 deletions gptqmodel/looper/gptq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(
prepare_dataset_func=prepare_dataset_func,
batch_size=batch_size,
require_fwd=require_fwd,
fwd_after_process=True,
subset_forward_early_stop=True,
)

self.calculate_w_wq_diff = calculate_w_wq_diff
Expand Down
5 changes: 4 additions & 1 deletion gptqmodel/looper/loop_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
require_fwd: bool = True,
fwd_after_process: bool = True,
fwd_all_modules_in_single_pass: bool = False,
subset_forward_early_stop: bool = False,
):
# process level lock
self.lock = threading.Lock()
Expand All @@ -79,7 +80,7 @@ def __init__(
# looper should bypass generate + hooks if this is false
self.require_fwd = require_fwd # default True

# after process(), do we need to forward again? paried with require_fwd == True
# after process(), do we need to forward again? paired with require_fwd == True
# if true, forward output is captured post process() and saved for next loop as input
# if false, forward output before process() call is saved for next loop as input
self.fwd_after_process = fwd_after_process # default True
Expand All @@ -88,6 +89,8 @@ def __init__(
# if true, fwd is repeated based on module dep sub-groups
# if false, sub-module groups are merged as one and fwd happens in one pass
self.fwd_all_modules_in_single_pass = fwd_all_modules_in_single_pass # default False
# when True, stop the layer forward immediately after the final module in a subset fires
self.subset_forward_early_stop = subset_forward_early_stop

self.inputs_cache: InputCache = InputCache(None, None, None, None)
self.tasks = {}
Expand Down
96 changes: 94 additions & 2 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]):
vram_strategy = VRAMStrategy.EXCLUSIVE
self._vram_strategy = vram_strategy
self._moe_subset_threshold = 16
self._subset_callback = getattr(self.gptq_model, "subset_callback", None)

for processor in self.processors:
self._processor_mask_tls(processor)
Expand All @@ -140,6 +141,10 @@ def register_layer_callback(self, callback) -> None:
"""Register or replace the layer-complete callback target."""
self._layer_callback = callback

def register_subset_callback(self, callback) -> None:
"""Register or replace the subset event callback target."""
self._subset_callback = callback

def _resolve_layer_callback(self):
for candidate in (
getattr(self, "_layer_callback", None),
Expand All @@ -152,6 +157,16 @@ def _resolve_layer_callback(self):
return candidate
return None

def _resolve_subset_callback(self):
for candidate in (
getattr(self, "_subset_callback", None),
getattr(self, "subset_callback", None),
getattr(self.gptq_model, "subset_callback", None),
):
if candidate is not None:
return candidate
return None

def callbackup(self, layer_idx: int, submodule_finalized: bool):
callback = self._resolve_layer_callback()
if callback is None:
Expand All @@ -173,6 +188,26 @@ def callbackup(self, layer_idx: int, submodule_finalized: bool):
raise result
return result

def _subset_event_dispatch(
self,
*,
stage: str,
layer_idx: int,
subset_index: int,
subset_total: int,
module_names: List[str],
processor: str,
) -> None:
self._emit_subset_event(
stage=stage,
layer_idx=layer_idx,
subset_index=subset_index,
subset_total=subset_total,
module_names=module_names,
processor=processor,
raise_in_place=True,
)

def _request_loop_stop(self, exc: Optional[BaseException]) -> None:
with self.lock:
if self._loop_stop_exc is None and exc is not None:
Expand All @@ -189,6 +224,60 @@ def _check_loop_stop(self) -> bool:
raise self._loop_stop_exc
return True

def _emit_subset_event(
self,
*,
stage: str,
layer_idx: int,
subset_index: int,
subset_total: int,
module_names: List[str],
processor: str,
raise_in_place: bool,
) -> None:
callback = self._resolve_subset_callback()
if callback is None:
return

handler = getattr(callback, "subset_event", None)
if handler is None and callable(callback):
handler = callback
if handler is None:
return

try:
result = handler(
stage=stage,
layer_idx=layer_idx,
subset_index=subset_index,
subset_total=subset_total,
module_names=module_names,
processor=processor,
)
except StopMainLoop as exc:
self._request_loop_stop(exc)
if raise_in_place:
raise
return
except BaseException as exc:
self._request_loop_stop(exc)
if raise_in_place:
raise
return

if result is StopMainLoop:
exc = StopMainLoop(f"Subset callback requested stop at layer {layer_idx} subset {subset_index}")
self._request_loop_stop(exc)
if raise_in_place:
raise exc
return

if isinstance(result, StopMainLoop):
self._request_loop_stop(result)
if raise_in_place:
raise result
return

def _emit_layer_complete(
self,
layer_idx: int,
Expand Down Expand Up @@ -266,11 +355,14 @@ def _coerce_to_int(self, value) -> Optional[int]:

def _resolve_batch_total(self, raw_count, fallback_sequence) -> int:
count = self._coerce_to_int(raw_count)
fallback_len = self._safe_len(fallback_sequence)
fallback = self._coerce_to_int(fallback_len)

if count is not None and count > 0:
if fallback is not None and fallback >= 0:
return min(count, fallback)
return count

fallback_len = self._safe_len(fallback_sequence)
fallback = self._coerce_to_int(fallback_len)
if fallback is not None:
return max(fallback, 0)

Expand Down
3 changes: 3 additions & 0 deletions gptqmodel/looper/stage_inputs_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ def store_input_hook(module, args, kwargs):
one_kwargs[k] = nested_move_to(v, device=data_device)
layer_input_kwargs.append(one_kwargs)

# In normal repeating layer/sbuset early stop happens on the last module forward
# but the first model input embedding call we use a simple model register forwar hook
# and wait for the first instance this callback is called
raise STOP_FORWARD_EXCEPTION

if cur_layer_device == META:
Expand Down
3 changes: 2 additions & 1 deletion gptqmodel/looper/stage_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def run_layer_stage(
names[:8],
)
subset_result = run_subset_stage(
looper,
looper=looper,
processor=processor,
module=module,
layer_inputs=layer_inputs,
Expand All @@ -151,6 +151,7 @@ def run_layer_stage(
log=log,
region_timer=region_timer,
previous_processed_subset=previous_subset_processed,
subset_event_cb=looper._subset_event_dispatch,
)

layer_inputs = subset_result.layer_inputs
Expand Down
25 changes: 23 additions & 2 deletions gptqmodel/looper/stage_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import math
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Callable, Dict, List, Optional

import torch

Expand Down Expand Up @@ -69,6 +69,7 @@ def run_subset_stage(
log=None,
region_timer=None,
previous_processed_subset: Optional[Dict[str, NamedModule]] = None,
subset_event_cb: Optional[Callable[..., None]] = None,
) -> SubsetStageResult:
"""Process a single subset of modules within the layer quantization loop."""
logger = log or setup_logger()
Expand All @@ -88,6 +89,18 @@ def run_subset_stage(
layer_module=module,
)

def emit_subset_event(stage: str) -> None:
if subset_event_cb is None:
return
subset_event_cb(
stage=stage,
layer_idx=layer_index,
subset_index=subset_index,
subset_total=subset_total,
module_names=list(subset.keys()),
processor=processor_name,
)

# TODO FIXME: If a full layer has no module to quantize a simple forward() is enough and output is captured
# to be used as next layer's input. So one pass forward (entire layer simple forward wihout need of dealing
# with subset loops and micro forward loops, just full layer, usally XXXDecodeLayer.forward().
Expand Down Expand Up @@ -235,7 +248,8 @@ def run_subset_stage(
if hasattr(subset[name], 'forward_hook'):
original_hook = processor.pre_process_fwd_hook(name)
subset[name].forward_hook = looper._masked_hook_wrapper(processor, original_hook, hook_source)
if is_last and processor.fwd_after_process:
enable_stop = processor.fwd_after_process or getattr(processor, "subset_forward_early_stop", False)
if is_last and enable_stop:
subset[name].forward_hook_last = True
else:
original_hook = processor.pre_process_fwd_hook(name)
Expand All @@ -262,6 +276,8 @@ def run_subset_stage(
len(subset),
)

emit_subset_event("forward_start")

fwd_start = time.perf_counter()
forward_source = f"{layer_descriptor}:subset{subset_index + 1}/{subset_total}"

Expand Down Expand Up @@ -325,6 +341,7 @@ def run_subset_stage(
processor.receive_layer_inputs(forward_outputs)
layer_inputs = processor.inputs_cache.layer_inputs
del forward_outputs
emit_subset_event("forward_end")

fwd_time = time.perf_counter() - fwd_start
processor.set_fwd_time(fwd_time)
Expand Down Expand Up @@ -387,6 +404,8 @@ def run_subset_stage(
processed_subset: Dict[str, NamedModule] = {}
futures = []

emit_subset_event("quant_start")

@torch.inference_mode()
def _process_on_worker(
proc: LoopProcessor,
Expand Down Expand Up @@ -474,6 +493,8 @@ def _process_on_worker(
processed_subset[name] = named_module
torch_sync()

emit_subset_event("quant_complete")

context = SubsetForwardContext(
subset=subset,
forward_device_map=forward_device_map,
Expand Down
45 changes: 45 additions & 0 deletions tests/models/test_glm4_moe._awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium
from model_test import ModelTest

from gptqmodel.quantization.config import VRAMStrategy
from gptqmodel.utils.eval import EVAL
from gptqmodel.quantization import FORMAT, METHOD



# | Metric | MARLIN |
# |--------------------------------|----------|
# | arc_challenge :: acc,none | 0.5094 |
# | arc_challenge :: acc_norm,none | 0.5486 |
class TestQwen3Moe(ModelTest):
FORMAT = FORMAT.GEMM
METHOD = METHOD.AWQ

# HESSIAN_CHUNK_SIZE = 256 * 1024 * 1024
NATIVE_MODEL_ID = "/monster/data/model/Qwen3-30B-A3B"
EVAL_TASKS = {
EVAL.LM_EVAL.ARC_CHALLENGE: {
"acc": {"value": 0.5094, "floor_pct": 0.04},
"acc_norm": {"value": 0.5486, "floor_pct": 0.04},
},
}

# VRAM_STRATEGY = VRAMStrategy.BALANCED
# TRUST_REMOTE_CODE = False
# APPLY_CHAT_TEMPLATE = True
# EVAL_BATCH_SIZE = 6
# V2 = False
# DEBUG = True
# ACT_GROUP_AWARE = True
# DESC_ACT = False
# DATASET_SIZE = 512
# DATASET_SORT = "desc"
# QUANT_BATCH_SIZE = 4
# CALIB_NOISE_MODE = "unseen"
# CALIB_NOISE_PERCENT = 0.025

def test_mimo(self):
self.quant_lm_eval()
16 changes: 8 additions & 8 deletions tests/models/test_llama3_2_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

# | Metric | MARLIN |
# |--------------------------------|----------|
# | arc_challenge :: acc,none | 0.3131 |
# | arc_challenge :: acc_norm,none | 0.3379 |
# | mmlu_stem :: acc,none | 0.3527 |
# | gsm8k_plat :: exact,flexible | 0.2754 |
# | arc_challenge :: acc,none | 0.3166 |
# | arc_challenge :: acc_norm,none | 0.3464 |
# | mmlu_stem :: acc,none | 0.3692 |
# | gsm8k_plat :: exact,flexible | 0.2994 |
class TestLlama3_2_awq(ModelTest):
NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct"
EVAL_BATCH_SIZE = 64
Expand All @@ -24,25 +24,25 @@ class TestLlama3_2_awq(ModelTest):
EVAL.LM_EVAL.GSM8K_PLATINUM_COT: {
"chat_template": True,
"exact_match,flexible-extract": {
"value": 0.2754,
"value": 0.2994,
"floor_pct": 0.04,
},
},
EVAL.LM_EVAL.ARC_CHALLENGE: {
"chat_template": True,
"acc": {
"value": 0.3131,
"value": 0.3166,
"floor_pct": 0.04,
},
"acc_norm": {
"value": 0.3379,
"value": 0.3464,
"floor_pct": 0.04,
},
},
EVAL.LM_EVAL.MMLU_STEM: {
"chat_template": False,
"acc": {
"value": 0.3527,
"value": 0.3692,
"floor_pct": 0.04,
},
},
Expand Down