diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index 4fc19e4fe..8b660d210 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -24,6 +24,7 @@ from ..utils.logger import setup_logger, log_time_block from ..utils.device import get_device from ..utils.model import create_quant_module, find_modules, move_to, pack_model, pack_module +from ..utils.offload import parent_module_lock from ..utils.torch import tf32_disable_guard log = setup_logger() @@ -277,21 +278,23 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): logger=log, module_name=module_label, ): - create_quant_module( - name=module.full_name, - linear_cls=model.qlinear_kernel, - bits=self.qcfg.bits, - desc_act=self.qcfg.desc_act, - dynamic=self.qcfg.dynamic, - group_size=self.qcfg.group_size, - module=model.model, - submodule=module, - sym=self.qcfg.sym, - device=self.qcfg.device, - lm_head_name=model.lm_head, - pack_dtype=self.qcfg.pack_dtype, - register_buffers=False, - ) + parent_key = getattr(module, "full_name", getattr(module, "name", "")) + with parent_module_lock(parent_key): + create_quant_module( + name=module.full_name, + linear_cls=model.qlinear_kernel, + bits=self.qcfg.bits, + desc_act=self.qcfg.desc_act, + dynamic=self.qcfg.dynamic, + group_size=self.qcfg.group_size, + module=model.model, + submodule=module, + sym=self.qcfg.sym, + device=self.qcfg.device, + lm_head_name=model.lm_head, + pack_dtype=self.qcfg.pack_dtype, + register_buffers=False, + ) if timer is not None and create_start is not None: timer.record( "submodule_finalize_create", @@ -311,17 +314,19 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): logger=log, module_name=module_label, ): - packer_label = pack_module( - name=module.full_name, - qModules=qModules, - q_scales=q_scales, - q_zeros=q_zeros, - q_g_idx=q_g_idx, - layers=layers, - quant_linear_cls=model.qlinear_kernel, - lock=self.lock, - quantize_config=self.qcfg, - ) + parent_key = getattr(module, "full_name", getattr(module, "name", "")) + with parent_module_lock(parent_key): + packer_label = pack_module( + name=module.full_name, + qModules=qModules, + q_scales=q_scales, + q_zeros=q_zeros, + q_g_idx=q_g_idx, + layers=layers, + quant_linear_cls=model.qlinear_kernel, + lock=self.lock, + quantize_config=self.qcfg, + ) if timer is not None and pack_start is not None: timer.record( "submodule_finalize_pack", diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 9481d6a55..298f47850 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -1505,7 +1505,6 @@ def _finalize_on_worker(process, module, idx, total, module_label, layer_idx): time.perf_counter() - start, source=resolved_label, ) - process_name = process.name() if process is not None else "" return FinalizeProgressInfo(module_label, process_name, layer_idx) diff --git a/gptqmodel/utils/offload.py b/gptqmodel/utils/offload.py index 8d0ee5b36..0cdc0b543 100644 --- a/gptqmodel/utils/offload.py +++ b/gptqmodel/utils/offload.py @@ -8,8 +8,8 @@ import os import shutil import struct -from threading import Lock -from typing import Iterable, List, Optional, Set, Tuple +from threading import Lock, RLock +from typing import Dict, Iterable, List, Optional, Set, Tuple import accelerate import torch @@ -68,8 +68,35 @@ def is_meta_module(m: nn.Module) -> bool: return True return False -# Serialize access to module.state_dict(), which is not thread-safe under concurrent calls. -_OFFLOAD_LOCK = Lock() +# Serialize access to module.state_dict(), which is not thread-safe under +# concurrent calls that mutate the same parent module. +_PARENT_LOCKS: Dict[str, RLock] = {} +_PARENT_LOCKS_GUARD = Lock() +_ROOT_PARENT_KEY = "" + + +def _parent_lock_key(module_name: str) -> str: + if not module_name: + return _ROOT_PARENT_KEY + parts = module_name.split(".") + if len(parts) <= 1: + return parts[0] + return ".".join(parts[:-1]) + + +@contextlib.contextmanager +def parent_module_lock(module_name: str): + key = _parent_lock_key(module_name) + with _PARENT_LOCKS_GUARD: + lock = _PARENT_LOCKS.get(key) + if lock is None: + lock = RLock() + _PARENT_LOCKS[key] = lock + lock.acquire() + try: + yield + finally: + lock.release() def _prepare_offload_directory(target_dir: str) -> None: @@ -131,8 +158,7 @@ def _bundle_module_state_dict(module: nn.Module, offload_dir: str) -> dict: def offload_to_disk(module: List[str] | nn.Module, model: nn.Module, disk_path: str = "."): - with _OFFLOAD_LOCK: - _offload_to_disk_impl(module=module, model=model, disk_path=disk_path) + _offload_to_disk_impl(module=module, model=model, disk_path=disk_path) def _offload_to_disk_impl(module: List[str] | nn.Module, model: nn.Module, disk_path: str = "."): @@ -173,6 +199,11 @@ def _offload_to_disk_impl(module: List[str] | nn.Module, model: nn.Module, disk_ #offload_to_disk = _OFFLOAD_SAFE.offload_to_disk def _offload_disk(module: nn.Module, name: str, disk_path: str = "."): + with parent_module_lock(name): + _offload_disk_locked(module=module, name=name, disk_path=disk_path) + + +def _offload_disk_locked(module: nn.Module, name: str, disk_path: str = "."): if is_meta_module(module): # print(f"[skip] '{name}' is on meta; leaving as-is") return