From 93d331894620d6c9e9b41646942d369c426284d6 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 1 Oct 2025 08:44:48 +0000 Subject: [PATCH 1/3] rehome params + skip deepcopy for single device Signed-off-by: Qubitium --- gptqmodel/looper/module_looper.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index d67ab8452..07c36c608 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -197,13 +197,22 @@ def _clone_module_for_devices(self, module: torch.nn.Module, devices: List[torch module_label = getattr(module, "full_name", module.__class__.__name__) clone_timings = [] if DEBUG_ON else None + if len(devices) == 1: + target_device = devices[0] + module = module.to(target_device) + module.eval() + _rehome_module_to_device(module, target_device, move_parameters=True, move_buffers=True) + self._clear_non_picklable_state(module) + clones[target_device] = module + return clones + self._clear_non_picklable_state(module) for dev in devices: start_ts = time.perf_counter() if DEBUG_ON else None replica = copy.deepcopy(module) replica = replica.to(dev) replica.eval() - _rehome_module_to_device(replica, dev, move_parameters=False, move_buffers=True) + _rehome_module_to_device(replica, dev, move_parameters=True, move_buffers=True) self._clear_non_picklable_state(replica) clones[dev] = replica if clone_timings is not None and start_ts is not None: @@ -255,7 +264,7 @@ def _forward_batch_worker( queue. """ module_device = get_device(module) - _rehome_module_to_device(module, module_device, move_parameters=False, move_buffers=True) + _rehome_module_to_device(module, module_device, move_parameters=True, move_buffers=True) inputs = [move_to(inp, device=module_device, stream=False) for inp in layer_input] @@ -408,7 +417,7 @@ def _run_forward_batches_single( if reuse_kv and prev_kv is not None: additional_inputs["kv_last_layer"] = nested_move_to(prev_kv, device=cur_layer_device, stream=False) - _rehome_module_to_device(module, cur_layer_device, move_parameters=False, move_buffers=True) + _rehome_module_to_device(module, cur_layer_device, move_parameters=True, move_buffers=True) module_output = None try: From d4d0288c5a0c8c85ae337b08b05e27c65051fa31 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 1 Oct 2025 09:38:22 +0000 Subject: [PATCH 2/3] faster clone Signed-off-by: Qubitium --- gptqmodel/looper/module_looper.py | 43 ++++++++- tests/test_torch_replicate.py | 154 ++++++++++++++++++++++++++++++ 2 files changed, 196 insertions(+), 1 deletion(-) create mode 100644 tests/test_torch_replicate.py diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 07c36c608..42ca3edc7 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -22,6 +22,7 @@ from typing import Dict, List, Optional import torch +from torch.nn.parallel import replicate as torch_replicate from ..looper.dequantize_processor import DequantizeProcessor from ..looper.eora_processor import EoraProcessor @@ -203,10 +204,49 @@ def _clone_module_for_devices(self, module: torch.nn.Module, devices: List[torch module.eval() _rehome_module_to_device(module, target_device, move_parameters=True, move_buffers=True) self._clear_non_picklable_state(module) + setattr(module, "_gptqmodule_device_hint", target_device) clones[target_device] = module return clones self._clear_non_picklable_state(module) + primary_device = devices[0] + replicate_devices = [dev for dev in devices if dev.type == primary_device.type] + use_replicate = ( + len(replicate_devices) == len(devices) + and primary_device.type == "cuda" + and torch.cuda.is_available() + ) + + if use_replicate: + try: + module = module.to(primary_device) + module.eval() + _rehome_module_to_device(module, primary_device, move_parameters=True, move_buffers=True) + setattr(module, "_gptqmodule_device_hint", primary_device) + replicate_start = time.perf_counter() if DEBUG_ON else None + replicas = torch_replicate(module, devices) + if clone_timings is not None and replicate_start is not None: + clone_timings.append(("replicate", time.perf_counter() - replicate_start)) + + for dev, replica in zip(devices, replicas): + replica.eval() + _rehome_module_to_device(replica, dev, move_parameters=True, move_buffers=True) + self._clear_non_picklable_state(replica) + setattr(replica, "_gptqmodule_device_hint", dev) + clones[dev] = replica + + if clone_timings: + timing_str = ", ".join( + f"{str(dev)}={duration * 1000:.2f}ms" for dev, duration in clone_timings + ) + log.debug(f"ModuleLooper: replicate {module_label} -> {timing_str}") + return clones + except Exception: + log.info("Clone: fast clone failed") + # Fall back to deepcopy path if replicate is unsupported for this module. + if clone_timings is not None: + clone_timings.append(("replicate_failed", 0.0)) + for dev in devices: start_ts = time.perf_counter() if DEBUG_ON else None replica = copy.deepcopy(module) @@ -214,6 +254,7 @@ def _clone_module_for_devices(self, module: torch.nn.Module, devices: List[torch replica.eval() _rehome_module_to_device(replica, dev, move_parameters=True, move_buffers=True) self._clear_non_picklable_state(replica) + setattr(replica, "_gptqmodule_device_hint", dev) clones[dev] = replica if clone_timings is not None and start_ts is not None: clone_timings.append((dev, time.perf_counter() - start_ts)) @@ -263,7 +304,7 @@ def _forward_batch_worker( update. The thin signature keeps the function pickleable for the worker queue. """ - module_device = get_device(module) + module_device = getattr(module, "_gptqmodule_device_hint", None) or get_device(module) _rehome_module_to_device(module, module_device, move_parameters=True, move_buffers=True) inputs = [move_to(inp, device=module_device, stream=False) for inp in layer_input] diff --git a/tests/test_torch_replicate.py b/tests/test_torch_replicate.py new file mode 100644 index 000000000..9aacf9e01 --- /dev/null +++ b/tests/test_torch_replicate.py @@ -0,0 +1,154 @@ +import copy +from typing import Callable + +import pytest +import torch +from tabulate import tabulate +from torch.nn.parallel import replicate as torch_replicate + + +TIMED_TRIALS = 5 +WARMUP_TRIALS = 1 + + +def _build_template_module() -> torch.nn.Module: + torch.manual_seed(0) + return torch.nn.Sequential( + torch.nn.Linear(4096, 4096, bias=False), + torch.nn.GELU(), + torch.nn.Linear(4096, 4096, bias=False), + ) + + +def _replicate_strategy(module: torch.nn.Module, devices: list[torch.device]) -> list[torch.nn.Module]: + return torch_replicate(module, devices) + + +def _deepcopy_strategy(module: torch.nn.Module, devices: list[torch.device]) -> list[torch.nn.Module]: + clones = [] + for dev in devices: + replica = copy.deepcopy(module) + clones.append(replica.to(dev)) + return clones + + +def _benchmark( + strategy: Callable[[torch.nn.Module, list[torch.device]], list[torch.nn.Module]], + devices: list[torch.device], + template: torch.nn.Module, + *, + trials: int = TIMED_TRIALS, + warmup: int = WARMUP_TRIALS, +) -> tuple[list[float], list[int]]: + times: list[float] = [] + mems: list[int] = [] + + def _run(record: bool) -> None: + module = copy.deepcopy(template).to(devices[0]).eval() + torch.cuda.synchronize() + + baselines = {} + for dev in devices: + baselines[dev] = torch.cuda.memory_allocated(dev) + torch.cuda.reset_peak_memory_stats(dev) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + clones = strategy(module, devices) + end_event.record() + torch.cuda.synchronize() + + if record: + duration = start_event.elapsed_time(end_event) / 1000.0 + extra_mem = 0 + for dev in devices: + peak = torch.cuda.max_memory_allocated(dev) + extra_mem += max(0, peak - baselines[dev]) + + times.append(duration) + mems.append(extra_mem) + + del clones + del module + torch.cuda.empty_cache() + torch.cuda.synchronize() + + for _ in range(warmup): + _run(record=False) + for _ in range(trials): + _run(record=True) + + return times, mems + + +def _summarise_metrics(times: list[float], mems: list[int]): + avg_time = sum(times) / len(times) + avg_mem = sum(mems) / len(mems) + return { + "time_avg": avg_time, + "time_min": min(times), + "time_max": max(times), + "mem_avg": avg_mem, + "mem_min": min(mems), + "mem_max": max(mems), + } + + +@pytest.mark.cuda +def test_torch_replicate(): + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + pytest.skip("torch.nn.parallel.replicate comparison requires at least two CUDA devices") + + devices = [torch.device(f"cuda:{idx}") for idx in range(2)] + template = _build_template_module() + + replicate_times, replicate_mems = _benchmark(_replicate_strategy, devices, template) + deepcopy_times, deepcopy_mems = _benchmark(_deepcopy_strategy, devices, template) + + replicate_summary = _summarise_metrics(replicate_times, replicate_mems) + deepcopy_summary = _summarise_metrics(deepcopy_times, deepcopy_mems) + + table = [ + [ + "replicate", + replicate_summary["time_avg"], + replicate_summary["time_min"], + replicate_summary["time_max"], + replicate_summary["mem_avg"] / (1024 ** 2), + replicate_summary["mem_min"] / (1024 ** 2), + replicate_summary["mem_max"] / (1024 ** 2), + ], + [ + "deepcopy", + deepcopy_summary["time_avg"], + deepcopy_summary["time_min"], + deepcopy_summary["time_max"], + deepcopy_summary["mem_avg"] / (1024 ** 2), + deepcopy_summary["mem_min"] / (1024 ** 2), + deepcopy_summary["mem_max"] / (1024 ** 2), + ], + ] + + headers = [ + "strategy", + "time_avg_s", + "time_min_s", + "time_max_s", + "mem_avg_MB", + "mem_min_MB", + "mem_max_MB", + ] + + print(tabulate(table, headers=headers, floatfmt=".4f")) + + assert replicate_summary["time_avg"] <= deepcopy_summary["time_avg"], ( + "replicate slower than deepcopy: " + f"replicate={replicate_summary['time_avg']:.4f}s, deepcopy={deepcopy_summary['time_avg']:.4f}s" + ) + assert replicate_summary["mem_avg"] <= deepcopy_summary["mem_avg"], ( + "replicate used more memory: " + f"replicate={replicate_summary['mem_avg'] / (1024 ** 2):.1f}MB, " + f"deepcopy={deepcopy_summary['mem_avg'] / (1024 ** 2):.1f}MB" + ) From e4332bff7bc7f635f7e12f80c93880ef1158e146 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 1 Oct 2025 10:27:52 +0000 Subject: [PATCH 3/3] fix turtle model thread safe state Signed-off-by: Qubitium --- gptqmodel/looper/module_looper.py | 2 + gptqmodel/models/base.py | 65 ++++++++++++++++++++++++------- 2 files changed, 52 insertions(+), 15 deletions(-) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 42ca3edc7..f3444c9a4 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -808,6 +808,8 @@ def loop(self, fail_safe: bool = False, **kwargs): quant_modules_pb.title(f"Quantizing layer {layer_index} of {layer_count - 1}").draw() module = layers[layer_index] + self.gptq_model.wait_for_turtle_reload() + if module.__class__.__name__.lower() == "MllamaCrossAttentionDecoderLayer".lower(): # TODO FIXME: currently we not support quantizing cross attention layer (pixel_values) continue diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index b1a9ca8b4..3525134e1 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -207,6 +207,8 @@ def __init__( self._background_pool: Optional["DeviceThreadPool"] = None self._turtle_reload_future: Optional[Future] = None self._turtle_reload_lock = threading.Lock() + self._turtle_ready = threading.Event() + self._turtle_ready.set() # compat: state to assist in checkpoint_format gptq(v1) to gptq_v2 conversion self.qlinear_kernel = qlinear_kernel @@ -1086,6 +1088,24 @@ def format_nodes(nodes): def register_background_pool(self, pool: Optional["DeviceThreadPool"]) -> None: self._background_pool = pool + def _wait_for_turtle_ready(self, timeout: Optional[float] = None) -> None: + if self.turtle_model is None: + return + + self._turtle_ready.wait(timeout=timeout) + + def _ensure_turtle_ready(self) -> None: + if self.turtle_model is None: + return + + self._apply_completed_turtle_reload() + + if self._turtle_ready.is_set(): + return + + self._wait_for_turtle_ready() + self._apply_completed_turtle_reload() + def _clone_model_init_kwargs(self, source: PreTrainedModel) -> Dict[str, Any]: kwargs = getattr(source, "_model_init_kwargs", {}) or {} if isinstance(kwargs, dict): @@ -1112,13 +1132,16 @@ def _schedule_turtle_reload(self) -> None: self._apply_completed_turtle_reload() if self.turtle_model is None or self.model_local_path is None: + self._turtle_ready.set() return pool = self._background_pool if pool is None: + self._turtle_ready.clear() new_model = self._reload_turtle_model_sync() if new_model is not None: self.turtle_model = new_model + self._turtle_ready.set() return with self._turtle_reload_lock: @@ -1132,25 +1155,35 @@ def _schedule_turtle_reload(self) -> None: loader = self.loader def _reload_task(): - model = loader.from_pretrained( - model_local_path, - config=config, - low_cpu_mem_usage=True, - **reload_kwargs, - ) - model._model_init_kwargs = reload_kwargs - return model + try: + model = loader.from_pretrained( + model_local_path, + config=config, + low_cpu_mem_usage=True, + **reload_kwargs, + ) + model._model_init_kwargs = reload_kwargs + return model + finally: + self._turtle_ready.set() + + self._turtle_ready.clear() try: - self._turtle_reload_future = pool.submit(CPU, _reload_task) + future = pool.submit(CPU, _reload_task) except Exception as exc: log.warning("Turtle reload scheduling failed; falling back to sync reload: %s", exc) self._turtle_reload_future = None + else: + self._turtle_reload_future = future + return - if self._turtle_reload_future is None: - new_model = self._reload_turtle_model_sync() - if new_model is not None: - self.turtle_model = new_model + # Fallback: synchronous reload (either pool missing or submit failed) + self._turtle_ready.clear() + new_model = self._reload_turtle_model_sync() + if new_model is not None: + self.turtle_model = new_model + self._turtle_ready.set() def _apply_completed_turtle_reload(self, *, wait: bool = False) -> None: """Adopt the result of any finished background turtle reload.""" @@ -1181,15 +1214,17 @@ def _apply_completed_turtle_reload(self, *, wait: bool = False) -> None: self._turtle_reload_future = None if new_model is not None: self.turtle_model = new_model + self._turtle_ready.set() return with self._turtle_reload_lock: if self._turtle_reload_future is future: self.turtle_model = new_model self._turtle_reload_future = None + self._turtle_ready.set() def wait_for_turtle_reload(self) -> None: - self._apply_completed_turtle_reload(wait=True) + self._ensure_turtle_ready() # transfer actually materizlied module from turtle (real) to shell def shell_module_materialize( @@ -1198,7 +1233,7 @@ def shell_module_materialize( device: torch.device, non_blocking: bool = False, ) -> torch.nn.Module: - self._apply_completed_turtle_reload() + self._ensure_turtle_ready() if self.turtle_model is None: if get_device(target_submodule) != device: