From aba134a0b4907a5eb1dd48a6ff055d96fdff646e Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 15 Oct 2025 22:06:10 +0000 Subject: [PATCH 1/3] cleanup --- tests/test_gc.py | 43 +++++++++++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/tests/test_gc.py b/tests/test_gc.py index a254f3e36..6f3485319 100644 --- a/tests/test_gc.py +++ b/tests/test_gc.py @@ -3,36 +3,59 @@ from queue import Queue import pytest - +from gptqmodel.utils.safe import gc as safe_gc torch = pytest.importorskip("torch", reason="requires PyTorch") _THREAD_COUNT = 16 -_ITERATIONS = 20_000 +_ITERATIONS = 5_000 _ALLOCATION_BYTES = 1024 * 1024 -def _worker(barrier: threading.Barrier, errors: Queue) -> None: +def _worker(barrier: threading.Barrier, safe: bool, errors: Queue) -> None: try: barrier.wait() for _ in range(_ITERATIONS): - tensor = torch.empty(_ALLOCATION_BYTES, dtype=torch.uint8) - del tensor - gc.collect() + t = torch.empty(_ALLOCATION_BYTES, dtype=torch.uint8) + del t + if safe: + safe_gc.collect() + else: + gc.collect() except Exception as exc: # pragma: no cover - stress test safeguard errors.put(exc) -@pytest.mark.gc_stress -@pytest.mark.timeout(300) -def test_multithreaded_gc_collect(): +@pytest.mark.xfail +@pytest.mark.timeout(30) +def test_multithreaded_gc_collect_unsafe(): + """Stress test that repeated gc.collect calls do not crash under threading.""" + barrier = threading.Barrier(_THREAD_COUNT) + errors: Queue = Queue() + + threads = [ + threading.Thread(target=_worker, args=(barrier, False, errors), name=f"gc-worker-{i}") + for i in range(_THREAD_COUNT) + ] + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + if not errors.empty(): + exc = errors.get() + pytest.fail(f"GC stress worker raised: {exc}") + +@pytest.mark.timeout(30) +def test_multithreaded_gc_collect_safe(): """Stress test that repeated gc.collect calls do not crash under threading.""" barrier = threading.Barrier(_THREAD_COUNT) errors: Queue = Queue() threads = [ - threading.Thread(target=_worker, args=(barrier, errors), name=f"gc-worker-{i}") + threading.Thread(target=_worker, args=(barrier, True, errors), name=f"gc-worker-{i}") for i in range(_THREAD_COUNT) ] for thread in threads: From 88be5476a42860a4054ed0343563f40147ae9fd5 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 15 Oct 2025 23:56:18 +0000 Subject: [PATCH 2/3] tests --- pyproject.toml | 1 + tests/pytest.ini | 1 + tests/test_gc.py | 37 ++++++++++++++++++++++++++----------- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f1ed7d767..5ce3b598d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ Homepage = "https://github.com/ModelCloud/GPTQModel" [project.optional-dependencies] test = [ "pytest>=8.3.5", + "pytest-timeout>=2.3.1", "parameterized", ] quality = [ diff --git a/tests/pytest.ini b/tests/pytest.ini index f8fe22115..82a0c02b1 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -6,5 +6,6 @@ markers = ci: CPU-only CI regression coverage for DeviceThreadPool affinity behaviour cuda: Requires CUDA device inference: Inference workloads that replicate models across devices + timeout: Requires pytest-timeout plugin; retained for downstream compatibility filterwarnings = ignore:Warning only once for all operators.*:UserWarning diff --git a/tests/test_gc.py b/tests/test_gc.py index 6f3485319..3a8bca08d 100644 --- a/tests/test_gc.py +++ b/tests/test_gc.py @@ -1,5 +1,7 @@ import gc +import os import threading +import traceback from queue import Queue import pytest @@ -8,14 +10,16 @@ torch = pytest.importorskip("torch", reason="requires PyTorch") -_THREAD_COUNT = 16 -_ITERATIONS = 5_000 -_ALLOCATION_BYTES = 1024 * 1024 +_THREAD_COUNT = min(4, max(2, (os.cpu_count() or 2))) +_ITERATIONS = 8 +_ALLOCATION_BYTES = 32 * 1024 +_BARRIER_TIMEOUT_S = 5 +_JOIN_TIMEOUT_S = 30 def _worker(barrier: threading.Barrier, safe: bool, errors: Queue) -> None: try: - barrier.wait() + barrier.wait(timeout=_BARRIER_TIMEOUT_S) for _ in range(_ITERATIONS): t = torch.empty(_ALLOCATION_BYTES, dtype=torch.uint8) del t @@ -24,7 +28,8 @@ def _worker(barrier: threading.Barrier, safe: bool, errors: Queue) -> None: else: gc.collect() except Exception as exc: # pragma: no cover - stress test safeguard - errors.put(exc) + # Preserve the traceback so the failing test shows context from worker threads. + errors.put(traceback.format_exc()) @pytest.mark.xfail @@ -42,11 +47,17 @@ def test_multithreaded_gc_collect_unsafe(): thread.start() for thread in threads: - thread.join() + thread.join(timeout=_JOIN_TIMEOUT_S) + if thread.is_alive(): + print("thread did not finish") + pytest.fail(f"GC stress worker {thread.name} did not finish") if not errors.empty(): - exc = errors.get() - pytest.fail(f"GC stress worker raised: {exc}") + failures = [] + while not errors.empty(): + failures.append(errors.get()) + pytest.fail("GC stress worker raised:\n" + "\n".join(failures)) + @pytest.mark.timeout(30) def test_multithreaded_gc_collect_safe(): @@ -62,8 +73,12 @@ def test_multithreaded_gc_collect_safe(): thread.start() for thread in threads: - thread.join() + thread.join(timeout=_JOIN_TIMEOUT_S) + if thread.is_alive(): + pytest.fail(f"GC stress worker {thread.name} did not finish") if not errors.empty(): - exc = errors.get() - pytest.fail(f"GC stress worker raised: {exc}") + failures = [] + while not errors.empty(): + failures.append(errors.get()) + pytest.fail("GC stress worker raised:\n" + "\n".join(failures)) From 0dfe933875ef84d25ad57ec990e2faf19c17d12d Mon Sep 17 00:00:00 2001 From: Qubitium Date: Thu, 16 Oct 2025 00:15:55 +0000 Subject: [PATCH 3/3] fix moe non-linear capture flushes --- gptqmodel/quantization/gptq.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index cca4416f0..ac6c3c79a 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -441,11 +441,20 @@ def process_batch(self, inp: torch.Tensor) -> Tuple[int, Optional[torch.Tensor], return batch_token_size, xtx, canonical_device - def _flush_pending_updates_locked(self) -> None: - while True: + def _flush_pending_updates_locked(self, *, allow_gaps: bool = False) -> None: + while self._pending_updates: update = self._pending_updates.pop(self._next_batch_index, None) if update is None: - break + if not allow_gaps: + break + + next_index = min(self._pending_updates.keys()) + if next_index != self._next_batch_index: + self._next_batch_index = next_index + + update = self._pending_updates.pop(self._next_batch_index, None) + if update is None: + break batch_token_size, xtx, device = update @@ -547,7 +556,7 @@ def quantize( start = time.time() with self.lock: - self._flush_pending_updates_locked() + self._flush_pending_updates_locked(allow_gaps=True) if self._pending_updates: raise RuntimeError( f"Pending Hessian updates remain for module '{self.name}' before quantization."