Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions gptqmodel/looper/gptq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ..utils.logger import setup_logger, log_time_block
from ..utils.device import get_device
from ..utils.model import create_quant_module, find_modules, move_to, pack_model, pack_module
from ..utils.offload import parent_module_lock
from ..utils.module_locks import parent_module_lock
from ..utils.torch import tf32_disable_guard

log = setup_logger()
Expand Down Expand Up @@ -268,6 +268,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs):

layers = find_modules(model.model)
module_label = getattr(module, "full_name", getattr(module, "name", ""))
parent_key = getattr(module, "full_name", getattr(module, "name", None))

# replace module with quantized module
timer = getattr(model, "quant_region_timer", None)
Expand All @@ -278,7 +279,6 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs):
logger=log,
module_name=module_label,
):
parent_key = getattr(module, "full_name", getattr(module, "name", ""))
with parent_module_lock(parent_key):
create_quant_module(
name=module.full_name,
Expand Down Expand Up @@ -314,7 +314,6 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs):
logger=log,
module_name=module_label,
):
parent_key = getattr(module, "full_name", getattr(module, "name", ""))
with parent_module_lock(parent_key):
packer_label = pack_module(
name=module.full_name,
Expand Down
57 changes: 38 additions & 19 deletions gptqmodel/looper/named_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import threading
import contextlib
from typing import Any, Dict, Optional

import torch
Expand All @@ -13,6 +13,7 @@
from torch.nn.modules.conv import _ConvNd

from ..utils.logger import setup_logger
from ..utils.module_locks import get_parent_lock, parent_module_lock
from ..utils.stream import stream_sync as stream_sync_events
from ..utils.stream import stream_tensor_dict_to_cpu

Expand All @@ -23,16 +24,16 @@ class NamedModule(torch.nn.Module):
def __init__(self, module: torch.nn.Module, name: str, full_name:str, layer_index: int) -> None:
super().__init__()

self.module = module # wrapped module
self.module = module # wrapped module
self.module_dtype = next(module.parameters()).dtype
self.name = name # module name
self.full_name = full_name # module full name (path) within model
self.layer_index = layer_index # layerid in a repeating layer, if in outside layer, this info may be fake
self.name = name # module name
self.full_name = full_name # full dotted path inside model
self.layer_index = layer_index # layer index for repeated blocks
self._parent_lock = get_parent_lock(full_name)

# persistent work state for named module (used by some LoopProcessors)
# store all `processed()` work state/data/result here
self.state = {}
self._state_lock = threading.RLock()

# print(f"NamedModule init: name: `{name}, full-name: `{full_name}`")

Expand Down Expand Up @@ -75,22 +76,22 @@ def named_buffers(self, prefix: str = "", recurse: bool = True):
def register_buffer(
self, name: str, tensor: Optional[Tensor], persistent: bool = True
) -> None:
with self._state_lock:
with self._parent_lock:
return self.module.register_buffer(name, tensor, persistent)

def unregister_buffer(self, name: str):
with self._state_lock:
with self._parent_lock:
if name in self.module._buffers:
del self.module._buffers[name]
if hasattr(self.module, name):
delattr(self.module, name)

def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
with self._state_lock:
with self._parent_lock:
return self.module.register_parameter(name, param)

def unregister_parameter(self, name: str) -> None:
with self._state_lock:
with self._parent_lock:
if name in self.module._parameters:
del self.module._parameters[name]
if hasattr(self.module, name):
Expand All @@ -107,8 +108,15 @@ def unregister_parameter(self, name: str) -> None:

# getattr is only called if python cannot find attr for `self`
def __getattr__(self, name: str):
with self._state_lock:
return getattr(self.module, name)
try:
lock = object.__getattribute__(self, "_parent_lock")
except AttributeError:
lock = None
module = object.__getattribute__(self, "module")
if lock is None:
return getattr(module, name)
with lock:
return getattr(module, name)

# setattr is always called by python even if attr exists in `self`
def __setattr__(self, name: str, value: Any) -> None:
Expand All @@ -119,46 +127,57 @@ def __setattr__(self, name: str, value: Any) -> None:
"full_name",
"layer_index",
"state",
"_parent_lock",
"target_device",
"register_buffer",
"unregister_buffer",
"register_parameter",
"unregister_parameter",
"_state_lock",
]:
object.__setattr__(self, name, value)
return

