From 835ae1507aaf7b28779d4abc12d60fd570ef34fb Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 3 Nov 2025 05:52:37 +0000 Subject: [PATCH 1/5] don't emit debug logs --- gptqmodel/looper/stage_layer.py | 41 +++++------ gptqmodel/looper/stage_subset.py | 117 ++++++++++++++++--------------- 2 files changed, 81 insertions(+), 77 deletions(-) diff --git a/gptqmodel/looper/stage_layer.py b/gptqmodel/looper/stage_layer.py index f7bb7472e..e1c41ccfc 100644 --- a/gptqmodel/looper/stage_layer.py +++ b/gptqmodel/looper/stage_layer.py @@ -15,7 +15,7 @@ import torch -from .. import DEVICE_THREAD_POOL +from .. import DEBUG_ON, DEVICE_THREAD_POOL from ..looper.awq_processor import AWQProcessor from ..looper.gptq_processor import GPTQProcessor from ..looper.named_module import NamedModule @@ -107,25 +107,26 @@ def run_layer_stage( for index, names in enumerate(modules): # Process the layer in smaller subsets so attention groups or # MoE experts can be quantized independently within a layer. - if isinstance(processor, AWQProcessor): - log.info( - "StageLayer[awq]: layer=%s subset=%s/%s size=%s names=%s", - layer_index, - index + 1, - subset_total, - len(names), - names[:5], - ) - elif log.isEnabledFor(logging.DEBUG): - log.debug( - "StageLayer: layer=%s subset=%s/%s processor=%s size=%s names=%s", - layer_index, - index + 1, - subset_total, - processor.name(), - len(names), - names[:8], - ) + if DEBUG_ON and log.isEnabledFor(logging.DEBUG): + if isinstance(processor, AWQProcessor): + log.debug( + "StageLayer[awq]: layer=%s subset=%s/%s size=%s names=%s", + layer_index, + index + 1, + subset_total, + len(names), + names[:5], + ) + else: + log.debug( + "StageLayer: layer=%s subset=%s/%s processor=%s size=%s names=%s", + layer_index, + index + 1, + subset_total, + processor.name(), + len(names), + names[:8], + ) subset_result = run_subset_stage( looper, processor=processor, diff --git a/gptqmodel/looper/stage_subset.py b/gptqmodel/looper/stage_subset.py index bd306d080..5738481fe 100644 --- a/gptqmodel/looper/stage_subset.py +++ b/gptqmodel/looper/stage_subset.py @@ -15,7 +15,7 @@ import torch -from .. import DEVICE_THREAD_POOL +from .. import DEBUG_ON, DEVICE_THREAD_POOL from ..looper.gptq_processor import GPTQProcessor from ..looper.loop_processor import LoopProcessor from ..looper.named_module import NamedModule @@ -111,25 +111,26 @@ def run_subset_stage( # ) # return SubsetStageResult(processed_subset={}, layer_inputs=layer_inputs, forward_context=None) - if is_awq_processor: - logger.info( - "StageSubset[awq]: layer=%s subset=%s/%s modules=%s sample=%s", - layer_index, - subset_index + 1, - subset_total, - len(subset), - list(subset.keys())[:8], - ) - elif logger.isEnabledFor(logging.DEBUG): - logger.debug( - "StageSubset: layer=%s subset=%s/%s processor=%s created %s modules (sample=%s)", - layer_index, - subset_index + 1, - subset_total, - processor_name, - len(subset), - list(subset.keys())[:8], - ) + if DEBUG_ON and logger.isEnabledFor(logging.DEBUG): + if is_awq_processor: + logger.debug( + "StageSubset[awq]: layer=%s subset=%s/%s modules=%s sample=%s", + layer_index, + subset_index + 1, + subset_total, + len(subset), + list(subset.keys())[:8], + ) + else: + logger.debug( + "StageSubset: layer=%s subset=%s/%s processor=%s created %s modules (sample=%s)", + layer_index, + subset_index + 1, + subset_total, + processor_name, + len(subset), + list(subset.keys())[:8], + ) moe_group_keys_all: List[str] = [] forward_device_map: Dict[str, torch.device] = {} @@ -242,23 +243,24 @@ def run_subset_stage( looper._masked_hook_wrapper(processor, original_hook, hook_source) )) - if is_awq_processor: - logger.info( - "StageSubset[awq]: layer=%s subset=%s/%s registering hooks for %s modules", - layer_index, - subset_index + 1, - subset_total, - len(subset), - ) - elif logger.isEnabledFor(logging.DEBUG): - logger.debug( - "StageSubset: layer=%s subset=%s/%s processor=%s registering hooks for %s modules", - layer_index, - subset_index + 1, - subset_total, - processor_name, - len(subset), - ) + if DEBUG_ON and logger.isEnabledFor(logging.DEBUG): + if is_awq_processor: + logger.debug( + "StageSubset[awq]: layer=%s subset=%s/%s registering hooks for %s modules", + layer_index, + subset_index + 1, + subset_total, + len(subset), + ) + else: + logger.debug( + "StageSubset: layer=%s subset=%s/%s processor=%s registering hooks for %s modules", + layer_index, + subset_index + 1, + subset_total, + processor_name, + len(subset), + ) fwd_start = time.perf_counter() forward_source = f"{layer_descriptor}:subset{subset_index + 1}/{subset_total}" @@ -410,26 +412,27 @@ def _process_on_worker( timer = getattr(looper.gptq_model, "quant_region_timer", None) start = time.perf_counter() if timer else None try: - if is_awq_processor: - logger.info( - "StageSubsetWorker[awq]: layer=%s subset=%s/%s module=%s previous_subset=%s", - getattr(nm, "layer_index", None), - subset_idx + 1, - subset_total_count, - module_label, - bool(previous_subset_ref), - ) - elif logger.isEnabledFor(logging.DEBUG): - logger.debug( - "StageSubsetWorker: processor=%s layer=%s subset=%s/%s module=%s running on %s (previous_subset=%s)", - proc_name, - getattr(nm, "layer_index", None), - subset_idx + 1, - subset_total_count, - module_label, - expected_device, - bool(previous_subset_ref), - ) + if DEBUG_ON and logger.isEnabledFor(logging.DEBUG): + if is_awq_processor: + logger.debug( + "StageSubsetWorker[awq]: layer=%s subset=%s/%s module=%s previous_subset=%s", + getattr(nm, "layer_index", None), + subset_idx + 1, + subset_total_count, + module_label, + bool(previous_subset_ref), + ) + else: + logger.debug( + "StageSubsetWorker: processor=%s layer=%s subset=%s/%s module=%s running on %s (previous_subset=%s)", + proc_name, + getattr(nm, "layer_index", None), + subset_idx + 1, + subset_total_count, + module_label, + expected_device, + bool(previous_subset_ref), + ) proc.process( module=nm, subset=subset_ref, From 95d265390bf13adce916465d358654c7f44d1982 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 3 Nov 2025 06:06:10 +0000 Subject: [PATCH 2/5] log --- tests/models/model_test.py | 27 +++++++++++++++++++++++++-- tests/models/test_llama3_2_awq.py | 1 + 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/tests/models/model_test.py b/tests/models/model_test.py index b4612b723..f0ebaba3b 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -24,7 +24,7 @@ from enum import Enum # noqa: E402 from pathlib import Path # noqa: E402 -from typing import Any, Dict, List # noqa: E402 +from typing import Any, Dict, List, Optional # noqa: E402 from logbar import LogBar # noqa: E402 from tabulate import tabulate # noqa: E402 @@ -57,7 +57,7 @@ def is_flash_attn_2_available(): # type: ignore return False -from gptqmodel import BACKEND, GPTQModel # noqa: E402 +from gptqmodel import BACKEND, DEBUG_ON, GPTQModel # noqa: E402 from gptqmodel.models.base import BaseQModel # noqa: E402 from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402 from gptqmodel.quantization import FORMAT, METHOD # noqa: E402 @@ -65,6 +65,7 @@ def is_flash_attn_2_available(): # type: ignore from gptqmodel.utils.eval import EVAL # noqa: E402 from gptqmodel.utils.model import MODALITY # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 +from gptqmodel.looper.module_looper import StopMainLoop # noqa: E402 RAND_SEED = 898 @@ -127,6 +128,7 @@ class ModelTest(unittest.TestCase): LM_HEAD_LOSS_MAX_DELTA_PERCENT = 0.1 # ±10% EXPECT_LM_HEAD_LOSS = None + STOP_AFTER_LAYER: Optional[int] = None GENERIC_TEST_PROMPTS = [ {"prompt": "Which city is the capital city of France?", "keywords": ["paris"]}, @@ -136,6 +138,22 @@ class ModelTest(unittest.TestCase): {"prompt": "Name the largest ocean on Earth.", "keywords": ["pacific"]}, ] + @staticmethod + def _build_layer_stop_callback(layer_idx: int): + class _StopAfterLayer: + def __init__(self, target: int): + self._target = target + self._triggered = False + + def layer_complete(self, *, layer_idx: int, submodule_finalized: bool): + if self._triggered: + return None + if layer_idx > self._target or (submodule_finalized and layer_idx >= self._target): + self._triggered = True + raise StopMainLoop + + return _StopAfterLayer(layer_idx) + def _normalize_task_identifier(self, task): if isinstance(task, Enum): return task.value @@ -816,6 +834,11 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne **args, ) + self._layer_stop_callback = None + if DEBUG_ON and self.STOP_AFTER_LAYER is not None: + self._layer_stop_callback = self._build_layer_stop_callback(self.STOP_AFTER_LAYER) + model.layer_callback = self._layer_stop_callback + tokenizer = model.tokenizer self._post_quant_eval_records = {} self._effective_load_backend = None diff --git a/tests/models/test_llama3_2_awq.py b/tests/models/test_llama3_2_awq.py index 855736ddd..6128d9073 100644 --- a/tests/models/test_llama3_2_awq.py +++ b/tests/models/test_llama3_2_awq.py @@ -19,6 +19,7 @@ class TestLlama3_2_awq(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" EVAL_BATCH_SIZE = 64 DATASET_CONCAT_SIZE = 2048 # new + # STOP_AFTER_LAYER = 0 EVAL_TASKS = { EVAL.LM_EVAL.GSM8K_PLATINUM_COT: { "chat_template": True, From 3ac9924d8b6b78e4184038fae716be62f15d2f5b Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 3 Nov 2025 06:19:07 +0000 Subject: [PATCH 3/5] update comments --- gptqmodel/looper/awq_processor.py | 22 ++++++++++----- tests/models/model_test.py | 45 +++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index 8b39f45bc..2b858ca7a 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -578,8 +578,10 @@ def _search_best_scale( inp = inp.to(next(module2inspect.parameters()).device) # [STEP 1]: Compute per-channel mean of normalised weights - # Accumulate statistics per-layer to avoid concatenating large tensors - # (original implementation materialized a giant cat() that doubled VRAM usage) + # Stream across each Linear instead of concatenating all weights at once. This mirrors the + # previous cat()+view pipeline while keeping peak memory low: for every group we normalise + # |w| by its per-group max so the values land on a [0, 1] scale, then accumulate the totals + # per output channel so the mean can be computed without allocating the combined tensor. first_weight = layers[0].weight weight_dtype = first_weight.dtype weight_device = first_weight.device @@ -609,14 +611,15 @@ def _search_best_scale( w_mean = (w_sum / row_count).to(weight_dtype) # [STEP 2]: Compute per-channel mean of the input activation with chunking - # Stream directly on the source device to avoid creating full CPU copies + # Stream directly on the source device to avoid creating full CPU copies while still enforcing + # a predictable memory bound derived from max_chunk_memory. inp_flat = inp.abs().view(-1, inp.shape[-1]) num_elements = inp_flat.size(0) num_channels = inp_flat.size(1) float32_size = torch.tensor([], dtype=torch.float32).element_size() element_size_bytes = float32_size # accumulation happens in FP32 - # Calculate chunk size dynamically based on max_chunk_memory + # Calculate chunk size dynamically based on the available memory budget (default 1 GiB). chunk_size = int(self.max_chunk_memory // (element_size_bytes * num_channels)) chunk_size = min(chunk_size, num_elements) chunk_size = max(chunk_size, 1) @@ -627,6 +630,7 @@ def _search_best_scale( for i in range(0, num_elements, chunk_size): end = min(i + chunk_size, num_elements) chunk = inp_flat[i:end] + # Accumulate each chunk in FP32 to balance precision and memory usage. chunk_sum = chunk.to(torch.float32).sum(dim=0) x_sum += chunk_sum @@ -705,7 +709,7 @@ def _compute_best_clip( w_all = w best_max_val_all = [] device = w_all.device - # Pre-allocate scratch buffers so the inner loop never allocates large temporaries + # Pre-allocate scratch buffers so the inner clamp loop never allocates large temporaries. scratch_clamp = torch.empty_like(w_all[:oc_batch_size]) scratch_quant = torch.empty_like(scratch_clamp) input_feat = input_feat.to(device) @@ -713,14 +717,14 @@ def _compute_best_clip( for i_b in range(org_w_shape[0] // oc_batch_size): w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size] - org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1 + org_max_val = w.abs().amax(dim=-1, keepdim=True) # [co_batch, 1, n_group, 1] best_max_val = org_max_val.clone() min_errs = torch.ones_like(org_max_val) * 1e9 clamp_slice = scratch_clamp[: w.shape[0]] quant_slice = scratch_quant[: w.shape[0]] - org_out = (input_feat * w).sum(dim=-1) + org_out = (input_feat * w).sum(dim=-1) # [co_batch, n_token, n_group] for i_s in range(int(max_shrink * n_grid)): max_val = org_max_val * (1 - i_s / n_grid) @@ -729,6 +733,7 @@ def _compute_best_clip( self._pseudo_quantize_tensor_into(clamp_slice, quant_slice) cur_out = (input_feat * quant_slice).sum(dim=-1) + # Evaluate the reconstruction error for the current clamp ratio and keep the best one. err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape) cur_best_idx = err < min_errs min_errs[cur_best_idx] = err[cur_best_idx] @@ -874,6 +879,8 @@ def _compute_best_scale( scales[torch.isnan(scales)] = 1 # Q(W * s) + # Temporarily apply the candidate scale, quantize the in-flight weights without allocating, + # and rely on the CPU master copy to restore the original FP values after evaluation. for fc in linears2scale: fc.weight.mul_(scales_view) self._pseudo_quantize_tensor_into(fc.weight, fc.weight) @@ -894,6 +901,7 @@ def _compute_best_scale( for fc in linears2scale: fc.weight.copy_(orig_weights_cpu[fc].to(device=fc.weight.device, dtype=fc.weight.dtype)) + # Reset weights one final time so callers always see the pristine FP copy. for fc in linears2scale: fc.weight.copy_(orig_weights_cpu[fc].to(device=fc.weight.device, dtype=fc.weight.dtype)) orig_weights_cpu.clear() diff --git a/tests/models/model_test.py b/tests/models/model_test.py index f0ebaba3b..e6a6afc3f 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -154,6 +154,30 @@ def layer_complete(self, *, layer_idx: int, submodule_finalized: bool): return _StopAfterLayer(layer_idx) + def _debug_layer_stop_triggered(self) -> bool: + if not DEBUG_ON: + return False + callback = getattr(self, "_layer_stop_callback", None) + return bool(callback and getattr(callback, "_triggered", False)) + + def _finalize_quant_debug_path( + self, + *, + model, + tokenizer, + processor, + need_create_processor: bool, + cleanup_callback, + ): + if cleanup_callback is not None: + try: + cleanup_callback() + except Exception: + pass + if need_create_processor: + return model, tokenizer, processor + return model, tokenizer + def _normalize_task_identifier(self, task): if isinstance(task, Enum): return task.value @@ -842,6 +866,7 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne tokenizer = model.tokenizer self._post_quant_eval_records = {} self._effective_load_backend = None + processor = None is_image_to_text_model = MODALITY.IMAGE_TO_TEXT in model.modality calibration_dataset = get_calib_dataset(model) if is_image_to_text_model else self.load_dataset(tokenizer, self.DATASET_SIZE) @@ -857,6 +882,8 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne # ovis cannot load processor is_ovis_model = model.__class__.__name__ == "OvisGPTQ" need_create_processor = is_image_to_text_model and not is_ovis_model + + debug_short_circuit = False if not is_quantized: save_context = None planned_save_path = None @@ -875,6 +902,20 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne self.check_kernel(model, self.KERNEL_QUANT) + debug_short_circuit = self._debug_layer_stop_triggered() + if debug_short_circuit: + log.info( + "DEBUG mode: layer stop triggered at %s; skipping post-quant save and evaluation pipeline.", + self.STOP_AFTER_LAYER, + ) + return self._finalize_quant_debug_path( + model=model, + tokenizer=tokenizer, + processor=None, + need_create_processor=need_create_processor, + cleanup_callback=cleanup_callback, + ) + # TODO: make into shared method with save_context as path: cleanup_callback = None @@ -1117,6 +1158,10 @@ def quant_lm_eval(self): self.check_kernel(self.model, self.KERNEL_INFERENCE) + if self._debug_layer_stop_triggered(): + log.info("DEBUG mode: skipping lm_eval and baseline checks after early layer stop.") + return + eval_records = getattr(self, "_post_quant_eval_records", {}) target_backend = self._current_load_backend() if eval_records and len(eval_records) == 1 and target_backend in eval_records: From ec7e9514ef1cc7e8caebaab97ec55c5dd203a45a Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 3 Nov 2025 09:08:18 +0000 Subject: [PATCH 4/5] update scores, add single gpu ci test option --- tests/models/model_test.py | 29 +++++++++++++++++++++++------ tests/models/test_llama3_2_awq.py | 12 ++++++------ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/tests/models/model_test.py b/tests/models/model_test.py index e6a6afc3f..2048cc4ac 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -96,6 +96,7 @@ class ModelTest(unittest.TestCase): DATASET_SORT = "desc" DELETE_QUANTIZED_MODEL = True EVAL_TASKS = None + EVAL_SINGLE_GPU = True LOAD_MODEL_EXTRA_ARGS: Dict[str, Any] = {} KERNEL_QUANT = {} # kernel sets @@ -478,8 +479,12 @@ def perform_post_quant_validation(self, model_path, trust_remote_code=False): for backend in compare_backends: log.info(f"Loading post-quant model with backend `{backend.name}`") - # Pin post-quant loads to the first CUDA device to avoid auto sharding across GPUs. - use_cuda_map = torch.cuda.is_available() and backend != BACKEND.TORCH_FUSED + # When EVAL_SINGLE_GPU is enabled, pin post-quant loads to the first CUDA device to avoid auto sharding. + use_cuda_map = ( + self.EVAL_SINGLE_GPU + and torch.cuda.is_available() + and backend != BACKEND.TORCH_FUSED + ) if use_cuda_map: model = self.loadQuantModel( model_path, @@ -932,8 +937,12 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne q_model = reuse_candidates.pop(target_backend, None) if q_model is None: - # Ensure the post-quant reload stays on a single CUDA device when available. - use_cuda_map = torch.cuda.is_available() and target_backend != BACKEND.TORCH_FUSED + # When single-GPU evaluation is requested, keep the reload scoped to cuda:0. + use_cuda_map = ( + self.EVAL_SINGLE_GPU + and torch.cuda.is_available() + and target_backend != BACKEND.TORCH_FUSED + ) if use_cuda_map: q_model = self.loadQuantModel( path, @@ -1010,7 +1019,8 @@ def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_pa multi_device = False if multi_device: - load_kwargs["device_map"] = {"": "cuda:0"} + if self.EVAL_SINGLE_GPU: + load_kwargs["device_map"] = {"": "cuda:0"} model = GPTQModel.load( model_id_or_path, @@ -1032,11 +1042,18 @@ def lm_eval(self, model, trust_remote_code=False, delete_quantized_model=False, model_path = model if self.USE_VLLM: + tensor_parallel = 1 + if not self.EVAL_SINGLE_GPU: + try: + candidate = torch.cuda.device_count() + except Exception: + candidate = 1 + tensor_parallel = max(1, candidate) model_args = { "pretrained": model_path, "dtype": "auto", #"float16", "gpu_memory_utilization": 0.8, - "tensor_parallel_size": 1, + "tensor_parallel_size": tensor_parallel, "trust_remote_code": trust_remote_code, "max_model_len": self.MODEL_MAX_LEN } diff --git a/tests/models/test_llama3_2_awq.py b/tests/models/test_llama3_2_awq.py index 6128d9073..0556673aa 100644 --- a/tests/models/test_llama3_2_awq.py +++ b/tests/models/test_llama3_2_awq.py @@ -11,10 +11,10 @@ # | Metric | MARLIN | # |--------------------------------|----------| -# | arc_challenge :: acc,none | 0.3106 | -# | arc_challenge :: acc_norm,none | 0.3532 | +# | arc_challenge :: acc,none | 0.3131 | +# | arc_challenge :: acc_norm,none | 0.3379 | # | mmlu_stem :: acc,none | 0.3527 | -# | gsm8k_plat :: exact,flexible | 0.2192 | +# | gsm8k_plat :: exact,flexible | 0.2754 | class TestLlama3_2_awq(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" EVAL_BATCH_SIZE = 64 @@ -24,18 +24,18 @@ class TestLlama3_2_awq(ModelTest): EVAL.LM_EVAL.GSM8K_PLATINUM_COT: { "chat_template": True, "exact_match,flexible-extract": { - "value": 0.2192, + "value": 0.2754, "floor_pct": 0.04, }, }, EVAL.LM_EVAL.ARC_CHALLENGE: { "chat_template": True, "acc": { - "value": 0.3106, + "value": 0.3131, "floor_pct": 0.04, }, "acc_norm": { - "value": 0.3532, + "value": 0.3379, "floor_pct": 0.04, }, }, From d57650308f9e1d7cf7372327ef05a3bfa84f20e4 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 3 Nov 2025 09:09:12 +0000 Subject: [PATCH 5/5] odd version for dev --- gptqmodel/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptqmodel/version.py b/gptqmodel/version.py index 14e239dfe..b62cbd7ea 100644 --- a/gptqmodel/version.py +++ b/gptqmodel/version.py @@ -7,4 +7,4 @@ # even minor versions are release # 5.2.0 => release, 5.1.0 => devel # micro version (5.2.x) denotes patch fix, i.e. 5.2.1 is a patch fix release -__version__ = "5.2.0" +__version__ = "5.3.0"