diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index a827ad48c..04bbdd8ac 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -6,6 +6,7 @@ from __future__ import annotations import contextlib +import os import queue import threading import time @@ -16,14 +17,19 @@ from ..utils.logger import setup_logger - log = setup_logger() +# Debug logging is very chatty and can alter timings subtly in tests. +# We gate all extra diagnostics behind the DEBUG env (1/true/yes/on). +DEBUG_ON = str(os.environ.get("DEBUG", "")).lower() in ("1", "true", "yes", "on") +# DeviceLike allows ergonomic call sites: 'cuda:0', 0, torch.device('cuda', 0), etc. DeviceLike = Union[str, int, torch.device] # --------------------------- Backend availability helpers --------------------------- +# We keep these helpers small and side-effect free—only feature checks—so we can +# query once and rely on final results without redundant availability if-ladders. def _mps_available() -> bool: return ( @@ -34,6 +40,11 @@ def _mps_available() -> bool: # --- HARD COPIES of original empty_cache callables (never auto-switched) --- +# IMPORTANT: Do NOT “optimize” these by directly calling torch.*.empty_cache. +# We intentionally capture a snapshot of the original functions to defend against +# later code mutating those attributes to a no-op. The janitor will prefer the +# *live* attribute if callable (so monkeypatching works), but falls back to these +# hard copies if the live attr is missing or non-callable. TORCH_CUDA_EMPTY_CACHE: Optional[Callable[[], None]] = None TORCH_XPU_EMPTY_CACHE: Optional[Callable[[], None]] = None TORCH_MPS_EMPTY_CACHE: Optional[Callable[[], None]] = None @@ -43,6 +54,7 @@ def _mps_available() -> bool: if TORCH_CUDA_EMPTY_CACHE is not None and not callable(TORCH_CUDA_EMPTY_CACHE): TORCH_CUDA_EMPTY_CACHE = None except Exception: + # If introspection fails, we keep the hard copy as None. TORCH_CUDA_EMPTY_CACHE = None try: @@ -58,10 +70,15 @@ def _mps_available() -> bool: TORCH_MPS_EMPTY_CACHE = None except Exception: TORCH_MPS_EMPTY_CACHE = None -# ------------------------------------------------------------------------------- +# --------------------------- Device coercion & context helpers --------------------------- + def _coerce_device(d: DeviceLike) -> torch.device: + """ + Convert a DeviceLike into a concrete torch.device. For integers, we + interpret as accelerator indices if present, otherwise map to CPU/MPS. + """ if isinstance(d, torch.device): return d if isinstance(d, int): @@ -72,12 +89,17 @@ def _coerce_device(d: DeviceLike) -> torch.device: if _mps_available(): return torch.device("mps") return torch.device("cpu") + # Accept strings like 'cuda:0', 'xpu:1', 'cpu', 'mps' return torch.device(d) @contextlib.contextmanager def _device_ctx(dev: torch.device): - """Set the caller thread’s current device for CUDA/XPU so library handles match.""" + """ + Set the caller thread’s *current* device while running a task so handles/streams + line up correctly. For CUDA/XPU we set the per-thread current device; CPU/MPS + do not require pinning here. + """ if dev.type == "cuda": with torch.cuda.device(dev.index): yield @@ -89,7 +111,10 @@ def _device_ctx(dev: torch.device): def _activate_thread_device(dev: torch.device): - """Pin the worker thread to the device.""" + """ + Pin the worker thread to its device once, before entering its main loop. + CUDA/XPU require per-thread device activation for correct handle usage. + """ if dev.type == "cuda": torch.cuda.set_device(dev.index) elif dev.type == "xpu" and hasattr(torch, "xpu"): @@ -98,6 +123,9 @@ def _activate_thread_device(dev: torch.device): # --------------------------- Read-Write Lock (writer-preference) --------------------------- +# We implement a writer-preference RWLock. Multiple readers may hold the lock, +# but when a writer is waiting we block new readers, ensuring GC (writer) can +# eventually acquire exclusivity even under task pressure. class _RWLock: """ @@ -119,7 +147,7 @@ def __init__(self): def acquire_write(self): me = threading.get_ident() with self._cond: - if self._writer == me: # re-entrant + if self._writer == me: # Re-entrant writer self._writer_depth += 1 return self._writers_waiting += 1 @@ -153,10 +181,12 @@ def writer(self): def acquire_read(self): me = threading.get_ident() with self._cond: - # writer can re-enter as reader + # The writer may re-enter as a reader; this keeps invariants simple + # for code that wants to read while already holding write. if self._writer == me: self._readers += 1 return + # If a writer is waiting, block new readers to give it priority. while self._writer is not None or self._writers_waiting > 0: self._cond.wait() self._readers += 1 @@ -179,33 +209,45 @@ def reader(self): class _LockGroup(contextlib.AbstractContextManager): - """Acquire multiple device write locks in deterministic order to avoid deadlocks.""" + """ + Acquire multiple device **write** locks in deterministic order to avoid deadlocks. + Helpful for GC passes or any multi-device exclusive operation. + """ def __init__(self, ordered_pairs: List[tuple[str, _RWLock]]): self._pairs = ordered_pairs def __enter__(self): - for _, lk in self._pairs: + for name, lk in self._pairs: + if DEBUG_ON: log.debug(f"_LockGroup: acquiring write lock for {name}") lk.acquire_write() + if DEBUG_ON: log.debug(f"_LockGroup: acquired write lock for {name}") return self def __exit__(self, exc_type, exc, tb): - for _, lk in reversed(self._pairs): + for name, lk in reversed(self._pairs): + if DEBUG_ON: log.debug(f"_LockGroup: releasing write lock for {name}") lk.release_write() return False class _ReadLockGroup(contextlib.AbstractContextManager): - """Acquire multiple device read locks in deterministic order.""" + """ + Acquire multiple device **read** locks in deterministic order. + Useful for multi-device snapshots or read-only, consistent views. + """ def __init__(self, ordered_pairs: List[tuple[str, _RWLock]]): self._pairs = ordered_pairs def __enter__(self): - for _, lk in self._pairs: + for name, lk in self._pairs: + if DEBUG_ON: log.debug(f"_ReadLockGroup: acquiring read lock for {name}") lk.acquire_read() + if DEBUG_ON: log.debug(f"_ReadLockGroup: acquired read lock for {name}") return self def __exit__(self, exc_type, exc, tb): - for _, lk in reversed(self._pairs): + for name, lk in reversed(self._pairs): + if DEBUG_ON: log.debug(f"_ReadLockGroup: releasing read lock for {name}") lk.release_read() return False @@ -218,6 +260,7 @@ class _WaitAndLock(contextlib.AbstractContextManager): On exit: releases locks. """ def __init__(self, pairs: List[tuple[str, _RWLock]]): + self._pairs = pairs self._group = _LockGroup(pairs) def __enter__(self): @@ -228,81 +271,85 @@ def __exit__(self, exc_type, exc, tb): # --------------------------- Worker Thread --------------------------- +# Each worker is bound to a specific device and runs a single thread. Tasks are +# executed under the device’s read lock; GC acquires the writer lock to keep +# memory management steps from interleaving with tasks. class _DeviceWorker: """ Single worker thread bound to one device. + Queue entries: (is_task: bool, fn, args, kwargs, future) - Supports configurable lifecycle: after N tasks, stop accepting new work, - drain its queue, and exit; the pool will spawn a replacement. + - is_task=False is a sentinel to exit the thread loop. + - Tasks run within a device-scoped reader lock to prevent interleaving + with GC passes (which need a write lock). """ def __init__( self, device: torch.device, rwlock: _RWLock, on_task_finished: Callable[[str], None], - on_retire_request: Callable[[str, _DeviceWorker], None], - on_worker_exit: Callable[[str, _DeviceWorker], None], + on_worker_exit: Callable[[str, "_DeviceWorker"], None], name: Optional[str] = None, inference_mode: bool = False, - lifecycle_calls: int = 50, ): self.device = device self.rwlock = rwlock self._on_task_finished = on_task_finished - self._on_retire_request = on_retire_request self._on_worker_exit = on_worker_exit - self._lifecycle_limit = max(0, int(lifecycle_calls)) # 0 disables rotation - self._tasks_since_spawn = 0 self.key = f"{device.type}:{device.index}" if device.index is not None else device.type self.name = name or f"DPWorker-{self.key}" self._q: "queue.Queue[Tuple[bool, Callable[..., Any], tuple, dict, Future]]" = queue.Queue() self._stop = threading.Event() - self._retire_requested = False - self._accepting = True self._inference_mode = inference_mode self._thread = threading.Thread(target=self._run, name=self.name, daemon=True) self._thread.start() + if DEBUG_ON: log.debug(f"Spawned worker '{self.name}' for {self.key}") - # --- lifecycle / accept state --- - def is_accepting(self) -> bool: - return self._accepting and not self._stop.is_set() - - def request_stop(self): - self._stop.set() - self._q.put((False, lambda: None, (), {}, Future())) - - # --- public API for pool --- def submit(self, fn: Callable[..., Any], /, *args, **kwargs) -> Future: + """ + Enqueue a callable and return a Future that resolves with its result/exception. + """ fut = Future() self._q.put((True, fn, args, kwargs, fut)) + if DEBUG_ON: log.debug(f"{self.name}: task enqueued; qsize={self._q.qsize()}") return fut def stop(self): - self.request_stop() + """ + Request thread exit by setting a sentinel work item. The run loop exits + ASAP after receiving it. + """ + self._stop.set() + self._q.put((False, lambda: None, (), {}, Future())) + if DEBUG_ON: log.debug(f"{self.name}: stop requested; sentinel queued") def join(self): + """ + Join the worker thread; for graceful shutdowns and tests. + """ + if DEBUG_ON: log.debug(f"{self.name}: joining thread") self._thread.join() - # --- internal main loop --- def _run(self): + """ + Main loop: pull tasks, set device context, execute, mark completion, and + fulfill or fail the future. Completion is accounted BEFORE resolving the + future to make stats() deterministic even under test interleavings. + """ _activate_thread_device(self.device) maybe_inference = torch.inference_mode() if self._inference_mode else contextlib.nullcontext() with maybe_inference: while not self._stop.is_set(): - # If we're retiring and nothing is queued, exit gracefully - if self._retire_requested and self._q.empty(): - break - try: - is_task, fn, args, kwargs, fut = self._q.get(timeout=0.05) - except queue.Empty: - continue + is_task, fn, args, kwargs, fut = self._q.get() try: if not is_task: - break # sentinel -> exit - # Tasks take a read lock so GC's writer lock can't interleave + if DEBUG_ON: log.debug(f"{self.name}: received sentinel; exiting") + break + if DEBUG_ON: log.debug(f"{self.name}: task begin; qsize={self._q.qsize()}") + # Tasks take a **read** lock so janitor's write lock can't interleave with self.rwlock.reader(): stream = kwargs.pop("_cuda_stream", None) with _device_ctx(self.device): @@ -311,32 +358,32 @@ def _run(self): result = fn(*args, **kwargs) else: result = fn(*args, **kwargs) + # Counters must be updated before resolving futures to prevent + # tests reading stats mid-transition and seeing stale totals. + self._on_task_finished(self.key) if not fut.cancelled(): fut.set_result(result) + if DEBUG_ON: log.debug(f"{self.name}: task done") except BaseException as exc: + # Even on exception we must decrement inflight and update totals. + self._on_task_finished(self.key) if not fut.cancelled(): fut.set_exception(exc) + if DEBUG_ON: log.debug(f"{self.name}: task exception: {exc!r}") finally: - if is_task: - self._tasks_since_spawn += 1 - self._on_task_finished(self.key) - # Lifecycle check: once we hit the limit, mark retiring (stop accepting) - if self._lifecycle_limit > 0 and self._tasks_since_spawn >= self._lifecycle_limit: - if not self._retire_requested: - self._retire_requested = True - self._accepting = False - # Notify pool to spawn a replacement now - self._on_retire_request(self.key, self) self._q.task_done() - - # Thread is exiting; notify pool for cleanup try: self._on_worker_exit(self.key, self) finally: - pass + if DEBUG_ON: log.debug(f"{self.name}: exited") # --------------------------- Public Pool --------------------------- +# - Builds workers per device with per-device RWLocks +# - Tracks inflight counts (with condition vars) and completion counters +# - Provides wait() with optional lock=True for exclusive operations +# - Runs a janitor background thread that performs periodic empty_cache() under +# exclusive writer locks, coalescing triggers via a short debounce window. class DeviceThreadPool: """ @@ -348,9 +395,8 @@ class DeviceThreadPool: - Per-device RWLocks + global lock and family/all read-locks. - wait(scope, lock=False/True) to drain tasks (optionally with exclusive locks). - Per-device/global completed counters and in-flight counters. - - Janitor: triggers empty-cache after N completions on accelerator devices, under a global lock. - - GC diagnostics helpers. - - Worker lifecycle rotation: after N tasks (default 50), workers retire and are replaced. + - Janitor: triggers empty-cache after N completions on accelerator devices (per-device lock). + - GC diagnostics helpers (snapshots and ANSI tables). """ def __init__( @@ -365,7 +411,6 @@ def __init__( empty_cache_every_n: int = 50, # <=0 disables janitor workers: Optional[Dict[str, int]] = None, # e.g. {'cpu':4, 'cuda:per':1, 'cuda:0':3} gc_debounce_seconds: float = 0.02, # absorb bursty triggers before GC - worker_lifecycle_calls: int = 50, # <=0 disables lifecycle rotation ): """ Args: @@ -379,7 +424,6 @@ def __init__( - 'xpu:': N -> override for specific XPU index Unspecified devices default to 1 worker each. gc_debounce_seconds: short wait to coalesce multiple triggers. - worker_lifecycle_calls: number of tasks a worker handles before retiring (0=disabled). """ if devices is None: discovered: List[torch.device] = [] @@ -395,12 +439,13 @@ def __init__( discovered.append(torch.device("cpu")) devices = discovered + # Locks and device registry (keyed by "type[:index]" strings like 'cuda:0'). self._locks: Dict[str, _RWLock] = {} self._devices_by_key: Dict[str, torch.device] = {} - # Worker groups: key -> List[_DeviceWorker] + # Worker groups and RR dispatch bookkeeping. self._worker_groups: Dict[str, List[_DeviceWorker]] = {} - self._dispatch_rr: Dict[str, int] = {} # round-robin index per key + self._dispatch_rr: Dict[str, int] = {} self._dispatch_lock = threading.Lock() # Stats / GC / inflight control @@ -413,23 +458,22 @@ def __init__( self._stop_event = threading.Event() self._janitor: Optional[threading.Thread] = None - # in-flight (scheduled but not finished) counters + per-device CVs + # In-flight (scheduled but not finished) counters + per-device CVs. + # Each device has a condition variable to let wait() callers block + # until inflight hits zero for that device scope. self._inflight: Dict[str, int] = {} self._inflight_cv: Dict[str, threading.Condition] = {} - # GC dedupe/coalesce + # GC dedupe/coalesce: debounce window to absorb bursty triggers; + # per-device "done" watermark to skip redundant GC passes. self._gc_debounce_s = float(gc_debounce_seconds) - # per-device watermark of "done" as of last GC that actually ran self._last_gc_done_per_device: Dict[str, int] = {} - # Worker lifecycle rotation - self._worker_lifecycle_calls = int(worker_lifecycle_calls) - # Store inference mode for worker spawns self._inference_mode = bool(inference_mode) workers = workers or {} - # Build locks, inflight structs, and workers eagerly + # Eagerly build workers, locks, inflight tracking, and counters. for d in devices: dev = _coerce_device(d) if dev.type not in ("cuda", "xpu", "mps", "cpu"): @@ -453,14 +497,14 @@ def __init__( self._worker_groups[key] = group self._dispatch_rr[key] = 0 - # Canonical lock order + # A canonical ordering for multi-device lock acquisitions. self._ordered_keys = sorted(self._locks.keys()) # GC diagnostics counters self._gc_passes = 0 self._last_gc_ts: Optional[float] = None - # Start janitor if enabled and accelerators exist + # Start janitor if enabled and there exists at least one accelerator. if self._empty_cache_every_n > 0 and any( self._devices_by_key[k].type in ("cuda", "xpu", "mps") for k in self._ordered_keys ): @@ -468,37 +512,33 @@ def __init__( target=self._janitor_loop, name="DP-Janitor", daemon=True ) self._janitor.start() + if DEBUG_ON: + log.debug(f"DP-Janitor thread started (debounce={self._gc_debounce_s:.3f}s, threshold={self._empty_cache_every_n})") + else: + if DEBUG_ON: + log.debug("DP-Janitor disabled (no accelerators or threshold <= 0)") - # --------------- Worker management (spawn/retire/cleanup) --------------- + # --------------- Worker management --------------- def _spawn_worker(self, dev: torch.device, name: Optional[str] = None) -> _DeviceWorker: + """ + Create and start a worker bound to the provided device. + """ key = self._key(dev) - return _DeviceWorker( + w = _DeviceWorker( device=dev, rwlock=self._locks[key], on_task_finished=self._on_task_finished, - on_retire_request=self._on_worker_retire_request, on_worker_exit=self._on_worker_exit, name=name, inference_mode=self._inference_mode, - lifecycle_calls=self._worker_lifecycle_calls, ) + return w - def _on_worker_retire_request(self, key: str, worker: _DeviceWorker) -> None: + def _on_worker_exit(self, key: str, worker: _DeviceWorker) -> None: """ - A worker hit its lifecycle limit. Mark it non-accepting (it already is), - and immediately spawn a replacement to maintain capacity. + Clean up worker bookkeeping after a thread exits. """ - dev = self._devices_by_key[key] - with self._dispatch_lock: - group = self._worker_groups.get(key, []) - if worker in group: - replacement = self._spawn_worker(dev, name=f"{worker.name}.r{int(time.time()*1000)}") - group.append(replacement) - self._worker_groups[key] = group - - def _on_worker_exit(self, key: str, worker: _DeviceWorker) -> None: - """Cleanup finished workers from the group.""" with self._dispatch_lock: group = self._worker_groups.get(key, []) if worker in group: @@ -508,6 +548,7 @@ def _on_worker_exit(self, key: str, worker: _DeviceWorker) -> None: self._dispatch_rr[key] %= len(group) else: self._dispatch_rr[key] = 0 + if DEBUG_ON: log.debug(f"Worker '{worker.name}' exited for {key}") # --------------- Public Work API --------------- @@ -530,12 +571,13 @@ def submit( if _cuda_stream is not None and dev.type != "cuda": raise ValueError("_cuda_stream is only valid for CUDA devices") - # mark in-flight before enqueue to avoid races with wait() + if DEBUG_ON: log.debug(f"submit: device={key} fn={getattr(fn, '__name__', repr(fn))}") + # Mark in-flight before enqueue to avoid races with wait(). self._mark_scheduled(key) try: return worker.submit(fn, *args, _cuda_stream=_cuda_stream, **kwargs) except BaseException: - # roll back inflight if enqueue fails (rare) + # Roll back inflight if enqueue fails (rare) self._mark_finished(key) raise @@ -548,29 +590,52 @@ def do( _cuda_stream: Optional[torch.cuda.Stream] = None, **kwargs, ) -> Any: - """Synchronously schedule work and block for the result.""" + """ + Synchronously schedule work and block for the result. + """ fut = self.submit(device, fn, *args, _cuda_stream=_cuda_stream, **kwargs) return fut.result() def shutdown(self, wait: bool = True): - """Gracefully stop all workers and janitor.""" + """ + Gracefully stop all workers and janitor. + + IMPORTANT: We snapshot groups before stopping/joining to avoid mutating + the lists while iterating (workers remove themselves on exit). + """ self._stop_event.set() - self._gc_event.set() # wake janitor + self._gc_event.set() # wake janitor if waiting + + # Take stable snapshots under the dispatch lock. + with self._dispatch_lock: + group_snapshots: Dict[str, List[_DeviceWorker]] = { + key: list(group) for key, group in self._worker_groups.items() + } + + # Stop janitor first so it won't grab locks while workers wind down. if self._janitor is not None and wait: + if DEBUG_ON: log.debug("Joining DP-Janitor thread…") self._janitor.join() - for group in self._worker_groups.values(): - for w in group: + # Issue stop to every worker from the snapshots (no mutation hazards). + for key, snapshot in group_snapshots.items(): + for w in snapshot: w.stop() + + # Join everyone if requested. if wait: - for group in self._worker_groups.values(): - for w in group: + for key, snapshot in group_snapshots.items(): + for w in snapshot: w.join() + if DEBUG_ON: log.debug("DeviceThreadPool shutdown complete") + # --------------- Public Lock API --------------- def device_lock(self, device: DeviceLike): - """Exclusive lock for a single device (blocks all its workers).""" + """ + Obtain an exclusive lock for a single device (blocks all its workers). + """ dev = _coerce_device(device) key = self._key(dev) lk = self._locks.get(key) @@ -580,12 +645,13 @@ def device_lock(self, device: DeviceLike): def read_lock(self, device: DeviceLike | str): """ - Shared/read lock. Accepts: - - concrete device: torch.device('cuda:0'), 'cuda:1' - - family device: torch.device('cuda'), 'cuda', 'xpu', 'mps', 'cpu' + Obtain a read/shared lock. Accepts: + - Concrete device: torch.device('cuda:0'), 'cuda:1' + - Family device: 'cuda', 'xpu', 'mps', 'cpu' - 'all' for every device in the pool Returns a context manager. """ + # Family string shortcut (e.g., "cuda" or "all") if isinstance(device, str): if device == "all": pairs = [(k, self._locks[k]) for k in self._ordered_keys] @@ -597,9 +663,11 @@ def read_lock(self, device: DeviceLike | str): pairs = [(k, self._locks[k]) for k in keys] return _ReadLockGroup(pairs) + # torch.device / int / 'cuda:0' etc. dev = _coerce_device(device) key = self._key(dev) + # Family device with index=None -> all devices of that type if dev.index is None: fam = dev.type keys = [k for k in self._ordered_keys if k.startswith(fam)] @@ -608,6 +676,7 @@ def read_lock(self, device: DeviceLike | str): pairs = [(k, self._locks[k]) for k in keys] return _ReadLockGroup(pairs) + # Concrete device lk = self._locks.get(key) if lk is None: raise ValueError(f"Unknown device for pool: {dev}") @@ -626,23 +695,79 @@ def lock(self, devices: Optional[Iterable[DeviceLike]] = None): return _LockGroup(pairs) # --------------- Public Wait API --------------- + # The wait() primitive blocks until inflight work drains for a scope. + # With lock=True, it returns a context manager that first drains, then + # acquires writer locks—handy for "drain-and-free" sequences. - def wait(self, scope: Optional[Union[str, DeviceLike, Iterable[DeviceLike]]]) -> None | _WaitAndLock: + def wait( + self, + scope: Optional[Union[str, DeviceLike, Iterable[DeviceLike]]] = None, + *, + lock: bool = False, + ) -> None | _WaitAndLock: """ Wait until in-flight tasks for `scope` drain to zero. + + scope: + - None or 'all' -> all devices + - 'cuda' | 'xpu' | 'mps' | 'cpu' -> all devices of that type + - 'cuda:0' | 'xpu:1' -> specific device key + - torch.device or iterable of the above + + lock: + - False: block until drained. + - True: return a context manager that drains then acquires writer locks. """ keys = self._resolve_scope_to_keys(scope) + + if lock: + pairs = [(k, self._locks[k]) for k in sorted(keys)] + + class _WaitThenLock(contextlib.AbstractContextManager): + """ + Drain inflight for the given keys, then acquire writer locks + in canonical order; release on exit. + """ + def __init__(self, outer: DeviceThreadPool, pairs_local: List[tuple[str, _RWLock]], keys_local: List[str]): + self._outer = outer + self._pairs = pairs_local + self._keys = keys_local + self._group = _LockGroup(pairs_local) + + def __enter__(self): + if DEBUG_ON: log.debug(f"wait(lock=True) drain start: keys={self._keys}") + for kk in self._keys: + cv = self._outer._inflight_cv[kk] + with cv: + while self._outer._inflight[kk] > 0: + if DEBUG_ON: log.debug(f"wait(lock=True) blocking on inflight[{kk}]={self._outer._inflight[kk]}") + cv.wait() + if DEBUG_ON: log.debug(f"wait(lock=True) acquire writer locks: keys={self._keys}") + return self._group.__enter__() + + def __exit__(self, exc_type, exc, tb): + if DEBUG_ON: log.debug(f"wait(lock=True) releasing writer locks: keys={self._keys}") + return self._group.__exit__(exc_type, exc, tb) + + return _WaitThenLock(self, pairs, keys) + + # Simple drain (no locks) + if DEBUG_ON: log.debug(f"wait(lock=False) drain start: keys={keys}") for k in keys: cv = self._inflight_cv[k] with cv: while self._inflight[k] > 0: + if DEBUG_ON: log.debug(f"wait(lock=False) blocking on inflight[{k}]={self._inflight[k]}") cv.wait() + if DEBUG_ON: log.debug(f"wait(lock=False) drain done: keys={keys}") return None # --------------- Public Stats API --------------- def stats(self) -> Dict[str, Any]: - """Return counters snapshot: per-device and global.""" + """ + Return a snapshot of counters. Use under tests or ad-hoc diagnostics. + """ with self._stats_lock: return { "per_device": dict(self._per_device_done), @@ -651,11 +776,17 @@ def stats(self) -> Dict[str, Any]: } def device_completed(self, device: DeviceLike) -> int: + """ + Convenience accessor for per-device completed count (atomic snapshot). + """ key = self._key(_coerce_device(device)) with self._stats_lock: return int(self._per_device_done.get(key, 0)) def total_completed(self) -> int: + """ + Convenience accessor for global completed count (atomic snapshot). + """ with self._stats_lock: return int(self._total_done) @@ -666,32 +797,40 @@ def _key(self, dev: torch.device) -> str: return f"{dev.type}{idx}" def _pick_worker(self, key: str) -> _DeviceWorker: - group = self._worker_groups.get(key) - if not group: - raise ValueError(f"Device {key} not part of this pool.") - + """ + Round-robin selection across available workers for a device key. + If no workers exist (should not happen under normal init), we + spawn one lazily for robustness. + """ with self._dispatch_lock: + group = self._worker_groups.get(key) + if not group: + dev = self._devices_by_key[key] + w = self._spawn_worker(dev, name=f"DPWorker-{key}#0") + self._worker_groups[key] = [w] + self._dispatch_rr[key] = 0 + return w + n = len(group) if n == 0: - raise ValueError(f"No workers available for device {key}") - start = self._dispatch_rr[key] % n - idx = start - # Find the next accepting worker - for _ in range(n): - w = group[idx] - if w.is_accepting(): - self._dispatch_rr[key] = (idx + 1) % n - return w - idx = (idx + 1) % n - # If none are accepting, spawn a fresh one and use it - dev = self._devices_by_key[key] - neww = self._spawn_worker(dev, name=f"DPWorker-{key}#hot") - group.append(neww) - self._worker_groups[key] = group - self._dispatch_rr[key] = (len(group) - 1 + 1) % len(group) - return neww + dev = self._devices_by_key[key] + w = self._spawn_worker(dev, name=f"DPWorker-{key}#0") + group.append(w) + self._dispatch_rr[key] = 0 + return w + + idx = self._dispatch_rr[key] % n + w = group[idx] + self._dispatch_rr[key] = (idx + 1) % n + return w def _resolve_workers_for_device(self, dev: torch.device, table: Dict[str, int]) -> int: + """ + Resolve worker count from policy table: + - exact key (e.g. 'cuda:0') overrides + - family-per (e.g. 'cuda:per') applies to all indices + - family singletons ('cpu', 'mps') apply to that single device + """ key = self._key(dev) if key in table: return int(table[key]) @@ -703,6 +842,11 @@ def _resolve_workers_for_device(self, dev: torch.device, table: Dict[str, int]) return 1 def _normalize_scope_to_keys(self, scope: Iterable[DeviceLike]) -> List[str]: + """ + Normalize a scope specification into a sorted list of device keys. + Accepts strings ('cuda', 'cuda:0', 'all'), ints (device indices), + and torch.device objects. + """ keys: List[str] = [] for s in scope: if isinstance(s, str): @@ -726,7 +870,10 @@ def _normalize_scope_to_keys(self, scope: Iterable[DeviceLike]) -> List[str]: keys.append(k) return keys - def _resolve_scope_to_keys(self, scope: Optional[Union[str, DeviceLike, Iterable[DeviceLike]]]) -> List[str]: + def _resolve_scope_to_keys(self, scope: Optional[Union[str, DeviceLike, Iterable[DeviceLike]]] = None) -> List[str]: + """ + Helper for wait()/lock(): expand a scope into concrete device keys. + """ if scope is None or (isinstance(scope, str) and scope == "all"): return list(self._ordered_keys) if isinstance(scope, (str, torch.device, int)): @@ -736,35 +883,58 @@ def _resolve_scope_to_keys(self, scope: Optional[Union[str, DeviceLike, Iterable # ---- inflight & completion accounting ---- def _mark_scheduled(self, key: str) -> None: + """ + Increment inflight for a device key and emit a debug breadcrumb if enabled. + """ cv = self._inflight_cv[key] with cv: - self._inflight[key] += 1 + self._inflight[key] = self._inflight.get(key, 0) + 1 + if DEBUG_ON: log.debug(f"inflight[{key}] ++ -> {self._inflight[key]}") def _mark_finished(self, key: str) -> None: + """ + Decrement inflight for a device key, clamp at zero on underflow, and + notify waiters when the device drains. + """ cv = self._inflight_cv[key] with cv: - self._inflight[key] -= 1 + new_val = self._inflight.get(key, 0) - 1 + if new_val < 0: + if DEBUG_ON: log.debug(f"WARNING: inflight[{key}] underflow ({new_val}); clamping to 0") + new_val = 0 + self._inflight[key] = new_val + if DEBUG_ON: log.debug(f"inflight[{key}] -- -> {self._inflight[key]}") if self._inflight[key] == 0: cv.notify_all() def _on_task_finished(self, key: str) -> None: + """ + Called at the end of every task (success or failure). Updates counters + and signals the janitor if the per-device threshold is reached. + """ self._mark_finished(key) trigger_gc = False with self._stats_lock: - self._per_device_done[key] += 1 + self._per_device_done[key] = self._per_device_done.get(key, 0) + 1 self._total_done += 1 dev_type = self._devices_by_key[key].type if self._empty_cache_every_n > 0 and dev_type in ("cuda", "xpu", "mps"): n = self._per_device_done[key] if n % self._empty_cache_every_n == 0: trigger_gc = True + if DEBUG_ON: + log.debug(f"GC trigger set by {key}: per_device_done={n} threshold={self._empty_cache_every_n} total_done={self._total_done}") if trigger_gc: self._gc_event.set() # ---- ANSI table rendering for GC diagnostics ---- def _ansi_table(self, headers: List[str], rows: List[List[str]]) -> str: + """ + Render a simple ANSI/ASCII table with bold headers. Used only for + human-readable diagnostics; not used in hot paths. + """ widths = [len(h) for h in headers] for r in rows: for i, cell in enumerate(r): @@ -798,6 +968,9 @@ def format_row(cols: List[str]): return "\n".join(lines) def _collect_state_snapshot(self) -> Dict[str, Any]: + """ + Safely collect a snapshot of pool state for diagnostics and GC decisions. + """ with self._stats_lock: per_done = dict(self._per_device_done) total_done = int(self._total_done) @@ -832,6 +1005,9 @@ def _collect_state_snapshot(self) -> Dict[str, Any]: return snap def _render_gc_table(self, snap: Dict[str, Any]) -> str: + """ + Pretty-print a GC state table; used only when someone wants to log it. + """ headers = [ "Device", "Type", "Index", "Workers", "Inflight", "Done", "Threshold", "NextGC", "Accel" @@ -866,12 +1042,16 @@ def _render_gc_table(self, snap: Dict[str, Any]) -> str: table_totals = self._ansi_table(totals_headers, totals_rows) return table_main + "\n" + table_totals - # ---- janitor (global empty-cache under lock) ---- + # ---- janitor (per-device empty-cache under exclusive writer lock) ---- + # The janitor runs in the background. When a device completes N tasks, a trigger + # is set. The janitor debounces triggers, takes a snapshot, and if at least one + # accelerator device progressed by >= N tasks since the last pass, it iterates + # devices, acquiring each device's writer lock before calling empty_cache(). def _synchronize_all(self): """ - Ensure devices are idle before empty_cache() to avoid races with outstanding kernels. - Iterate discovered devices and only guard on attribute presence for backend sync. + Optionally ensure devices are idle before empty_cache() to avoid races with + outstanding kernels. Keeping this disabled by default for performance. """ # CUDA for key in self._ordered_keys: @@ -880,7 +1060,6 @@ def _synchronize_all(self): continue with torch.cuda.device(dev.index): torch.cuda.synchronize() - # XPU for key in self._ordered_keys: dev = self._devices_by_key[key] @@ -889,7 +1068,6 @@ def _synchronize_all(self): if hasattr(torch, "xpu") and hasattr(torch.xpu, "synchronize"): with torch.xpu.device(dev.index): torch.xpu.synchronize() - # MPS has_mps_device = any(self._devices_by_key[k].type == "mps" for k in self._ordered_keys) if has_mps_device and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): @@ -914,29 +1092,59 @@ def _should_run_gc_from_snapshot(self, snap: Dict[str, Any]) -> bool: return False def _update_gc_watermarks(self, snap_after: Dict[str, Any]) -> None: - """Record 'done' counters as of a GC pass.""" + """ + Record 'done' counters as of a GC pass to require fresh progress + before a subsequent pass is allowed. + """ for k in snap_after["devices"]: self._last_gc_done_per_device[k] = snap_after["per_done"].get(k, 0) def _janitor_loop(self): + """ + Main janitor loop: + - Waits on a trigger (with short timeout to honor shutdowns promptly). + - Debounces additional triggers for a brief window. + - Takes a snapshot and decides whether to run. + - For each accelerator device, acquires its writer lock and calls + empty_cache() using the LIVE attribute if callable, otherwise the + HARD COPY captured at import time. + """ + WAIT_TIMEOUT = 0.1 while True: - self._gc_event.wait() + if DEBUG_ON: log.debug("DP-Janitor: waiting for trigger…") + if self._stop_event.is_set(): + if DEBUG_ON: log.debug("DP-Janitor: stop event set before wait; exiting") + break + + triggered = self._gc_event.wait(timeout=WAIT_TIMEOUT) + if not triggered: + continue + + self._gc_event.clear() if self._stop_event.is_set(): + if DEBUG_ON: log.debug("DP-Janitor: stop event set after trigger; exiting") break + # Debounce window: absorb additional triggers before deciding. if self._gc_debounce_s > 0: t_end = time.time() + self._gc_debounce_s + if DEBUG_ON: log.debug(f"DP-Janitor: debounce window start ({self._gc_debounce_s:.3f}s)") while time.time() < t_end: - self._gc_event.clear() + if self._stop_event.is_set(): + if DEBUG_ON: log.debug("DP-Janitor: stop during debounce; exiting") + return self._gc_event.wait(timeout=max(0.0, t_end - time.time())) - self._gc_event.clear() - else: - self._gc_event.clear() + self._gc_event.clear() + if DEBUG_ON: log.debug("DP-Janitor: debounce window end") + # Snapshot & decision try: pre = self._collect_state_snapshot() - log.debug("GC trigger received; acquiring global exclusive lock…") + if DEBUG_ON: + log.debug(f"DP-Janitor: pre-snapshot taken: total_done={pre['total_done']}, threshold={pre['threshold']}, inflight={pre['inflight']}") + log.debug("GC trigger received; evaluating whether to run…") except Exception as e: + # Fallback snapshot (unlikely path; logging should not crash janitor) try: log.warn(f"Failed to render GC pre-snapshot: {e!r}") except Exception: @@ -957,34 +1165,77 @@ def _janitor_loop(self): } if not self._should_run_gc_from_snapshot(pre): + if DEBUG_ON: log.debug("DP-Janitor: skip GC (no device progressed by threshold since last pass)") continue - with self.lock(): # writer lock across ALL devices - t0 = time.time() - # Optional but often expensive: - # self._synchronize_all() - self._empty_all_caches() - t1 = time.time() + t0 = time.time() + # Optionally synchronize devices; often too slow to be worthwhile: + # self._synchronize_all() - self._gc_passes += 1 - self._last_gc_ts = t1 + # Per-device exclusive: acquire write lock, then call empty_cache(). + for key in sorted(self._ordered_keys): + dev = self._devices_by_key[key] + if dev.type not in ("cuda", "xpu", "mps"): + continue + lk = self._locks[key] + if DEBUG_ON: log.debug(f"DP-Janitor: attempting writer lock for {key}") + with lk.writer(): + if DEBUG_ON: log.debug(f"DP-Janitor: acquired writer lock for {key}") + + if dev.type == "cuda": + live = getattr(torch.cuda, "empty_cache", None) if hasattr(torch, "cuda") else None + use_fn = live if callable(live) else TORCH_CUDA_EMPTY_CACHE + if DEBUG_ON: + src = "live" if use_fn is live else "hardcopy" + log.debug(f"DP-Janitor: empty_cache(cuda) using {src} on {key}") + if use_fn is not None: + with torch.cuda.device(dev.index): + use_fn() + + elif dev.type == "xpu": + live = getattr(torch.xpu, "empty_cache", None) if hasattr(torch, "xpu") else None + use_fn = live if callable(live) else TORCH_XPU_EMPTY_CACHE + if DEBUG_ON: + src = "live" if use_fn is live else "hardcopy" + log.debug(f"DP-Janitor: empty_cache(xpu) using {src} on {key}") + if use_fn is not None: + with torch.xpu.device(dev.index): + use_fn() + + elif dev.type == "mps": + live = getattr(torch.mps, "empty_cache", None) if hasattr(torch, "mps") else None + use_fn = live if callable(live) else TORCH_MPS_EMPTY_CACHE + if DEBUG_ON: + src = "live" if use_fn is live else "hardcopy" + log.debug(f"DP-Janitor: empty_cache(mps) using {src}") + if use_fn is not None: + use_fn() + + t1 = time.time() + self._gc_passes += 1 + self._last_gc_ts = t1 + + # Post-pass accounting & watermarks. + try: + post = self._collect_state_snapshot() + self._update_gc_watermarks(post) + log.info(f"GC completed in {t1 - t0:.3f}s (pass #{self._gc_passes}).") + if DEBUG_ON: log.debug(f"DP-Janitor: post-snapshot: inflight={post['inflight']} per_done={post['per_done']}") + except Exception as e: try: - post = self._collect_state_snapshot() - self._update_gc_watermarks(post) - log.info(f"GC completed in {t1 - t0:.3f}s (pass #{self._gc_passes}).") - except Exception as e: - try: - log.warn(f"Failed to render GC post-snapshot: {e!r}") - except Exception: - pass + log.warn(f"Failed to render GC post-snapshot: {e!r}") + except Exception: + pass + # Legacy helper (not used by janitor). Kept for compatibility with any + # external callers that previously expected a "clear everything" helper. def _empty_all_caches(self): """ - Call the captured originals if available; no redundant availability checks - and no try/except around empty_cache (fail loud if backend misbehaves). + Call the captured originals if available. This does not consult the live + attribute and therefore does not pick up monkeypatching. Prefer the janitor’s + per-device logic for production use. """ - # CUDA if TORCH_CUDA_EMPTY_CACHE is not None: for key in self._ordered_keys: dev = self._devices_by_key[key] @@ -992,9 +1243,6 @@ def _empty_all_caches(self): continue with torch.cuda.device(dev.index): TORCH_CUDA_EMPTY_CACHE() - log.debug(f"cuda empty cache called on {dev.index}") - - # XPU if TORCH_XPU_EMPTY_CACHE is not None: for key in self._ordered_keys: dev = self._devices_by_key[key] @@ -1002,8 +1250,6 @@ def _empty_all_caches(self): continue with torch.xpu.device(dev.index): TORCH_XPU_EMPTY_CACHE() - - # MPS (only if this pool actually has an MPS device) if TORCH_MPS_EMPTY_CACHE is not None: has_mps_device = any(self._devices_by_key[k].type == "mps" for k in self._ordered_keys) if has_mps_device: