From 6301de19e402e0311eb8683381da92ec86d1f9e1 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 14 Oct 2025 11:29:53 +0000 Subject: [PATCH 1/8] layer csllback Signed-off-by: Qubitium --- gptqmodel/looper/module_looper.py | 123 ++++++++++++++++++++++++++- tests/test_module_looper_callback.py | 101 ++++++++++++++++++++++ 2 files changed, 223 insertions(+), 1 deletion(-) create mode 100644 tests/test_module_looper_callback.py diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 577b5eaac..770d15107 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -65,6 +65,10 @@ class FinalizeProgressInfo(NamedTuple): layer_idx: Optional[int] +class StopMainLoop(Exception): + """Signal that the module loop should abort immediately.""" + + class ModuleLooper(): """Drive the per-layer quantisation workflow over one or more devices. @@ -77,6 +81,10 @@ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]): self.gptq_model = model self.support_batch_quantize = model.support_batch_quantize self.lock = threading.Lock() + self._layer_callback = getattr(model, "layer_callback", None) + self._loop_stop_event = threading.Event() + self._loop_stop_exc: Optional[BaseException] = None + self._loop_stop_waited = False disk_speed = estimate_disk_io_speed() disk_speed_mb = disk_speed / (1024 * 1024) @@ -105,6 +113,80 @@ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]): for processor in self.processors: self._processor_mask_tls(processor) + def register_layer_callback(self, callback) -> None: + """Register or replace the layer-complete callback target.""" + self._layer_callback = callback + + def _resolve_layer_callback(self): + for candidate in ( + getattr(self, "_layer_callback", None), + getattr(self, "layer_callback", None), + getattr(self.gptq_model, "layer_callback", None), + getattr(self.gptq_model, "callbackup", None), + getattr(self.gptq_model, "callback", None), + ): + if candidate is not None: + return candidate + return None + + def callbackup(self, layer_idx: int, submodule_finalized: bool): + callback = self._resolve_layer_callback() + if callback is None: + return None + + handler = getattr(callback, "layer_complete", None) + if handler is None and callable(callback): + handler = callback + if handler is None: + return None + + try: + result = handler(layer_idx=layer_idx, submodule_finalized=submodule_finalized) + except StopMainLoop: + raise + if result is StopMainLoop: + raise StopMainLoop(f"Layer callback requested stop at layer {layer_idx}") + if isinstance(result, StopMainLoop): + raise result + return result + + def _request_loop_stop(self, exc: BaseException) -> None: + with self.lock: + if self._loop_stop_exc is None: + self._loop_stop_exc = exc + self._loop_stop_event.set() + + def _check_loop_stop(self) -> None: + if not self._loop_stop_event.is_set(): + return + if not self._loop_stop_waited: + DEVICE_THREAD_POOL.wait() + self._loop_stop_waited = True + raise self._loop_stop_exc or StopMainLoop("Module loop stopped by callback") + + def _emit_layer_complete( + self, + layer_idx: int, + submodule_finalized: bool, + *, + raise_in_place: bool, + ) -> None: + try: + self.callbackup(layer_idx=layer_idx, submodule_finalized=submodule_finalized) + except StopMainLoop as exc: + if raise_in_place: + raise + self._request_loop_stop(exc) + except BaseException as exc: + if raise_in_place: + raise + log.exception( + "Layer completion callback raised an exception (layer=%s, submodule_finalized=%s)", + layer_idx, + submodule_finalized, + ) + self._request_loop_stop(exc) + # Processors capture activations through hooks that need thread-local state # so masks survive the roundtrip to worker threads. def _processor_mask_tls(self, processor: LoopProcessor) -> threading.local: @@ -872,6 +954,7 @@ def loop(self, fail_safe: bool = False, **kwargs): setattr(parent, module_path[-1], hooked_lm_head) for layer_index in pb: + self._check_loop_stop() is_lm_head_module = layer_index >= layer_count if is_lm_head_module: @@ -928,6 +1011,17 @@ def loop(self, fail_safe: bool = False, **kwargs): lock_ctx = DEVICE_THREAD_POOL.read_lock(cur_layer_device) with ctx(lock_ctx, device_ctx(device_for_ctx)): processor.layer_quantize(module, cur_layer_device, named_childs) + if p_index == len(self.processors) - 1: + self._emit_layer_complete( + layer_idx=layer_index, + submodule_finalized=False, + raise_in_place=True, + ) + self._emit_layer_complete( + layer_idx=layer_index, + submodule_finalized=True, + raise_in_place=True, + ) continue layer_inputs = processor.inputs_cache.layer_inputs @@ -1339,6 +1433,12 @@ def _finalize_on_worker(process, module, idx, total, module_label, layer_idx): finalize_futures_snapshot = list(finalize_futures) + self._emit_layer_complete( + layer_idx=layer_index, + submodule_finalized=False, + raise_in_place=True, + ) + if finalize_futures_snapshot: known_layers = sorted( { @@ -1371,11 +1471,17 @@ def _drain_finalize_futures( futures, finalize_pb_local, finalize_count_local, + layer_idx_for_callback, ): completed_local = 0 try: for future in as_completed(futures): - result = future.result() + try: + result = future.result() + except BaseException as exc: + log.exception("Submodule finalize task raised an exception") + self._request_loop_stop(exc) + return if isinstance(result, FinalizeProgressInfo): module_label = result.module_label @@ -1399,6 +1505,11 @@ def _drain_finalize_futures( ).subtitle(subtitle).draw() finally: finalize_pb_local.close() + self._emit_layer_complete( + layer_idx=layer_idx_for_callback, + submodule_finalized=True, + raise_in_place=False, + ) if finalize_futures_snapshot: # Drain finalize futures asynchronously so the main loop can continue scheduling work. @@ -1408,14 +1519,23 @@ def _drain_finalize_futures( [future for future, *_ in finalize_futures_snapshot], finalize_pb, finalize_count, + layer_index, ), name="SubmoduleFinalizeWatcher", daemon=True, ).start() + else: + self._emit_layer_complete( + layer_idx=layer_index, + submodule_finalized=True, + raise_in_place=True, + ) # LifeCycle: All sub-modules have finalized meaning quantization work is complete + self._check_loop_stop() # Ensure ANY remaining tasks the looper submitted have drained DEVICE_THREAD_POOL.wait() # same as wait('all') + self._check_loop_stop() # paranoid safety check # torch_sync() @@ -1437,6 +1557,7 @@ def _drain_finalize_futures( try: for index, reverse_p in enumerate(reversed_processors, start=1): + self._check_loop_stop() if isinstance(reverse_p, GPTQProcessor): pass elif isinstance(reverse_p, EoraProcessor): diff --git a/tests/test_module_looper_callback.py b/tests/test_module_looper_callback.py new file mode 100644 index 000000000..9d6f68d9c --- /dev/null +++ b/tests/test_module_looper_callback.py @@ -0,0 +1,101 @@ +import types + +import pytest + +from gptqmodel.looper.module_looper import ModuleLooper, StopMainLoop + + +class DummyQModel: + def __init__(self): + self.support_batch_quantize = False + self.quantize_config = types.SimpleNamespace(device=None) + self.layer_callback = None + + +def make_looper(layer_callback=None): + model = DummyQModel() + if layer_callback is not None: + model.layer_callback = layer_callback + processors = [types.SimpleNamespace()] + return ModuleLooper(model=model, processors=processors) + + +def test_callbackup_invokes_model_layer_callback(): + calls = [] + + class Recorder: + def layer_complete(self, *, layer_idx, submodule_finalized): + calls.append((layer_idx, submodule_finalized)) + + looper = make_looper(layer_callback=Recorder()) + + looper.callbackup(layer_idx=3, submodule_finalized=False) + looper.callbackup(layer_idx=3, submodule_finalized=True) + + assert calls == [(3, False), (3, True)] + + +def test_callbackup_stop_request_via_returning_class(): + def stopper(**_): + return StopMainLoop + + looper = make_looper(layer_callback=stopper) + + with pytest.raises(StopMainLoop): + looper.callbackup(layer_idx=1, submodule_finalized=False) + + +def test_callbackup_stop_request_via_instance(): + def stopper(**_): + return StopMainLoop("stop") + + looper = make_looper(layer_callback=stopper) + + with pytest.raises(StopMainLoop): + looper.callbackup(layer_idx=1, submodule_finalized=False) + + +def test_emit_layer_complete_records_stop(monkeypatch): + err = ValueError("boom") + + def raising_callback(*, layer_idx, submodule_finalized): + raise err + + looper = make_looper(layer_callback=raising_callback) + + looper._emit_layer_complete( + layer_idx=7, + submodule_finalized=False, + raise_in_place=False, + ) + + assert looper._loop_stop_exc is err + assert looper._loop_stop_event.is_set() + + monkeypatch.setattr( + "gptqmodel.looper.module_looper.DEVICE_THREAD_POOL.wait", + lambda *_, **__: None, + ) + + with pytest.raises(ValueError) as exc: + looper._check_loop_stop() + + assert exc.value is err + + +def test_emit_layer_complete_propagates_when_requested(): + err = RuntimeError("direct") + + def raising_callback(*, layer_idx, submodule_finalized): + raise err + + looper = make_looper(layer_callback=raising_callback) + + with pytest.raises(RuntimeError) as exc: + looper._emit_layer_complete( + layer_idx=2, + submodule_finalized=True, + raise_in_place=True, + ) + + assert exc.value is err From 54bd4582c2519ba5f519acf6ef2e3aa77573b811 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 14 Oct 2025 14:05:35 +0000 Subject: [PATCH 2/8] add logs Signed-off-by: Qubitium --- gptqmodel/looper/gptq_processor.py | 4 +- gptqmodel/looper/module_looper.py | 22 +- gptqmodel/utils/model.py | 11 +- tests/models/test_multi_vs_single_gpu.py | 302 +++++++++++++++++++++++ tests/test_format_conversion_flow.py | 20 +- tests/test_module_looper_callback.py | 24 ++ 6 files changed, 367 insertions(+), 16 deletions(-) create mode 100644 tests/models/test_multi_vs_single_gpu.py diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index 75bc45892..565460634 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -311,7 +311,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): logger=log, module_name=module_label, ): - pack_module( + packer_label = pack_module( name=module.full_name, qModules=qModules, q_scales=q_scales, @@ -326,7 +326,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): timer.record( "submodule_finalize_pack", time.perf_counter() - pack_start, - source=module_label, + source=f"{module_label} [{packer_label or 'module.pack_original'}]", ) # TODO: store module quant results in module, not global processor result diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 770d15107..aed41d0fa 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -150,19 +150,21 @@ def callbackup(self, layer_idx: int, submodule_finalized: bool): raise result return result - def _request_loop_stop(self, exc: BaseException) -> None: + def _request_loop_stop(self, exc: Optional[BaseException]) -> None: with self.lock: - if self._loop_stop_exc is None: + if self._loop_stop_exc is None and exc is not None: self._loop_stop_exc = exc self._loop_stop_event.set() - def _check_loop_stop(self) -> None: + def _check_loop_stop(self) -> bool: if not self._loop_stop_event.is_set(): - return + return False if not self._loop_stop_waited: DEVICE_THREAD_POOL.wait() self._loop_stop_waited = True - raise self._loop_stop_exc or StopMainLoop("Module loop stopped by callback") + if self._loop_stop_exc is not None: + raise self._loop_stop_exc + return True def _emit_layer_complete( self, @@ -173,10 +175,9 @@ def _emit_layer_complete( ) -> None: try: self.callbackup(layer_idx=layer_idx, submodule_finalized=submodule_finalized) - except StopMainLoop as exc: - if raise_in_place: - raise - self._request_loop_stop(exc) + except StopMainLoop: + self._request_loop_stop(None) + return except BaseException as exc: if raise_in_place: raise @@ -954,7 +955,8 @@ def loop(self, fail_safe: bool = False, **kwargs): setattr(parent, module_path[-1], hooked_lm_head) for layer_index in pb: - self._check_loop_stop() + if self._check_loop_stop(): + break is_lm_head_module = layer_index >= layer_count if is_lm_head_module: diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 3d2642493..cbb30cf87 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -726,6 +726,8 @@ def pack_module( except (RuntimeError, ValueError): log.warning(f"pack_module: unable to parse target device `{cfg_device}`; defaulting to CUDA auto-select.") + packer_label = None + with lock: layers[name] = layer qModules[name] = module @@ -734,8 +736,9 @@ def pack_module( if quant_linear_cls.QUANT_TYPE == "qqq": if q_scales_extra is not None: q_scales_extra = q_scales_extra.to(CPU) + packer_label = "module.pack" with log_time_block( - "module.pack", + packer_label, logger=log, module_name=name, ): @@ -767,8 +770,10 @@ def pack_module( "original": "module.pack_original", } + packer_label = label_map[effective_impl] + with log_time_block( - label_map[effective_impl], + packer_label, logger=log, module_name=name, ): @@ -811,6 +816,8 @@ def pack_module( # qModules[name].to(layer_device) # log.info(f"Pack: moving module back to `{layer_device}` cost = {time.time()-start} seconds") + return packer_label + def pack_model( model, quant_result: Dict[str, Dict[str, Any]], diff --git a/tests/models/test_multi_vs_single_gpu.py b/tests/models/test_multi_vs_single_gpu.py new file mode 100644 index 000000000..c5005b9f7 --- /dev/null +++ b/tests/models/test_multi_vs_single_gpu.py @@ -0,0 +1,302 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +import os +import sys +from contextlib import ExitStack +from dataclasses import dataclass +from decimal import Decimal +from typing import Dict, Iterable, List, Tuple +from unittest import mock + +import torch + +from gptqmodel import GPTQModel +from gptqmodel.models.writer import ( + PROCESS_LOG_LAYER, + PROCESS_LOG_MODULE, + QUANT_LOG_LOSS, + QUANT_LOG_NSAMPLES, +) +from gptqmodel.looper.module_looper import StopMainLoop +from gptqmodel.quantization.config import QuantizeConfig +from gptqmodel.utils.torch import torch_empty_cache + +from model_test import ModelTest + + +@dataclass(frozen=True) +class LayerMetrics: + loss: Decimal + samples: int + + + +def _is_free_threaded() -> bool: + gil_check = getattr(sys, "_is_gil_enabled", None) + if callable(gil_check): + return not gil_check() + env_value = os.environ.get("PYTHON_GIL", "1").lower() + return env_value in {"0", "false", "off"} + + +class TestMultiVsSingleGPU(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" + NATIVE_ARC_CHALLENGE_ACC = 0.3311 + NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3549 + QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.05 + APPLY_CHAT_TEMPLATE = True + V2 = False + DEBUG = True + ACT_GROUP_AWARE = False + DESC_ACT = True + DATASET_SIZE = 1024 + DATASET_SORT = "desc" + QUANT_BATCH_SIZE = 4 + USE_FLASH_ATTN = True + + def test_quantization_first_layer_metrics_match_between_single_and_dual_gpu(self) -> None: + if not torch.cuda.is_available(): + self.skipTest("CUDA required for multi-GPU regression test") + + visible_devices = torch.cuda.device_count() + if visible_devices < 2: + self.skipTest("Requires at least two CUDA devices") + + if sys.version_info < (3, 13): + self.skipTest("Requires Python 3.13 free-threaded runtime") + + if not _is_free_threaded(): + self.skipTest("PYTHON_GIL must be disabled (set PYTHON_GIL=0) for multi-threaded quantization") + + single_layer_metrics, single_batch_stats = self._quantize_first_layer(device_indices=[0]) + multi_layer_metrics, multi_batch_stats = self._quantize_first_layer(device_indices=[0, 1]) + + self.assertTrue(single_layer_metrics, "Single-GPU quantization produced no layer-0 metrics") + self.assertTrue(multi_layer_metrics, "Multi-GPU quantization produced no layer-0 metrics") + self.assertEqual( + set(single_layer_metrics.keys()), + set(multi_layer_metrics.keys()), + "Layer-0 module set differs between single-GPU and multi-GPU quantization", + ) + + mismatches: Dict[str, Dict[str, str]] = {} + for module_name in single_layer_metrics: + single = single_layer_metrics[module_name] + multi = multi_layer_metrics[module_name] + if single.samples != multi.samples or single.loss != multi.loss: + mismatches[module_name] = { + "single_samples": str(single.samples), + "multi_samples": str(multi.samples), + "single_loss": str(single.loss), + "multi_loss": str(multi.loss), + } + + if mismatches: + debug_details = self._format_batch_debug(single_batch_stats, multi_batch_stats) + details = "; ".join( + f"{module}: loss {info['single_loss']} vs {info['multi_loss']}, " + f"samples {info['single_samples']} vs {info['multi_samples']}" + for module, info in mismatches.items() + ) + self.fail( + "Layer-0 quantization metrics diverged between device configurations: " + f"{details}; batch-debug: {debug_details}" + ) + + def _quantize_first_layer( + self, device_indices: Iterable[int] + ) -> Tuple[Dict[str, LayerMetrics], Dict[str, List[Dict[str, float]]]]: + target_devices = [torch.device(f"cuda:{idx}") for idx in device_indices] + selection = lambda _base_device: target_devices + + class _StopAfterLayer: + def __init__(self, layer_idx: int): + self._layer_idx = layer_idx + self._triggered = False + + def layer_complete(self, *, layer_idx: int, submodule_finalized: bool): + if self._triggered: + return None + if submodule_finalized and layer_idx >= self._layer_idx: + self._triggered = True + raise StopMainLoop + + quant_config = QuantizeConfig( + quant_method=self.METHOD, + format=self.FORMAT, + bits=self.BITS, + group_size=self.GROUP_SIZE, + desc_act=self.DESC_ACT if not self.ACT_GROUP_AWARE else False, + act_group_aware=self.ACT_GROUP_AWARE, + fail_safe=self.FAIL_SAFE, + sym=self.SYM, + v2=self.V2, + adapter=self.EORA, + device=target_devices[0], + ) + + load_kwargs = {} + if self.USE_FLASH_ATTN: + load_kwargs["attn_implementation"] = "flash_attention_2" + + model = GPTQModel.load( + self.NATIVE_MODEL_ID, + quantize_config=quant_config, + trust_remote_code=self.TRUST_REMOTE_CODE, + dtype=self.TORCH_DTYPE, + debug=self.DEBUG, + **load_kwargs, + ) + + dataset = self.load_dataset(model.tokenizer, self.DATASET_SIZE) + model.layer_callback = _StopAfterLayer(layer_idx=0) + + batch_debug: Dict[str, List[Dict[str, float]]] = {} + primary_handles: Dict[str, str] = {} + + with ExitStack() as stack: + stack.enter_context( + mock.patch("gptqmodel.looper.module_looper.select_forward_devices", new=selection) + ) + stack.enter_context( + mock.patch("gptqmodel.utils.looper_helpers.select_forward_devices", new=selection) + ) + stack.enter_context(self._capture_primary_handles(primary_handles)) + stack.enter_context(self._capture_batches(batch_debug, primary_handles)) + model.quantize( + dataset, + calibration_sort=self.DATASET_SORT, + backend=self.QUANT_BACKEND, + batch_size=self.QUANT_BATCH_SIZE, + ) + + first_layer_stats = self._extract_first_layer_metrics(model.quant_log) + + # Clear GPU memory before the next run + del dataset + del model + torch_empty_cache() + + return first_layer_stats, batch_debug + + def _extract_first_layer_metrics(self, quant_log: List[Dict[str, str]]) -> Dict[str, LayerMetrics]: + layer_metrics: Dict[str, LayerMetrics] = {} + for entry in quant_log: + try: + layer_index = int(entry.get(PROCESS_LOG_LAYER)) + except (TypeError, ValueError): + continue + + if layer_index != 0: + continue + + module_name = entry.get(PROCESS_LOG_MODULE) + if not module_name: + continue + + loss_value = entry.get(QUANT_LOG_LOSS) + sample_value = entry.get(QUANT_LOG_NSAMPLES) + if loss_value is None or sample_value is None: + continue + + layer_metrics[module_name] = LayerMetrics( + loss=Decimal(loss_value), + samples=int(sample_value), + ) + return layer_metrics + + @staticmethod + def _format_batch_debug( + single_batch_stats: Dict[str, List[Dict[str, float]]], + multi_batch_stats: Dict[str, List[Dict[str, float]]], + ) -> str: + def _summarize(stats: Dict[str, List[Dict[str, float]]]) -> Dict[str, Dict[str, float]]: + summary: Dict[str, Dict[str, float]] = {} + for name, entries in stats.items(): + per_handle: Dict[str, Dict[str, float]] = {} + for item in entries: + handle = item["handle"] + info = per_handle.setdefault( + handle, + { + "batches": 0.0, + "samples": 0.0, + "sum_hash": 0.0, + "device": item.get("device", "?"), + "primary": False, + }, + ) + info["batches"] += 1.0 + info["samples"] += item["after"] - item["before"] + info["sum_hash"] += item["sum"] + info["primary"] = info["primary"] or bool(item.get("is_primary", False)) + summary[name] = { + handle: values + for handle, values in sorted(per_handle.items(), key=lambda kv: kv[0]) + } + return summary + + single_summary = _summarize(single_batch_stats) + multi_summary = _summarize(multi_batch_stats) + + parts = [] + module_names = sorted(set(single_summary) | set(multi_summary)) + for module in module_names: + single_info = single_summary.get(module, {}) + multi_info = multi_summary.get(module, {}) + parts.append( + f"{module}:single={single_info},multi={multi_info}" + ) + return " | ".join(parts) + + @staticmethod + def _capture_batches(storage: Dict[str, List[Dict[str, float]]], primary_handles: Dict[str, str]): + from gptqmodel.quantization.gptq import GPTQ # local import to avoid circular refs + + original_add_batch = GPTQ.add_batch + + def wrapped_add_batch(self, inp, out): # type: ignore[override] + module_name = getattr(self, "name", "") + before = getattr(self, "nsamples", 0) + # Summaries calculated before running original implementation + try: + sum_value = inp.detach().to(dtype=torch.float64).sum().item() + except Exception: # pragma: no cover - defensive logging + sum_value = float("nan") + device = str(getattr(inp, "device", "unknown")) + + original_add_batch(self, inp, out) + + after = getattr(self, "nsamples", 0) + storage.setdefault(module_name, []).append( + { + "before": float(before), + "after": float(after), + "sum": float(sum_value), + "handle": hex(id(self)), + "device": device, + "is_primary": hex(id(self)) == primary_handles.get(module_name), + } + ) + + return mock.patch.object(GPTQ, "add_batch", new=wrapped_add_batch) + + @staticmethod + def _capture_primary_handles(primary_handles: Dict[str, str]): + from gptqmodel.looper.gptq_processor import GPTQProcessor # local import to avoid cycles + + original_preprocess = GPTQProcessor.preprocess + + def wrapped_preprocess(self, module, fail_safe=False): # type: ignore[override] + result = original_preprocess(self, module, fail_safe) + task = self.tasks.get(module.name) + if task is not None: + primary_handles[module.name] = hex(id(task)) + return result + + return mock.patch.object(GPTQProcessor, "preprocess", new=wrapped_preprocess) diff --git a/tests/test_format_conversion_flow.py b/tests/test_format_conversion_flow.py index 561887e8d..8f62d41dd 100644 --- a/tests/test_format_conversion_flow.py +++ b/tests/test_format_conversion_flow.py @@ -12,12 +12,13 @@ from gptqmodel.utils.model import pack_module -class _DummyLayer: +class _DummyLayer(torch.nn.Module): def __init__(self): + super().__init__() self.weight = torch.nn.Parameter(torch.zeros(1, 1)) def to(self, *_args, **_kwargs): - return self + return super().to(*_args, **_kwargs) class _DummyQuantModule: @@ -33,6 +34,21 @@ def to(self, *_args, **_kwargs): def pack(self, **_kwargs): pass + def parameters(self): + return iter(()) + + def buffers(self): + return iter(()) + + def pack_block(self, **_kwargs): + pass + + def pack_original(self, **_kwargs): + pass + + def pack_gpu(self, **_kwargs): + pass + def qzero_format(self, format: int | None = None): if format is not None: self._fmt = format diff --git a/tests/test_module_looper_callback.py b/tests/test_module_looper_callback.py index 9d6f68d9c..bce22c254 100644 --- a/tests/test_module_looper_callback.py +++ b/tests/test_module_looper_callback.py @@ -99,3 +99,27 @@ def raising_callback(*, layer_idx, submodule_finalized): ) assert exc.value is err + + +def test_emit_layer_complete_stops_cleanly_on_stop_main_loop(monkeypatch): + class Stopper: + def layer_complete(self, *, layer_idx, submodule_finalized): + raise StopMainLoop() + + looper = make_looper(layer_callback=Stopper()) + + looper._emit_layer_complete( + layer_idx=0, + submodule_finalized=True, + raise_in_place=True, + ) + + assert looper._loop_stop_exc is None + assert looper._loop_stop_event.is_set() + + monkeypatch.setattr( + "gptqmodel.looper.module_looper.DEVICE_THREAD_POOL.wait", + lambda *_, **__: None, + ) + + assert looper._check_loop_stop() is True From f41b0fdb3253c2f3686a352fbf348a79af1cfe7b Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 14 Oct 2025 14:39:16 +0000 Subject: [PATCH 3/8] default to block_cpu packer Signed-off-by: Qubitium --- gptqmodel/nn_modules/qlinear/__init__.py | 2 +- gptqmodel/quantization/config.py | 4 +- gptqmodel/utils/looper_helpers.py | 67 +++++++++++++----------- gptqmodel/utils/torch.py | 4 +- tests/models/model_test.py | 1 + 5 files changed, 42 insertions(+), 36 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index 89b0625bc..eeaa456ae 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -493,7 +493,7 @@ def pack_block( zeros: t.Tensor, g_idx: t.Tensor, block_in: int = 8192, - workers: int = 8, + workers: int = 4, ): """ Parallel qweight pack on CPU (threaded over input blocks). qzeros path = original logic. diff --git a/gptqmodel/quantization/config.py b/gptqmodel/quantization/config.py index 4ce618708..f86fdf674 100644 --- a/gptqmodel/quantization/config.py +++ b/gptqmodel/quantization/config.py @@ -197,8 +197,8 @@ class QuantizeConfig(): # affects [`qweights`, `qzeros`] pack_dtype: Optional[Union[str, torch.dtype]] = field(default=torch.int32) - # packing implementation hint (`original` = legacy CPU pack, `gpu` enables CUDA pack, `cpu` forces block CPU pack). - pack_impl: str = field(default="original") + # packing implementation hinpt (`original` = legacy CPU pack, `gpu` enables CUDA pack, `cpu` forces block CPU pack). + pack_impl: str = field(default="cpu") # pending used field adapter: Optional[Union[Dict[str, Any], Lora]] = field(default=None) diff --git a/gptqmodel/utils/looper_helpers.py b/gptqmodel/utils/looper_helpers.py index 80b93ce12..db6b0584a 100644 --- a/gptqmodel/utils/looper_helpers.py +++ b/gptqmodel/utils/looper_helpers.py @@ -5,6 +5,7 @@ from __future__ import annotations import copy +import threading import time from contextlib import contextmanager from typing import Dict, List, Optional, Sequence, Tuple @@ -93,6 +94,7 @@ def device_ctx(dev: Optional[torch.device | "DEVICE"]): # cpu/mps/meta -> nothing special needed yield +_rehome_lock = threading.Lock() @torch.inference_mode() def rehome_module_to_device( @@ -105,47 +107,48 @@ def rehome_module_to_device( only_mismatched: bool = True, ) -> None: """Move registered tensors on ``module`` to ``device`` with defensive fallbacks.""" - for sub in module.modules(): - if move_buffers: - np_set = getattr(sub, "_non_persistent_buffers_set", set()) - for name, buf in list(getattr(sub, "_buffers", {}).items()): - if buf is None or not isinstance(buf, torch.Tensor): - continue - if not include_non_persistent_buffers and name in np_set: - continue - if only_mismatched and buf.device == device: - continue - try: - sub._buffers[name] = buf.to(device, non_blocking=True) - except Exception: + with _rehome_lock: + for sub in module.modules(): + if move_buffers: + np_set = getattr(sub, "_non_persistent_buffers_set", set()) + for name, buf in list(getattr(sub, "_buffers", {}).items()): + if buf is None or not isinstance(buf, torch.Tensor): + continue + if not include_non_persistent_buffers and name in np_set: + continue + if only_mismatched and buf.device == device: + continue try: - sub._buffers[name] = buf.to(device) + sub._buffers[name] = buf.to(device, non_blocking=True) except Exception: - pass - - if move_parameters: - for pname, p in list(getattr(sub, "_parameters", {}).items()): - if p is None or not isinstance(p, torch.nn.Parameter): - continue - if only_mismatched and p.device == device: - continue - try: - with torch.no_grad(): - new_p = torch.nn.Parameter( - p.data.to(device, non_blocking=True), - requires_grad=p.requires_grad, - ) - sub._parameters[pname] = new_p - except Exception: + try: + sub._buffers[name] = buf.to(device) + except Exception: + pass + + if move_parameters: + for pname, p in list(getattr(sub, "_parameters", {}).items()): + if p is None or not isinstance(p, torch.nn.Parameter): + continue + if only_mismatched and p.device == device: + continue try: with torch.no_grad(): new_p = torch.nn.Parameter( - p.data.to(device), + p.data.to(device, non_blocking=True), requires_grad=p.requires_grad, ) sub._parameters[pname] = new_p except Exception: - pass + try: + with torch.no_grad(): + new_p = torch.nn.Parameter( + p.data.to(device), + requires_grad=p.requires_grad, + ) + sub._parameters[pname] = new_p + except Exception: + pass def clear_non_picklable_state(module: torch.nn.Module) -> List[Tuple[str, int]]: diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py index 93f42a1e3..d5e5f0744 100644 --- a/gptqmodel/utils/torch.py +++ b/gptqmodel/utils/torch.py @@ -21,6 +21,8 @@ # pytorch 2.6.0 fixes many compilation errors TORCH_HAS_COMPILE = version.parse(torch.__version__).release >= version.Version('2.6').release TORCH_GTE_28 = version.parse(torch.__version__).release >= version.Version('2.8').release +TORCH_GTE_210 = version.parse(torch.__version__).release >= version.Version('2.10').release + TORCH_HAS_FUSED_OPS = version.parse(torch.__version__).release >= version.Version('2.8').release HAS_CUDA = False @@ -140,7 +142,7 @@ def torch_compile(module: Union[torch.nn.Module, Callable], backend:str ="induct log_gil_requirements_for("Torch Compile") return module - if gte_python_3_14(): + if gte_python_3_14() and not TORCH_GTE_210: log_gil_requirements_for("Torch Compile") return module diff --git a/tests/models/model_test.py b/tests/models/model_test.py index 07eaf8257..0fb6f0eea 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -526,6 +526,7 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne sym=self.SYM, v2=self.V2, adapter=self.EORA, + pack_impl="cpu", ) log.info(f"Quant config: {quantize_config}") From 63c3d3adb362c71323528167ab1e36b0944dc0c5 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 14 Oct 2025 14:47:07 +0000 Subject: [PATCH 4/8] check pack edge case Signed-off-by: Qubitium --- gptqmodel/nn_modules/qlinear/__init__.py | 18 +++++++++++++++++- gptqmodel_ext/pack_block_cpu.cpp | 19 ++++++++++++++++--- tests/test_pack.py | 19 +++++++++++++++++++ 3 files changed, 52 insertions(+), 4 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index eeaa456ae..419969fb4 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -529,6 +529,7 @@ def pack_block( scales = scales.T.contiguous() # [G, out] zeros = zeros.T.contiguous() # [G, out] scale_zeros = zeros * scales # [G, out] + num_groups = scales.shape[0] # small buffers self.register_buffer("scales", scales.to(dtype=t.float16)) @@ -641,7 +642,22 @@ def _process_block(i0: int, i1: int): # [out, blk] Wblk = W[:, i0:i1] # select group rows for these inputs from [G, out] -> [blk, out], then T -> [out, blk] - gsel = g_idx[i0:i1] # [blk] + gsel = g_idx[i0:i1].to(dtype=t.int64, copy=False) # [blk] + if gsel.numel() == 0: + return + + neg_mask = gsel < 0 + if neg_mask.any(): + gsel = gsel.clone() + gsel[neg_mask] += num_groups + + gsel_max = int(gsel.max().item()) + gsel_min = int(gsel.min().item()) + if gsel_min < 0 or gsel_max >= num_groups: + raise IndexError( + f"pack_block: g_idx values out of range after normalization (min={gsel_min}, max={gsel_max}, groups={num_groups})." + ) + sz_blk_T = scale_zeros.index_select(0, gsel).T # [out, blk] s_blk_T = scales.index_select(0, gsel).T # [out, blk] diff --git a/gptqmodel_ext/pack_block_cpu.cpp b/gptqmodel_ext/pack_block_cpu.cpp index 0b9e9ef00..eae714322 100644 --- a/gptqmodel_ext/pack_block_cpu.cpp +++ b/gptqmodel_ext/pack_block_cpu.cpp @@ -102,9 +102,22 @@ std::tuple pack_block_cpu( for (int out = 0; out < out_features; ++out) { for (int lane = 0; lane < word_bits; ++lane) { const int64_t input_idx = base_input + lane; - const int32_t group = gidx_ptr[input_idx]; - float scale = scales_ptr[group * scales_stride + out]; - float offset = scale_zeros_ptr[group * scales_stride + out]; + const int32_t raw_group = gidx_ptr[input_idx]; + int32_t group = raw_group; + if (group < 0) { + group += static_cast(groups); + } + TORCH_CHECK( + group >= 0 && group < groups, + "pack_block_cpu: g_idx[", + input_idx, + "]=", + raw_group, + " is out of range for groups=", + groups + ); + float scale = scales_ptr[static_cast(group) * scales_stride + out]; + float offset = scale_zeros_ptr[static_cast(group) * scales_stride + out]; float w = weight_ptr[out * out_stride + input_idx]; if (scale == 0.0f) { scale = 1e-6f; diff --git a/tests/test_pack.py b/tests/test_pack.py index 4e35a894b..e898b9c53 100644 --- a/tests/test_pack.py +++ b/tests/test_pack.py @@ -143,3 +143,22 @@ def test_pack_consistency(self, bits, group_size): floatfmt=".3e", ) ) + + def test_pack_negative_g_idx(self): + bits = 4 + group_size = 32 + self.current_bits = bits + self.current_group_size = group_size + linear, scales, zeros, g_idx = self._build_inputs(bits, group_size) + + groups = int(g_idx.max().item() + 1) + g_idx_neg = g_idx.to(dtype=torch.int32) + g_idx_neg[::7] -= groups + + baseline = self._run_impl("original", linear, scales, zeros, g_idx_neg) + pack_cpu = self._run_impl("pack_block", linear, scales, zeros, g_idx_neg) + + self.assertTrue(torch.equal(pack_cpu["qweight"], baseline["qweight"])) + self.assertTrue(torch.equal(pack_cpu["qzeros"], baseline["qzeros"])) + self.assertTrue(torch.equal(pack_cpu["scales"], baseline["scales"])) + self.assertTrue(torch.equal(pack_cpu["g_idx"].to(dtype=baseline["g_idx"].dtype), baseline["g_idx"])) From 78229e733640fada748a9fd32ab74146df62e316 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 14 Oct 2025 15:21:06 +0000 Subject: [PATCH 5/8] fix multi-gpu H accumlation order Signed-off-by: Qubitium --- gptqmodel/looper/gptq_processor.py | 3 +- gptqmodel/looper/loop_processor.py | 17 +++- gptqmodel/looper/module_looper.py | 116 ++++++++++++----------- gptqmodel/nn_modules/qlinear/__init__.py | 2 +- gptqmodel/quantization/gptq.py | 112 ++++++++++++++++++---- gptqmodel/utils/looper_helpers.py | 2 + tests/models/model_test.py | 25 ++++- tests/models/test_multi_vs_single_gpu.py | 5 +- 8 files changed, 201 insertions(+), 81 deletions(-) diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index 565460634..a069618d9 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -106,8 +106,9 @@ def is_skipped(self, module: NamedModule) -> bool: def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]: def tmp(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): g = self.tasks[name] # noqa: F821 + batch_idx = self.current_batch_index() with tf32_disable_guard(): - g.add_batch(inp[0].data, out.data) # noqa: F821 + g.add_batch(inp[0].data, out.data, batch_index=batch_idx) # noqa: F821 del inp, out return tmp diff --git a/gptqmodel/looper/loop_processor.py b/gptqmodel/looper/loop_processor.py index b9e8ac10b..227addb38 100644 --- a/gptqmodel/looper/loop_processor.py +++ b/gptqmodel/looper/loop_processor.py @@ -175,10 +175,25 @@ def __init__( log.warn(f"The average length of input_ids of calibration_dataset should be greater than " f"{min_calibration_dataset_input_ids_avg_length}: actual avg: {avg}.") - self.num_batches = len(calibration) + self.num_batches = len(calibration) self.calibration_dataset = calibration + # Track the current calibration batch index on a per-thread basis so + # processors can retrieve deterministic ordering information (e.g. + # GPTQ's Hessian updates) even when forwards run on multiple threads. + self._batch_tls = threading.local() + + def _set_current_batch_index(self, batch_index: Optional[int]) -> None: + if batch_index is None: + if hasattr(self._batch_tls, "index"): + delattr(self._batch_tls, "index") + else: + self._batch_tls.index = int(batch_index) + + def current_batch_index(self) -> Optional[int]: + return getattr(self._batch_tls, "index", None) + def _async_log_writer(self, stat): with open(self.log_tmp_log_file_name, 'a') as f: json.dump(stat, f, indent=4) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index aed41d0fa..04a3a88b9 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -483,72 +483,76 @@ def _run_forward_batches_single( stage_label = progress_stage or "Forward" for batch_idx in range(total_batches): - layer_input = [move_to(inp, device=cur_layer_device, stream=False) for inp in layer_inputs[batch_idx]] + processor._set_current_batch_index(batch_idx) + try: + layer_input = [move_to(inp, device=cur_layer_device, stream=False) for inp in layer_inputs[batch_idx]] - raw_mask = attention_masks[batch_idx] - attn_tensor = raw_mask if raw_mask is None else move_to(raw_mask, device=cur_layer_device, stream=False) + raw_mask = attention_masks[batch_idx] + attn_tensor = raw_mask if raw_mask is None else move_to(raw_mask, device=cur_layer_device, stream=False) - keep_mask = None - if attn_tensor is not None: - seq_len = layer_input[0].shape[1] if (len(layer_input) > 0 and layer_input[0].dim() >= 2) else None - keep_mask = normalize_seq_mask(attn_tensor, seq_len=seq_len) - self._set_processor_mask(processor, keep_mask) + keep_mask = None + if attn_tensor is not None: + seq_len = layer_input[0].shape[1] if (len(layer_input) > 0 and layer_input[0].dim() >= 2) else None + keep_mask = normalize_seq_mask(attn_tensor, seq_len=seq_len) + self._set_processor_mask(processor, keep_mask) - additional_inputs: Dict[str, torch.Tensor] = {} - if self.support_batch_quantize and attn_tensor is not None: - additional_inputs["attention_mask"] = attn_tensor + additional_inputs: Dict[str, torch.Tensor] = {} + if self.support_batch_quantize and attn_tensor is not None: + additional_inputs["attention_mask"] = attn_tensor - if position_ids: - pos = position_ids[batch_idx] - if pos is not None: - additional_inputs["position_ids"] = move_to(pos, device=cur_layer_device, stream=False) + if position_ids: + pos = position_ids[batch_idx] + if pos is not None: + additional_inputs["position_ids"] = move_to(pos, device=cur_layer_device, stream=False) - for key, value in layer_input_kwargs[batch_idx].items(): - additional_inputs[key] = nested_move_to(value, device=cur_layer_device, stream=False) + for key, value in layer_input_kwargs[batch_idx].items(): + additional_inputs[key] = nested_move_to(value, device=cur_layer_device, stream=False) - if reuse_kv and prev_kv is not None: - additional_inputs["kv_last_layer"] = nested_move_to(prev_kv, device=cur_layer_device, stream=False) + if reuse_kv and prev_kv is not None: + additional_inputs["kv_last_layer"] = nested_move_to(prev_kv, device=cur_layer_device, stream=False) - rehome_module_to_device(module, cur_layer_device, move_parameters=True, move_buffers=True) + rehome_module_to_device(module, cur_layer_device, move_parameters=True, move_buffers=True) - module_output = None - try: - if is_lm_head_module: - module_output = module(*layer_input) - else: - module_output = module(*layer_input, **additional_inputs) - except StopForward: module_output = None + try: + if is_lm_head_module: + module_output = module(*layer_input) + else: + module_output = module(*layer_input, **additional_inputs) + except StopForward: + module_output = None + finally: + self._set_processor_mask(processor, None) + + if ( + reuse_kv + and module_output is not None + and isinstance(module_output, tuple) + and len(module_output) > 0 + and shared_kv_cache_dict.get(layer_index) is None + ): + shared_kv_cache_dict[layer_index] = module_output[-1] + + if need_outputs and module_output is not None: + primary = module_output[0] if isinstance(module_output, tuple) else module_output + primary = move_to(primary, device=cur_layer_device, stream=False) + outputs.append([primary]) + + rows_for_batch = batch_row_counts[batch_idx] if batch_idx < len(batch_row_counts) else 0 + if rows_for_batch <= 0: + rows_for_batch = self._batch_row_count(layer_inputs[batch_idx]) if layer_inputs and batch_idx < len(layer_inputs) else 1 + rows_for_batch = max(rows_for_batch, 1) + + processed_rows = min(processed_rows + rows_for_batch, total_rows) + if progress_pb is not None: + if progress_title: + progress_pb.title(progress_title) + progress_pb.current_iter_step = processed_rows + progress_pb.subtitle( + f"{stage_label} rows {processed_rows}/{total_rows}" + ).draw() finally: - self._set_processor_mask(processor, None) - - if ( - reuse_kv - and module_output is not None - and isinstance(module_output, tuple) - and len(module_output) > 0 - and shared_kv_cache_dict.get(layer_index) is None - ): - shared_kv_cache_dict[layer_index] = module_output[-1] - - if need_outputs and module_output is not None: - primary = module_output[0] if isinstance(module_output, tuple) else module_output - primary = move_to(primary, device=cur_layer_device, stream=False) - outputs.append([primary]) - - rows_for_batch = batch_row_counts[batch_idx] if batch_idx < len(batch_row_counts) else 0 - if rows_for_batch <= 0: - rows_for_batch = self._batch_row_count(layer_inputs[batch_idx]) if layer_inputs and batch_idx < len(layer_inputs) else 1 - rows_for_batch = max(rows_for_batch, 1) - - processed_rows = min(processed_rows + rows_for_batch, total_rows) - if progress_pb is not None: - if progress_title: - progress_pb.title(progress_title) - progress_pb.current_iter_step = processed_rows - progress_pb.subtitle( - f"{stage_label} rows {processed_rows}/{total_rows}" - ).draw() + processor._set_current_batch_index(None) return outputs diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index 419969fb4..b827d06d3 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -493,7 +493,7 @@ def pack_block( zeros: t.Tensor, g_idx: t.Tensor, block_in: int = 8192, - workers: int = 4, + workers: int = 1, ): """ Parallel qweight pack on CPU (threaded over input blocks). qzeros path = original logic. diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index e7692cd20..45906864e 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -10,7 +10,7 @@ import sys import threading import time -from typing import Optional +from typing import Dict, Optional, Tuple import numpy as np import torch @@ -70,6 +70,13 @@ def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None): module_device = get_device(self.module) setattr(self.module, "target_device", module_device) + if module_device.type == "meta": + self._default_hessian_device = torch.device("cpu") + else: + self._default_hessian_device = torch.device(module_device) + + self._hessian_device: Optional[torch.device] = None + self._validate_module(self.module) self.qcfg = qcfg if qcfg else QuantizeConfig() # HF compat will not pass qcfg @@ -89,6 +96,11 @@ def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None): self.H = torch.zeros((self.columns, self.columns), dtype=torch.float32) + # Track per-batch Hessian contributions so they can be applied in a + # deterministic order even when forwards execute in parallel. + self._pending_updates: Dict[int, Tuple[int, Optional[torch.Tensor], torch.device]] = {} + self._next_batch_index: int = 0 + @staticmethod def _validate_module(module): assert isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, @@ -127,20 +139,53 @@ def _clone_module(self, copy=True, device: torch.device = None): return clone.float() - def add_batch(self, inp: torch.Tensor, out: torch.Tensor): + def add_batch(self, inp: torch.Tensor, out: torch.Tensor, batch_index: Optional[int] = None): with self.lock: self.fwd_counter += 1 - # print(f"self.module.target_device = {self.module.target_device}") + batch_token_size, xtx, device = self.process_batch(inp) + + pending_index = batch_index if batch_index is not None else self._next_batch_index + self._pending_updates[pending_index] = (batch_token_size, xtx, device) + self._flush_pending_updates_locked() + + def _resolve_hessian_device(self, batch_device: torch.device) -> torch.device: + """Select a stable device for Hessian accumulation. + + The first non-meta device we observe (module target, default hint, or + batch input) becomes the canonical Hessian device for the lifetime of + this GPTQ instance. Subsequent batches keep using the same target to + avoid bouncing tensors across GPUs when calibration runs on multiple + devices concurrently. + """ + + if self._hessian_device is not None: + return self._hessian_device + + module_target = getattr(self.module, "target_device", None) + canonical = None + + if module_target is not None: + canonical = torch.device(module_target) + if canonical.type == "meta": + canonical = None + + if canonical is None and hasattr(self, "_default_hessian_device"): + canonical = self._default_hessian_device + + if canonical is None or canonical.type == "meta": + canonical = batch_device - self.process_batch(inp) + if canonical.type == "meta": + canonical = torch.device("cpu") - def process_batch(self, inp: torch.Tensor): + self._hessian_device = canonical + return canonical + + def process_batch(self, inp: torch.Tensor) -> Tuple[int, Optional[torch.Tensor], torch.device]: # print(f"inp = {inp}") # print(f"self.module = {self.module} device = {self.module.target_device}") inp_device = get_device(inp) - if inp_device.type == "cuda": - torch.cuda.set_device(inp_device) #inp = inp.to(device=self.module.target_device, dtype=torch.float32) @@ -182,26 +227,48 @@ def process_batch(self, inp: torch.Tensor): # Delay float32 cast until after reshaping to avoid an extra temporary tensor reshaped_inp = reshaped_inp.to(dtype=torch.float32) + canonical_device = self._resolve_hessian_device(inp_device) + reshaped_inp = reshaped_inp.to(dtype=torch.float64) batch_token_size = reshaped_inp.shape[0] - self.H = self.H.to(device=reshaped_inp.device) - - # moe model may receive an empty batch, return early if batch_token_size == 0: - return batch_token_size, reshaped_inp, 0, 0 + del reshaped_inp + return 0, None, canonical_device + + xtx = torch.matmul(reshaped_inp.T, reshaped_inp).to(dtype=torch.float32) + xtx = xtx.detach() + del reshaped_inp + + return batch_token_size, xtx, canonical_device + + def _flush_pending_updates_locked(self) -> None: + while True: + update = self._pending_updates.pop(self._next_batch_index, None) + if update is None: + break - beta = self.nsamples / (self.nsamples + batch_token_size) - alpha = 2.0 / (self.nsamples + batch_token_size) + batch_token_size, xtx, device = update - self.H.addmm_(reshaped_inp.T, reshaped_inp, beta=beta, alpha=alpha) + if batch_token_size > 0 and xtx is not None: + target_device = device if device is not None else self.H.device + if target_device is None: + target_device = self.H.device - # update number of collected samples - self.nsamples += batch_token_size + self.H = self.H.to(device=target_device) + if xtx.device != target_device: + xtx = xtx.to(device=target_device) - # inp returned here is flattened/reshaped original inp - # return batch_token_size, reshaped_inp, alpha, beta - del batch_token_size, reshaped_inp, alpha, beta + total = self.nsamples + batch_token_size + beta = self.nsamples / total + alpha = 2.0 / total + self.H.mul_(beta) + self.H.add_(xtx, alpha=alpha) + self.nsamples = total + + del xtx + + self._next_batch_index += 1 # FIXME, optimum needs fasterquant, we need to remove it def fasterquant( @@ -278,6 +345,13 @@ def quantize( # log.info(f"Quantization `{self.name}` using samples: `{self.nsamples}`") start = time.time() + with self.lock: + self._flush_pending_updates_locked() + if self._pending_updates: + raise RuntimeError( + f"Pending Hessian updates remain for module '{self.name}' before quantization." + ) + # Temporarily disable torch.compile due to compatibility issues with torch 2.8 # Will re-enable once the issue is fixed # if not TORCH_GTE_28 and not self.qcfg.mock_quantization: diff --git a/gptqmodel/utils/looper_helpers.py b/gptqmodel/utils/looper_helpers.py index db6b0584a..edcb71edb 100644 --- a/gptqmodel/utils/looper_helpers.py +++ b/gptqmodel/utils/looper_helpers.py @@ -337,6 +337,7 @@ def forward_batch_worker( reuse_kv: bool, prev_kv, ): + processor._set_current_batch_index(batch_index) module_device = getattr(module, "_gptqmodule_device_hint", None) or get_device(module) rehome_module_to_device(module, module_device, move_parameters=True, move_buffers=True) @@ -380,6 +381,7 @@ def forward_batch_worker( finally: if mask_tls is not None: mask_tls.value = None + processor._set_current_batch_index(None) if reuse_kv and module_output is not None and isinstance(module_output, tuple) and len(module_output) > 0: kv_next = module_output[-1] diff --git a/tests/models/model_test.py b/tests/models/model_test.py index 0fb6f0eea..bb70ab5a4 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -618,11 +618,34 @@ def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_pa active_backend = backend if backend is not None else self.LOAD_BACKEND + default_device_map = {"": "cpu"} if active_backend == BACKEND.TORCH_FUSED else "auto" + explicit_device = "device" in load_kwargs + inserted_device_map = False + if "device_map" not in load_kwargs and not explicit_device: + load_kwargs["device_map"] = default_device_map + inserted_device_map = True + + # Post-quant CI runs may expose multiple GPUs; pin loading to the first one to avoid spread-out auto maps. + if ( + (inserted_device_map or load_kwargs.get("device_map") == "auto") + and not explicit_device + and active_backend != BACKEND.TORCH_FUSED + and torch.cuda.is_available() + ): + visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") + candidates = [item.strip() for item in visible.split(",") if item.strip()] + try: + multi_device = len(candidates) > 1 if candidates else torch.cuda.device_count() > 1 + except Exception: + multi_device = False + + if multi_device: + load_kwargs["device_map"] = {"": "cuda:0"} + model = GPTQModel.load( model_id_or_path, trust_remote_code=trust_remote_code, backend=active_backend, - device_map={"": "cpu"} if active_backend == BACKEND.TORCH_FUSED else "auto", debug=self.DEBUG, adapter=self.EORA, **load_kwargs diff --git a/tests/models/test_multi_vs_single_gpu.py b/tests/models/test_multi_vs_single_gpu.py index c5005b9f7..116ad72ba 100644 --- a/tests/models/test_multi_vs_single_gpu.py +++ b/tests/models/test_multi_vs_single_gpu.py @@ -260,7 +260,7 @@ def _capture_batches(storage: Dict[str, List[Dict[str, float]]], primary_handles original_add_batch = GPTQ.add_batch - def wrapped_add_batch(self, inp, out): # type: ignore[override] + def wrapped_add_batch(self, inp, out, batch_index=None): # type: ignore[override] module_name = getattr(self, "name", "") before = getattr(self, "nsamples", 0) # Summaries calculated before running original implementation @@ -270,7 +270,7 @@ def wrapped_add_batch(self, inp, out): # type: ignore[override] sum_value = float("nan") device = str(getattr(inp, "device", "unknown")) - original_add_batch(self, inp, out) + original_add_batch(self, inp, out, batch_index=batch_index) after = getattr(self, "nsamples", 0) storage.setdefault(module_name, []).append( @@ -281,6 +281,7 @@ def wrapped_add_batch(self, inp, out): # type: ignore[override] "handle": hex(id(self)), "device": device, "is_primary": hex(id(self)) == primary_handles.get(module_name), + "batch_index": None if batch_index is None else int(batch_index), } ) From 60e40f5a6c5fc3d5ea805b81bec54670261530a8 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 14 Oct 2025 17:04:39 +0000 Subject: [PATCH 6/8] reduce memory usage with h buffer Signed-off-by: Qubitium --- gptqmodel/quantization/config.py | 17 ++++ gptqmodel/quantization/gptq.py | 157 ++++++++++++++++++++++++++++++- tests/test_hessian_chunk.py | 55 ++++++++++- 3 files changed, 224 insertions(+), 5 deletions(-) diff --git a/gptqmodel/quantization/config.py b/gptqmodel/quantization/config.py index f86fdf674..708f71593 100644 --- a/gptqmodel/quantization/config.py +++ b/gptqmodel/quantization/config.py @@ -229,6 +229,11 @@ class QuantizeConfig(): # skip all heavy computations for testing model loading mock_quantization: bool = field(default=False, metadata={"help": "Skip heavy computations for fast model loading validation"}) + # Hessian accumulation controls (GPTQ only) + hessian_chunk_size: Optional[int] = field(default=None, metadata={"help": "Maximum rows per Hessian chunk"}) + hessian_chunk_bytes: Optional[int] = field(default=None, metadata={"help": "Memory budget (in bytes) for Hessian chunk staging"}) + hessian_use_bfloat16_staging: bool = field(default=False, metadata={"help": "Stage Hessian chunks in bfloat16 when supported"}) + def __post_init__(self): fields_info = fields(self) @@ -304,6 +309,18 @@ def __post_init__(self): if self.damp_auto_increment < 0: raise ValueError("QuantizeConfig:: `damp_auto_increment` must greater than 0.") + if self.hessian_chunk_size is not None: + if not isinstance(self.hessian_chunk_size, int): + raise ValueError("QuantizeConfig: `hessian_chunk_size` must be an integer or None.") + if self.hessian_chunk_size <= 0: + raise ValueError("QuantizeConfig: `hessian_chunk_size` must be a positive integer.") + + if self.hessian_chunk_bytes is not None: + if not isinstance(self.hessian_chunk_bytes, int): + raise ValueError("QuantizeConfig: `hessian_chunk_bytes` must be an integer or None.") + if self.hessian_chunk_bytes <= 0: + raise ValueError("QuantizeConfig: `hessian_chunk_bytes` must be a positive integer amount of bytes.") + # validate hybrid act order if self.act_group_aware and self.desc_act: raise ValueError("QuantizeConfig:: `act_group_aware` == `True` requires `desc_act` == `False`.") diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 45906864e..348eb1d5c 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -5,6 +5,7 @@ # adapted from @qwopqwop200 's [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda), which itself is based on [gptq](https://github.com/IST-DASLab/gptq) +import contextlib import math import os import sys @@ -31,6 +32,74 @@ lock = threading.Lock() +# Shared workspaces are cached globally per (device, dtype, columns) so that +# concurrent GPTQ instances reuse temporary buffers instead of repeatedly +# allocating large tensors during Hessian accumulation. +_WORKSPACE_CACHE: Dict[Tuple[str, Optional[int], torch.dtype, int], torch.Tensor] = {} +_WORKSPACE_LOCKS: Dict[Tuple[str, Optional[int], torch.dtype, int], threading.Lock] = {} +_BF16_SUPPORT_CACHE: Dict[Tuple[str, Optional[int]], bool] = {} + + +def _device_cache_key(device: torch.device) -> Tuple[str, Optional[int]]: + dev = torch.device(device) + return dev.type, dev.index + + +def _workspace_cache_key(device: torch.device, dtype: torch.dtype, cols: int) -> Tuple[str, Optional[int], torch.dtype, int]: + dev = torch.device(device) + return dev.type, dev.index, dtype, cols + + +def _needs_workspace_resize(workspace: Optional[torch.Tensor], required_rows: int, cols: int) -> bool: + if workspace is None: + return True + if workspace.ndim != 2: + return True + if workspace.shape[1] != cols: + return True + if workspace.shape[0] < required_rows: + return True + return False + + +@contextlib.contextmanager +def _lease_workspace(device: torch.device, dtype: torch.dtype, cols: int, required_rows: int): + key = _workspace_cache_key(device, dtype, cols) + lock = _WORKSPACE_LOCKS.setdefault(key, threading.Lock()) + with lock: + workspace = _WORKSPACE_CACHE.pop(key, None) + if _needs_workspace_resize(workspace, required_rows, cols): + rows = max(required_rows, 1) + workspace = torch.empty((rows, cols), dtype=dtype, device=device) + try: + yield workspace + finally: + with lock: + _WORKSPACE_CACHE[key] = workspace + + +def _device_supports_bfloat16(device: torch.device) -> bool: + cache_key = _device_cache_key(device) + cached = _BF16_SUPPORT_CACHE.get(cache_key) + if cached is not None: + return cached + + dev = torch.device(device) + if dev.type == "meta": + _BF16_SUPPORT_CACHE[cache_key] = False + return False + + try: + a = torch.zeros((1, 1), dtype=torch.bfloat16, device=dev) + b = torch.zeros((1, 1), dtype=torch.bfloat16, device=dev) + _ = torch.matmul(a, b) + support = True + except Exception: + support = False + + _BF16_SUPPORT_CACHE[cache_key] = support + return support + def get_number_of_rows_and_cols(layer: nn.Module): # return layer.weight.shape[0], np.prod(layer.weight.shape[1:]) @@ -149,6 +218,87 @@ def add_batch(self, inp: torch.Tensor, out: torch.Tensor, batch_index: Optional[ self._pending_updates[pending_index] = (batch_token_size, xtx, device) self._flush_pending_updates_locked() + def _preferred_staging_dtype(self, input_dtype: torch.dtype, device: torch.device) -> torch.dtype: + device = torch.device(device) + + if not self.qcfg.hessian_use_bfloat16_staging: + return torch.float32 + + if input_dtype not in (torch.float16, torch.bfloat16): + return torch.float32 + + if not _device_supports_bfloat16(device): + return torch.float32 + + return torch.bfloat16 + + def _resolve_hessian_chunk_size(self, rows: int, stage_dtype: torch.dtype) -> Optional[int]: + if rows == 0: + return None + + cfg_chunk = self.qcfg.hessian_chunk_size + if cfg_chunk is not None: + return max(1, min(cfg_chunk, rows)) + + bytes_budget = self.qcfg.hessian_chunk_bytes + if bytes_budget is not None: + bytes_per_row = self.columns * torch.tensor([], dtype=stage_dtype).element_size() + if bytes_per_row > 0: + chunk_rows = bytes_budget // bytes_per_row + if chunk_rows > 0: + return max(1, min(int(chunk_rows), rows)) + return 1 + + return None + + @contextlib.contextmanager + def _borrow_materialized_chunk_fp32( + self, + chunk: torch.Tensor, + rows: int, + ) -> torch.Tensor: + if rows == 0: + yield chunk.new_zeros((0, self.columns), dtype=torch.float32) + return + + device = chunk.device + stage_dtype = self._preferred_staging_dtype(chunk.dtype, device) + + with _lease_workspace(device, stage_dtype, self.columns, rows) as staging_workspace: + staging_view = staging_workspace[:rows, :] + staging_view.copy_(chunk.to(dtype=stage_dtype)) + + if stage_dtype == torch.float32: + yield staging_view + else: + with _lease_workspace(device, torch.float32, self.columns, rows) as fp32_workspace: + fp32_view = fp32_workspace[:rows, :] + fp32_view.copy_(staging_view.to(torch.float32)) + yield fp32_view + + def _compute_hessian_xtx(self, matrix: torch.Tensor) -> torch.Tensor: + rows = matrix.shape[0] + if rows == 0: + return torch.zeros((self.columns, self.columns), dtype=torch.float64, device=matrix.device) + + stage_dtype = self._preferred_staging_dtype(matrix.dtype, matrix.device) + chunk_size = self._resolve_hessian_chunk_size(rows, stage_dtype) + + if chunk_size is None: + mat64 = matrix.to(dtype=torch.float64) + return torch.matmul(mat64.T, mat64) + + xtx_accum = torch.zeros((self.columns, self.columns), dtype=torch.float64, device=matrix.device) + + for start in range(0, rows, chunk_size): + rows_this = min(chunk_size, rows - start) + source = matrix[start:start + rows_this] + with self._borrow_materialized_chunk_fp32(source, rows_this) as materialized: + materialized64 = materialized.to(dtype=torch.float64) + xtx_accum.add_(torch.matmul(materialized64.T, materialized64)) + + return xtx_accum + def _resolve_hessian_device(self, batch_device: torch.device) -> torch.device: """Select a stable device for Hessian accumulation. @@ -225,10 +375,9 @@ def process_batch(self, inp: torch.Tensor) -> Tuple[int, Optional[torch.Tensor], reshaped_inp = unfold(reshaped_inp) reshaped_inp = reshaped_inp.transpose(1, 2).flatten(0, 1) - # Delay float32 cast until after reshaping to avoid an extra temporary tensor - reshaped_inp = reshaped_inp.to(dtype=torch.float32) + # Delay dtype conversion until we materialize Hessian chunks to avoid unnecessary temporaries + reshaped_inp = reshaped_inp.contiguous() canonical_device = self._resolve_hessian_device(inp_device) - reshaped_inp = reshaped_inp.to(dtype=torch.float64) batch_token_size = reshaped_inp.shape[0] @@ -236,7 +385,7 @@ def process_batch(self, inp: torch.Tensor) -> Tuple[int, Optional[torch.Tensor], del reshaped_inp return 0, None, canonical_device - xtx = torch.matmul(reshaped_inp.T, reshaped_inp).to(dtype=torch.float32) + xtx = self._compute_hessian_xtx(reshaped_inp).to(dtype=torch.float32) xtx = xtx.detach() del reshaped_inp diff --git a/tests/test_hessian_chunk.py b/tests/test_hessian_chunk.py index 4b4a29d5b..0c1acba01 100644 --- a/tests/test_hessian_chunk.py +++ b/tests/test_hessian_chunk.py @@ -3,11 +3,13 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import contextlib import math import statistics import time import tracemalloc import types +from concurrent.futures import ThreadPoolExecutor from typing import Callable, Dict, Iterable, List, Tuple import pytest @@ -40,9 +42,11 @@ def _clone_module(module: torch.nn.Module) -> torch.nn.Module: def _instrument_chunks(gptq: GPTQ) -> None: original = gptq._borrow_materialized_chunk_fp32 + @contextlib.contextmanager def wrapped(self, chunk, rows): self._chunk_invocations += 1 - return original(chunk, rows) + with original(chunk, rows) as materialized: + yield materialized gptq._chunk_invocations = 0 gptq._borrow_materialized_chunk_fp32 = types.MethodType(wrapped, gptq) @@ -140,6 +144,55 @@ def test_hessian_chunk_bytes_budget(): assert workspace.shape[1] == gptq.columns +@pytest.mark.cuda +def test_hessian_workspace_thread_safety_cuda(): + if not torch.cuda.is_available(): + pytest.skip("CUDA required for workspace stress test") + + device = torch.device("cuda", 0) + + base = torch.nn.Linear(128, 64, bias=False).to(device) + cfg = QuantizeConfig( + hessian_chunk_size=128, + hessian_use_bfloat16_staging=True, + ) + + gptq_workers = [GPTQ(_clone_module(base).to(device), cfg) for _ in range(3)] + rows = 512 + iters_per_worker = 6 + + def worker(task_id: int) -> None: + gptq = gptq_workers[task_id % len(gptq_workers)] + torch.cuda.set_device(device.index or 0) + for i in range(iters_per_worker): + calib = torch.randn(rows, base.in_features, device=device, dtype=torch.float16) + batch_size, xtx, canonical_device = gptq.process_batch(calib) + assert batch_size == rows + assert xtx is not None + assert canonical_device == gptq._hessian_device + + with ThreadPoolExecutor(max_workers=8) as pool: + futures = [pool.submit(worker, idx) for idx in range(16)] + for fut in futures: + fut.result() + + for gptq in gptq_workers: + assert gptq._hessian_device == device + + cols = base.in_features + fp32_key = gptq_impl._workspace_cache_key(device, torch.float32, cols) + assert fp32_key in gptq_impl._WORKSPACE_CACHE + fp32_workspace = gptq_impl._WORKSPACE_CACHE[fp32_key] + expected_rows = cfg.hessian_chunk_size or rows + assert fp32_workspace.shape[0] >= expected_rows + assert fp32_workspace.shape[1] == cols + + stage_dtype = gptq_workers[0]._preferred_staging_dtype(torch.float16, device) + if stage_dtype == torch.bfloat16: + bf16_key = gptq_impl._workspace_cache_key(device, torch.bfloat16, cols) + assert bf16_key in gptq_impl._WORKSPACE_CACHE + + def _benchmark_case( base_module: torch.nn.Module, cfg_factory: Callable[[], QuantizeConfig], From 882b170fdfc90cf82606af2a30b28476d6c88831 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 14 Oct 2025 17:05:25 +0000 Subject: [PATCH 7/8] format Signed-off-by: Qubitium --- gptqmodel/nn_modules/qlinear/__init__.py | 2 +- gptqmodel/utils/env.py | 1 + gptqmodel/utils/looper_helpers.py | 3 ++- tests/models/test_multi_vs_single_gpu.py | 13 +++++-------- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index b827d06d3..589417a44 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -18,8 +18,8 @@ from ...adapter.adapter import LORA_MERGED_WEIGHT_PATHS, Adapter from ...models._const import DEVICE, PLATFORM from ...utils.backend import BACKEND -from ...utils.logger import setup_logger from ...utils.env import env_flag +from ...utils.logger import setup_logger from ...utils.safe import THREADPOOLCTL diff --git a/gptqmodel/utils/env.py b/gptqmodel/utils/env.py index 2b1e85535..e2839083c 100644 --- a/gptqmodel/utils/env.py +++ b/gptqmodel/utils/env.py @@ -9,6 +9,7 @@ import os + _TRUTHY = {"1", "true", "yes", "on", "y"} diff --git a/gptqmodel/utils/looper_helpers.py b/gptqmodel/utils/looper_helpers.py index edcb71edb..b2158e561 100644 --- a/gptqmodel/utils/looper_helpers.py +++ b/gptqmodel/utils/looper_helpers.py @@ -16,13 +16,14 @@ from .. import DEBUG_ON, DEVICE_THREAD_POOL from ..nn_modules.hooked_linear import StopForward from ..utils.attn_mask import normalize_seq_mask -from ..utils.env import env_flag from ..utils.device import get_device +from ..utils.env import env_flag from ..utils.logger import setup_logger from ..utils.model import move_to, nested_move_to from ..utils.safe import ThreadSafe from ..utils.torch import ALL_DEVICES, CPU, torch_sync + USE_TORCH_REPLICATE = env_flag("GPTQMODEL_USE_TORCH_REPLICATE") diff --git a/tests/models/test_multi_vs_single_gpu.py b/tests/models/test_multi_vs_single_gpu.py index 116ad72ba..f9eecc11e 100644 --- a/tests/models/test_multi_vs_single_gpu.py +++ b/tests/models/test_multi_vs_single_gpu.py @@ -14,20 +14,19 @@ from unittest import mock import torch +from model_test import ModelTest from gptqmodel import GPTQModel +from gptqmodel.looper.module_looper import StopMainLoop from gptqmodel.models.writer import ( PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES, ) -from gptqmodel.looper.module_looper import StopMainLoop from gptqmodel.quantization.config import QuantizeConfig from gptqmodel.utils.torch import torch_empty_cache -from model_test import ModelTest - @dataclass(frozen=True) class LayerMetrics: @@ -112,7 +111,8 @@ def _quantize_first_layer( self, device_indices: Iterable[int] ) -> Tuple[Dict[str, LayerMetrics], Dict[str, List[Dict[str, float]]]]: target_devices = [torch.device(f"cuda:{idx}") for idx in device_indices] - selection = lambda _base_device: target_devices + def selection(_base_device): + return target_devices class _StopAfterLayer: def __init__(self, layer_idx: int): @@ -235,10 +235,7 @@ def _summarize(stats: Dict[str, List[Dict[str, float]]]) -> Dict[str, Dict[str, info["samples"] += item["after"] - item["before"] info["sum_hash"] += item["sum"] info["primary"] = info["primary"] or bool(item.get("is_primary", False)) - summary[name] = { - handle: values - for handle, values in sorted(per_handle.items(), key=lambda kv: kv[0]) - } + summary[name] = dict(sorted(per_handle.items(), key=lambda kv: kv[0])) return summary single_summary = _summarize(single_batch_stats) From 9aaff886a3fcfe6c909c83546be1c84ad593d69d Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 14 Oct 2025 17:18:44 +0000 Subject: [PATCH 8/8] extra safety Signed-off-by: Qubitium --- gptqmodel/quantization/gptq.py | 16 ++++++++++++---- gptqmodel/utils/looper_helpers.py | 4 +++- tests/models/test_llama3_2.py | 4 ++-- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 348eb1d5c..ef8a957d1 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -269,12 +269,20 @@ def _borrow_materialized_chunk_fp32( staging_view.copy_(chunk.to(dtype=stage_dtype)) if stage_dtype == torch.float32: - yield staging_view + try: + yield staging_view + finally: + if device.type == "cuda": + torch.cuda.current_stream(device).synchronize() else: with _lease_workspace(device, torch.float32, self.columns, rows) as fp32_workspace: - fp32_view = fp32_workspace[:rows, :] - fp32_view.copy_(staging_view.to(torch.float32)) - yield fp32_view + try: + fp32_view = fp32_workspace[:rows, :] + fp32_view.copy_(staging_view.to(torch.float32)) + yield fp32_view + finally: + if device.type == "cuda": + torch.cuda.current_stream(device).synchronize() def _compute_hessian_xtx(self, matrix: torch.Tensor) -> torch.Tensor: rows = matrix.shape[0] diff --git a/gptqmodel/utils/looper_helpers.py b/gptqmodel/utils/looper_helpers.py index b2158e561..5e79876ae 100644 --- a/gptqmodel/utils/looper_helpers.py +++ b/gptqmodel/utils/looper_helpers.py @@ -28,6 +28,7 @@ _THREAD_SAFE_PARALLEL = ThreadSafe(torch_parallel) +_DEEPCOPY_LOCK = threading.Lock() def torch_replicate( module: torch.nn.Module, @@ -310,7 +311,8 @@ def _prepare_module(target_device: torch.device, step_name: str) -> None: for dev in devices: start_ts = time.perf_counter() - replica = copy.deepcopy(module) + with _DEEPCOPY_LOCK: + replica = copy.deepcopy(module) replica.eval() rehome_module_to_device(replica, dev, move_parameters=True, move_buffers=True) clear_state_fn(replica) diff --git a/tests/models/test_llama3_2.py b/tests/models/test_llama3_2.py index 51476212b..ae0f6152e 100644 --- a/tests/models/test_llama3_2.py +++ b/tests/models/test_llama3_2.py @@ -18,8 +18,8 @@ class TestLlama3_2(ModelTest): APPLY_CHAT_TEMPLATE = True V2 = False DEBUG = True - ACT_GROUP_AWARE = False - DESC_ACT = True + ACT_GROUP_AWARE = True + DESC_ACT = False DATASET_SIZE = 1024 DATASET_SORT = "desc" QUANT_BATCH_SIZE = 4