From 5f7d624c1fde9d1d6b323b887c79fedc3ad2fb86 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 1 Oct 2025 07:40:51 +0000 Subject: [PATCH 1/2] remove unused named_module.target_device_stream and preprocess_streaming() loop process api Signed-off-by: Qubitium --- gptqmodel/looper/eora_processor.py | 10 +-------- gptqmodel/looper/gptq_processor.py | 9 +------- gptqmodel/looper/module_looper.py | 36 +++++++++++------------------- gptqmodel/looper/named_module.py | 4 ++-- gptqmodel/looper/qqq_processor.py | 9 +------- gptqmodel/quantization/gptq.py | 2 +- gptqmodel/quantization/qqq.py | 2 +- gptqmodel/utils/torch.py | 7 +++--- tests/test_tf32_performance.py | 1 + 9 files changed, 24 insertions(+), 56 deletions(-) diff --git a/gptqmodel/looper/eora_processor.py b/gptqmodel/looper/eora_processor.py index 088895013..218a45211 100644 --- a/gptqmodel/looper/eora_processor.py +++ b/gptqmodel/looper/eora_processor.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import contextlib import copy import time from typing import Callable, Dict, Optional, Tuple @@ -106,15 +107,6 @@ def tmp(module, input: Tuple[torch.Tensor, ...], output: torch.Tensor): ) return tmp - def pre_process_streaming(self, module: NamedModule): - eigen_matrix = self.eigen_scaling_diag_matrix[module.name] - with torch_streamCtx(module.target_device_stream): - if eigen_matrix is not None: - self.eigen_scaling_diag_matrix[module.name] = eigen_matrix.to(device=module.target_device, non_blocking=True) - - module.state["w_wq_diff"] = module.state["w_wq_diff"].to(device=module.target_device, non_blocking=True) - module.state["wq"] = module.state["wq"].to(device=module.target_device, non_blocking=True) - def process(self, module: NamedModule): assert isinstance(module.adapter_cfg, Lora) diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index 409705d93..1a8f394ba 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import contextlib import copy import threading from typing import Callable, Optional, Tuple @@ -108,14 +109,6 @@ def tmp(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): del inp, out return tmp - def pre_process_streaming(self, module: NamedModule): - g = self.tasks[module.name] - with torch_streamCtx(module.target_device_stream): - # log.debug(f"streaming module `{g.name}` to device = `{module.target_device}`") - if g.H is not None: - g.H = g.H.to(device=module.target_device, non_blocking=True) - g.module.weight.data = g.module.weight.data.to(device=module.target_device, non_blocking=True) - def process(self, module: NamedModule): # Reset peak memory stats #torch.cuda.reset_peak_memory_stats() diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index dbc18b746..ecdd4b282 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -197,20 +197,16 @@ def _clone_module_for_devices(self, module: torch.nn.Module, devices: List[torch module_label = getattr(module, "full_name", module.__class__.__name__) clone_timings = [] if DEBUG_ON else None - cleared_attrs = self._clear_non_picklable_state(module) - try: - for dev in devices: - start_ts = time.perf_counter() if DEBUG_ON else None - replica = copy.deepcopy(module) - replica = replica.to(dev) - replica.eval() - _rehome_module_to_device(replica, dev, move_parameters=False, move_buffers=True) - self._clear_non_picklable_state(replica) - clones[dev] = replica - if clone_timings is not None and start_ts is not None: - clone_timings.append((dev, time.perf_counter() - start_ts)) - finally: - self._restore_non_picklable_state(cleared_attrs) + for dev in devices: + start_ts = time.perf_counter() if DEBUG_ON else None + replica = copy.deepcopy(module) + replica = replica.to(dev) + replica.eval() + _rehome_module_to_device(replica, dev, move_parameters=False, move_buffers=True) + clones[dev] = replica + if clone_timings is not None and start_ts is not None: + clone_timings.append((dev, time.perf_counter() - start_ts)) + if clone_timings: timing_str = ", ".join(f"{str(dev)}={duration * 1000:.2f}ms" for dev, duration in clone_timings) log.debug(f"ModuleLooper: deepcopy {module_label} -> {timing_str}") @@ -224,9 +220,6 @@ def maybe_clear(obj: torch.nn.Module): if id(obj) in seen: return seen.add(id(obj)) - if hasattr(obj, "target_device_stream"): - cleared.append((obj, getattr(obj, "target_device_stream"))) - setattr(obj, "target_device_stream", None) if isinstance(module, torch.nn.Module): for sub in module.modules(): @@ -236,11 +229,6 @@ def maybe_clear(obj: torch.nn.Module): return cleared - @staticmethod - def _restore_non_picklable_state(cleared): - for obj, value in cleared: - setattr(obj, "target_device_stream", value) - @staticmethod def _forward_batch_worker( module: torch.nn.Module, @@ -828,7 +816,9 @@ def loop(self, fail_safe: bool = False, **kwargs): for idx, (name, m) in enumerate(subset.items()): is_last = (idx == subset_size - 1) - m.module.target_device, m.module.target_device_stream = device_next() + target_device = device_next() + m.target_device = target_device + m.module.target_device = target_device # Wrap the processor hook with masking if hasattr(subset[name], 'forward_hook'): diff --git a/gptqmodel/looper/named_module.py b/gptqmodel/looper/named_module.py index e10120d30..a8c618a11 100644 --- a/gptqmodel/looper/named_module.py +++ b/gptqmodel/looper/named_module.py @@ -28,8 +28,8 @@ def __init__(self, module: torch.nn.Module, name: str, full_name:str, layer_inde self.layer_index = layer_index # layerid in a repeating layer, if in outside layer, this info may be fake # some processing will move this module to target_device gptq, eora, etc - # self.target_device, self.target_device_stream = device_next() - self.target_device, self.target_device_stream = None, None + # self.target_device = device_next() + self.target_device = None # persistent work state for named module (used by some LoopProcessors) diff --git a/gptqmodel/looper/qqq_processor.py b/gptqmodel/looper/qqq_processor.py index d1bc276c3..f140ab737 100644 --- a/gptqmodel/looper/qqq_processor.py +++ b/gptqmodel/looper/qqq_processor.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import contextlib import copy from typing import Callable, Optional, Tuple @@ -107,14 +108,6 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): q.add_batch(inp[0].data, out.data) # noqa: F821 return tmp - def pre_process_streaming(self, module: NamedModule): - q = self.tasks[module.name] - with torch_streamCtx(module.target_device_stream): - if q.H is not None: - q.H = q.H.to(device=module.target_device, non_blocking=True) - module.weight.data = module.weight.data.to(device=module.target_device, non_blocking=True) - - def process(self, module: NamedModule): self.pb.title(f"Quantizing {module.name} in layer ").draw() qqq = self.tasks diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index f3ab4951c..894fadcb2 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -74,7 +74,7 @@ def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None): self.name = HF_OPTIMUM self.module = module # emulate NamedModule properties - self.module.target_device, self.module.target_device_stream = device_next() + self.module.target_device = device_next() self._validate_module(self.module) diff --git a/gptqmodel/quantization/qqq.py b/gptqmodel/quantization/qqq.py index fe7bcc26e..e392cdc41 100644 --- a/gptqmodel/quantization/qqq.py +++ b/gptqmodel/quantization/qqq.py @@ -196,7 +196,7 @@ def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None): self.name = HF_OPTIMUM self.layer = module # emulate NamedModule properties - self.layer.target_device, self.layer.target_device_stream = device_next() + self.layer.target_device = device_next() self.dev = self.layer.weight.device self._validate_module(self.layer) diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py index bcda2b2ed..564b2f8f2 100644 --- a/gptqmodel/utils/torch.py +++ b/gptqmodel/utils/torch.py @@ -217,14 +217,13 @@ def device_next_reset(): global NEXT_DEVICE_INDEX NEXT_DEVICE_INDEX = 0 -def device_next(balance_strategy: BalanceStrategy = DEFAULT_BALANCE_STRATEGY) -> (torch.device, Union[torch.cuda.Stream, torch.xpu.Stream]): +def device_next(balance_strategy: BalanceStrategy = DEFAULT_BALANCE_STRATEGY) -> torch.device: global NEXT_DEVICE_INDEX if len(ALL_DEVICES) <= 1: - return ALL_DEVICES[0], ALL_STREAMS[0] + return ALL_DEVICES[0] device = ALL_DEVICES[NEXT_DEVICE_INDEX] - device_stream = ALL_STREAMS[NEXT_DEVICE_INDEX] if NEXT_DEVICE_INDEX < len(ALL_DEVICES) - 1: NEXT_DEVICE_INDEX += 1 else: @@ -233,7 +232,7 @@ def device_next(balance_strategy: BalanceStrategy = DEFAULT_BALANCE_STRATEGY) -> else: NEXT_DEVICE_INDEX = 0 - return (device, device_stream) + return device def torch_streamCtx(stream: Union[torch.cuda.Stream, torch.xpu.Stream]) -> StreamContext: return torch.cuda.stream(stream) if HAS_CUDA else torch.xpu.stream(stream) diff --git a/tests/test_tf32_performance.py b/tests/test_tf32_performance.py index 48fbcf94a..3f37e588e 100644 --- a/tests/test_tf32_performance.py +++ b/tests/test_tf32_performance.py @@ -3,6 +3,7 @@ from gptqmodel.utils.torch import tf32_disable_guard, tf32_enable_guard + try: from tabulate import tabulate except ImportError: # pragma: no cover From 1c2fc10ad1b2ba2ee0f29b01978dca4044873cc1 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 1 Oct 2025 07:53:19 +0000 Subject: [PATCH 2/2] fix test Signed-off-by: Qubitium --- gptqmodel/looper/eora_processor.py | 2 +- gptqmodel/looper/gptq_processor.py | 2 +- gptqmodel/looper/loop_processor.py | 6 ++- gptqmodel/looper/qqq_processor.py | 2 +- tests/test_gptq_device_ctx.py | 84 ++++++++++++++++++++++++++++++ 5 files changed, 92 insertions(+), 4 deletions(-) create mode 100644 tests/test_gptq_device_ctx.py diff --git a/gptqmodel/looper/eora_processor.py b/gptqmodel/looper/eora_processor.py index 218a45211..ae2c5a7b7 100644 --- a/gptqmodel/looper/eora_processor.py +++ b/gptqmodel/looper/eora_processor.py @@ -177,7 +177,7 @@ def process(self, module: NamedModule): PROCESS_LOG_LAYER: module.layer_index, PROCESS_LOG_MODULE: module.name, PROCESS_LOG_TIME: f"{duration:.3f}", - PROCESS_LOG_FWD_TIME: f"{self.fwd_time:.3f}", + PROCESS_LOG_FWD_TIME: self.formatted_fwd_time(), PROCESS_MAX_MEMORY: max_memory, } diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index 1a8f394ba..37a5a483d 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -165,7 +165,7 @@ def process(self, module: NamedModule): QUANT_LOG_NSAMPLES: f"{nsamples}", QUANT_LOG_DAMP: f"{damp_percent:.5f}", PROCESS_LOG_TIME: f"{duration:.3f}", - PROCESS_LOG_FWD_TIME: f"{self.fwd_time:.3f}", + PROCESS_LOG_FWD_TIME: self.formatted_fwd_time(), PROCESS_MAX_MEMORY: get_max_memory(), } diff --git a/gptqmodel/looper/loop_processor.py b/gptqmodel/looper/loop_processor.py index ff654a425..c0821b4bb 100644 --- a/gptqmodel/looper/loop_processor.py +++ b/gptqmodel/looper/loop_processor.py @@ -277,6 +277,10 @@ def set_calibration_dataset(self, calibration_dataset): def set_fwd_time(self, fwd_time: float): self.fwd_time = fwd_time + def formatted_fwd_time(self) -> str: + fwd_time = self.fwd_time if self.fwd_time is not None else 0.0 + return f"{fwd_time:.3f}" + # called first def preprocess(self, module: NamedModule, **kwargs): pass @@ -353,4 +357,4 @@ def get_max_memory() -> str: max_memory = f"{active_0:.2f}MB, {active_1:.2f}MB" else: max_memory = f"{active_0:.2f}MB" - return max_memory \ No newline at end of file + return max_memory diff --git a/gptqmodel/looper/qqq_processor.py b/gptqmodel/looper/qqq_processor.py index f140ab737..f60eef4ba 100644 --- a/gptqmodel/looper/qqq_processor.py +++ b/gptqmodel/looper/qqq_processor.py @@ -149,7 +149,7 @@ def process(self, module: NamedModule): QUANT_LOG_NSAMPLES: f"{nsamples}", QUANT_LOG_DAMP: f"{damp_percent:.5f}", PROCESS_LOG_TIME: f"{duration:.3f}", - PROCESS_LOG_FWD_TIME: f"{self.fwd_time:.3f}", + PROCESS_LOG_FWD_TIME: self.formatted_fwd_time(), } if self.qcfg.dynamic is not None: diff --git a/tests/test_gptq_device_ctx.py b/tests/test_gptq_device_ctx.py new file mode 100644 index 000000000..4d57e5a5b --- /dev/null +++ b/tests/test_gptq_device_ctx.py @@ -0,0 +1,84 @@ +import concurrent.futures +import os +import sys +from typing import Dict, List + +import pytest +import torch + +from gptqmodel.looper.gptq_processor import GPTQProcessor +from gptqmodel.looper.named_module import NamedModule +from gptqmodel.quantization.config import QuantizeConfig + + +def _dummy_prepare_dataset(*, calibration_dataset, calibration_dataset_concat_size, calibration_dataset_sort, batch_size): + return calibration_dataset + + +class _DummyProgressBar: + def title(self, _): + return self + + def draw(self): + return None + + +def _is_free_threaded() -> bool: + gil_check = getattr(sys, "_is_gil_enabled", None) + if callable(gil_check): + return not gil_check() + env_value = os.environ.get("PYTHON_GIL", "1").lower() + return env_value in {"0", "false", "off"} + + +def _run_quant_on_device(device_index: int) -> torch.device: + torch.cuda.set_device(device_index) + target = torch.device(f"cuda:{device_index}") + module = torch.nn.Linear(8, 8, bias=False).to(target) + named = NamedModule(module, name=f"linear_{device_index}", full_name=f"model.layers.{device_index}.linear", layer_index=device_index) + + qcfg = QuantizeConfig(mock_quantization=True, group_size=-1, desc_act=False) + processor = GPTQProcessor( + tokenizer=None, + qcfg=qcfg, + calibration=None, + prepare_dataset_func=_dummy_prepare_dataset, + calibration_concat_size=None, + calibration_sort=None, + batch_size=1, + logger_board="", + require_fwd=False, + calculate_w_wq_diff=False, + ) + processor.pb = _DummyProgressBar() + + processor.preprocess(named, fail_safe=False) + named.module.target_device = target + + processor.process(named) + + return named.weight.data.device + + +#@pytest.mark.cuda +def test_gptq_quantize_keeps_weight_on_assigned_device_multigpu_free_thread(): + if not torch.cuda.is_available(): + pytest.skip("CUDA required for multi-GPU device context test") + + if torch.cuda.device_count() < 8: + pytest.skip("Requires at least 8 CUDA devices") + + if sys.version_info < (3, 13): + pytest.skip("Requires Python 3.13 free-threading build") + + if not _is_free_threaded(): + pytest.skip("Requires PYTHON_GIL=0 (free-threading)") + + device_indices: List[int] = list(range(8)) + + with concurrent.futures.ThreadPoolExecutor(max_workers=len(device_indices)) as pool: + futures = [pool.submit(_run_quant_on_device, idx) for idx in device_indices] + results: Dict[int, torch.device] = {idx: future.result() for idx, future in zip(device_indices, futures)} + + for idx, device in results.items(): + assert device.type == "cuda" and device.index == idx