From 88ea2cc0ec7ee33edd4872b5084d1857586076da Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 29 Sep 2025 21:31:05 +0000 Subject: [PATCH 1/7] missing optional default Signed-off-by: Qubitium --- gptqmodel/utils/threadx.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index a827ad48c..36f8af445 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -627,7 +627,7 @@ def lock(self, devices: Optional[Iterable[DeviceLike]] = None): # --------------- Public Wait API --------------- - def wait(self, scope: Optional[Union[str, DeviceLike, Iterable[DeviceLike]]]) -> None | _WaitAndLock: + def wait(self, scope: Optional[Union[str, DeviceLike, Iterable[DeviceLike]]] = None) -> None | _WaitAndLock: """ Wait until in-flight tasks for `scope` drain to zero. """ @@ -726,7 +726,7 @@ def _normalize_scope_to_keys(self, scope: Iterable[DeviceLike]) -> List[str]: keys.append(k) return keys - def _resolve_scope_to_keys(self, scope: Optional[Union[str, DeviceLike, Iterable[DeviceLike]]]) -> List[str]: + def _resolve_scope_to_keys(self, scope: Optional[Union[str, DeviceLike, Iterable[DeviceLike]]] = None) -> List[str]: if scope is None or (isinstance(scope, str) and scope == "all"): return list(self._ordered_keys) if isinstance(scope, (str, torch.device, int)): @@ -992,7 +992,7 @@ def _empty_all_caches(self): continue with torch.cuda.device(dev.index): TORCH_CUDA_EMPTY_CACHE() - log.debug(f"cuda empty cache called on {dev.index}") + # log.debug(f"cuda empty cache called on {dev.index}") # XPU if TORCH_XPU_EMPTY_CACHE is not None: From ba36fdf2449ac9e8c2d0088b98054c83d48194d2 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 29 Sep 2025 21:50:27 +0000 Subject: [PATCH 2/7] fixed missing wait(lock) Signed-off-by: Qubitium --- gptqmodel/utils/threadx.py | 57 +++++++++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index 36f8af445..3fc79df61 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -34,6 +34,7 @@ def _mps_available() -> bool: # --- HARD COPIES of original empty_cache callables (never auto-switched) --- +# We capture the original callables at import-time so later code cannot alias them to no-ops. TORCH_CUDA_EMPTY_CACHE: Optional[Callable[[], None]] = None TORCH_XPU_EMPTY_CACHE: Optional[Callable[[], None]] = None TORCH_MPS_EMPTY_CACHE: Optional[Callable[[], None]] = None @@ -65,6 +66,7 @@ def _coerce_device(d: DeviceLike) -> torch.device: if isinstance(d, torch.device): return d if isinstance(d, int): + # Order: prefer CUDA -> XPU -> MPS -> CPU for numeric indices if torch.cuda.is_available(): return torch.device("cuda", d) if hasattr(torch, "xpu") and torch.xpu.is_available(): @@ -586,6 +588,7 @@ def read_lock(self, device: DeviceLike | str): - 'all' for every device in the pool Returns a context manager. """ + # family string shortcut if isinstance(device, str): if device == "all": pairs = [(k, self._locks[k]) for k in self._ordered_keys] @@ -600,6 +603,7 @@ def read_lock(self, device: DeviceLike | str): dev = _coerce_device(device) key = self._key(dev) + # family device with index=None -> all devices of that type if dev.index is None: fam = dev.type keys = [k for k in self._ordered_keys if k.startswith(fam)] @@ -608,6 +612,7 @@ def read_lock(self, device: DeviceLike | str): pairs = [(k, self._locks[k]) for k in keys] return _ReadLockGroup(pairs) + # concrete device lk = self._locks.get(key) if lk is None: raise ValueError(f"Unknown device for pool: {dev}") @@ -627,11 +632,54 @@ def lock(self, devices: Optional[Iterable[DeviceLike]] = None): # --------------- Public Wait API --------------- - def wait(self, scope: Optional[Union[str, DeviceLike, Iterable[DeviceLike]]] = None) -> None | _WaitAndLock: + def wait( + self, + scope: Optional[Union[str, DeviceLike, Iterable[DeviceLike]]] = None, + *, + lock: bool = False, + ) -> None | _WaitAndLock: """ Wait until in-flight tasks for `scope` drain to zero. + + scope: + - None or 'all' -> all devices + - 'cuda' | 'xpu' | 'mps' | 'cpu' -> all devices of that type + - 'cuda:0' | 'xpu:1' -> specific device key + - torch.device or iterable of the above + + lock: + - False (default): block until drained, then return None. + - True: return a context manager that **waits for drain AND acquires + exclusive write locks** over the scope. Usage: + `with pool.wait("cuda", lock=True): ...` """ keys = self._resolve_scope_to_keys(scope) + + if lock: + # Build a context manager that, on __enter__, acquires writer locks for the scope. + # We additionally ensure the scope is drained *before* locks are taken, so that + # once the locks are held there is no in-flight work and no new readers can start. + pairs = [(k, self._locks[k]) for k in sorted(keys)] + + class _WaitThenLock(_WaitAndLock): + def __init__(self_outer, outer: DeviceThreadPool, pairs_local: List[tuple[str, _RWLock]], keys_local: List[str]): + self_outer._outer = outer + self_outer._pairs = pairs_local + super().__init__(pairs_local) + + def __enter__(self_outer): + # Drain first + for kk in keys: + cv = self_outer._outer._inflight_cv[kk] + with cv: + while self_outer._outer._inflight[kk] > 0: + cv.wait() + # Then acquire writer locks to block any new tasks on the scope + return super()._WaitAndLock__enter__() if hasattr(super(), "_WaitAndLock__enter__") else super().__enter__() + + return _WaitThenLock(self, pairs, keys) + + # Pure wait without lock: wait for inflight to reach zero for each key. for k in keys: cv = self._inflight_cv[k] with cv: @@ -713,6 +761,7 @@ def _normalize_scope_to_keys(self, scope: Iterable[DeviceLike]) -> List[str]: raise ValueError(f"Unknown device key in scope: {s}") keys.append(s) else: + # family: cuda/xpu/mps/cpu fam = s fam_keys = [k for k in self._ordered_keys if k.startswith(fam)] if not fam_keys: @@ -748,6 +797,7 @@ def _mark_finished(self, key: str) -> None: cv.notify_all() def _on_task_finished(self, key: str) -> None: + # inflight decrement + counters + potential GC trigger self._mark_finished(key) trigger_gc = False @@ -765,6 +815,7 @@ def _on_task_finished(self, key: str) -> None: # ---- ANSI table rendering for GC diagnostics ---- def _ansi_table(self, headers: List[str], rows: List[List[str]]) -> str: + """Render a simple ANSI/ASCII table with bold headers.""" widths = [len(h) for h in headers] for r in rows: for i, cell in enumerate(r): @@ -798,6 +849,7 @@ def format_row(cols: List[str]): return "\n".join(lines) def _collect_state_snapshot(self) -> Dict[str, Any]: + """Safely collect a snapshot of pool state for diagnostics.""" with self._stats_lock: per_done = dict(self._per_device_done) total_done = int(self._total_done) @@ -832,6 +884,7 @@ def _collect_state_snapshot(self) -> Dict[str, Any]: return snap def _render_gc_table(self, snap: Dict[str, Any]) -> str: + """Build the ANSI table for the current snapshot.""" headers = [ "Device", "Type", "Index", "Workers", "Inflight", "Done", "Threshold", "NextGC", "Accel" @@ -924,6 +977,7 @@ def _janitor_loop(self): if self._stop_event.is_set(): break + # Debounce to coalesce bursty triggers; we keep draining the event during the window. if self._gc_debounce_s > 0: t_end = time.time() + self._gc_debounce_s while time.time() < t_end: @@ -983,6 +1037,7 @@ def _empty_all_caches(self): """ Call the captured originals if available; no redundant availability checks and no try/except around empty_cache (fail loud if backend misbehaves). + Only backends present in this pool are touched (prevents MPS backend errors). """ # CUDA if TORCH_CUDA_EMPTY_CACHE is not None: From 7189abca8d3e08bb35f459624e1f62c47efe24e9 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 29 Sep 2025 22:17:36 +0000 Subject: [PATCH 3/7] remove thread rotation Signed-off-by: Qubitium --- gptqmodel/utils/threadx.py | 223 ++++++++++++++----------------------- 1 file changed, 85 insertions(+), 138 deletions(-) diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index 3fc79df61..d8dff1d12 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -16,13 +16,10 @@ from ..utils.logger import setup_logger - log = setup_logger() - DeviceLike = Union[str, int, torch.device] - # --------------------------- Backend availability helpers --------------------------- def _mps_available() -> bool: @@ -32,9 +29,8 @@ def _mps_available() -> bool: and torch.backends.mps.is_available() ) - # --- HARD COPIES of original empty_cache callables (never auto-switched) --- -# We capture the original callables at import-time so later code cannot alias them to no-ops. +# Captured at import-time so later code cannot alias them to no-ops. TORCH_CUDA_EMPTY_CACHE: Optional[Callable[[], None]] = None TORCH_XPU_EMPTY_CACHE: Optional[Callable[[], None]] = None TORCH_MPS_EMPTY_CACHE: Optional[Callable[[], None]] = None @@ -59,14 +55,14 @@ def _mps_available() -> bool: TORCH_MPS_EMPTY_CACHE = None except Exception: TORCH_MPS_EMPTY_CACHE = None -# ------------------------------------------------------------------------------- +# ------------------------------------------------------------------------------- def _coerce_device(d: DeviceLike) -> torch.device: if isinstance(d, torch.device): return d if isinstance(d, int): - # Order: prefer CUDA -> XPU -> MPS -> CPU for numeric indices + # Prefer CUDA -> XPU -> MPS -> CPU for numeric indices if torch.cuda.is_available(): return torch.device("cuda", d) if hasattr(torch, "xpu") and torch.xpu.is_available(): @@ -76,10 +72,9 @@ def _coerce_device(d: DeviceLike) -> torch.device: return torch.device("cpu") return torch.device(d) - @contextlib.contextmanager def _device_ctx(dev: torch.device): - """Set the caller thread’s current device for CUDA/XPU so library handles match.""" + """Set the caller thread’s current device so library handles match.""" if dev.type == "cuda": with torch.cuda.device(dev.index): yield @@ -89,7 +84,6 @@ def _device_ctx(dev: torch.device): else: yield - def _activate_thread_device(dev: torch.device): """Pin the worker thread to the device.""" if dev.type == "cuda": @@ -98,7 +92,6 @@ def _activate_thread_device(dev: torch.device): torch.xpu.set_device(dev.index) # mps/cpu: nothing to pin - # --------------------------- Read-Write Lock (writer-preference) --------------------------- class _RWLock: @@ -113,7 +106,7 @@ class _RWLock: def __init__(self): self._cond = threading.Condition() self._readers = 0 - self._writer: Optional[int] = None # thread id that owns write + self._writer: Optional[int] = None self._writer_depth = 0 self._writers_waiting = 0 @@ -179,7 +172,6 @@ def reader(self): finally: self.release_read() - class _LockGroup(contextlib.AbstractContextManager): """Acquire multiple device write locks in deterministic order to avoid deadlocks.""" def __init__(self, ordered_pairs: List[tuple[str, _RWLock]]): @@ -195,7 +187,6 @@ def __exit__(self, exc_type, exc, tb): lk.release_write() return False - class _ReadLockGroup(contextlib.AbstractContextManager): """Acquire multiple device read locks in deterministic order.""" def __init__(self, ordered_pairs: List[tuple[str, _RWLock]]): @@ -211,7 +202,6 @@ def __exit__(self, exc_type, exc, tb): lk.release_read() return False - class _WaitAndLock(contextlib.AbstractContextManager): """ Context manager returned by pool.wait(scope, lock=True). @@ -220,6 +210,7 @@ class _WaitAndLock(contextlib.AbstractContextManager): On exit: releases locks. """ def __init__(self, pairs: List[tuple[str, _RWLock]]): + self._pairs = pairs self._group = _LockGroup(pairs) def __enter__(self): @@ -228,54 +219,36 @@ def __enter__(self): def __exit__(self, exc_type, exc, tb): return self._group.__exit__(exc_type, exc, tb) - # --------------------------- Worker Thread --------------------------- class _DeviceWorker: """ Single worker thread bound to one device. Queue entries: (is_task: bool, fn, args, kwargs, future) - Supports configurable lifecycle: after N tasks, stop accepting new work, - drain its queue, and exit; the pool will spawn a replacement. """ def __init__( self, device: torch.device, rwlock: _RWLock, on_task_finished: Callable[[str], None], - on_retire_request: Callable[[str, _DeviceWorker], None], - on_worker_exit: Callable[[str, _DeviceWorker], None], + on_worker_exit: Callable[[str, "_DeviceWorker"], None], name: Optional[str] = None, inference_mode: bool = False, - lifecycle_calls: int = 50, ): self.device = device self.rwlock = rwlock self._on_task_finished = on_task_finished - self._on_retire_request = on_retire_request self._on_worker_exit = on_worker_exit - self._lifecycle_limit = max(0, int(lifecycle_calls)) # 0 disables rotation - self._tasks_since_spawn = 0 self.key = f"{device.type}:{device.index}" if device.index is not None else device.type self.name = name or f"DPWorker-{self.key}" self._q: "queue.Queue[Tuple[bool, Callable[..., Any], tuple, dict, Future]]" = queue.Queue() self._stop = threading.Event() - self._retire_requested = False - self._accepting = True self._inference_mode = inference_mode self._thread = threading.Thread(target=self._run, name=self.name, daemon=True) self._thread.start() - # --- lifecycle / accept state --- - def is_accepting(self) -> bool: - return self._accepting and not self._stop.is_set() - - def request_stop(self): - self._stop.set() - self._q.put((False, lambda: None, (), {}, Future())) - # --- public API for pool --- def submit(self, fn: Callable[..., Any], /, *args, **kwargs) -> Future: fut = Future() @@ -283,7 +256,8 @@ def submit(self, fn: Callable[..., Any], /, *args, **kwargs) -> Future: return fut def stop(self): - self.request_stop() + self._stop.set() + self._q.put((False, lambda: None, (), {}, Future())) # sentinel def join(self): self._thread.join() @@ -294,13 +268,7 @@ def _run(self): maybe_inference = torch.inference_mode() if self._inference_mode else contextlib.nullcontext() with maybe_inference: while not self._stop.is_set(): - # If we're retiring and nothing is queued, exit gracefully - if self._retire_requested and self._q.empty(): - break - try: - is_task, fn, args, kwargs, fut = self._q.get(timeout=0.05) - except queue.Empty: - continue + is_task, fn, args, kwargs, fut = self._q.get() try: if not is_task: break # sentinel -> exit @@ -320,15 +288,7 @@ def _run(self): fut.set_exception(exc) finally: if is_task: - self._tasks_since_spawn += 1 self._on_task_finished(self.key) - # Lifecycle check: once we hit the limit, mark retiring (stop accepting) - if self._lifecycle_limit > 0 and self._tasks_since_spawn >= self._lifecycle_limit: - if not self._retire_requested: - self._retire_requested = True - self._accepting = False - # Notify pool to spawn a replacement now - self._on_retire_request(self.key, self) self._q.task_done() # Thread is exiting; notify pool for cleanup @@ -337,7 +297,6 @@ def _run(self): finally: pass - # --------------------------- Public Pool --------------------------- class DeviceThreadPool: @@ -350,9 +309,8 @@ class DeviceThreadPool: - Per-device RWLocks + global lock and family/all read-locks. - wait(scope, lock=False/True) to drain tasks (optionally with exclusive locks). - Per-device/global completed counters and in-flight counters. - - Janitor: triggers empty-cache after N completions on accelerator devices, under a global lock. + - Janitor: triggers empty-cache after N completions on accelerator devices (per-device lock). - GC diagnostics helpers. - - Worker lifecycle rotation: after N tasks (default 50), workers retire and are replaced. """ def __init__( @@ -367,7 +325,6 @@ def __init__( empty_cache_every_n: int = 50, # <=0 disables janitor workers: Optional[Dict[str, int]] = None, # e.g. {'cpu':4, 'cuda:per':1, 'cuda:0':3} gc_debounce_seconds: float = 0.02, # absorb bursty triggers before GC - worker_lifecycle_calls: int = 50, # <=0 disables lifecycle rotation ): """ Args: @@ -381,7 +338,6 @@ def __init__( - 'xpu:': N -> override for specific XPU index Unspecified devices default to 1 worker each. gc_debounce_seconds: short wait to coalesce multiple triggers. - worker_lifecycle_calls: number of tasks a worker handles before retiring (0=disabled). """ if devices is None: discovered: List[torch.device] = [] @@ -424,8 +380,6 @@ def __init__( # per-device watermark of "done" as of last GC that actually ran self._last_gc_done_per_device: Dict[str, int] = {} - # Worker lifecycle rotation - self._worker_lifecycle_calls = int(worker_lifecycle_calls) # Store inference mode for worker spawns self._inference_mode = bool(inference_mode) @@ -471,7 +425,7 @@ def __init__( ) self._janitor.start() - # --------------- Worker management (spawn/retire/cleanup) --------------- + # --------------- Worker management --------------- def _spawn_worker(self, dev: torch.device, name: Optional[str] = None) -> _DeviceWorker: key = self._key(dev) @@ -479,26 +433,11 @@ def _spawn_worker(self, dev: torch.device, name: Optional[str] = None) -> _Devic device=dev, rwlock=self._locks[key], on_task_finished=self._on_task_finished, - on_retire_request=self._on_worker_retire_request, on_worker_exit=self._on_worker_exit, name=name, inference_mode=self._inference_mode, - lifecycle_calls=self._worker_lifecycle_calls, ) - def _on_worker_retire_request(self, key: str, worker: _DeviceWorker) -> None: - """ - A worker hit its lifecycle limit. Mark it non-accepting (it already is), - and immediately spawn a replacement to maintain capacity. - """ - dev = self._devices_by_key[key] - with self._dispatch_lock: - group = self._worker_groups.get(key, []) - if worker in group: - replacement = self._spawn_worker(dev, name=f"{worker.name}.r{int(time.time()*1000)}") - group.append(replacement) - self._worker_groups[key] = group - def _on_worker_exit(self, key: str, worker: _DeviceWorker) -> None: """Cleanup finished workers from the group.""" with self._dispatch_lock: @@ -588,7 +527,6 @@ def read_lock(self, device: DeviceLike | str): - 'all' for every device in the pool Returns a context manager. """ - # family string shortcut if isinstance(device, str): if device == "all": pairs = [(k, self._locks[k]) for k in self._ordered_keys] @@ -603,7 +541,6 @@ def read_lock(self, device: DeviceLike | str): dev = _coerce_device(device) key = self._key(dev) - # family device with index=None -> all devices of that type if dev.index is None: fam = dev.type keys = [k for k in self._ordered_keys if k.startswith(fam)] @@ -612,7 +549,6 @@ def read_lock(self, device: DeviceLike | str): pairs = [(k, self._locks[k]) for k in keys] return _ReadLockGroup(pairs) - # concrete device lk = self._locks.get(key) if lk is None: raise ValueError(f"Unknown device for pool: {dev}") @@ -656,26 +592,27 @@ def wait( keys = self._resolve_scope_to_keys(scope) if lock: - # Build a context manager that, on __enter__, acquires writer locks for the scope. - # We additionally ensure the scope is drained *before* locks are taken, so that - # once the locks are held there is no in-flight work and no new readers can start. pairs = [(k, self._locks[k]) for k in sorted(keys)] - class _WaitThenLock(_WaitAndLock): - def __init__(self_outer, outer: DeviceThreadPool, pairs_local: List[tuple[str, _RWLock]], keys_local: List[str]): - self_outer._outer = outer - self_outer._pairs = pairs_local - super().__init__(pairs_local) + class _WaitThenLock(contextlib.AbstractContextManager): + def __init__(self, outer: DeviceThreadPool, pairs_local: List[tuple[str, _RWLock]], keys_local: List[str]): + self._outer = outer + self._pairs = pairs_local + self._keys = keys_local + self._group = _LockGroup(pairs_local) - def __enter__(self_outer): + def __enter__(self): # Drain first - for kk in keys: - cv = self_outer._outer._inflight_cv[kk] + for kk in self._keys: + cv = self._outer._inflight_cv[kk] with cv: - while self_outer._outer._inflight[kk] > 0: + while self._outer._inflight[kk] > 0: cv.wait() # Then acquire writer locks to block any new tasks on the scope - return super()._WaitAndLock__enter__() if hasattr(super(), "_WaitAndLock__enter__") else super().__enter__() + return self._group.__enter__() + + def __exit__(self, exc_type, exc, tb): + return self._group.__exit__(exc_type, exc, tb) return _WaitThenLock(self, pairs, keys) @@ -714,30 +651,32 @@ def _key(self, dev: torch.device) -> str: return f"{dev.type}{idx}" def _pick_worker(self, key: str) -> _DeviceWorker: - group = self._worker_groups.get(key) - if not group: - raise ValueError(f"Device {key} not part of this pool.") - + """ + Simple round-robin over the current group. No lifecycle rotation, + no accept-state, no spawning here (workers are created up front). + If the group is somehow empty, spawn one to recover. + """ with self._dispatch_lock: + group = self._worker_groups.get(key) + if not group: + dev = self._devices_by_key[key] + w = self._spawn_worker(dev, name=f"DPWorker-{key}#0") + self._worker_groups[key] = [w] + self._dispatch_rr[key] = 0 + return w + n = len(group) if n == 0: - raise ValueError(f"No workers available for device {key}") - start = self._dispatch_rr[key] % n - idx = start - # Find the next accepting worker - for _ in range(n): - w = group[idx] - if w.is_accepting(): - self._dispatch_rr[key] = (idx + 1) % n - return w - idx = (idx + 1) % n - # If none are accepting, spawn a fresh one and use it - dev = self._devices_by_key[key] - neww = self._spawn_worker(dev, name=f"DPWorker-{key}#hot") - group.append(neww) - self._worker_groups[key] = group - self._dispatch_rr[key] = (len(group) - 1 + 1) % len(group) - return neww + dev = self._devices_by_key[key] + w = self._spawn_worker(dev, name=f"DPWorker-{key}#0") + group.append(w) + self._dispatch_rr[key] = 0 + return w + + idx = self._dispatch_rr[key] % n + w = group[idx] + self._dispatch_rr[key] = (idx + 1) % n + return w def _resolve_workers_for_device(self, dev: torch.device, table: Dict[str, int]) -> int: key = self._key(dev) @@ -761,7 +700,6 @@ def _normalize_scope_to_keys(self, scope: Iterable[DeviceLike]) -> List[str]: raise ValueError(f"Unknown device key in scope: {s}") keys.append(s) else: - # family: cuda/xpu/mps/cpu fam = s fam_keys = [k for k in self._ordered_keys if k.startswith(fam)] if not fam_keys: @@ -919,7 +857,7 @@ def _render_gc_table(self, snap: Dict[str, Any]) -> str: table_totals = self._ansi_table(totals_headers, totals_rows) return table_main + "\n" + table_totals - # ---- janitor (global empty-cache under lock) ---- + # ---- janitor (per-device empty-cache under exclusive writer lock) ---- def _synchronize_all(self): """ @@ -977,7 +915,7 @@ def _janitor_loop(self): if self._stop_event.is_set(): break - # Debounce to coalesce bursty triggers; we keep draining the event during the window. + # Debounce to coalesce bursty triggers; keep draining the event during the window. if self._gc_debounce_s > 0: t_end = time.time() + self._gc_debounce_s while time.time() < t_end: @@ -1013,33 +951,45 @@ def _janitor_loop(self): if not self._should_run_gc_from_snapshot(pre): continue - with self.lock(): # writer lock across ALL devices - t0 = time.time() - # Optional but often expensive: - # self._synchronize_all() - self._empty_all_caches() - t1 = time.time() + t0 = time.time() - self._gc_passes += 1 - self._last_gc_ts = t1 + # Optional sync (usually skipped for performance): + # self._synchronize_all() + # GC each device independently under its writer lock. + for key in sorted(self._ordered_keys): + dev = self._devices_by_key[key] + if dev.type not in ("cuda", "xpu", "mps"): + continue + + lk = self._locks[key] + with lk.writer(): + if dev.type == "cuda" and TORCH_CUDA_EMPTY_CACHE is not None: + with torch.cuda.device(dev.index): + TORCH_CUDA_EMPTY_CACHE() + elif dev.type == "xpu" and TORCH_XPU_EMPTY_CACHE is not None: + with torch.xpu.device(dev.index): + TORCH_XPU_EMPTY_CACHE() + elif dev.type == "mps" and TORCH_MPS_EMPTY_CACHE is not None: + TORCH_MPS_EMPTY_CACHE() + + t1 = time.time() + + self._gc_passes += 1 + self._last_gc_ts = t1 + + try: + post = self._collect_state_snapshot() + self._update_gc_watermarks(post) + log.info(f"GC completed in {t1 - t0:.3f}s (pass #{self._gc_passes}).") + except Exception as e: try: - post = self._collect_state_snapshot() - self._update_gc_watermarks(post) - log.info(f"GC completed in {t1 - t0:.3f}s (pass #{self._gc_passes}).") - except Exception as e: - try: - log.warn(f"Failed to render GC post-snapshot: {e!r}") - except Exception: - pass + log.warn(f"Failed to render GC post-snapshot: {e!r}") + except Exception: + pass + # Legacy helper retained for compatibility (not used by janitor). def _empty_all_caches(self): - """ - Call the captured originals if available; no redundant availability checks - and no try/except around empty_cache (fail loud if backend misbehaves). - Only backends present in this pool are touched (prevents MPS backend errors). - """ - # CUDA if TORCH_CUDA_EMPTY_CACHE is not None: for key in self._ordered_keys: dev = self._devices_by_key[key] @@ -1047,9 +997,7 @@ def _empty_all_caches(self): continue with torch.cuda.device(dev.index): TORCH_CUDA_EMPTY_CACHE() - # log.debug(f"cuda empty cache called on {dev.index}") - # XPU if TORCH_XPU_EMPTY_CACHE is not None: for key in self._ordered_keys: dev = self._devices_by_key[key] @@ -1058,7 +1006,6 @@ def _empty_all_caches(self): with torch.xpu.device(dev.index): TORCH_XPU_EMPTY_CACHE() - # MPS (only if this pool actually has an MPS device) if TORCH_MPS_EMPTY_CACHE is not None: has_mps_device = any(self._devices_by_key[k].type == "mps" for k in self._ordered_keys) if has_mps_device: From 09e3ce9888d7903cef9ac3cae776147654e07176 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 29 Sep 2025 22:36:06 +0000 Subject: [PATCH 4/7] fix thread shutdown and add debug Signed-off-by: Qubitium --- gptqmodel/utils/threadx.py | 117 +++++++++++++++++++++++++++++-------- 1 file changed, 93 insertions(+), 24 deletions(-) diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index d8dff1d12..e41e371fd 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -6,6 +6,7 @@ from __future__ import annotations import contextlib +import os import queue import threading import time @@ -18,6 +19,9 @@ log = setup_logger() +# DEBUG guard (only emit debug logs if DEBUG=1/true/yes/on) +DEBUG_ON = str(os.environ.get("DEBUG", "")).lower() in ("1", "true", "yes", "on") + DeviceLike = Union[str, int, torch.device] # --------------------------- Backend availability helpers --------------------------- @@ -106,7 +110,7 @@ class _RWLock: def __init__(self): self._cond = threading.Condition() self._readers = 0 - self._writer: Optional[int] = None + self._writer: Optional[int] = None # thread id that owns write self._writer_depth = 0 self._writers_waiting = 0 @@ -178,12 +182,15 @@ def __init__(self, ordered_pairs: List[tuple[str, _RWLock]]): self._pairs = ordered_pairs def __enter__(self): - for _, lk in self._pairs: + for name, lk in self._pairs: + if DEBUG_ON: log.debug(f"_LockGroup: acquiring write lock for {name}") lk.acquire_write() + if DEBUG_ON: log.debug(f"_LockGroup: acquired write lock for {name}") return self def __exit__(self, exc_type, exc, tb): - for _, lk in reversed(self._pairs): + for name, lk in reversed(self._pairs): + if DEBUG_ON: log.debug(f"_LockGroup: releasing write lock for {name}") lk.release_write() return False @@ -193,12 +200,15 @@ def __init__(self, ordered_pairs: List[tuple[str, _RWLock]]): self._pairs = ordered_pairs def __enter__(self): - for _, lk in self._pairs: + for name, lk in self._pairs: + if DEBUG_ON: log.debug(f"_ReadLockGroup: acquiring read lock for {name}") lk.acquire_read() + if DEBUG_ON: log.debug(f"_ReadLockGroup: acquired read lock for {name}") return self def __exit__(self, exc_type, exc, tb): - for _, lk in reversed(self._pairs): + for name, lk in reversed(self._pairs): + if DEBUG_ON: log.debug(f"_ReadLockGroup: releasing read lock for {name}") lk.release_read() return False @@ -248,18 +258,22 @@ def __init__( self._inference_mode = inference_mode self._thread = threading.Thread(target=self._run, name=self.name, daemon=True) self._thread.start() + if DEBUG_ON: log.debug(f"Spawned worker '{self.name}' for {self.key}") # --- public API for pool --- def submit(self, fn: Callable[..., Any], /, *args, **kwargs) -> Future: fut = Future() self._q.put((True, fn, args, kwargs, fut)) + if DEBUG_ON: log.debug(f"{self.name}: task enqueued; qsize={self._q.qsize()}") return fut def stop(self): self._stop.set() self._q.put((False, lambda: None, (), {}, Future())) # sentinel + if DEBUG_ON: log.debug(f"{self.name}: stop requested; sentinel queued") def join(self): + if DEBUG_ON: log.debug(f"{self.name}: joining thread") self._thread.join() # --- internal main loop --- @@ -271,7 +285,9 @@ def _run(self): is_task, fn, args, kwargs, fut = self._q.get() try: if not is_task: + if DEBUG_ON: log.debug(f"{self.name}: received sentinel; exiting") break # sentinel -> exit + if DEBUG_ON: log.debug(f"{self.name}: task begin; qsize={self._q.qsize()}") # Tasks take a read lock so GC's writer lock can't interleave with self.rwlock.reader(): stream = kwargs.pop("_cuda_stream", None) @@ -283,19 +299,20 @@ def _run(self): result = fn(*args, **kwargs) if not fut.cancelled(): fut.set_result(result) + if DEBUG_ON: log.debug(f"{self.name}: task done") except BaseException as exc: if not fut.cancelled(): fut.set_exception(exc) + if DEBUG_ON: log.debug(f"{self.name}: task exception: {exc!r}") finally: - if is_task: - self._on_task_finished(self.key) + self._on_task_finished(self.key) self._q.task_done() # Thread is exiting; notify pool for cleanup try: self._on_worker_exit(self.key, self) finally: - pass + if DEBUG_ON: log.debug(f"{self.name}: exited") # --------------------------- Public Pool --------------------------- @@ -424,12 +441,15 @@ def __init__( target=self._janitor_loop, name="DP-Janitor", daemon=True ) self._janitor.start() + if DEBUG_ON: log.debug(f"DP-Janitor thread started (debounce={self._gc_debounce_s:.3f}s, threshold={self._empty_cache_every_n})") + else: + if DEBUG_ON: log.debug("DP-Janitor disabled (no accelerators or threshold <= 0)") # --------------- Worker management --------------- def _spawn_worker(self, dev: torch.device, name: Optional[str] = None) -> _DeviceWorker: key = self._key(dev) - return _DeviceWorker( + w = _DeviceWorker( device=dev, rwlock=self._locks[key], on_task_finished=self._on_task_finished, @@ -437,6 +457,7 @@ def _spawn_worker(self, dev: torch.device, name: Optional[str] = None) -> _Devic name=name, inference_mode=self._inference_mode, ) + return w def _on_worker_exit(self, key: str, worker: _DeviceWorker) -> None: """Cleanup finished workers from the group.""" @@ -449,6 +470,7 @@ def _on_worker_exit(self, key: str, worker: _DeviceWorker) -> None: self._dispatch_rr[key] %= len(group) else: self._dispatch_rr[key] = 0 + if DEBUG_ON: log.debug(f"Worker '{worker.name}' exited for {key}") # --------------- Public Work API --------------- @@ -471,6 +493,7 @@ def submit( if _cuda_stream is not None and dev.type != "cuda": raise ValueError("_cuda_stream is only valid for CUDA devices") + if DEBUG_ON: log.debug(f"submit: device={key} fn={getattr(fn, '__name__', repr(fn))}") # mark in-flight before enqueue to avoid races with wait() self._mark_scheduled(key) try: @@ -498,6 +521,7 @@ def shutdown(self, wait: bool = True): self._stop_event.set() self._gc_event.set() # wake janitor if self._janitor is not None and wait: + if DEBUG_ON: log.debug("Joining DP-Janitor thread…") self._janitor.join() for group in self._worker_groups.values(): @@ -507,6 +531,7 @@ def shutdown(self, wait: bool = True): for group in self._worker_groups.values(): for w in group: w.join() + if DEBUG_ON: log.debug("DeviceThreadPool shutdown complete") # --------------- Public Lock API --------------- @@ -602,26 +627,31 @@ def __init__(self, outer: DeviceThreadPool, pairs_local: List[tuple[str, _RWLock self._group = _LockGroup(pairs_local) def __enter__(self): - # Drain first + if DEBUG_ON: log.debug(f"wait(lock=True) drain start: keys={self._keys}") for kk in self._keys: cv = self._outer._inflight_cv[kk] with cv: while self._outer._inflight[kk] > 0: + if DEBUG_ON: log.debug(f"wait(lock=True) blocking on inflight[{kk}]={self._outer._inflight[kk]}") cv.wait() - # Then acquire writer locks to block any new tasks on the scope + if DEBUG_ON: log.debug(f"wait(lock=True) acquire writer locks: keys={self._keys}") return self._group.__enter__() def __exit__(self, exc_type, exc, tb): + if DEBUG_ON: log.debug(f"wait(lock=True) releasing writer locks: keys={self._keys}") return self._group.__exit__(exc_type, exc, tb) return _WaitThenLock(self, pairs, keys) # Pure wait without lock: wait for inflight to reach zero for each key. + if DEBUG_ON: log.debug(f"wait(lock=False) drain start: keys={keys}") for k in keys: cv = self._inflight_cv[k] with cv: while self._inflight[k] > 0: + if DEBUG_ON: log.debug(f"wait(lock=False) blocking on inflight[{k}]={self._inflight[k]}") cv.wait() + if DEBUG_ON: log.debug(f"wait(lock=False) drain done: keys={keys}") return None # --------------- Public Stats API --------------- @@ -652,9 +682,8 @@ def _key(self, dev: torch.device) -> str: def _pick_worker(self, key: str) -> _DeviceWorker: """ - Simple round-robin over the current group. No lifecycle rotation, - no accept-state, no spawning here (workers are created up front). - If the group is somehow empty, spawn one to recover. + Simple round-robin over the current group. + If the group is empty (unlikely), spawn one to recover. """ with self._dispatch_lock: group = self._worker_groups.get(key) @@ -725,12 +754,18 @@ def _resolve_scope_to_keys(self, scope: Optional[Union[str, DeviceLike, Iterable def _mark_scheduled(self, key: str) -> None: cv = self._inflight_cv[key] with cv: - self._inflight[key] += 1 + self._inflight[key] = self._inflight.get(key, 0) + 1 + if DEBUG_ON: log.debug(f"inflight[{key}] ++ -> {self._inflight[key]}") def _mark_finished(self, key: str) -> None: cv = self._inflight_cv[key] with cv: - self._inflight[key] -= 1 + new_val = self._inflight.get(key, 0) - 1 + if new_val < 0: + if DEBUG_ON: log.debug(f"WARNING: inflight[{key}] underflow ({new_val}); clamping to 0") + new_val = 0 + self._inflight[key] = new_val + if DEBUG_ON: log.debug(f"inflight[{key}] -- -> {self._inflight[key]}") if self._inflight[key] == 0: cv.notify_all() @@ -740,13 +775,15 @@ def _on_task_finished(self, key: str) -> None: trigger_gc = False with self._stats_lock: - self._per_device_done[key] += 1 + self._per_device_done[key] = self._per_device_done.get(key, 0) + 1 self._total_done += 1 dev_type = self._devices_by_key[key].type if self._empty_cache_every_n > 0 and dev_type in ("cuda", "xpu", "mps"): n = self._per_device_done[key] if n % self._empty_cache_every_n == 0: trigger_gc = True + if DEBUG_ON: + log.debug(f"GC trigger set by {key}: per_device_done={n} threshold={self._empty_cache_every_n} total_done={self._total_done}") if trigger_gc: self._gc_event.set() @@ -910,24 +947,46 @@ def _update_gc_watermarks(self, snap_after: Dict[str, Any]) -> None: self._last_gc_done_per_device[k] = snap_after["per_done"].get(k, 0) def _janitor_loop(self): + # Use timeouts so we can honor stop_event without relying on a final trigger. + WAIT_TIMEOUT = 0.1 # seconds while True: - self._gc_event.wait() + if DEBUG_ON: log.debug("DP-Janitor: waiting for trigger…") + # Exit promptly if shutdown requested before/while waiting. + if self._stop_event.is_set(): + if DEBUG_ON: log.debug("DP-Janitor: stop event set before wait; exiting") + break + + # Wait with a timeout so we can re-check stop_event periodically. + triggered = self._gc_event.wait(timeout=WAIT_TIMEOUT) + if not triggered: + # Timed out; loop to check stop again. + continue + + # Clear the event and re-check stop before doing anything else. + self._gc_event.clear() if self._stop_event.is_set(): + if DEBUG_ON: log.debug("DP-Janitor: stop event set after trigger; exiting") break # Debounce to coalesce bursty triggers; keep draining the event during the window. if self._gc_debounce_s > 0: t_end = time.time() + self._gc_debounce_s + if DEBUG_ON: log.debug(f"DP-Janitor: debounce window start ({self._gc_debounce_s:.3f}s)") while time.time() < t_end: - self._gc_event.clear() + # If stopping during debounce, honor it immediately + if self._stop_event.is_set(): + if DEBUG_ON: log.debug("DP-Janitor: stop during debounce; exiting") + return + # Drain subsequent triggers within the window self._gc_event.wait(timeout=max(0.0, t_end - time.time())) - self._gc_event.clear() - else: - self._gc_event.clear() + self._gc_event.clear() + if DEBUG_ON: log.debug("DP-Janitor: debounce window end") try: pre = self._collect_state_snapshot() - log.debug("GC trigger received; acquiring global exclusive lock…") + if DEBUG_ON: + log.debug(f"DP-Janitor: pre-snapshot taken: total_done={pre['total_done']}, threshold={pre['threshold']}, inflight={pre['inflight']}") + log.debug("GC trigger received; evaluating whether to run…") except Exception as e: try: log.warn(f"Failed to render GC pre-snapshot: {e!r}") @@ -949,11 +1008,12 @@ def _janitor_loop(self): } if not self._should_run_gc_from_snapshot(pre): + if DEBUG_ON: log.debug("DP-Janitor: skip GC (no device progressed by threshold since last pass)") continue t0 = time.time() - # Optional sync (usually skipped for performance): + # Optional sync (disabled by default as it can be costly) # self._synchronize_all() # GC each device independently under its writer lock. @@ -963,15 +1023,23 @@ def _janitor_loop(self): continue lk = self._locks[key] + if DEBUG_ON: log.debug(f"DP-Janitor: attempting writer lock for {key}") with lk.writer(): + if DEBUG_ON: log.debug(f"DP-Janitor: acquired writer lock for {key}") if dev.type == "cuda" and TORCH_CUDA_EMPTY_CACHE is not None: + if DEBUG_ON: log.debug(f"DP-Janitor: empty_cache(cuda) begin on {key}") with torch.cuda.device(dev.index): TORCH_CUDA_EMPTY_CACHE() + if DEBUG_ON: log.debug(f"DP-Janitor: empty_cache(cuda) done on {key}") elif dev.type == "xpu" and TORCH_XPU_EMPTY_CACHE is not None: + if DEBUG_ON: log.debug(f"DP-Janitor: empty_cache(xpu) begin on {key}") with torch.xpu.device(dev.index): TORCH_XPU_EMPTY_CACHE() + if DEBUG_ON: log.debug(f"DP-Janitor: empty_cache(xpu) done on {key}") elif dev.type == "mps" and TORCH_MPS_EMPTY_CACHE is not None: + if DEBUG_ON: log.debug("DP-Janitor: empty_cache(mps) begin") TORCH_MPS_EMPTY_CACHE() + if DEBUG_ON: log.debug("DP-Janitor: empty_cache(mps) done") t1 = time.time() @@ -982,6 +1050,7 @@ def _janitor_loop(self): post = self._collect_state_snapshot() self._update_gc_watermarks(post) log.info(f"GC completed in {t1 - t0:.3f}s (pass #{self._gc_passes}).") + if DEBUG_ON: log.debug(f"DP-Janitor: post-snapshot: inflight={post['inflight']} per_done={post['per_done']}") except Exception as e: try: log.warn(f"Failed to render GC post-snapshot: {e!r}") From 5f32d54445e85a595588ea043d02276518956ca2 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 29 Sep 2025 22:44:02 +0000 Subject: [PATCH 5/7] fix task result state order Signed-off-by: Qubitium --- gptqmodel/utils/threadx.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index e41e371fd..a9519e860 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -297,15 +297,20 @@ def _run(self): result = fn(*args, **kwargs) else: result = fn(*args, **kwargs) + # IMPORTANT: mark finished (decrement inflight & bump counters) + # BEFORE resolving the future so stats() reflects completion + # as soon as do()/submit().result() returns. + self._on_task_finished(self.key) if not fut.cancelled(): fut.set_result(result) if DEBUG_ON: log.debug(f"{self.name}: task done") except BaseException as exc: + # Also mark finished BEFORE surfacing exception to caller + self._on_task_finished(self.key) if not fut.cancelled(): fut.set_exception(exc) if DEBUG_ON: log.debug(f"{self.name}: task exception: {exc!r}") finally: - self._on_task_finished(self.key) self._q.task_done() # Thread is exiting; notify pool for cleanup From 8e44ac78ea716f46ff2c7aff6af686f12001f088 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 29 Sep 2025 23:15:57 +0000 Subject: [PATCH 6/7] fix empty cache state Signed-off-by: Qubitium --- gptqmodel/utils/threadx.py | 335 +++++++++++++++++++++++++++---------- 1 file changed, 245 insertions(+), 90 deletions(-) diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index a9519e860..928df3be5 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -19,12 +19,17 @@ log = setup_logger() -# DEBUG guard (only emit debug logs if DEBUG=1/true/yes/on) +# Debug logging is very chatty and can alter timings subtly in tests. +# We gate all extra diagnostics behind the DEBUG env (1/true/yes/on). DEBUG_ON = str(os.environ.get("DEBUG", "")).lower() in ("1", "true", "yes", "on") +# DeviceLike allows ergonomic call sites: 'cuda:0', 0, torch.device('cuda', 0), etc. DeviceLike = Union[str, int, torch.device] + # --------------------------- Backend availability helpers --------------------------- +# We keep these helpers small and side-effect free—only feature checks—so we can +# query once and rely on final results without redundant availability if-ladders. def _mps_available() -> bool: return ( @@ -33,8 +38,13 @@ def _mps_available() -> bool: and torch.backends.mps.is_available() ) + # --- HARD COPIES of original empty_cache callables (never auto-switched) --- -# Captured at import-time so later code cannot alias them to no-ops. +# IMPORTANT: Do NOT “optimize” these by directly calling torch.*.empty_cache. +# We intentionally capture a snapshot of the original functions to defend against +# later code mutating those attributes to a no-op. The janitor will prefer the +# *live* attribute if callable (so monkeypatching works), but falls back to these +# hard copies if the live attr is missing or non-callable. TORCH_CUDA_EMPTY_CACHE: Optional[Callable[[], None]] = None TORCH_XPU_EMPTY_CACHE: Optional[Callable[[], None]] = None TORCH_MPS_EMPTY_CACHE: Optional[Callable[[], None]] = None @@ -44,6 +54,7 @@ def _mps_available() -> bool: if TORCH_CUDA_EMPTY_CACHE is not None and not callable(TORCH_CUDA_EMPTY_CACHE): TORCH_CUDA_EMPTY_CACHE = None except Exception: + # If introspection fails, we keep the hard copy as None. TORCH_CUDA_EMPTY_CACHE = None try: @@ -60,13 +71,17 @@ def _mps_available() -> bool: except Exception: TORCH_MPS_EMPTY_CACHE = None -# ------------------------------------------------------------------------------- + +# --------------------------- Device coercion & context helpers --------------------------- def _coerce_device(d: DeviceLike) -> torch.device: + """ + Convert a DeviceLike into a concrete torch.device. For integers, we + interpret as accelerator indices if present, otherwise map to CPU/MPS. + """ if isinstance(d, torch.device): return d if isinstance(d, int): - # Prefer CUDA -> XPU -> MPS -> CPU for numeric indices if torch.cuda.is_available(): return torch.device("cuda", d) if hasattr(torch, "xpu") and torch.xpu.is_available(): @@ -74,11 +89,17 @@ def _coerce_device(d: DeviceLike) -> torch.device: if _mps_available(): return torch.device("mps") return torch.device("cpu") + # Accept strings like 'cuda:0', 'xpu:1', 'cpu', 'mps' return torch.device(d) + @contextlib.contextmanager def _device_ctx(dev: torch.device): - """Set the caller thread’s current device so library handles match.""" + """ + Set the caller thread’s *current* device while running a task so handles/streams + line up correctly. For CUDA/XPU we set the per-thread current device; CPU/MPS + do not require pinning here. + """ if dev.type == "cuda": with torch.cuda.device(dev.index): yield @@ -88,15 +109,23 @@ def _device_ctx(dev: torch.device): else: yield + def _activate_thread_device(dev: torch.device): - """Pin the worker thread to the device.""" + """ + Pin the worker thread to its device once, before entering its main loop. + CUDA/XPU require per-thread device activation for correct handle usage. + """ if dev.type == "cuda": torch.cuda.set_device(dev.index) elif dev.type == "xpu" and hasattr(torch, "xpu"): torch.xpu.set_device(dev.index) # mps/cpu: nothing to pin + # --------------------------- Read-Write Lock (writer-preference) --------------------------- +# We implement a writer-preference RWLock. Multiple readers may hold the lock, +# but when a writer is waiting we block new readers, ensuring GC (writer) can +# eventually acquire exclusivity even under task pressure. class _RWLock: """ @@ -118,7 +147,7 @@ def __init__(self): def acquire_write(self): me = threading.get_ident() with self._cond: - if self._writer == me: # re-entrant + if self._writer == me: # Re-entrant writer self._writer_depth += 1 return self._writers_waiting += 1 @@ -152,10 +181,12 @@ def writer(self): def acquire_read(self): me = threading.get_ident() with self._cond: - # writer can re-enter as reader + # The writer may re-enter as a reader; this keeps invariants simple + # for code that wants to read while already holding write. if self._writer == me: self._readers += 1 return + # If a writer is waiting, block new readers to give it priority. while self._writer is not None or self._writers_waiting > 0: self._cond.wait() self._readers += 1 @@ -176,8 +207,12 @@ def reader(self): finally: self.release_read() + class _LockGroup(contextlib.AbstractContextManager): - """Acquire multiple device write locks in deterministic order to avoid deadlocks.""" + """ + Acquire multiple device **write** locks in deterministic order to avoid deadlocks. + Helpful for GC passes or any multi-device exclusive operation. + """ def __init__(self, ordered_pairs: List[tuple[str, _RWLock]]): self._pairs = ordered_pairs @@ -194,8 +229,12 @@ def __exit__(self, exc_type, exc, tb): lk.release_write() return False + class _ReadLockGroup(contextlib.AbstractContextManager): - """Acquire multiple device read locks in deterministic order.""" + """ + Acquire multiple device **read** locks in deterministic order. + Useful for multi-device snapshots or read-only, consistent views. + """ def __init__(self, ordered_pairs: List[tuple[str, _RWLock]]): self._pairs = ordered_pairs @@ -212,6 +251,7 @@ def __exit__(self, exc_type, exc, tb): lk.release_read() return False + class _WaitAndLock(contextlib.AbstractContextManager): """ Context manager returned by pool.wait(scope, lock=True). @@ -229,12 +269,20 @@ def __enter__(self): def __exit__(self, exc_type, exc, tb): return self._group.__exit__(exc_type, exc, tb) + # --------------------------- Worker Thread --------------------------- +# Each worker is bound to a specific device and runs a single thread. Tasks are +# executed under the device’s read lock; GC acquires the writer lock to keep +# memory management steps from interleaving with tasks. class _DeviceWorker: """ Single worker thread bound to one device. + Queue entries: (is_task: bool, fn, args, kwargs, future) + - is_task=False is a sentinel to exit the thread loop. + - Tasks run within a device-scoped reader lock to prevent interleaving + with GC passes (which need a write lock). """ def __init__( self, @@ -260,24 +308,37 @@ def __init__( self._thread.start() if DEBUG_ON: log.debug(f"Spawned worker '{self.name}' for {self.key}") - # --- public API for pool --- def submit(self, fn: Callable[..., Any], /, *args, **kwargs) -> Future: + """ + Enqueue a callable and return a Future that resolves with its result/exception. + """ fut = Future() self._q.put((True, fn, args, kwargs, fut)) if DEBUG_ON: log.debug(f"{self.name}: task enqueued; qsize={self._q.qsize()}") return fut def stop(self): + """ + Request thread exit by setting a sentinel work item. The run loop exits + ASAP after receiving it. + """ self._stop.set() - self._q.put((False, lambda: None, (), {}, Future())) # sentinel + self._q.put((False, lambda: None, (), {}, Future())) if DEBUG_ON: log.debug(f"{self.name}: stop requested; sentinel queued") def join(self): + """ + Join the worker thread; for graceful shutdowns and tests. + """ if DEBUG_ON: log.debug(f"{self.name}: joining thread") self._thread.join() - # --- internal main loop --- def _run(self): + """ + Main loop: pull tasks, set device context, execute, mark completion, and + fulfill or fail the future. Completion is accounted BEFORE resolving the + future to make stats() deterministic even under test interleavings. + """ _activate_thread_device(self.device) maybe_inference = torch.inference_mode() if self._inference_mode else contextlib.nullcontext() with maybe_inference: @@ -286,9 +347,9 @@ def _run(self): try: if not is_task: if DEBUG_ON: log.debug(f"{self.name}: received sentinel; exiting") - break # sentinel -> exit + break if DEBUG_ON: log.debug(f"{self.name}: task begin; qsize={self._q.qsize()}") - # Tasks take a read lock so GC's writer lock can't interleave + # Tasks take a **read** lock so janitor's write lock can't interleave with self.rwlock.reader(): stream = kwargs.pop("_cuda_stream", None) with _device_ctx(self.device): @@ -297,29 +358,32 @@ def _run(self): result = fn(*args, **kwargs) else: result = fn(*args, **kwargs) - # IMPORTANT: mark finished (decrement inflight & bump counters) - # BEFORE resolving the future so stats() reflects completion - # as soon as do()/submit().result() returns. + # Counters must be updated before resolving futures to prevent + # tests reading stats mid-transition and seeing stale totals. self._on_task_finished(self.key) if not fut.cancelled(): fut.set_result(result) if DEBUG_ON: log.debug(f"{self.name}: task done") except BaseException as exc: - # Also mark finished BEFORE surfacing exception to caller + # Even on exception we must decrement inflight and update totals. self._on_task_finished(self.key) if not fut.cancelled(): fut.set_exception(exc) if DEBUG_ON: log.debug(f"{self.name}: task exception: {exc!r}") finally: self._q.task_done() - - # Thread is exiting; notify pool for cleanup try: self._on_worker_exit(self.key, self) finally: if DEBUG_ON: log.debug(f"{self.name}: exited") + # --------------------------- Public Pool --------------------------- +# - Builds workers per device with per-device RWLocks +# - Tracks inflight counts (with condition vars) and completion counters +# - Provides wait() with optional lock=True for exclusive operations +# - Runs a janitor background thread that performs periodic empty_cache() under +# exclusive writer locks, coalescing triggers via a short debounce window. class DeviceThreadPool: """ @@ -332,7 +396,7 @@ class DeviceThreadPool: - wait(scope, lock=False/True) to drain tasks (optionally with exclusive locks). - Per-device/global completed counters and in-flight counters. - Janitor: triggers empty-cache after N completions on accelerator devices (per-device lock). - - GC diagnostics helpers. + - GC diagnostics helpers (snapshots and ANSI tables). """ def __init__( @@ -375,12 +439,13 @@ def __init__( discovered.append(torch.device("cpu")) devices = discovered + # Locks and device registry (keyed by "type[:index]" strings like 'cuda:0'). self._locks: Dict[str, _RWLock] = {} self._devices_by_key: Dict[str, torch.device] = {} - # Worker groups: key -> List[_DeviceWorker] + # Worker groups and RR dispatch bookkeeping. self._worker_groups: Dict[str, List[_DeviceWorker]] = {} - self._dispatch_rr: Dict[str, int] = {} # round-robin index per key + self._dispatch_rr: Dict[str, int] = {} self._dispatch_lock = threading.Lock() # Stats / GC / inflight control @@ -393,21 +458,22 @@ def __init__( self._stop_event = threading.Event() self._janitor: Optional[threading.Thread] = None - # in-flight (scheduled but not finished) counters + per-device CVs + # In-flight (scheduled but not finished) counters + per-device CVs. + # Each device has a condition variable to let wait() callers block + # until inflight hits zero for that device scope. self._inflight: Dict[str, int] = {} self._inflight_cv: Dict[str, threading.Condition] = {} - # GC dedupe/coalesce + # GC dedupe/coalesce: debounce window to absorb bursty triggers; + # per-device "done" watermark to skip redundant GC passes. self._gc_debounce_s = float(gc_debounce_seconds) - # per-device watermark of "done" as of last GC that actually ran self._last_gc_done_per_device: Dict[str, int] = {} - # Store inference mode for worker spawns self._inference_mode = bool(inference_mode) workers = workers or {} - # Build locks, inflight structs, and workers eagerly + # Eagerly build workers, locks, inflight tracking, and counters. for d in devices: dev = _coerce_device(d) if dev.type not in ("cuda", "xpu", "mps", "cpu"): @@ -431,14 +497,14 @@ def __init__( self._worker_groups[key] = group self._dispatch_rr[key] = 0 - # Canonical lock order + # A canonical ordering for multi-device lock acquisitions. self._ordered_keys = sorted(self._locks.keys()) # GC diagnostics counters self._gc_passes = 0 self._last_gc_ts: Optional[float] = None - # Start janitor if enabled and accelerators exist + # Start janitor if enabled and there exists at least one accelerator. if self._empty_cache_every_n > 0 and any( self._devices_by_key[k].type in ("cuda", "xpu", "mps") for k in self._ordered_keys ): @@ -446,13 +512,18 @@ def __init__( target=self._janitor_loop, name="DP-Janitor", daemon=True ) self._janitor.start() - if DEBUG_ON: log.debug(f"DP-Janitor thread started (debounce={self._gc_debounce_s:.3f}s, threshold={self._empty_cache_every_n})") + if DEBUG_ON: + log.debug(f"DP-Janitor thread started (debounce={self._gc_debounce_s:.3f}s, threshold={self._empty_cache_every_n})") else: - if DEBUG_ON: log.debug("DP-Janitor disabled (no accelerators or threshold <= 0)") + if DEBUG_ON: + log.debug("DP-Janitor disabled (no accelerators or threshold <= 0)") # --------------- Worker management --------------- def _spawn_worker(self, dev: torch.device, name: Optional[str] = None) -> _DeviceWorker: + """ + Create and start a worker bound to the provided device. + """ key = self._key(dev) w = _DeviceWorker( device=dev, @@ -465,7 +536,9 @@ def _spawn_worker(self, dev: torch.device, name: Optional[str] = None) -> _Devic return w def _on_worker_exit(self, key: str, worker: _DeviceWorker) -> None: - """Cleanup finished workers from the group.""" + """ + Clean up worker bookkeeping after a thread exits. + """ with self._dispatch_lock: group = self._worker_groups.get(key, []) if worker in group: @@ -499,12 +572,12 @@ def submit( raise ValueError("_cuda_stream is only valid for CUDA devices") if DEBUG_ON: log.debug(f"submit: device={key} fn={getattr(fn, '__name__', repr(fn))}") - # mark in-flight before enqueue to avoid races with wait() + # Mark in-flight before enqueue to avoid races with wait(). self._mark_scheduled(key) try: return worker.submit(fn, *args, _cuda_stream=_cuda_stream, **kwargs) except BaseException: - # roll back inflight if enqueue fails (rare) + # Roll back inflight if enqueue fails (rare) self._mark_finished(key) raise @@ -517,14 +590,18 @@ def do( _cuda_stream: Optional[torch.cuda.Stream] = None, **kwargs, ) -> Any: - """Synchronously schedule work and block for the result.""" + """ + Synchronously schedule work and block for the result. + """ fut = self.submit(device, fn, *args, _cuda_stream=_cuda_stream, **kwargs) return fut.result() def shutdown(self, wait: bool = True): - """Gracefully stop all workers and janitor.""" + """ + Gracefully stop all workers and janitor. + """ self._stop_event.set() - self._gc_event.set() # wake janitor + self._gc_event.set() # wake janitor if waiting if self._janitor is not None and wait: if DEBUG_ON: log.debug("Joining DP-Janitor thread…") self._janitor.join() @@ -541,7 +618,9 @@ def shutdown(self, wait: bool = True): # --------------- Public Lock API --------------- def device_lock(self, device: DeviceLike): - """Exclusive lock for a single device (blocks all its workers).""" + """ + Obtain an exclusive lock for a single device (blocks all its workers). + """ dev = _coerce_device(device) key = self._key(dev) lk = self._locks.get(key) @@ -551,12 +630,13 @@ def device_lock(self, device: DeviceLike): def read_lock(self, device: DeviceLike | str): """ - Shared/read lock. Accepts: - - concrete device: torch.device('cuda:0'), 'cuda:1' - - family device: torch.device('cuda'), 'cuda', 'xpu', 'mps', 'cpu' + Obtain a read/shared lock. Accepts: + - Concrete device: torch.device('cuda:0'), 'cuda:1' + - Family device: 'cuda', 'xpu', 'mps', 'cpu' - 'all' for every device in the pool Returns a context manager. """ + # Family string shortcut (e.g., "cuda" or "all") if isinstance(device, str): if device == "all": pairs = [(k, self._locks[k]) for k in self._ordered_keys] @@ -568,9 +648,11 @@ def read_lock(self, device: DeviceLike | str): pairs = [(k, self._locks[k]) for k in keys] return _ReadLockGroup(pairs) + # torch.device / int / 'cuda:0' etc. dev = _coerce_device(device) key = self._key(dev) + # Family device with index=None -> all devices of that type if dev.index is None: fam = dev.type keys = [k for k in self._ordered_keys if k.startswith(fam)] @@ -579,6 +661,7 @@ def read_lock(self, device: DeviceLike | str): pairs = [(k, self._locks[k]) for k in keys] return _ReadLockGroup(pairs) + # Concrete device lk = self._locks.get(key) if lk is None: raise ValueError(f"Unknown device for pool: {dev}") @@ -597,6 +680,9 @@ def lock(self, devices: Optional[Iterable[DeviceLike]] = None): return _LockGroup(pairs) # --------------- Public Wait API --------------- + # The wait() primitive blocks until inflight work drains for a scope. + # With lock=True, it returns a context manager that first drains, then + # acquires writer locks—handy for "drain-and-free" sequences. def wait( self, @@ -614,10 +700,8 @@ def wait( - torch.device or iterable of the above lock: - - False (default): block until drained, then return None. - - True: return a context manager that **waits for drain AND acquires - exclusive write locks** over the scope. Usage: - `with pool.wait("cuda", lock=True): ...` + - False: block until drained. + - True: return a context manager that drains then acquires writer locks. """ keys = self._resolve_scope_to_keys(scope) @@ -625,6 +709,10 @@ def wait( pairs = [(k, self._locks[k]) for k in sorted(keys)] class _WaitThenLock(contextlib.AbstractContextManager): + """ + Drain inflight for the given keys, then acquire writer locks + in canonical order; release on exit. + """ def __init__(self, outer: DeviceThreadPool, pairs_local: List[tuple[str, _RWLock]], keys_local: List[str]): self._outer = outer self._pairs = pairs_local @@ -648,7 +736,7 @@ def __exit__(self, exc_type, exc, tb): return _WaitThenLock(self, pairs, keys) - # Pure wait without lock: wait for inflight to reach zero for each key. + # Simple drain (no locks) if DEBUG_ON: log.debug(f"wait(lock=False) drain start: keys={keys}") for k in keys: cv = self._inflight_cv[k] @@ -662,7 +750,9 @@ def __exit__(self, exc_type, exc, tb): # --------------- Public Stats API --------------- def stats(self) -> Dict[str, Any]: - """Return counters snapshot: per-device and global.""" + """ + Return a snapshot of counters. Use under tests or ad-hoc diagnostics. + """ with self._stats_lock: return { "per_device": dict(self._per_device_done), @@ -671,11 +761,17 @@ def stats(self) -> Dict[str, Any]: } def device_completed(self, device: DeviceLike) -> int: + """ + Convenience accessor for per-device completed count (atomic snapshot). + """ key = self._key(_coerce_device(device)) with self._stats_lock: return int(self._per_device_done.get(key, 0)) def total_completed(self) -> int: + """ + Convenience accessor for global completed count (atomic snapshot). + """ with self._stats_lock: return int(self._total_done) @@ -687,8 +783,9 @@ def _key(self, dev: torch.device) -> str: def _pick_worker(self, key: str) -> _DeviceWorker: """ - Simple round-robin over the current group. - If the group is empty (unlikely), spawn one to recover. + Round-robin selection across available workers for a device key. + If no workers exist (should not happen under normal init), we + spawn one lazily for robustness. """ with self._dispatch_lock: group = self._worker_groups.get(key) @@ -713,6 +810,12 @@ def _pick_worker(self, key: str) -> _DeviceWorker: return w def _resolve_workers_for_device(self, dev: torch.device, table: Dict[str, int]) -> int: + """ + Resolve worker count from policy table: + - exact key (e.g. 'cuda:0') overrides + - family-per (e.g. 'cuda:per') applies to all indices + - family singletons ('cpu', 'mps') apply to that single device + """ key = self._key(dev) if key in table: return int(table[key]) @@ -724,6 +827,11 @@ def _resolve_workers_for_device(self, dev: torch.device, table: Dict[str, int]) return 1 def _normalize_scope_to_keys(self, scope: Iterable[DeviceLike]) -> List[str]: + """ + Normalize a scope specification into a sorted list of device keys. + Accepts strings ('cuda', 'cuda:0', 'all'), ints (device indices), + and torch.device objects. + """ keys: List[str] = [] for s in scope: if isinstance(s, str): @@ -748,6 +856,9 @@ def _normalize_scope_to_keys(self, scope: Iterable[DeviceLike]) -> List[str]: return keys def _resolve_scope_to_keys(self, scope: Optional[Union[str, DeviceLike, Iterable[DeviceLike]]] = None) -> List[str]: + """ + Helper for wait()/lock(): expand a scope into concrete device keys. + """ if scope is None or (isinstance(scope, str) and scope == "all"): return list(self._ordered_keys) if isinstance(scope, (str, torch.device, int)): @@ -757,12 +868,19 @@ def _resolve_scope_to_keys(self, scope: Optional[Union[str, DeviceLike, Iterable # ---- inflight & completion accounting ---- def _mark_scheduled(self, key: str) -> None: + """ + Increment inflight for a device key and emit a debug breadcrumb if enabled. + """ cv = self._inflight_cv[key] with cv: self._inflight[key] = self._inflight.get(key, 0) + 1 if DEBUG_ON: log.debug(f"inflight[{key}] ++ -> {self._inflight[key]}") def _mark_finished(self, key: str) -> None: + """ + Decrement inflight for a device key, clamp at zero on underflow, and + notify waiters when the device drains. + """ cv = self._inflight_cv[key] with cv: new_val = self._inflight.get(key, 0) - 1 @@ -775,7 +893,10 @@ def _mark_finished(self, key: str) -> None: cv.notify_all() def _on_task_finished(self, key: str) -> None: - # inflight decrement + counters + potential GC trigger + """ + Called at the end of every task (success or failure). Updates counters + and signals the janitor if the per-device threshold is reached. + """ self._mark_finished(key) trigger_gc = False @@ -795,7 +916,10 @@ def _on_task_finished(self, key: str) -> None: # ---- ANSI table rendering for GC diagnostics ---- def _ansi_table(self, headers: List[str], rows: List[List[str]]) -> str: - """Render a simple ANSI/ASCII table with bold headers.""" + """ + Render a simple ANSI/ASCII table with bold headers. Used only for + human-readable diagnostics; not used in hot paths. + """ widths = [len(h) for h in headers] for r in rows: for i, cell in enumerate(r): @@ -829,7 +953,9 @@ def format_row(cols: List[str]): return "\n".join(lines) def _collect_state_snapshot(self) -> Dict[str, Any]: - """Safely collect a snapshot of pool state for diagnostics.""" + """ + Safely collect a snapshot of pool state for diagnostics and GC decisions. + """ with self._stats_lock: per_done = dict(self._per_device_done) total_done = int(self._total_done) @@ -864,7 +990,9 @@ def _collect_state_snapshot(self) -> Dict[str, Any]: return snap def _render_gc_table(self, snap: Dict[str, Any]) -> str: - """Build the ANSI table for the current snapshot.""" + """ + Pretty-print a GC state table; used only when someone wants to log it. + """ headers = [ "Device", "Type", "Index", "Workers", "Inflight", "Done", "Threshold", "NextGC", "Accel" @@ -900,11 +1028,15 @@ def _render_gc_table(self, snap: Dict[str, Any]) -> str: return table_main + "\n" + table_totals # ---- janitor (per-device empty-cache under exclusive writer lock) ---- + # The janitor runs in the background. When a device completes N tasks, a trigger + # is set. The janitor debounces triggers, takes a snapshot, and if at least one + # accelerator device progressed by >= N tasks since the last pass, it iterates + # devices, acquiring each device's writer lock before calling empty_cache(). def _synchronize_all(self): """ - Ensure devices are idle before empty_cache() to avoid races with outstanding kernels. - Iterate discovered devices and only guard on attribute presence for backend sync. + Optionally ensure devices are idle before empty_cache() to avoid races with + outstanding kernels. Keeping this disabled by default for performance. """ # CUDA for key in self._ordered_keys: @@ -913,7 +1045,6 @@ def _synchronize_all(self): continue with torch.cuda.device(dev.index): torch.cuda.synchronize() - # XPU for key in self._ordered_keys: dev = self._devices_by_key[key] @@ -922,7 +1053,6 @@ def _synchronize_all(self): if hasattr(torch, "xpu") and hasattr(torch.xpu, "synchronize"): with torch.xpu.device(dev.index): torch.xpu.synchronize() - # MPS has_mps_device = any(self._devices_by_key[k].type == "mps" for k in self._ordered_keys) if has_mps_device and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): @@ -947,52 +1077,59 @@ def _should_run_gc_from_snapshot(self, snap: Dict[str, Any]) -> bool: return False def _update_gc_watermarks(self, snap_after: Dict[str, Any]) -> None: - """Record 'done' counters as of a GC pass.""" + """ + Record 'done' counters as of a GC pass to require fresh progress + before a subsequent pass is allowed. + """ for k in snap_after["devices"]: self._last_gc_done_per_device[k] = snap_after["per_done"].get(k, 0) def _janitor_loop(self): - # Use timeouts so we can honor stop_event without relying on a final trigger. - WAIT_TIMEOUT = 0.1 # seconds + """ + Main janitor loop: + - Waits on a trigger (with short timeout to honor shutdowns promptly). + - Debounces additional triggers for a brief window. + - Takes a snapshot and decides whether to run. + - For each accelerator device, acquires its writer lock and calls + empty_cache() using the LIVE attribute if callable, otherwise the + HARD COPY captured at import time. + """ + WAIT_TIMEOUT = 0.1 while True: if DEBUG_ON: log.debug("DP-Janitor: waiting for trigger…") - # Exit promptly if shutdown requested before/while waiting. if self._stop_event.is_set(): if DEBUG_ON: log.debug("DP-Janitor: stop event set before wait; exiting") break - # Wait with a timeout so we can re-check stop_event periodically. triggered = self._gc_event.wait(timeout=WAIT_TIMEOUT) if not triggered: - # Timed out; loop to check stop again. continue - # Clear the event and re-check stop before doing anything else. self._gc_event.clear() if self._stop_event.is_set(): if DEBUG_ON: log.debug("DP-Janitor: stop event set after trigger; exiting") break - # Debounce to coalesce bursty triggers; keep draining the event during the window. + # Debounce window: absorb additional triggers before deciding. if self._gc_debounce_s > 0: t_end = time.time() + self._gc_debounce_s if DEBUG_ON: log.debug(f"DP-Janitor: debounce window start ({self._gc_debounce_s:.3f}s)") while time.time() < t_end: - # If stopping during debounce, honor it immediately if self._stop_event.is_set(): if DEBUG_ON: log.debug("DP-Janitor: stop during debounce; exiting") return - # Drain subsequent triggers within the window self._gc_event.wait(timeout=max(0.0, t_end - time.time())) self._gc_event.clear() if DEBUG_ON: log.debug("DP-Janitor: debounce window end") + # Snapshot & decision try: pre = self._collect_state_snapshot() if DEBUG_ON: log.debug(f"DP-Janitor: pre-snapshot taken: total_done={pre['total_done']}, threshold={pre['threshold']}, inflight={pre['inflight']}") log.debug("GC trigger received; evaluating whether to run…") except Exception as e: + # Fallback snapshot (unlikely path; logging should not crash janitor) try: log.warn(f"Failed to render GC pre-snapshot: {e!r}") except Exception: @@ -1017,11 +1154,10 @@ def _janitor_loop(self): continue t0 = time.time() - - # Optional sync (disabled by default as it can be costly) + # Optionally synchronize devices; often too slow to be worthwhile: # self._synchronize_all() - # GC each device independently under its writer lock. + # Per-device exclusive: acquire write lock, then call empty_cache(). for key in sorted(self._ordered_keys): dev = self._devices_by_key[key] if dev.type not in ("cuda", "xpu", "mps"): @@ -1031,26 +1167,41 @@ def _janitor_loop(self): if DEBUG_ON: log.debug(f"DP-Janitor: attempting writer lock for {key}") with lk.writer(): if DEBUG_ON: log.debug(f"DP-Janitor: acquired writer lock for {key}") - if dev.type == "cuda" and TORCH_CUDA_EMPTY_CACHE is not None: - if DEBUG_ON: log.debug(f"DP-Janitor: empty_cache(cuda) begin on {key}") - with torch.cuda.device(dev.index): - TORCH_CUDA_EMPTY_CACHE() - if DEBUG_ON: log.debug(f"DP-Janitor: empty_cache(cuda) done on {key}") - elif dev.type == "xpu" and TORCH_XPU_EMPTY_CACHE is not None: - if DEBUG_ON: log.debug(f"DP-Janitor: empty_cache(xpu) begin on {key}") - with torch.xpu.device(dev.index): - TORCH_XPU_EMPTY_CACHE() - if DEBUG_ON: log.debug(f"DP-Janitor: empty_cache(xpu) done on {key}") - elif dev.type == "mps" and TORCH_MPS_EMPTY_CACHE is not None: - if DEBUG_ON: log.debug("DP-Janitor: empty_cache(mps) begin") - TORCH_MPS_EMPTY_CACHE() - if DEBUG_ON: log.debug("DP-Janitor: empty_cache(mps) done") - t1 = time.time() + if dev.type == "cuda": + live = getattr(torch.cuda, "empty_cache", None) if hasattr(torch, "cuda") else None + use_fn = live if callable(live) else TORCH_CUDA_EMPTY_CACHE + if DEBUG_ON: + src = "live" if use_fn is live else "hardcopy" + log.debug(f"DP-Janitor: empty_cache(cuda) using {src} on {key}") + if use_fn is not None: + with torch.cuda.device(dev.index): + use_fn() + + elif dev.type == "xpu": + live = getattr(torch.xpu, "empty_cache", None) if hasattr(torch, "xpu") else None + use_fn = live if callable(live) else TORCH_XPU_EMPTY_CACHE + if DEBUG_ON: + src = "live" if use_fn is live else "hardcopy" + log.debug(f"DP-Janitor: empty_cache(xpu) using {src} on {key}") + if use_fn is not None: + with torch.xpu.device(dev.index): + use_fn() + + elif dev.type == "mps": + live = getattr(torch.mps, "empty_cache", None) if hasattr(torch, "mps") else None + use_fn = live if callable(live) else TORCH_MPS_EMPTY_CACHE + if DEBUG_ON: + src = "live" if use_fn is live else "hardcopy" + log.debug(f"DP-Janitor: empty_cache(mps) using {src}") + if use_fn is not None: + use_fn() + t1 = time.time() self._gc_passes += 1 self._last_gc_ts = t1 + # Post-pass accounting & watermarks. try: post = self._collect_state_snapshot() self._update_gc_watermarks(post) @@ -1062,8 +1213,14 @@ def _janitor_loop(self): except Exception: pass - # Legacy helper retained for compatibility (not used by janitor). + # Legacy helper (not used by janitor). Kept for compatibility with any + # external callers that previously expected a "clear everything" helper. def _empty_all_caches(self): + """ + Call the captured originals if available. This does not consult the live + attribute and therefore does not pick up monkeypatching. Prefer the janitor’s + per-device logic for production use. + """ if TORCH_CUDA_EMPTY_CACHE is not None: for key in self._ordered_keys: dev = self._devices_by_key[key] @@ -1071,7 +1228,6 @@ def _empty_all_caches(self): continue with torch.cuda.device(dev.index): TORCH_CUDA_EMPTY_CACHE() - if TORCH_XPU_EMPTY_CACHE is not None: for key in self._ordered_keys: dev = self._devices_by_key[key] @@ -1079,7 +1235,6 @@ def _empty_all_caches(self): continue with torch.xpu.device(dev.index): TORCH_XPU_EMPTY_CACHE() - if TORCH_MPS_EMPTY_CACHE is not None: has_mps_device = any(self._devices_by_key[k].type == "mps" for k in self._ordered_keys) if has_mps_device: From 4c2a139f01c3e32e06a5e55e3c58d4d1ba4249c6 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 29 Sep 2025 23:33:49 +0000 Subject: [PATCH 7/7] fix unsafe shutdown Signed-off-by: Qubitium --- gptqmodel/utils/threadx.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index 928df3be5..04bbdd8ac 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -599,20 +599,35 @@ def do( def shutdown(self, wait: bool = True): """ Gracefully stop all workers and janitor. + + IMPORTANT: We snapshot groups before stopping/joining to avoid mutating + the lists while iterating (workers remove themselves on exit). """ self._stop_event.set() self._gc_event.set() # wake janitor if waiting + + # Take stable snapshots under the dispatch lock. + with self._dispatch_lock: + group_snapshots: Dict[str, List[_DeviceWorker]] = { + key: list(group) for key, group in self._worker_groups.items() + } + + # Stop janitor first so it won't grab locks while workers wind down. if self._janitor is not None and wait: if DEBUG_ON: log.debug("Joining DP-Janitor thread…") self._janitor.join() - for group in self._worker_groups.values(): - for w in group: + # Issue stop to every worker from the snapshots (no mutation hazards). + for key, snapshot in group_snapshots.items(): + for w in snapshot: w.stop() + + # Join everyone if requested. if wait: - for group in self._worker_groups.values(): - for w in group: + for key, snapshot in group_snapshots.items(): + for w in snapshot: w.join() + if DEBUG_ON: log.debug("DeviceThreadPool shutdown complete") # --------------- Public Lock API ---------------