diff --git a/gptqmodel/utils/looper_helpers.py b/gptqmodel/utils/looper_helpers.py index 91bdc62cc..77740d442 100644 --- a/gptqmodel/utils/looper_helpers.py +++ b/gptqmodel/utils/looper_helpers.py @@ -24,7 +24,7 @@ from ..utils.torch import ALL_DEVICES, CPU, torch_sync -USE_TORCH_REPLICATE = env_flag("GPTQMODEL_USE_TORCH_REPLICATE", False) +USE_TORCH_REPLICATE = env_flag("GPTQMODEL_USE_TORCH_REPLICATE", True) _THREAD_SAFE_PARALLEL = ThreadSafe(torch_parallel) diff --git a/gptqmodel/utils/safe.py b/gptqmodel/utils/safe.py index 05e9551f1..dc448b297 100644 --- a/gptqmodel/utils/safe.py +++ b/gptqmodel/utils/safe.py @@ -8,6 +8,7 @@ from __future__ import annotations +import gc import threading from functools import wraps from types import ModuleType @@ -99,9 +100,10 @@ def __repr__(self): TORCH_LINALG = ThreadSafe(torch.linalg) THREADPOOLCTL = ThreadSafe(_threadpoolctl) - +GC = ThreadSafe(gc) __all__ = [ "ThreadSafe", "TORCH_LINALG", "THREADPOOLCTL", + "GC", ] diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py index 24803ef26..cddca4d46 100644 --- a/gptqmodel/utils/torch.py +++ b/gptqmodel/utils/torch.py @@ -4,7 +4,6 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import contextlib -import gc as py_gc import time from contextlib import contextmanager from enum import Enum @@ -15,6 +14,7 @@ from torch.cpu import StreamContext from ..utils.logger import setup_logger +from ..utils.safe import GC from . import gte_python_3_13_3, gte_python_3_14, has_gil_disabled, log_gil_requirements_for @@ -46,22 +46,15 @@ class BalanceStrategy(str, Enum): log = setup_logger() -def _format_gc_call_suffix(args: tuple, kwargs: dict) -> str: - parts: list[str] = [] - if args: - parts.extend(repr(arg) for arg in args) - if kwargs: - parts.extend(f"{key}={repr(val)}" for key, val in kwargs.items()) - return f"({', '.join(parts)})" if parts else "()" - - -def timed_gc_collect(*args, **kwargs) -> int: +def timed_gc_collect() -> int: """Run ``gc.collect`` and log the elapsed time along with reclaimed object count.""" - suffix = _format_gc_call_suffix(args, kwargs) start = time.perf_counter() - collected = py_gc.collect(*args, **kwargs) + + # Python 3.14 removed gen1 so there is only gen0 and gen2 + collected = GC.collect() + duration = time.perf_counter() - start - log.info(f"gc.collect{suffix} reclaimed {collected} objects in {duration:.3f}s") + log.info(f"gc.collect() reclaimed {collected} objects in {duration:.3f}s") return collected # reset dynamo cache on each model load since during ci loop model inference may exhuast cache diff --git a/tests/models/model_test.py b/tests/models/model_test.py index df66fb391..01885ee81 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -49,6 +49,13 @@ from transformers import AutoProcessor, AutoTokenizer # noqa: E402 + +try: # noqa: E402 + from transformers.utils import is_flash_attn_2_available # noqa: E402 +except Exception: # pragma: no cover - availability check + def is_flash_attn_2_available(): # type: ignore + return False + from gptqmodel import BACKEND, GPTQModel # noqa: E402 from gptqmodel.models.base import BaseQModel # noqa: E402 from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402 @@ -238,11 +245,21 @@ def perform_post_quant_validation(self, model_path, trust_remote_code=False): for backend in compare_backends: log.info(f"Loading post-quant model with backend `{backend.name}`") - model = self.loadQuantModel( - model_path, - trust_remote_code=trust_remote_code, - backend=backend, - ) + # Pin post-quant loads to the first CUDA device to avoid auto sharding across GPUs. + use_cuda_map = torch.cuda.is_available() and backend != BACKEND.TORCH_FUSED + if use_cuda_map: + model = self.loadQuantModel( + model_path, + trust_remote_code=trust_remote_code, + backend=backend, + device_map={"": "cuda:0"}, + ) + else: + model = self.loadQuantModel( + model_path, + trust_remote_code=trust_remote_code, + backend=backend, + ) tokenizer = model.tokenizer or self.load_tokenizer(model_path, trust_remote_code=trust_remote_code) inference_records[backend] = self.run_generic_inference_checks(model, tokenizer, backend) @@ -541,7 +558,10 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne args = kwargs if kwargs else {} if self.USE_FLASH_ATTN: - args["attn_implementation"] = "flash_attention_2" + if is_flash_attn_2_available(): + args["attn_implementation"] = "flash_attention_2" + else: + log.warn("flash-attn requested but not available; falling back to framework defaults") log.info(f"args: {args}") @@ -591,7 +611,16 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne q_model = reuse_candidates.pop(target_backend, None) if q_model is None: - q_model = self.loadQuantModel(path, trust_remote_code=trust_remote_code) + # Ensure the post-quant reload stays on a single CUDA device when available. + use_cuda_map = torch.cuda.is_available() and self.LOAD_BACKEND != BACKEND.TORCH_FUSED + if use_cuda_map: + q_model = self.loadQuantModel( + path, + trust_remote_code=trust_remote_code, + device_map={"": "cuda:0"}, + ) + else: + q_model = self.loadQuantModel(path, trust_remote_code=trust_remote_code) else: log.info(f"Reusing post-quant validation model for backend `{target_backend.name}`") @@ -620,7 +649,10 @@ def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_pa load_kwargs = dict(args) if self.USE_FLASH_ATTN: - load_kwargs["attn_implementation"] = "flash_attention_2" + if is_flash_attn_2_available(): + load_kwargs["attn_implementation"] = "flash_attention_2" + else: + log.warn("flash-attn requested but not available; falling back to framework defaults") active_backend = backend if backend is not None else self.LOAD_BACKEND diff --git a/tests/test_gc.py b/tests/test_gc.py new file mode 100644 index 000000000..a254f3e36 --- /dev/null +++ b/tests/test_gc.py @@ -0,0 +1,46 @@ +import gc +import threading +from queue import Queue + +import pytest + + +torch = pytest.importorskip("torch", reason="requires PyTorch") + + +_THREAD_COUNT = 16 +_ITERATIONS = 20_000 +_ALLOCATION_BYTES = 1024 * 1024 + + +def _worker(barrier: threading.Barrier, errors: Queue) -> None: + try: + barrier.wait() + for _ in range(_ITERATIONS): + tensor = torch.empty(_ALLOCATION_BYTES, dtype=torch.uint8) + del tensor + gc.collect() + except Exception as exc: # pragma: no cover - stress test safeguard + errors.put(exc) + + +@pytest.mark.gc_stress +@pytest.mark.timeout(300) +def test_multithreaded_gc_collect(): + """Stress test that repeated gc.collect calls do not crash under threading.""" + barrier = threading.Barrier(_THREAD_COUNT) + errors: Queue = Queue() + + threads = [ + threading.Thread(target=_worker, args=(barrier, errors), name=f"gc-worker-{i}") + for i in range(_THREAD_COUNT) + ] + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + if not errors.empty(): + exc = errors.get() + pytest.fail(f"GC stress worker raised: {exc}")