From 114498ddb5322587ae7da399a8856c318312e875 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 2 Nov 2025 12:36:44 +0000 Subject: [PATCH 1/4] comments --- gptqmodel/looper/module_looper.py | 7 +++++++ gptqmodel/looper/stage_layer.py | 14 ++++++++++++++ gptqmodel/looper/stage_subset.py | 16 ++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 6a414a0aa..88c92cf7f 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -820,6 +820,8 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) -> num_devices = len(devices) for index, device in enumerate(devices): + # Split the outstanding batches across devices so that each accelerator + # receives a contiguous slice. remaining_batches = max(total_batches - segment_start, 0) remaining_devices = max(num_devices - index, 1) segment_length = remaining_batches // remaining_devices @@ -841,6 +843,7 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) -> max_segment_length = len(indices) for position in range(max_segment_length): + # Submit one batch per device futures = [] for device in devices: segment_indices = device_segments.get(device, []) @@ -874,6 +877,7 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) -> ) for fut in futures: + # Preserve the original batch order batch_idx, module_output, kv_next = fut.result() if need_outputs and module_output is not None: results[batch_idx] = module_output @@ -903,6 +907,8 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) -> ordered_outputs: List[List[torch.Tensor]] = [] for idx in range(total_batches): + # Rebuild the ordered list of batch outputs expected by the next + # stage. module_output = results.get(idx) if module_output is None: raise RuntimeError("Forward batch returned no output; data-parallel execution produced empty result.") @@ -1105,6 +1111,7 @@ def _loop_impl(self, fail_safe: bool = False, **kwargs): try: for index, reverse_p in enumerate(reversed_processors, start=1): + # Finalize processors in reverse order self._check_loop_stop() if isinstance(reverse_p, GPTQProcessor): pass diff --git a/gptqmodel/looper/stage_layer.py b/gptqmodel/looper/stage_layer.py index 972abc295..f7bb7472e 100644 --- a/gptqmodel/looper/stage_layer.py +++ b/gptqmodel/looper/stage_layer.py @@ -48,6 +48,8 @@ def run_layer_stage( """Execute the main per-layer quantization loop.""" log = logger or setup_logger() for layer_index in pb: + # Iterate over every transformer layer (plus lm_head when enabled) as + # progress-bar controlled units of work. if looper._check_loop_stop(): break is_lm_head_module = layer_index >= layer_count @@ -78,6 +80,8 @@ def run_layer_stage( full = find_modules(module, name=looper.gptq_model.lm_head if is_lm_head_module else "") for p_index, processor in enumerate(looper.processors): + # Each processor contributes a quantization phase; walk them in + # order so their caches and side effects line up with the pipeline. processor.log_call_count = 0 # reset processor.collect_memory_info(layer_index) @@ -101,6 +105,8 @@ def run_layer_stage( previous_subset_processed: Optional[Dict[str, NamedModule]] = None for index, names in enumerate(modules): + # Process the layer in smaller subsets so attention groups or + # MoE experts can be quantized independently within a layer. if isinstance(processor, AWQProcessor): log.info( "StageLayer[awq]: layer=%s subset=%s/%s size=%s names=%s", @@ -273,6 +279,8 @@ def run_layer_stage( looper.gptq_model.post_quantize(module) for finalized in processed_subset.values(): + # Reset finalized modules to CPU to guarantee deterministic + # ownership before the next processor touches the layer. if isinstance(finalized, NamedModule): setattr(finalized, "target_device", CPU) inner_module = getattr(finalized, "module", None) @@ -299,6 +307,8 @@ def run_layer_stage( finalize_tasks = [] for reverse_p in reversed(looper.processors): + # Collect finalize tasks in reverse to mirror the processor + # execution order and honor downstream dependencies. for module in processed_subset.values(): actual_module = module.module if isinstance(module, NamedModule) else module @@ -383,6 +393,8 @@ def _finalize_on_worker(process, module, idx, total, module_label, layer_idx): # ).draw() for index, (process, module, module_label, target_dev, layer_idx) in enumerate(finalize_tasks, start=1): + # Schedule finalize work on the device thread pool so CPU + # bound tasks do not stall the main orchestration loop. future = DEVICE_THREAD_POOL.submit( target_dev, _finalize_on_worker, @@ -440,6 +452,8 @@ def _drain_finalize_futures( completed_local = 0 try: for future in as_completed(futures): + # Drain futures as they complete to surface errors + # quickly and keep the progress bar in sync. try: result = future.result() except BaseException as exc: diff --git a/gptqmodel/looper/stage_subset.py b/gptqmodel/looper/stage_subset.py index 63e103c58..90d96489c 100644 --- a/gptqmodel/looper/stage_subset.py +++ b/gptqmodel/looper/stage_subset.py @@ -148,6 +148,8 @@ def run_subset_stage( combined_names.append(candidate) for sub_name in combined_names: + # Group every expert (including ones outside the current subset) so + # load balancing decisions can span the full MoE family. group_key = looper._extract_moe_group_key(sub_name) if group_key is None: continue @@ -208,6 +210,9 @@ def run_subset_stage( subset_size = len(subset) for idx, (name, m) in enumerate(subset.items()): + # Register the forward hook that captures activations for quantization. + # The final module optionally flips a flag so processors can trigger + # once-per-subset logic after the forward pass. is_last = (idx == subset_size - 1) hook_source = getattr(m, "full_name", None) if hook_source is None: @@ -320,9 +325,12 @@ def run_subset_stage( pb.title(layer_title).subtitle("").draw() for h in handle: + # Detach temporary hooks to avoid leaking state into future passes. h.remove() for name in subset: + # Reset inline hook attributes on NamedModule wrappers so future passes + # do not reuse state from this subset run. if hasattr(subset[name], 'forward_hook'): subset[name].forward_hook = None subset[name].forward_hook_last = False @@ -330,6 +338,8 @@ def run_subset_stage( moe_skip_modules = [] if isinstance(processor, GPTQProcessor): for name in subset: + # Skip MoE experts that never fired; they likely lacked calibration + # traffic and would produce invalid statistics. if processor.tasks[name].fwd_counter == 0: logger.error(f"`{name}` was not invoked, if it is a MoE module, it may lack sufficient calibration data routed to it.") moe_skip_modules.append(name) @@ -343,6 +353,8 @@ def run_subset_stage( quant_target_devices: Dict[str, torch.device] = {} for name, named_module in subset.items(): + # Ensure each module has a matching processor task before sending it to + # the worker pool; otherwise freeze it on the current device. task_map = getattr(processor, "tasks", None) has_task = bool(task_map and task_map.get(name) is not None) @@ -424,6 +436,8 @@ def _process_on_worker( return nm.name, nm for name, named_module in subset.items(): + # Launch processing for every module in the subset; tasks may run in + # parallel as allowed by the device thread pool. tgt_dev = quant_target_devices.get(name, cur_layer_device) futures.append( DEVICE_THREAD_POOL.submit( @@ -440,6 +454,8 @@ def _process_on_worker( ) for fut in futures: + # Collect results in submission order so the final subset map preserves + # deterministic iteration for downstream consumers. name, named_module = fut.result() processed_subset[name] = named_module torch_sync() From 475e6b4526820568eb6af66a7a78f8b2f8ccbcba Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 2 Nov 2025 12:59:30 +0000 Subject: [PATCH 2/4] make sure bg thread partial is ready for consumption before use --- gptqmodel/quantization/gptq.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 3ef45433d..e869e2297 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -23,6 +23,7 @@ from ..quantization import QuantizeConfig from ..utils.device import get_device from ..utils.logger import setup_logger +from ..utils.torch import torch_sync from .gar import compose_final_perm, compute_global_perm, compute_local_perms, invert_perm from .quantizer import HF_OPTIMUM, Quantizer @@ -460,7 +461,7 @@ def _select_hessian_target_device(self, requested: Optional[torch.device]) -> to return torch.device("cpu") - def _materialize_global_hessian(self, target_device: Optional[torch.device] = None) -> None: + def materialize_global_hessian(self, target_device: Optional[torch.device] = None) -> None: device = self._select_hessian_target_device(target_device) with self.lock: @@ -500,19 +501,11 @@ def _materialize_global_hessian(self, target_device: Optional[torch.device] = No for partial_device, partial in self._device_hessian_partials.items(): if partial.device != result_accum.device or partial.dtype != torch.float32: - # TODO FIXME multi-3090 using P2P is revaling an issue where result_accum and/or partial is not ready for consolidation on the main thread - # when parials are calculated on the individual - try: - tmp = partial.to(device=result_accum.device, dtype=torch.float32) - result_accum.add_(tmp) - del tmp - except: - log.warn(f"Quantization: Module `{self.name}` -> Retry 1/2 partial.to in 0.5s") - time.sleep(0.25) - tmp = partial.to(device=result_accum.device, dtype=torch.float32) - result_accum.add_(tmp) - del tmp - + # TODO: each partial sync should be done at the partial and not here + # partials done on background threads, make sure partial is ready for consumption + torch_sync(partial.device) + + result_accum.add_(partial.to(device=result_accum.device, dtype=torch.float32)) else: result_accum.add_(partial) @@ -527,7 +520,7 @@ def _materialize_global_hessian(self, target_device: Optional[torch.device] = No del result_accum def finalize_hessian(self, target_device: Optional[torch.device] = None) -> torch.Tensor: - self._materialize_global_hessian(target_device=target_device) + self.materialize_global_hessian(target_device=target_device) if self.H is None: self.H = torch.zeros((self.columns, self.columns), dtype=torch.float32, device=self._select_hessian_target_device(target_device)) return self.H From 224d80d099baac78c47457b223101100f145d237 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 2 Nov 2025 13:11:13 +0000 Subject: [PATCH 3/4] avoid delattr exception --- gptqmodel/models/writer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index 4f7d2f4ae..a9cdf644a 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -241,7 +241,11 @@ def strip_attention_impl_fields(target: Any) -> Dict[str, Any]: for attr in ("attn_implementation", "_attn_implementation"): if hasattr(target, attr): removed[attr] = getattr(target, attr) - delattr(target, attr) + # Avoid AttributeError: property '_attn_implementation' of 'Qwen2Config' object has no deleter + try: + delattr(target, attr) + except Exception: + pass return removed generation_config = getattr(self.model, "generation_config", None) From 36298bf8e7105dab3ef945c16725c4c94c3db4fb Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 2 Nov 2025 14:04:01 +0000 Subject: [PATCH 4/4] fix sync error at src --- gptqmodel/quantization/gptq.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index e869e2297..018bc2d3a 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -350,7 +350,9 @@ def _compute_hessian_xtx(self, matrix: torch.Tensor) -> torch.Tensor: if chunk_size is None: mat32 = matrix.to(dtype=torch.float32) - return torch.matmul(mat32.T, mat32) + xtx = torch.matmul(mat32.T, mat32) + torch_sync(device=xtx.device) + return xtx xtx_accum = torch.zeros((self.columns, self.columns), dtype=torch.float32, device=matrix.device) @@ -361,6 +363,7 @@ def _compute_hessian_xtx(self, matrix: torch.Tensor) -> torch.Tensor: materialized32 = materialized xtx_accum.add_(torch.matmul(materialized32.T, materialized32)) + torch_sync(device=xtx_accum.device) return xtx_accum def process_batch(self, inp: torch.Tensor) -> Tuple[int, Optional[torch.Tensor], torch.device]: @@ -501,10 +504,6 @@ def materialize_global_hessian(self, target_device: Optional[torch.device] = Non for partial_device, partial in self._device_hessian_partials.items(): if partial.device != result_accum.device or partial.dtype != torch.float32: - # TODO: each partial sync should be done at the partial and not here - # partials done on background threads, make sure partial is ready for consumption - torch_sync(partial.device) - result_accum.add_(partial.to(device=result_accum.device, dtype=torch.float32)) else: result_accum.add_(partial)