Skip to content
Merged

Tests #2039

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Homepage = "https://github.com/ModelCloud/GPTQModel"
[project.optional-dependencies]
test = [
"pytest>=8.3.5",
"pytest-timeout>=2.3.1",
"parameterized",
]
quality = [
Expand Down
1 change: 1 addition & 0 deletions tests/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ markers =
ci: CPU-only CI regression coverage for DeviceThreadPool affinity behaviour
cuda: Requires CUDA device
inference: Inference workloads that replicate models across devices
timeout: Requires pytest-timeout plugin; retained for downstream compatibility
filterwarnings =
ignore:Warning only once for all operators.*:UserWarning
72 changes: 55 additions & 17 deletions tests/test_gc.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,84 @@
import gc
import os
import threading
import traceback
from queue import Queue

import pytest

from gptqmodel.utils.safe import gc as safe_gc

torch = pytest.importorskip("torch", reason="requires PyTorch")


_THREAD_COUNT = 16
_ITERATIONS = 20_000
_ALLOCATION_BYTES = 1024 * 1024
_THREAD_COUNT = min(4, max(2, (os.cpu_count() or 2)))
_ITERATIONS = 8
_ALLOCATION_BYTES = 32 * 1024
_BARRIER_TIMEOUT_S = 5
_JOIN_TIMEOUT_S = 30


def _worker(barrier: threading.Barrier, errors: Queue) -> None:
def _worker(barrier: threading.Barrier, safe: bool, errors: Queue) -> None:
try:
barrier.wait()
barrier.wait(timeout=_BARRIER_TIMEOUT_S)
for _ in range(_ITERATIONS):
tensor = torch.empty(_ALLOCATION_BYTES, dtype=torch.uint8)
del tensor
gc.collect()
t = torch.empty(_ALLOCATION_BYTES, dtype=torch.uint8)
del t
if safe:
safe_gc.collect()
else:
gc.collect()
except Exception as exc: # pragma: no cover - stress test safeguard
errors.put(exc)
# Preserve the traceback so the failing test shows context from worker threads.
errors.put(traceback.format_exc())


@pytest.mark.xfail
@pytest.mark.timeout(30)
def test_multithreaded_gc_collect_unsafe():
"""Stress test that repeated gc.collect calls do not crash under threading."""
barrier = threading.Barrier(_THREAD_COUNT)
errors: Queue = Queue()

threads = [
threading.Thread(target=_worker, args=(barrier, False, errors), name=f"gc-worker-{i}")
for i in range(_THREAD_COUNT)
]
for thread in threads:
thread.start()

for thread in threads:
thread.join(timeout=_JOIN_TIMEOUT_S)
if thread.is_alive():
print("thread did not finish")
pytest.fail(f"GC stress worker {thread.name} did not finish")

if not errors.empty():
failures = []
while not errors.empty():
failures.append(errors.get())
pytest.fail("GC stress worker raised:\n" + "\n".join(failures))


@pytest.mark.gc_stress
@pytest.mark.timeout(300)
def test_multithreaded_gc_collect():
@pytest.mark.timeout(30)
def test_multithreaded_gc_collect_safe():
"""Stress test that repeated gc.collect calls do not crash under threading."""
barrier = threading.Barrier(_THREAD_COUNT)
errors: Queue = Queue()

threads = [
threading.Thread(target=_worker, args=(barrier, errors), name=f"gc-worker-{i}")
threading.Thread(target=_worker, args=(barrier, True, errors), name=f"gc-worker-{i}")
for i in range(_THREAD_COUNT)
]
for thread in threads:
thread.start()

for thread in threads:
thread.join()
thread.join(timeout=_JOIN_TIMEOUT_S)
if thread.is_alive():
pytest.fail(f"GC stress worker {thread.name} did not finish")

if not errors.empty():
exc = errors.get()
pytest.fail(f"GC stress worker raised: {exc}")
failures = []
while not errors.empty():
failures.append(errors.get())
pytest.fail("GC stress worker raised:\n" + "\n".join(failures))