Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gptqmodel/utils/looper_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ..utils.torch import ALL_DEVICES, CPU, torch_sync


USE_TORCH_REPLICATE = env_flag("GPTQMODEL_USE_TORCH_REPLICATE", False)
USE_TORCH_REPLICATE = env_flag("GPTQMODEL_USE_TORCH_REPLICATE", True)


_THREAD_SAFE_PARALLEL = ThreadSafe(torch_parallel)
Expand Down
4 changes: 3 additions & 1 deletion gptqmodel/utils/safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from __future__ import annotations

import gc
import threading
from functools import wraps
from types import ModuleType
Expand Down Expand Up @@ -99,9 +100,10 @@ def __repr__(self):

TORCH_LINALG = ThreadSafe(torch.linalg)
THREADPOOLCTL = ThreadSafe(_threadpoolctl)

GC = ThreadSafe(gc)
__all__ = [
"ThreadSafe",
"TORCH_LINALG",
"THREADPOOLCTL",
"GC",
]
21 changes: 7 additions & 14 deletions gptqmodel/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import contextlib
import gc as py_gc
import time
from contextlib import contextmanager
from enum import Enum
Expand All @@ -15,6 +14,7 @@
from torch.cpu import StreamContext

from ..utils.logger import setup_logger
from ..utils.safe import GC
from . import gte_python_3_13_3, gte_python_3_14, has_gil_disabled, log_gil_requirements_for


Expand Down Expand Up @@ -46,22 +46,15 @@ class BalanceStrategy(str, Enum):
log = setup_logger()


def _format_gc_call_suffix(args: tuple, kwargs: dict) -> str:
parts: list[str] = []
if args:
parts.extend(repr(arg) for arg in args)
if kwargs:
parts.extend(f"{key}={repr(val)}" for key, val in kwargs.items())
return f"({', '.join(parts)})" if parts else "()"


def timed_gc_collect(*args, **kwargs) -> int:
def timed_gc_collect() -> int:
"""Run ``gc.collect`` and log the elapsed time along with reclaimed object count."""
suffix = _format_gc_call_suffix(args, kwargs)
start = time.perf_counter()
collected = py_gc.collect(*args, **kwargs)

# Python 3.14 removed gen1 so there is only gen0 and gen2
collected = GC.collect()

duration = time.perf_counter() - start
log.info(f"gc.collect{suffix} reclaimed {collected} objects in {duration:.3f}s")
log.info(f"gc.collect() reclaimed {collected} objects in {duration:.3f}s")
return collected

# reset dynamo cache on each model load since during ci loop model inference may exhuast cache
Expand Down
48 changes: 40 additions & 8 deletions tests/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@

from transformers import AutoProcessor, AutoTokenizer # noqa: E402


try: # noqa: E402
from transformers.utils import is_flash_attn_2_available # noqa: E402
except Exception: # pragma: no cover - availability check
def is_flash_attn_2_available(): # type: ignore
return False

from gptqmodel import BACKEND, GPTQModel # noqa: E402
from gptqmodel.models.base import BaseQModel # noqa: E402
from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402
Expand Down Expand Up @@ -238,11 +245,21 @@ def perform_post_quant_validation(self, model_path, trust_remote_code=False):

for backend in compare_backends:
log.info(f"Loading post-quant model with backend `{backend.name}`")
model = self.loadQuantModel(
model_path,
trust_remote_code=trust_remote_code,
backend=backend,
)
# Pin post-quant loads to the first CUDA device to avoid auto sharding across GPUs.
use_cuda_map = torch.cuda.is_available() and backend != BACKEND.TORCH_FUSED
if use_cuda_map:
model = self.loadQuantModel(
model_path,
trust_remote_code=trust_remote_code,
backend=backend,
device_map={"": "cuda:0"},
)
else:
model = self.loadQuantModel(
model_path,
trust_remote_code=trust_remote_code,
backend=backend,
)
tokenizer = model.tokenizer or self.load_tokenizer(model_path, trust_remote_code=trust_remote_code)
inference_records[backend] = self.run_generic_inference_checks(model, tokenizer, backend)

Expand Down Expand Up @@ -541,7 +558,10 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne
args = kwargs if kwargs else {}

if self.USE_FLASH_ATTN:
args["attn_implementation"] = "flash_attention_2"
if is_flash_attn_2_available():
args["attn_implementation"] = "flash_attention_2"
else:
log.warn("flash-attn requested but not available; falling back to framework defaults")


log.info(f"args: {args}")
Expand Down Expand Up @@ -591,7 +611,16 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne

q_model = reuse_candidates.pop(target_backend, None)
if q_model is None:
q_model = self.loadQuantModel(path, trust_remote_code=trust_remote_code)
# Ensure the post-quant reload stays on a single CUDA device when available.
use_cuda_map = torch.cuda.is_available() and self.LOAD_BACKEND != BACKEND.TORCH_FUSED
if use_cuda_map:
q_model = self.loadQuantModel(
path,
trust_remote_code=trust_remote_code,
device_map={"": "cuda:0"},
)
else:
q_model = self.loadQuantModel(path, trust_remote_code=trust_remote_code)
else:
log.info(f"Reusing post-quant validation model for backend `{target_backend.name}`")

Expand Down Expand Up @@ -620,7 +649,10 @@ def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_pa
load_kwargs = dict(args)

if self.USE_FLASH_ATTN:
load_kwargs["attn_implementation"] = "flash_attention_2"
if is_flash_attn_2_available():
load_kwargs["attn_implementation"] = "flash_attention_2"
else:
log.warn("flash-attn requested but not available; falling back to framework defaults")

active_backend = backend if backend is not None else self.LOAD_BACKEND

Expand Down
46 changes: 46 additions & 0 deletions tests/test_gc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import gc
import threading
from queue import Queue

import pytest


torch = pytest.importorskip("torch", reason="requires PyTorch")


_THREAD_COUNT = 16
_ITERATIONS = 20_000
_ALLOCATION_BYTES = 1024 * 1024


def _worker(barrier: threading.Barrier, errors: Queue) -> None:
try:
barrier.wait()
for _ in range(_ITERATIONS):
tensor = torch.empty(_ALLOCATION_BYTES, dtype=torch.uint8)
del tensor
gc.collect()
except Exception as exc: # pragma: no cover - stress test safeguard
errors.put(exc)


@pytest.mark.gc_stress
@pytest.mark.timeout(300)
def test_multithreaded_gc_collect():
"""Stress test that repeated gc.collect calls do not crash under threading."""
barrier = threading.Barrier(_THREAD_COUNT)
errors: Queue = Queue()

threads = [
threading.Thread(target=_worker, args=(barrier, errors), name=f"gc-worker-{i}")
for i in range(_THREAD_COUNT)
]
for thread in threads:
thread.start()

for thread in threads:
thread.join()

if not errors.empty():
exc = errors.get()
pytest.fail(f"GC stress worker raised: {exc}")