From aba134a0b4907a5eb1dd48a6ff055d96fdff646e Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 15 Oct 2025 22:06:10 +0000 Subject: [PATCH 1/2] cleanup --- tests/test_gc.py | 43 +++++++++++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/tests/test_gc.py b/tests/test_gc.py index a254f3e36..6f3485319 100644 --- a/tests/test_gc.py +++ b/tests/test_gc.py @@ -3,36 +3,59 @@ from queue import Queue import pytest - +from gptqmodel.utils.safe import gc as safe_gc torch = pytest.importorskip("torch", reason="requires PyTorch") _THREAD_COUNT = 16 -_ITERATIONS = 20_000 +_ITERATIONS = 5_000 _ALLOCATION_BYTES = 1024 * 1024 -def _worker(barrier: threading.Barrier, errors: Queue) -> None: +def _worker(barrier: threading.Barrier, safe: bool, errors: Queue) -> None: try: barrier.wait() for _ in range(_ITERATIONS): - tensor = torch.empty(_ALLOCATION_BYTES, dtype=torch.uint8) - del tensor - gc.collect() + t = torch.empty(_ALLOCATION_BYTES, dtype=torch.uint8) + del t + if safe: + safe_gc.collect() + else: + gc.collect() except Exception as exc: # pragma: no cover - stress test safeguard errors.put(exc) -@pytest.mark.gc_stress -@pytest.mark.timeout(300) -def test_multithreaded_gc_collect(): +@pytest.mark.xfail +@pytest.mark.timeout(30) +def test_multithreaded_gc_collect_unsafe(): + """Stress test that repeated gc.collect calls do not crash under threading.""" + barrier = threading.Barrier(_THREAD_COUNT) + errors: Queue = Queue() + + threads = [ + threading.Thread(target=_worker, args=(barrier, False, errors), name=f"gc-worker-{i}") + for i in range(_THREAD_COUNT) + ] + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + if not errors.empty(): + exc = errors.get() + pytest.fail(f"GC stress worker raised: {exc}") + +@pytest.mark.timeout(30) +def test_multithreaded_gc_collect_safe(): """Stress test that repeated gc.collect calls do not crash under threading.""" barrier = threading.Barrier(_THREAD_COUNT) errors: Queue = Queue() threads = [ - threading.Thread(target=_worker, args=(barrier, errors), name=f"gc-worker-{i}") + threading.Thread(target=_worker, args=(barrier, True, errors), name=f"gc-worker-{i}") for i in range(_THREAD_COUNT) ] for thread in threads: From 88be5476a42860a4054ed0343563f40147ae9fd5 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 15 Oct 2025 23:56:18 +0000 Subject: [PATCH 2/2] tests --- pyproject.toml | 1 + tests/pytest.ini | 1 + tests/test_gc.py | 37 ++++++++++++++++++++++++++----------- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f1ed7d767..5ce3b598d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ Homepage = "https://github.com/ModelCloud/GPTQModel" [project.optional-dependencies] test = [ "pytest>=8.3.5", + "pytest-timeout>=2.3.1", "parameterized", ] quality = [ diff --git a/tests/pytest.ini b/tests/pytest.ini index f8fe22115..82a0c02b1 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -6,5 +6,6 @@ markers = ci: CPU-only CI regression coverage for DeviceThreadPool affinity behaviour cuda: Requires CUDA device inference: Inference workloads that replicate models across devices + timeout: Requires pytest-timeout plugin; retained for downstream compatibility filterwarnings = ignore:Warning only once for all operators.*:UserWarning diff --git a/tests/test_gc.py b/tests/test_gc.py index 6f3485319..3a8bca08d 100644 --- a/tests/test_gc.py +++ b/tests/test_gc.py @@ -1,5 +1,7 @@ import gc +import os import threading +import traceback from queue import Queue import pytest @@ -8,14 +10,16 @@ torch = pytest.importorskip("torch", reason="requires PyTorch") -_THREAD_COUNT = 16 -_ITERATIONS = 5_000 -_ALLOCATION_BYTES = 1024 * 1024 +_THREAD_COUNT = min(4, max(2, (os.cpu_count() or 2))) +_ITERATIONS = 8 +_ALLOCATION_BYTES = 32 * 1024 +_BARRIER_TIMEOUT_S = 5 +_JOIN_TIMEOUT_S = 30 def _worker(barrier: threading.Barrier, safe: bool, errors: Queue) -> None: try: - barrier.wait() + barrier.wait(timeout=_BARRIER_TIMEOUT_S) for _ in range(_ITERATIONS): t = torch.empty(_ALLOCATION_BYTES, dtype=torch.uint8) del t @@ -24,7 +28,8 @@ def _worker(barrier: threading.Barrier, safe: bool, errors: Queue) -> None: else: gc.collect() except Exception as exc: # pragma: no cover - stress test safeguard - errors.put(exc) + # Preserve the traceback so the failing test shows context from worker threads. + errors.put(traceback.format_exc()) @pytest.mark.xfail @@ -42,11 +47,17 @@ def test_multithreaded_gc_collect_unsafe(): thread.start() for thread in threads: - thread.join() + thread.join(timeout=_JOIN_TIMEOUT_S) + if thread.is_alive(): + print("thread did not finish") + pytest.fail(f"GC stress worker {thread.name} did not finish") if not errors.empty(): - exc = errors.get() - pytest.fail(f"GC stress worker raised: {exc}") + failures = [] + while not errors.empty(): + failures.append(errors.get()) + pytest.fail("GC stress worker raised:\n" + "\n".join(failures)) + @pytest.mark.timeout(30) def test_multithreaded_gc_collect_safe(): @@ -62,8 +73,12 @@ def test_multithreaded_gc_collect_safe(): thread.start() for thread in threads: - thread.join() + thread.join(timeout=_JOIN_TIMEOUT_S) + if thread.is_alive(): + pytest.fail(f"GC stress worker {thread.name} did not finish") if not errors.empty(): - exc = errors.get() - pytest.fail(f"GC stress worker raised: {exc}") + failures = [] + while not errors.empty(): + failures.append(errors.get()) + pytest.fail("GC stress worker raised:\n" + "\n".join(failures))