From 8c4448f60df65f8c9bd6db638a4f97c96318a0f4 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 20 Oct 2025 07:10:10 +0000 Subject: [PATCH] merge locking code --- gptqmodel/looper/gptq_processor.py | 5 ++- gptqmodel/looper/named_module.py | 57 ++++++++++++++++++++---------- gptqmodel/utils/module_locks.py | 53 +++++++++++++++++++++++++++ gptqmodel/utils/offload.py | 33 ++--------------- tests/test_named_module.py | 15 ++++---- 5 files changed, 103 insertions(+), 60 deletions(-) create mode 100644 gptqmodel/utils/module_locks.py diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index 8b660d210..c0622ad43 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -24,7 +24,7 @@ from ..utils.logger import setup_logger, log_time_block from ..utils.device import get_device from ..utils.model import create_quant_module, find_modules, move_to, pack_model, pack_module -from ..utils.offload import parent_module_lock +from ..utils.module_locks import parent_module_lock from ..utils.torch import tf32_disable_guard log = setup_logger() @@ -268,6 +268,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): layers = find_modules(model.model) module_label = getattr(module, "full_name", getattr(module, "name", "")) + parent_key = getattr(module, "full_name", getattr(module, "name", None)) # replace module with quantized module timer = getattr(model, "quant_region_timer", None) @@ -278,7 +279,6 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): logger=log, module_name=module_label, ): - parent_key = getattr(module, "full_name", getattr(module, "name", "")) with parent_module_lock(parent_key): create_quant_module( name=module.full_name, @@ -314,7 +314,6 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): logger=log, module_name=module_label, ): - parent_key = getattr(module, "full_name", getattr(module, "name", "")) with parent_module_lock(parent_key): packer_label = pack_module( name=module.full_name, diff --git a/gptqmodel/looper/named_module.py b/gptqmodel/looper/named_module.py index 570456f25..db078c3ab 100644 --- a/gptqmodel/looper/named_module.py +++ b/gptqmodel/looper/named_module.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -import threading +import contextlib from typing import Any, Dict, Optional import torch @@ -13,6 +13,7 @@ from torch.nn.modules.conv import _ConvNd from ..utils.logger import setup_logger +from ..utils.module_locks import get_parent_lock, parent_module_lock from ..utils.stream import stream_sync as stream_sync_events from ..utils.stream import stream_tensor_dict_to_cpu @@ -23,16 +24,16 @@ class NamedModule(torch.nn.Module): def __init__(self, module: torch.nn.Module, name: str, full_name:str, layer_index: int) -> None: super().__init__() - self.module = module # wrapped module + self.module = module # wrapped module self.module_dtype = next(module.parameters()).dtype - self.name = name # module name - self.full_name = full_name # module full name (path) within model - self.layer_index = layer_index # layerid in a repeating layer, if in outside layer, this info may be fake + self.name = name # module name + self.full_name = full_name # full dotted path inside model + self.layer_index = layer_index # layer index for repeated blocks + self._parent_lock = get_parent_lock(full_name) # persistent work state for named module (used by some LoopProcessors) # store all `processed()` work state/data/result here self.state = {} - self._state_lock = threading.RLock() # print(f"NamedModule init: name: `{name}, full-name: `{full_name}`") @@ -75,22 +76,22 @@ def named_buffers(self, prefix: str = "", recurse: bool = True): def register_buffer( self, name: str, tensor: Optional[Tensor], persistent: bool = True ) -> None: - with self._state_lock: + with self._parent_lock: return self.module.register_buffer(name, tensor, persistent) def unregister_buffer(self, name: str): - with self._state_lock: + with self._parent_lock: if name in self.module._buffers: del self.module._buffers[name] if hasattr(self.module, name): delattr(self.module, name) def register_parameter(self, name: str, param: Optional[Parameter]) -> None: - with self._state_lock: + with self._parent_lock: return self.module.register_parameter(name, param) def unregister_parameter(self, name: str) -> None: - with self._state_lock: + with self._parent_lock: if name in self.module._parameters: del self.module._parameters[name] if hasattr(self.module, name): @@ -107,8 +108,15 @@ def unregister_parameter(self, name: str) -> None: # getattr is only called if python cannot find attr for `self` def __getattr__(self, name: str): - with self._state_lock: - return getattr(self.module, name) + try: + lock = object.__getattribute__(self, "_parent_lock") + except AttributeError: + lock = None + module = object.__getattribute__(self, "module") + if lock is None: + return getattr(module, name) + with lock: + return getattr(module, name) # setattr is always called by python even if attr exists in `self` def __setattr__(self, name: str, value: Any) -> None: @@ -119,46 +127,57 @@ def __setattr__(self, name: str, value: Any) -> None: "full_name", "layer_index", "state", + "_parent_lock", "target_device", "register_buffer", "unregister_buffer", "register_parameter", "unregister_parameter", - "_state_lock", ]: object.__setattr__(self, name, value) return - with self._state_lock: - setattr(self.module, name, value) + try: + lock = object.__getattribute__(self, "_parent_lock") + except AttributeError: + lock = None + module = object.__getattribute__(self, "module") + if lock is None: + setattr(module, name, value) + else: + with lock: + setattr(module, name, value) def stream_state_payload_to_cpu( self, tensors: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: + state_lock = self._parent_lock return stream_tensor_dict_to_cpu( tensors, store_callback=lambda host_map: self.state.update(host_map), state=self.state, - state_lock=self._state_lock, + state_lock=state_lock, ) def stream_parameters_to_cpu(self) -> Dict[str, torch.Tensor]: + state_lock = self._parent_lock tensor_map = {name: param for name, param in self.module.named_parameters(recurse=False)} return stream_tensor_dict_to_cpu( tensor_map, store_callback=lambda host_map: self.state.setdefault("parameters_cpu", {}).update(host_map), state=self.state, - state_lock=self._state_lock, + state_lock=state_lock, ) def stream_buffers_to_cpu(self) -> Dict[str, torch.Tensor]: + state_lock = self._parent_lock tensor_map = {name: buf for name, buf in self.module.named_buffers(recurse=False)} return stream_tensor_dict_to_cpu( tensor_map, store_callback=lambda host_map: self.state.setdefault("buffers_cpu", {}).update(host_map), state=self.state, - state_lock=self._state_lock, + state_lock=state_lock, ) def stream_all_to_cpu(self) -> Dict[str, Dict[str, torch.Tensor]]: @@ -167,4 +186,4 @@ def stream_all_to_cpu(self) -> Dict[str, Dict[str, torch.Tensor]]: return {"parameters": params, "buffers": buffers} def stream_sync(self) -> None: - stream_sync_events(self.state, self._state_lock) + stream_sync_events(self.state, self._parent_lock) diff --git a/gptqmodel/utils/module_locks.py b/gptqmodel/utils/module_locks.py new file mode 100644 index 000000000..9819609ec --- /dev/null +++ b/gptqmodel/utils/module_locks.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +import contextlib +from threading import Lock, RLock +from typing import Dict, Iterator, Optional + + +__all__ = [ + "ROOT_PARENT_KEY", + "get_parent_lock", + "parent_lock_key", + "parent_module_lock", +] + + +ROOT_PARENT_KEY = "" + +_PARENT_LOCKS: Dict[str, RLock] = {} +_PARENT_LOCKS_GUARD = Lock() + + +def parent_lock_key(module_name: Optional[str]) -> str: + if not module_name: + return ROOT_PARENT_KEY + parts = module_name.split(".") + if len(parts) <= 1: + return parts[0] + return ".".join(parts[:-1]) + + +def get_parent_lock(module_name: Optional[str]) -> RLock: + key = parent_lock_key(module_name) + with _PARENT_LOCKS_GUARD: + lock = _PARENT_LOCKS.get(key) + if lock is None: + lock = RLock() + _PARENT_LOCKS[key] = lock + return lock + + +@contextlib.contextmanager +def parent_module_lock(module_name: Optional[str]) -> Iterator[None]: + lock = get_parent_lock(module_name) + lock.acquire() + try: + yield + finally: + lock.release() diff --git a/gptqmodel/utils/offload.py b/gptqmodel/utils/offload.py index 0cdc0b543..0c714c3f3 100644 --- a/gptqmodel/utils/offload.py +++ b/gptqmodel/utils/offload.py @@ -8,8 +8,7 @@ import os import shutil import struct -from threading import Lock, RLock -from typing import Dict, Iterable, List, Optional, Set, Tuple +from typing import Iterable, List, Optional, Set, Tuple import accelerate import torch @@ -23,6 +22,7 @@ from ..looper.named_module import NamedModule from .device import get_device +from .module_locks import parent_module_lock from .torch import CPU, META @@ -70,35 +70,6 @@ def is_meta_module(m: nn.Module) -> bool: # Serialize access to module.state_dict(), which is not thread-safe under # concurrent calls that mutate the same parent module. -_PARENT_LOCKS: Dict[str, RLock] = {} -_PARENT_LOCKS_GUARD = Lock() -_ROOT_PARENT_KEY = "" - - -def _parent_lock_key(module_name: str) -> str: - if not module_name: - return _ROOT_PARENT_KEY - parts = module_name.split(".") - if len(parts) <= 1: - return parts[0] - return ".".join(parts[:-1]) - - -@contextlib.contextmanager -def parent_module_lock(module_name: str): - key = _parent_lock_key(module_name) - with _PARENT_LOCKS_GUARD: - lock = _PARENT_LOCKS.get(key) - if lock is None: - lock = RLock() - _PARENT_LOCKS[key] = lock - lock.acquire() - try: - yield - finally: - lock.release() - - def _prepare_offload_directory(target_dir: str) -> None: if os.path.isdir(target_dir): shutil.rmtree(target_dir) diff --git a/tests/test_named_module.py b/tests/test_named_module.py index db368ea49..87188aec3 100644 --- a/tests/test_named_module.py +++ b/tests/test_named_module.py @@ -17,6 +17,7 @@ import torch from gptqmodel.looper.named_module import NamedModule +from gptqmodel.utils.module_locks import parent_module_lock def _make_linear(features: int = 8, device: torch.device | None = None) -> torch.nn.Linear: @@ -181,16 +182,16 @@ def _fingerprint_last_value(tensor: torch.Tensor) -> float: def _verify_expected(ctx: _ModuleContext, expected_items: tuple[tuple[str, _ExpectedTensor], ...]) -> bool: named = ctx.named for key, expected in expected_items: - with named._state_lock: + with parent_module_lock(named.full_name): host_tensor = named.state.get(key) event_map = named.state.get("streaming_event_map", {}) pending_event = event_map.get(key) if host_tensor is None: ctx.named.stream_sync() - with named._state_lock: + with parent_module_lock(named.full_name): host_tensor = named.state.get(key) if host_tensor is None: - with named._state_lock: + with parent_module_lock(named.full_name): available = sorted(str(k) for k in named.state.keys()) error_queue.put(f"Missing host tensor for key {key}; available={available}") stop_event.set() @@ -204,7 +205,7 @@ def _verify_expected(ctx: _ModuleContext, expected_items: tuple[tuple[str, _Expe or not math.isclose(actual_sum, expected.checksum, rel_tol=0.0, abs_tol=1e-2) ): ctx.named.stream_sync() - with named._state_lock: + with parent_module_lock(named.full_name): retry_tensor = named.state.get(key) if retry_tensor is not None: retry_val = _fingerprint_last_value(retry_tensor) @@ -213,7 +214,7 @@ def _verify_expected(ctx: _ModuleContext, expected_items: tuple[tuple[str, _Expe retry_val = None retry_sum = None del host_tensor - with named._state_lock: + with parent_module_lock(named.full_name): named.state.pop(key, None) error_queue.put( "Mismatch for " @@ -224,7 +225,7 @@ def _verify_expected(ctx: _ModuleContext, expected_items: tuple[tuple[str, _Expe stop_event.set() return False del host_tensor - with named._state_lock: + with parent_module_lock(named.full_name): named.state.pop(key, None) named.state.get("streaming_event_map", {}).pop(key, None) return True @@ -340,7 +341,7 @@ def _worker(thread_id: int) -> None: for device in devices: torch.cuda.synchronize(device=device) for ctx in module_contexts: - with ctx.named._state_lock: + with parent_module_lock(ctx.named.full_name): keys_to_remove = [key for key in ctx.named.state.keys() if key.startswith("thread")] for key in keys_to_remove: ctx.named.state.pop(key, None)