From 125c349a277647b776ef6f9becb456a01342dbf9 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 17 Oct 2025 13:57:44 +0000 Subject: [PATCH 1/7] group threadx gc events --- gptqmodel/utils/threadx.py | 19 ++++++ tests/test_threadx_janitor.py | 114 ++++++++++++++++++++++++++++++++++ 2 files changed, 133 insertions(+) create mode 100644 tests/test_threadx_janitor.py diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index 7bef93341..3a2550ca5 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -743,6 +743,8 @@ def __init__( # GC diagnostics counters self._gc_passes = 0 self._last_gc_ts: Optional[float] = None + self._gc_generation: int = 0 + self._last_consumed_gc_generation: int = 0 # Start janitor if enabled and there exists at least one accelerator. if self._empty_cache_every_n > 0 and any( @@ -1414,6 +1416,7 @@ def _on_task_finished(self, key: str) -> None: n = self._per_device_done[key] if n % self._empty_cache_every_n == 0: trigger_gc = True + self._gc_generation += 1 if DEBUG_ON: log.debug(f"GC trigger set by {key}: per_device_done={n} threshold={self._empty_cache_every_n} total_done={self._total_done}") if trigger_gc: @@ -1490,6 +1493,8 @@ def _collect_state_snapshot(self) -> Dict[str, Any]: "total_inflight": sum(inflight.values()), "total_workers": sum(workers.values()), "gc_passes": int(self._gc_passes), + "gc_generation": int(self._gc_generation), + "gc_generation_consumed": int(self._last_consumed_gc_generation), "last_gc_ts": self._last_gc_ts, "now": time.time(), } @@ -1646,6 +1651,15 @@ def _janitor_loop(self): if DEBUG_ON: log.debug("DP-Janitor: stop event set during auto-GC wait; exiting") break + with self._stats_lock: + current_generation = self._gc_generation + last_generation = self._last_consumed_gc_generation + + if current_generation == last_generation: + if DEBUG_ON: + log.debug("DP-Janitor: trigger generation already consumed; skipping") + continue + # Snapshot & decision try: pre = self._collect_state_snapshot() @@ -1675,6 +1689,8 @@ def _janitor_loop(self): if not self._should_run_gc_from_snapshot(pre): if DEBUG_ON: log.debug("DP-Janitor: skip GC (no device progressed by threshold since last pass)") + with self._stats_lock: + self._last_consumed_gc_generation = current_generation continue t0 = time.time() @@ -1736,6 +1752,9 @@ def _janitor_loop(self): log.warn(f"Failed to render GC post-snapshot: {e!r}") except Exception: pass + finally: + with self._stats_lock: + self._last_consumed_gc_generation = self._gc_generation # Legacy helper (not used by janitor). Kept for compatibility with any # external callers that previously expected a "clear everything" helper. diff --git a/tests/test_threadx_janitor.py b/tests/test_threadx_janitor.py new file mode 100644 index 000000000..b85708f4e --- /dev/null +++ b/tests/test_threadx_janitor.py @@ -0,0 +1,114 @@ +import contextlib +import threading +import time + +import pytest +import torch + +from gptqmodel.utils import threadx as threadx_mod + + +DeviceThreadPool = threadx_mod.DeviceThreadPool + + +class _DummyLock: + @contextlib.contextmanager + def writer(self): + yield + + +def _make_pool(): + pool = DeviceThreadPool.__new__(DeviceThreadPool) + pool._gc_event = threading.Event() + pool._stop_event = threading.Event() + pool._auto_gc_disable_cv = threading.Condition() + pool._auto_gc_disable_count = 0 + pool._gc_debounce_s = 0.0 + pool._stats_lock = threading.Lock() + pool._per_device_done = {} + pool._total_done = 0 + pool._empty_cache_every_n = 3 + pool._devices_by_key = {} + pool._locks = {} + pool._ordered_keys = [] + pool._worker_groups = {} + pool._inflight = {} + pool._inflight_cv = {} + pool._last_gc_done_per_device = {} + pool._gc_passes = 0 + pool._last_gc_ts = None + pool._gc_generation = 0 + pool._last_consumed_gc_generation = 0 + pool._synchronize_all = lambda: None + pool._virtual_to_parent = {} + pool._family_keys = {} + pool._dispatch_lock = threading.Lock() + pool._warmup_lock = threading.Lock() + pool._warmup_ran_keys = set() + pool._worker_warmups = {} + pool._serial_workers = {} + pool._ordered_keys = [] + # Bind instance methods that rely on self + pool._collect_state_snapshot = DeviceThreadPool._collect_state_snapshot.__get__(pool, DeviceThreadPool) + pool._should_run_gc_from_snapshot = DeviceThreadPool._should_run_gc_from_snapshot.__get__(pool, DeviceThreadPool) + pool._update_gc_watermarks = DeviceThreadPool._update_gc_watermarks.__get__(pool, DeviceThreadPool) + pool._mark_finished = DeviceThreadPool._mark_finished.__get__(pool, DeviceThreadPool) + pool._on_task_finished = DeviceThreadPool._on_task_finished.__get__(pool, DeviceThreadPool) + return pool + + +@pytest.mark.parametrize("threshold_triggers", [3]) +def test_janitor_coalesces_pending_triggers(monkeypatch, threshold_triggers): + pool = _make_pool() + pool._empty_cache_every_n = threshold_triggers + + key = "cuda:0" + dev = torch.device("cuda", 0) + pool._devices_by_key[key] = dev + pool._locks[key] = _DummyLock() + pool._ordered_keys = [key] + pool._worker_groups[key] = [] + pool._inflight[key] = 0 + pool._inflight_cv[key] = threading.Condition() + pool._last_gc_done_per_device[key] = 0 + pool._per_device_done[key] = 0 + + calls = {"count": 0} + + def fake_empty_cache(): + calls["count"] += 1 + + monkeypatch.setattr(threadx_mod.torch.cuda, "empty_cache", fake_empty_cache, raising=False) + monkeypatch.setattr(threadx_mod, "TORCH_CUDA_EMPTY_CACHE", fake_empty_cache, raising=False) + + @contextlib.contextmanager + def fake_cuda_device(index): + yield + + monkeypatch.setattr(threadx_mod.torch.cuda, "device", fake_cuda_device, raising=False) + + # Simulate multiple threshold triggers before janitor runs. + for _ in range(threshold_triggers * 3): + pool._inflight[key] = pool._inflight.get(key, 0) + 1 + pool._on_task_finished(key) + + assert pool._gc_generation == 3 + assert pool._gc_event.is_set() + + janitor = threading.Thread(target=pool._janitor_loop, daemon=True) + janitor.start() + + start = time.time() + while calls["count"] < 1 and time.time() - start < 1.0: + time.sleep(0.01) + + # Allow janitor time to spin in case extra passes would occur. + time.sleep(0.1) + + pool._stop_event.set() + pool._gc_event.set() + janitor.join(timeout=1.0) + + assert calls["count"] == 1 + assert pool._gc_passes == 1 + assert pool._last_consumed_gc_generation == pool._gc_generation From ad536094ff9661f534f996ac9b082329d0338509 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 17 Oct 2025 14:05:43 +0000 Subject: [PATCH 2/7] shorter gpu vram logs --- gptqmodel/looper/loop_processor.py | 31 ++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/gptqmodel/looper/loop_processor.py b/gptqmodel/looper/loop_processor.py index 383b03239..9bc339c1b 100644 --- a/gptqmodel/looper/loop_processor.py +++ b/gptqmodel/looper/loop_processor.py @@ -400,8 +400,35 @@ def device_memory_report(self) -> str: snapshot = self._snapshot_device_memory_gib() if not snapshot: return "n/a" - parts = [f"{device_id}={value:.1f}GB" for device_id, value in snapshot.items()] - return ", ".join(parts) + + def _format_gib(value: float) -> str: + text = f"{value:.1f}" + if text.endswith(".0"): + text = text[:-2] + return f"{text}G" + + grouped: Dict[str, List[Tuple[str, float, int]]] = {} + for order, (device_id, value) in enumerate(snapshot.items()): + family, _, index = device_id.partition(":") + grouped.setdefault(family, []).append((index, value, order)) + + segments: List[str] = [] + for family, entries in grouped.items(): + if not entries: + continue + + def sort_key(item: Tuple[str, float, int]) -> Tuple[int, int]: + index, _, order = item + try: + return 0, int(index) + except (TypeError, ValueError): + return 1, order + + values = [_format_gib(value) for _, value, _ in sorted(entries, key=sort_key)] + segment = f"{family} " + ", ".join(values) + segments.append(segment) + + return " | ".join(segments) def _close_device_smi_handles(self) -> None: for handle in self._device_smi_handles.values(): From c7f9568e3c19a0e7cbde185bdbc02df787c7a1b7 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 17 Oct 2025 14:16:06 +0000 Subject: [PATCH 3/7] nogil patch safetensors --- gptqmodel/utils/nogil_patcher.py | 24 ++++++++++++++++++++++++ gptqmodel/utils/offload.py | 3 +++ 2 files changed, 27 insertions(+) create mode 100644 gptqmodel/utils/nogil_patcher.py diff --git a/gptqmodel/utils/nogil_patcher.py b/gptqmodel/utils/nogil_patcher.py new file mode 100644 index 000000000..140e0844b --- /dev/null +++ b/gptqmodel/utils/nogil_patcher.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +"""Straightforward monkey patch helpers for nogil runtimes.""" + +from .safe import ThreadSafe + +_PATCHED_ATTR = "_gptqmodel_locked_save_file" + + +def patch_safetensors_save_file() -> None: + from safetensors import torch as safetensors_torch + + if getattr(safetensors_torch.save_file, _PATCHED_ATTR, False): + return + + wrapper = ThreadSafe(safetensors_torch).save_file + setattr(wrapper, _PATCHED_ATTR, True) + safetensors_torch.save_file = wrapper + + +__all__ = ["patch_safetensors_save_file"] diff --git a/gptqmodel/utils/offload.py b/gptqmodel/utils/offload.py index a81271547..abac70223 100644 --- a/gptqmodel/utils/offload.py +++ b/gptqmodel/utils/offload.py @@ -18,6 +18,9 @@ from accelerate import disk_offload from accelerate.hooks import remove_hook_from_module, remove_hook_from_submodules from accelerate.utils import align_module_device, has_offloaded_params +from .nogil_patcher import patch_safetensors_save_file + +patch_safetensors_save_file() from safetensors.torch import save_file as safetensors_save_file from torch import nn From 91279a08bdee2677a2970e2a4518cda21f338692 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 18 Oct 2025 00:00:33 +0000 Subject: [PATCH 4/7] nogil patch triton --- gptqmodel/__init__.py | 8 + gptqmodel/models/writer.py | 2 - gptqmodel/utils/nogil_patcher.py | 208 +++++++++++++++++++++++++- gptqmodel/utils/offload.py | 3 - tests/test_triton_autotune_threads.py | 76 ++++++++++ 5 files changed, 290 insertions(+), 7 deletions(-) create mode 100644 tests/test_triton_autotune_threads.py diff --git a/gptqmodel/__init__.py b/gptqmodel/__init__.py index 07c28adcb..d6670c849 100644 --- a/gptqmodel/__init__.py +++ b/gptqmodel/__init__.py @@ -5,6 +5,14 @@ import os + +# isort: off +from .utils.nogil_patcher import patch_safetensors_save_file, patch_triton_autotuner # noqa: E402 +# isort: on + +patch_safetensors_save_file() +patch_triton_autotuner() + from .utils.env import env_flag diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index dc796d4e2..aa72cc1be 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -37,14 +37,12 @@ META_FIELD_V2_ENABLED, META_QUANTIZER_GPTQMODEL, META_VALUE_URI, - METHOD, MIN_VERSION_WITH_V2, ) from ..utils.backend import BACKEND from ..utils.hf import sanitize_generation_config_file from ..utils.logger import setup_logger from ..utils.model import ( - convert_gptq_v2_to_v1_format, copy_py_files, find_modules, get_model_files_size, diff --git a/gptqmodel/utils/nogil_patcher.py b/gptqmodel/utils/nogil_patcher.py index 140e0844b..8ea35bbe2 100644 --- a/gptqmodel/utils/nogil_patcher.py +++ b/gptqmodel/utils/nogil_patcher.py @@ -5,13 +5,19 @@ """Straightforward monkey patch helpers for nogil runtimes.""" +import time + from .safe import ThreadSafe + _PATCHED_ATTR = "_gptqmodel_locked_save_file" def patch_safetensors_save_file() -> None: - from safetensors import torch as safetensors_torch + try: + from safetensors import torch as safetensors_torch + except ImportError: + return if getattr(safetensors_torch.save_file, _PATCHED_ATTR, False): return @@ -21,4 +27,202 @@ def patch_safetensors_save_file() -> None: safetensors_torch.save_file = wrapper -__all__ = ["patch_safetensors_save_file"] +__all__ = ["patch_safetensors_save_file", "patch_triton_autotuner"] + + +def patch_triton_autotuner() -> None: + try: + from triton.runtime import autotuner as module + except ImportError: + return + + autotuner_cls = module.Autotuner + if getattr(autotuner_cls, "_gptqmodel_threadsafe", False): + return + + builtins_mod = module.builtins + Config = module.Config + driver = module.driver + knobs = module.knobs + get_cache_manager = module.get_cache_manager + triton_key = module.triton_key + get_cache_invalidating_env_vars = module.get_cache_invalidating_env_vars + JITFunction = module.JITFunction + hashlib_mod = module.hashlib + json_mod = module.json + threading_mod = module.threading + + class CacheFuture: + __slots__ = ("event", "config", "error", "used_cached_result", "bench_time") + + def __init__(self): + self.event = threading_mod.Event() + self.config = None + self.error = None + self.used_cached_result = True + self.bench_time = None + + module.CacheFuture = CacheFuture + + original_init = autotuner_cls.__init__ + + def patched_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + cache_map = getattr(self, "cache", {}) + self._cache = dict(cache_map) + self.cache = self._cache + self._cache_lock = getattr(self, "_cache_lock", threading_mod.RLock()) + self._cache_futures = {} + + def patched_check_disk_cache(self, tuning_key, configs, bench_fn): + if not tuning_key or any(cfg.pre_hook for cfg in configs): + configs_timings, bench_time, best_config = bench_fn() + self.configs_timings = configs_timings + return False, bench_time, configs_timings, best_config + + from triton.compiler.compiler import make_backend + + fn = self.fn + while not isinstance(fn, JITFunction): + fn = fn.fn + + env_vars = get_cache_invalidating_env_vars() + cache_key = [ + triton_key(), + make_backend(driver.active.get_current_target()).hash(), + fn.cache_key, + str(sorted(env_vars.items())), + str(tuning_key), + ] + [str(c) for c in configs] + cache_key = hashlib_mod.sha256("-".join(cache_key).encode("utf-8")).hexdigest() + cache = get_cache_manager(cache_key) + file_name = f"{fn.__name__[:150]}.autotune.json" + path = cache.get_file(file_name) + if path: + with open(path, "r") as cached_configs: + timings = json_mod.load(cached_configs)["configs_timings"] + configs_timings = {Config(**config): timing for config, timing in timings} + self.configs_timings = configs_timings + best_config = builtins_mod.min(configs_timings, key=configs_timings.get) + return True, None, configs_timings, best_config + + configs_timings, bench_time, best_config = bench_fn() + self.configs_timings = configs_timings + cache.put( + json_mod.dumps({ + "key": tuning_key, + "configs_timings": [ + (config.__dict__, timings) + for config, timings in (configs_timings or {}).items() + if not config.pre_hook + ], + }), + file_name, + binary=False, + ) + return False, bench_time, configs_timings, best_config + + def _get_config_for_key(self, key, nargs, args, kwargs): + with self._cache_lock: + cached = self._cache.get(key) + if cached is not None: + return cached, True, None + + future = self._cache_futures.get(key) + if future is None: + future = CacheFuture() + self._cache_futures[key] = future + runner = True + else: + runner = False + + if not runner: + future.event.wait() + if future.error is not None: + raise future.error + return future.config, future.used_cached_result, future.bench_time + + pruned_configs = self.prune_configs(kwargs, nargs) + + def benchmark(): + bench_start = time.time() + timings = { + config: self._bench(nargs, *args, config=config, **kwargs) + for config in pruned_configs + } + bench_duration = time.time() - bench_start + best_config = builtins_mod.min(timings, key=timings.get) + full_nargs_local = {**nargs, **kwargs, **best_config.all_kwargs()} + self.pre_hook(full_nargs_local, reset_only=True) + return timings, bench_duration, best_config + + try: + if self.cache_results: + used_cached_result, bench_time, configs_timings, best_config = patched_check_disk_cache( + self, key, pruned_configs, benchmark + ) + else: + configs_timings, bench_time, best_config = benchmark() + used_cached_result = False + + if configs_timings is not None: + self.configs_timings = configs_timings + self.bench_time = bench_time + + if best_config is not None: + with self._cache_lock: + self._cache[key] = best_config + + future.config = best_config + future.used_cached_result = used_cached_result + future.bench_time = bench_time + return best_config, used_cached_result, bench_time + except BaseException as exc: + future.error = exc + raise + finally: + future.event.set() + with self._cache_lock: + self._cache_futures.pop(key, None) + + def patched_run(self, *args, **kwargs): + nargs = dict(zip(self.arg_names, args)) + used_cached_result = True + bench_time = None + key = None + if len(self.configs) > 1: + all_args = {**nargs, **kwargs} + named_args = {k: v for (k, v) in all_args.items() if k in self.arg_names} + key_values = [named_args[name] for name in self.keys if name in named_args] + for _, arg in named_args.items(): + if hasattr(arg, "dtype"): + key_values.append(str(arg.dtype)) + key = tuple(key_values) + config, used_cached_result, bench_time = _get_config_for_key(self, key, nargs, args, kwargs) + else: + config = self.configs[0] + + self.cache = self._cache + self.best_config = config + if knobs.autotuning.print and key is not None and not used_cached_result: + bench_time_value = bench_time if bench_time is not None else (self.bench_time or 0.0) + print( + f"Triton autotuning for function {self.base_fn.__name__},\n" + f"with key as {key},\n" + f"finished after {bench_time_value:.2f}s,\n" + f"best config selected: {self.best_config};" + ) + full_nargs = {**nargs, **kwargs, **config.all_kwargs()} + if config.pre_hook is not None: + config.pre_hook(full_nargs) + return self.fn.run( + *args, + **kwargs, + **config.all_kwargs(), + ) + + autotuner_cls.__init__ = patched_init + autotuner_cls.check_disk_cache = patched_check_disk_cache + autotuner_cls._get_config_for_key = _get_config_for_key + autotuner_cls.run = patched_run + autotuner_cls._gptqmodel_threadsafe = True diff --git a/gptqmodel/utils/offload.py b/gptqmodel/utils/offload.py index abac70223..a81271547 100644 --- a/gptqmodel/utils/offload.py +++ b/gptqmodel/utils/offload.py @@ -18,9 +18,6 @@ from accelerate import disk_offload from accelerate.hooks import remove_hook_from_module, remove_hook_from_submodules from accelerate.utils import align_module_device, has_offloaded_params -from .nogil_patcher import patch_safetensors_save_file - -patch_safetensors_save_file() from safetensors.torch import save_file as safetensors_save_file from torch import nn diff --git a/tests/test_triton_autotune_threads.py b/tests/test_triton_autotune_threads.py new file mode 100644 index 000000000..cf54976f4 --- /dev/null +++ b/tests/test_triton_autotune_threads.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import sys +import threading + +import pytest +import torch + +import gptqmodel # noqa: F401 # ensures monkey patches run before Triton import + + +try: + import triton + import triton.language as tl +except ImportError: # pragma: no cover - optional dependency + triton = None + tl = None + + +@pytest.mark.skipif(triton is None, reason="Triton is not installed") +def test_triton_autotune_threads_cuda(): + gil_enabled = getattr(sys, "_is_gil_enabled", lambda: True)() + if gil_enabled: + pytest.skip("Requires running with PYTHON_GIL=0") + if not torch.cuda.is_available(): + pytest.skip("CUDA backend required for Triton autotune threading test") + + device = "cuda" + N = 8192 + configs = [ + triton.Config(kwargs={"BLOCK": 128}, num_warps=2), + triton.Config(kwargs={"BLOCK": 256}, num_warps=4), + ] + + @triton.autotune(configs=configs, key=["N"]) + @triton.jit + def copy_kernel(dst, src, N, BLOCK: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK + tl.arange(0, BLOCK) + mask = offsets < N + values = tl.load(src + offsets, mask=mask) + tl.store(dst + offsets, values, mask=mask) + + def grid(meta): + return (triton.cdiv(N, meta["BLOCK"]),) + num_threads = 8 + sync_ready = threading.Barrier(num_threads + 1) + sync_start = threading.Barrier(num_threads + 1) + errors = [] + + def worker(): + dst = torch.empty(N, device=device, dtype=torch.float32) + src = torch.randn_like(dst) + sync_ready.wait() + sync_start.wait() + try: + for _ in range(4): + dst.zero_() + copy_kernel[grid](dst, src, N) + except Exception as exc: # pragma: no cover - captured for assertion + errors.append(exc) + + threads = [threading.Thread(target=worker) for _ in range(num_threads)] + for thread in threads: + thread.start() + + sync_ready.wait() + sync_start.wait() + + for thread in threads: + thread.join() + + assert not errors From 35db795fbb819f3b68519fc3b641595267c37d11 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 18 Oct 2025 00:05:13 +0000 Subject: [PATCH 5/7] update depends --- pyproject.toml | 4 ++-- requirements.txt | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 239f45dbd..b6e484ea4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,9 +33,9 @@ classifiers = [ dependencies = [ "accelerate>=1.10.1", "numpy==2.2.6", - "torch>=2.8.0", + "torch>=2.9.0", "safetensors>=0.6.2", - "transformers>=4.57.0", + "transformers>=4.57.1", "threadpoolctl>=3.6.0", "packaging>=24.2", "device-smi>=0.5.1", diff --git a/requirements.txt b/requirements.txt index 0134a22e4..a9697fcb5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ accelerate>=1.10.1 numpy==2.2.6 -torch>=2.8.0 +torch>=2.9.0 safetensors>=0.6.2 -transformers>=4.57.0 +transformers>=4.57.1 threadpoolctl>=3.6.0 packaging>=24.2 device-smi>=0.5.1 From 8b282dca30b9e01a8c2df9826e8b664385821a16 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 18 Oct 2025 00:25:44 +0000 Subject: [PATCH 6/7] offload needs more locking --- gptqmodel/utils/nogil_patcher.py | 12 ++++++++---- gptqmodel/utils/offload.py | 9 +++++++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/gptqmodel/utils/nogil_patcher.py b/gptqmodel/utils/nogil_patcher.py index 8ea35bbe2..7caf904fd 100644 --- a/gptqmodel/utils/nogil_patcher.py +++ b/gptqmodel/utils/nogil_patcher.py @@ -6,6 +6,7 @@ """Straightforward monkey patch helpers for nogil runtimes.""" import time +import threading from .safe import ThreadSafe @@ -33,9 +34,14 @@ def patch_safetensors_save_file() -> None: def patch_triton_autotuner() -> None: try: from triton.runtime import autotuner as module + import triton except ImportError: return + version = getattr(triton, "__version__", None) + if version is None or tuple(int(part) for part in version.split(".")[:3]) < (3, 5, 0): + return + autotuner_cls = module.Autotuner if getattr(autotuner_cls, "_gptqmodel_threadsafe", False): return @@ -50,13 +56,11 @@ def patch_triton_autotuner() -> None: JITFunction = module.JITFunction hashlib_mod = module.hashlib json_mod = module.json - threading_mod = module.threading - class CacheFuture: __slots__ = ("event", "config", "error", "used_cached_result", "bench_time") def __init__(self): - self.event = threading_mod.Event() + self.event = threading.Event() self.config = None self.error = None self.used_cached_result = True @@ -71,7 +75,7 @@ def patched_init(self, *args, **kwargs): cache_map = getattr(self, "cache", {}) self._cache = dict(cache_map) self.cache = self._cache - self._cache_lock = getattr(self, "_cache_lock", threading_mod.RLock()) + self._cache_lock = getattr(self, "_cache_lock", threading.RLock()) self._cache_futures = {} def patched_check_disk_cache(self, tuning_key, configs, bench_fn): diff --git a/gptqmodel/utils/offload.py b/gptqmodel/utils/offload.py index a81271547..97568bf0f 100644 --- a/gptqmodel/utils/offload.py +++ b/gptqmodel/utils/offload.py @@ -68,6 +68,7 @@ def is_meta_module(m: nn.Module) -> bool: return True return False +# Serialize access to module.state_dict(), which is not thread-safe under concurrent calls. _OFFLOAD_LOCK = Lock() @@ -91,7 +92,10 @@ def _bundle_module_state_dict(module: nn.Module, offload_dir: str) -> dict: tensors: dict[str, torch.Tensor] = {} with torch.inference_mode(): - for key, tensor in module.state_dict().items(): + with _OFFLOAD_LOCK: + state_items = list(module.state_dict().items()) + + for key, tensor in state_items: cpu_tensor = tensor.detach().to("cpu") tensors[key] = cpu_tensor.contiguous() entry = { @@ -190,7 +194,8 @@ def _offload_disk(module: nn.Module, name: str, disk_path: str = "."): total_bytes = 0 try: - state_items = module.state_dict().values() + with _OFFLOAD_LOCK: + state_items = list(module.state_dict().values()) except Exception: state_items = [] From f9f9a20955bb81254c5f7459732f00d701a1e079 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 18 Oct 2025 03:00:51 +0000 Subject: [PATCH 7/7] update tests and scores --- gptqmodel/utils/threadx.py | 12 +++++++++++- tests/models/model_test.py | 22 ++++++++++++++++------ tests/models/test_llama3_2.py | 8 +++++--- 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index 3a2550ca5..a3a64fc47 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -12,6 +12,7 @@ import threading import time import traceback +from datetime import datetime, timezone from concurrent.futures import Future from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union @@ -1738,14 +1739,23 @@ def _janitor_loop(self): use_fn() t1 = time.time() + prev_gc_ts = self._last_gc_ts self._gc_passes += 1 self._last_gc_ts = t1 + gc_timestamp = datetime.fromtimestamp(t1, tz=timezone.utc).isoformat() + if prev_gc_ts is None: + since_last_gc = "since last GC: n/a" + else: + delta_s = t1 - prev_gc_ts + since_last_gc = f"since last GC: {delta_s:.3f}s ({delta_s * 1000:.1f}ms)" # Post-pass accounting & watermarks. try: post = self._collect_state_snapshot() self._update_gc_watermarks(post) - log.info(f"GC completed in {t1 - t0:.3f}s (pass #{self._gc_passes}).") + log.info( + f"GC completed in {t1 - t0:.3f}s (pass #{self._gc_passes}) at {gc_timestamp}; {since_last_gc}." + ) if DEBUG_ON: log.debug(f"DP-Janitor: post-snapshot: inflight={post['inflight']} per_done={post['per_done']}") except Exception as e: try: diff --git a/tests/models/model_test.py b/tests/models/model_test.py index 01885ee81..a6900a1c2 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -277,7 +277,7 @@ def perform_post_quant_validation(self, model_path, trust_remote_code=False): self.render_inference_summary(inference_records) self.render_arc_summary(arc_records) - return reuse_candidates + return reuse_candidates, arc_records @staticmethod def _human_size(num_bytes: int) -> str: @@ -576,6 +576,7 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne ) tokenizer = model.tokenizer + self._post_quant_arc_records = {} is_image_to_text_model = MODALITY.IMAGE_TO_TEXT in model.modality calibration_dataset = get_calib_dataset(model) if is_image_to_text_model else self.load_dataset(tokenizer, self.DATASET_SIZE) @@ -607,7 +608,8 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne log.info(f"Quantized Model saved to tmp dir: {path}") target_backend = self.LOAD_BACKEND - reuse_candidates = self.perform_post_quant_validation(path, trust_remote_code=trust_remote_code) + reuse_candidates, arc_records = self.perform_post_quant_validation(path, trust_remote_code=trust_remote_code) + self._post_quant_arc_records = arc_records q_model = reuse_candidates.pop(target_backend, None) if q_model is None: @@ -808,10 +810,18 @@ def quant_lm_eval(self): self.check_kernel(self.model, self.KERNEL_INFERENCE) - task_results = self.lm_eval(model=self.SAVE_PATH if self.SAVE_PATH else self.model, - apply_chat_template=self.APPLY_CHAT_TEMPLATE, - trust_remote_code=self.TRUST_REMOTE_CODE, - delete_quantized_model=self.DELETE_QUANTIZED_MODEL) + arc_records = getattr(self, "_post_quant_arc_records", {}) + target_backend = self.LOAD_BACKEND + if arc_records and len(arc_records) == 1 and target_backend in arc_records: + log.info("Reusing ARC results for backend `%s`; skipping duplicate lm_eval run", target_backend.name) + task_results = arc_records[target_backend] + else: + task_results = self.lm_eval( + model=self.SAVE_PATH if self.SAVE_PATH else self.model, + apply_chat_template=self.APPLY_CHAT_TEMPLATE, + trust_remote_code=self.TRUST_REMOTE_CODE, + delete_quantized_model=self.DELETE_QUANTIZED_MODEL, + ) self.check_results(task_results) def check_results(self, task_results): diff --git a/tests/models/test_llama3_2.py b/tests/models/test_llama3_2.py index 20cda35a2..195638167 100644 --- a/tests/models/test_llama3_2.py +++ b/tests/models/test_llama3_2.py @@ -7,9 +7,11 @@ # a100:7, MARLIN kernel -# desc_act = False, act_group_aware = False 0.3114/0.3413 -# desc_act = False, act_group_aware = True 0.3268/0.3558 -# desc_act = True, 0.3157/0.3498 +# desc_act = False, act_group_aware = False 0.3200/0.3447 +# desc_act = False, act_group_aware = True 0.3181/0.3481 +# desc_act = True, REGRESSION 0.3191/0.3601 +# a100:6+7: MARLIN kernel +# desc_act = False, act_group_aware = True 0.3217/0.3643 class TestLlama3_2(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" NATIVE_ARC_CHALLENGE_ACC = 0.3268