Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions gptqmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

import os


DEBUG_ON = str(os.environ.get("DEBUG", "")).lower() in ("1", "true", "yes", "on")

from .models import GPTQModel, get_best_device
from .quantization import BaseQuantizeConfig, QuantizeConfig
from .utils import BACKEND
Expand Down
9 changes: 9 additions & 0 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ..models._const import SUPPORTS_MODULE_TYPES, DEVICE
from ..nn_modules.hooked_linear import (STOP_FORWARD_EXCEPTION, HookedLinear,
StopForward, replace_module_with_hooked_legacy)
from .. import DEBUG_ON
from ..utils.attn_mask import apply_keep_mask_bt, normalize_seq_mask
from ..utils.device import get_device, get_device_new
from ..utils.logger import setup_logger
Expand Down Expand Up @@ -193,18 +194,26 @@ def _select_forward_devices(self, base_device: Optional[torch.device]) -> List[t

def _clone_module_for_devices(self, module: torch.nn.Module, devices: List[torch.device]) -> Dict[torch.device, torch.nn.Module]:
clones: Dict[torch.device, torch.nn.Module] = {}
module_label = getattr(module, "full_name", module.__class__.__name__)
clone_timings = [] if DEBUG_ON else None

cleared_attrs = self._clear_non_picklable_state(module)
try:
for dev in devices:
start_ts = time.perf_counter() if DEBUG_ON else None
replica = copy.deepcopy(module)
replica = replica.to(dev)
replica.eval()
_rehome_module_to_device(replica, dev, move_parameters=False, move_buffers=True)
self._clear_non_picklable_state(replica)
clones[dev] = replica
if clone_timings is not None and start_ts is not None:
clone_timings.append((dev, time.perf_counter() - start_ts))
finally:
self._restore_non_picklable_state(cleared_attrs)
if clone_timings:
timing_str = ", ".join(f"{str(dev)}={duration * 1000:.2f}ms" for dev, duration in clone_timings)
log.debug(f"ModuleLooper: deepcopy {module_label} -> {timing_str}")
return clones

def _clear_non_picklable_state(self, module: torch.nn.Module):
Expand Down
21 changes: 8 additions & 13 deletions gptqmodel/utils/threadx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from __future__ import annotations

import contextlib
import os
import queue
import threading
import time
Expand All @@ -15,14 +14,14 @@

import torch

from .. import DEBUG_ON
from ..utils.logger import setup_logger


log = setup_logger()

# Debug logging is very chatty and can alter timings subtly in tests.
# We gate all extra diagnostics behind the DEBUG env (1/true/yes/on).
DEBUG_ON = str(os.environ.get("DEBUG", "")).lower() in ("1", "true", "yes", "on")

# DeviceLike allows ergonomic call sites: 'cuda:0', 0, torch.device('cuda', 0), etc.
DeviceLike = Union[str, int, torch.device]
Expand Down Expand Up @@ -1169,20 +1168,16 @@ def _janitor_loop(self):
empty_cache() using the LIVE attribute if callable, otherwise the
HARD COPY captured at import time.
"""
WAIT_TIMEOUT = 0.1
while True:
if DEBUG_ON: log.debug("DP-Janitor: waiting for trigger…")
if self._stop_event.is_set():
if DEBUG_ON: log.debug("DP-Janitor: stop event set before wait; exiting")
break

triggered = self._gc_event.wait(timeout=WAIT_TIMEOUT)
if not triggered:
continue
if DEBUG_ON:
log.debug("DP-Janitor: waiting for trigger…")

self._gc_event.wait()
self._gc_event.clear()

if self._stop_event.is_set():
if DEBUG_ON: log.debug("DP-Janitor: stop event set after trigger; exiting")
if DEBUG_ON:
log.debug("DP-Janitor: stop event set; exiting")
break

# Debounce window: absorb additional triggers before deciding.
Expand All @@ -1201,7 +1196,7 @@ def _janitor_loop(self):
while self._auto_gc_disable_count > 0 and not self._stop_event.is_set():
if DEBUG_ON:
log.debug("DP-Janitor: auto-GC disabled; waiting…")
self._auto_gc_disable_cv.wait(timeout=WAIT_TIMEOUT)
self._auto_gc_disable_cv.wait()
if self._stop_event.is_set():
if DEBUG_ON: log.debug("DP-Janitor: stop event set during auto-GC wait; exiting")
break
Expand Down