with self._state_lock:
setattr(self.module, name, value)
try:
lock = object.__getattribute__(self, "_parent_lock")
except AttributeError:
lock = None
module = object.__getattribute__(self, "module")
if lock is None:
setattr(module, name, value)
else:
with lock:
setattr(module, name, value)

def stream_state_payload_to_cpu(
self,
tensors: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
state_lock = self._parent_lock
return stream_tensor_dict_to_cpu(
tensors,
store_callback=lambda host_map: self.state.update(host_map),
state=self.state,
state_lock=self._state_lock,
state_lock=state_lock,
)

def stream_parameters_to_cpu(self) -> Dict[str, torch.Tensor]:
state_lock = self._parent_lock
tensor_map = {name: param for name, param in self.module.named_parameters(recurse=False)}
return stream_tensor_dict_to_cpu(
tensor_map,
store_callback=lambda host_map: self.state.setdefault("parameters_cpu", {}).update(host_map),
state=self.state,
state_lock=self._state_lock,
state_lock=state_lock,
)

def stream_buffers_to_cpu(self) -> Dict[str, torch.Tensor]:
state_lock = self._parent_lock
tensor_map = {name: buf for name, buf in self.module.named_buffers(recurse=False)}
return stream_tensor_dict_to_cpu(
tensor_map,
store_callback=lambda host_map: self.state.setdefault("buffers_cpu", {}).update(host_map),
state=self.state,
state_lock=self._state_lock,
state_lock=state_lock,
)

def stream_all_to_cpu(self) -> Dict[str, Dict[str, torch.Tensor]]:
Expand All @@ -167,4 +186,4 @@ def stream_all_to_cpu(self) -> Dict[str, Dict[str, torch.Tensor]]:
return {"parameters": params, "buffers": buffers}

def stream_sync(self) -> None:
stream_sync_events(self.state, self._state_lock)
stream_sync_events(self.state, self._parent_lock)
53 changes: 53 additions & 0 deletions gptqmodel/utils/module_locks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from __future__ import annotations

import contextlib
from threading import Lock, RLock
from typing import Dict, Iterator, Optional


__all__ = [
"ROOT_PARENT_KEY",
"get_parent_lock",
"parent_lock_key",
"parent_module_lock",
]


ROOT_PARENT_KEY = "<root>"

_PARENT_LOCKS: Dict[str, RLock] = {}
_PARENT_LOCKS_GUARD = Lock()


def parent_lock_key(module_name: Optional[str]) -> str:
if not module_name:
return ROOT_PARENT_KEY
parts = module_name.split(".")
if len(parts) <= 1:
return parts[0]
return ".".join(parts[:-1])


def get_parent_lock(module_name: Optional[str]) -> RLock:
key = parent_lock_key(module_name)
with _PARENT_LOCKS_GUARD:
lock = _PARENT_LOCKS.get(key)
if lock is None:
lock = RLock()
_PARENT_LOCKS[key] = lock
return lock


@contextlib.contextmanager
def parent_module_lock(module_name: Optional[str]) -> Iterator[None]:
lock = get_parent_lock(module_name)
lock.acquire()
try:
yield
finally:
lock.release()
33 changes: 2 additions & 31 deletions gptqmodel/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import os
import shutil
import struct
from threading import Lock, RLock
from typing import Dict, Iterable, List, Optional, Set, Tuple
from typing import Iterable, List, Optional, Set, Tuple

import accelerate
import torch
Expand All @@ -23,6 +22,7 @@

from ..looper.named_module import NamedModule
from .device import get_device
from .module_locks import parent_module_lock
from .torch import CPU, META


Expand Down Expand Up @@ -70,35 +70,6 @@ def is_meta_module(m: nn.Module) -> bool:

# Serialize access to module.state_dict(), which is not thread-safe under
# concurrent calls that mutate the same parent module.
_PARENT_LOCKS: Dict[str, RLock] = {}
_PARENT_LOCKS_GUARD = Lock()
_ROOT_PARENT_KEY = "<root>"


def _parent_lock_key(module_name: str) -> str:
if not module_name:
return _ROOT_PARENT_KEY
parts = module_name.split(".")
if len(parts) <= 1:
return parts[0]
return ".".join(parts[:-1])


@contextlib.contextmanager
def parent_module_lock(module_name: str):
key = _parent_lock_key(module_name)
with _PARENT_LOCKS_GUARD:
lock = _PARENT_LOCKS.get(key)
if lock is None:
lock = RLock()
_PARENT_LOCKS[key] = lock
lock.acquire()
try:
yield
finally:
lock.release()


def _prepare_offload_directory(target_dir: str) -> None:
if os.path.isdir(target_dir):
shutil.rmtree(target_dir)
Expand Down
15 changes: 8 additions & 7 deletions tests/test_named_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch

from gptqmodel.looper.named_module import NamedModule
from gptqmodel.utils.module_locks import parent_module_lock


def _make_linear(features: int = 8, device: torch.device | None = None) -> torch.nn.Linear:
Expand Down Expand Up @@ -181,16 +182,16 @@ def _fingerprint_last_value(tensor: torch.Tensor) -> float:
def _verify_expected(ctx: _ModuleContext, expected_items: tuple[tuple[str, _ExpectedTensor], ...]) -> bool:
named = ctx.named
for key, expected in expected_items:
with named._state_lock:
with parent_module_lock(named.full_name):
host_tensor = named.state.get(key)
event_map = named.state.get("streaming_event_map", {})
pending_event = event_map.get(key)
if host_tensor is None:
ctx.named.stream_sync()
with named._state_lock:
with parent_module_lock(named.full_name):
host_tensor = named.state.get(key)
if host_tensor is None:
with named._state_lock:
with parent_module_lock(named.full_name):
available = sorted(str(k) for k in named.state.keys())
error_queue.put(f"Missing host tensor for key {key}; available={available}")
stop_event.set()
Expand All @@ -204,7 +205,7 @@ def _verify_expected(ctx: _ModuleContext, expected_items: tuple[tuple[str, _Expe
or not math.isclose(actual_sum, expected.checksum, rel_tol=0.0, abs_tol=1e-2)
):
ctx.named.stream_sync()
with named._state_lock:
with parent_module_lock(named.full_name):
retry_tensor = named.state.get(key)
if retry_tensor is not None:
retry_val = _fingerprint_last_value(retry_tensor)
Expand All @@ -213,7 +214,7 @@ def _verify_expected(ctx: _ModuleContext, expected_items: tuple[tuple[str, _Expe
retry_val = None
retry_sum = None
del host_tensor
with named._state_lock:
with parent_module_lock(named.full_name):
named.state.pop(key, None)
error_queue.put(
"Mismatch for "
Expand All @@ -224,7 +225,7 @@ def _verify_expected(ctx: _ModuleContext, expected_items: tuple[tuple[str, _Expe
stop_event.set()
return False
del host_tensor
with named._state_lock:
with parent_module_lock(named.full_name):
named.state.pop(key, None)
named.state.get("streaming_event_map", {}).pop(key, None)
return True
Expand Down Expand Up @@ -340,7 +341,7 @@ def _worker(thread_id: int) -> None:
for device in devices:
torch.cuda.synchronize(device=device)
for ctx in module_contexts:
with ctx.named._state_lock:
with parent_module_lock(ctx.named.full_name):
keys_to_remove = [key for key in ctx.named.state.keys() if key.startswith("thread")]
for key in keys_to_remove:
ctx.named.state.pop(key, None)
Expand Down