From 70d5feb8f7ef124d1e4a01a7443824d7b3a0916d Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 15 Oct 2025 18:41:40 +0000 Subject: [PATCH 1/9] use replicate --- gptqmodel/utils/looper_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptqmodel/utils/looper_helpers.py b/gptqmodel/utils/looper_helpers.py index 91bdc62cc..77740d442 100644 --- a/gptqmodel/utils/looper_helpers.py +++ b/gptqmodel/utils/looper_helpers.py @@ -24,7 +24,7 @@ from ..utils.torch import ALL_DEVICES, CPU, torch_sync -USE_TORCH_REPLICATE = env_flag("GPTQMODEL_USE_TORCH_REPLICATE", False) +USE_TORCH_REPLICATE = env_flag("GPTQMODEL_USE_TORCH_REPLICATE", True) _THREAD_SAFE_PARALLEL = ThreadSafe(torch_parallel) From 6c7e1099933ca9fee7f7306912d1c333a920933d Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 15 Oct 2025 18:55:00 +0000 Subject: [PATCH 2/9] disable fa2 for now --- tests/models/test_llama3_2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_llama3_2.py b/tests/models/test_llama3_2.py index 20cda35a2..b4e339bf5 100644 --- a/tests/models/test_llama3_2.py +++ b/tests/models/test_llama3_2.py @@ -23,7 +23,7 @@ class TestLlama3_2(ModelTest): DATASET_SIZE = 1024 DATASET_SORT = "desc" QUANT_BATCH_SIZE = 4 - USE_FLASH_ATTN = True + USE_FLASH_ATTN = False # EORA = Lora( # # for quant, path is save path. for load, it is loading path # path="./eora_test", From 35ca61cdcd8ad221a6423ab39e6dd084c2000aa0 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 15 Oct 2025 19:08:12 +0000 Subject: [PATCH 3/9] use first gpu for eval --- tests/models/model_test.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/tests/models/model_test.py b/tests/models/model_test.py index df66fb391..036e4bf4c 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -238,11 +238,21 @@ def perform_post_quant_validation(self, model_path, trust_remote_code=False): for backend in compare_backends: log.info(f"Loading post-quant model with backend `{backend.name}`") - model = self.loadQuantModel( - model_path, - trust_remote_code=trust_remote_code, - backend=backend, - ) + # Pin post-quant loads to the first CUDA device to avoid auto sharding across GPUs. + use_cuda_map = torch.cuda.is_available() and backend != BACKEND.TORCH_FUSED + if use_cuda_map: + model = self.loadQuantModel( + model_path, + trust_remote_code=trust_remote_code, + backend=backend, + device_map={"": "cuda:0"}, + ) + else: + model = self.loadQuantModel( + model_path, + trust_remote_code=trust_remote_code, + backend=backend, + ) tokenizer = model.tokenizer or self.load_tokenizer(model_path, trust_remote_code=trust_remote_code) inference_records[backend] = self.run_generic_inference_checks(model, tokenizer, backend) @@ -591,7 +601,16 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne q_model = reuse_candidates.pop(target_backend, None) if q_model is None: - q_model = self.loadQuantModel(path, trust_remote_code=trust_remote_code) + # Ensure the post-quant reload stays on a single CUDA device when available. + use_cuda_map = torch.cuda.is_available() and self.LOAD_BACKEND != BACKEND.TORCH_FUSED + if use_cuda_map: + q_model = self.loadQuantModel( + path, + trust_remote_code=trust_remote_code, + device_map={"": "cuda:0"}, + ) + else: + q_model = self.loadQuantModel(path, trust_remote_code=trust_remote_code) else: log.info(f"Reusing post-quant validation model for backend `{target_backend.name}`") From bbcf27e5b2e0ef1310dd516cc763359105b16c1a Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 15 Oct 2025 19:10:38 +0000 Subject: [PATCH 4/9] only gc.collect gen1 --- gptqmodel/utils/torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py index 24803ef26..73f4fb431 100644 --- a/gptqmodel/utils/torch.py +++ b/gptqmodel/utils/torch.py @@ -55,11 +55,11 @@ def _format_gc_call_suffix(args: tuple, kwargs: dict) -> str: return f"({', '.join(parts)})" if parts else "()" -def timed_gc_collect(*args, **kwargs) -> int: +def timed_gc_collect() -> int: """Run ``gc.collect`` and log the elapsed time along with reclaimed object count.""" suffix = _format_gc_call_suffix(args, kwargs) start = time.perf_counter() - collected = py_gc.collect(*args, **kwargs) + collected = py_gc.collect(1) duration = time.perf_counter() - start log.info(f"gc.collect{suffix} reclaimed {collected} objects in {duration:.3f}s") return collected From bf17cdc4325b5e50c113f5a8f5d48005b431e8ac Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 15 Oct 2025 19:21:12 +0000 Subject: [PATCH 5/9] only gc.collect gen0 --- gptqmodel/utils/torch.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py index 73f4fb431..e56d1366e 100644 --- a/gptqmodel/utils/torch.py +++ b/gptqmodel/utils/torch.py @@ -5,6 +5,7 @@ import contextlib import gc as py_gc +import threading import time from contextlib import contextmanager from enum import Enum @@ -45,6 +46,7 @@ class BalanceStrategy(str, Enum): log = setup_logger() +GC_LOCK = threading.RLock() def _format_gc_call_suffix(args: tuple, kwargs: dict) -> str: parts: list[str] = [] @@ -59,7 +61,12 @@ def timed_gc_collect() -> int: """Run ``gc.collect`` and log the elapsed time along with reclaimed object count.""" suffix = _format_gc_call_suffix(args, kwargs) start = time.perf_counter() - collected = py_gc.collect(1) + + with GC_LOCK: + # TODO FIXME..lock to gen0 release for GIL=0 safety + # Python 3.14 removed gen1 so there is only gen0 and gen2 + collected = py_gc.collect(0) + duration = time.perf_counter() - start log.info(f"gc.collect{suffix} reclaimed {collected} objects in {duration:.3f}s") return collected From 69226191955c8202119ee6092b8a680f68140703 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 15 Oct 2025 19:21:12 +0000 Subject: [PATCH 6/9] only gc.collect gen0 --- gptqmodel/utils/torch.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py index 73f4fb431..80959e215 100644 --- a/gptqmodel/utils/torch.py +++ b/gptqmodel/utils/torch.py @@ -5,6 +5,7 @@ import contextlib import gc as py_gc +import threading import time from contextlib import contextmanager from enum import Enum @@ -45,23 +46,20 @@ class BalanceStrategy(str, Enum): log = setup_logger() - -def _format_gc_call_suffix(args: tuple, kwargs: dict) -> str: - parts: list[str] = [] - if args: - parts.extend(repr(arg) for arg in args) - if kwargs: - parts.extend(f"{key}={repr(val)}" for key, val in kwargs.items()) - return f"({', '.join(parts)})" if parts else "()" +GC_LOCK = threading.RLock() def timed_gc_collect() -> int: """Run ``gc.collect`` and log the elapsed time along with reclaimed object count.""" - suffix = _format_gc_call_suffix(args, kwargs) start = time.perf_counter() - collected = py_gc.collect(1) + + with GC_LOCK: + # TODO FIXME..lock to gen0 release for GIL=0 safety + # Python 3.14 removed gen1 so there is only gen0 and gen2 + collected = py_gc.collect(0) + duration = time.perf_counter() - start - log.info(f"gc.collect{suffix} reclaimed {collected} objects in {duration:.3f}s") + log.info(f"gc.collect(0) reclaimed {collected} objects in {duration:.3f}s") return collected # reset dynamo cache on each model load since during ci loop model inference may exhuast cache From abb7fe753b13902c546df34b4360dfcefa3c95c4 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 15 Oct 2025 20:53:01 +0000 Subject: [PATCH 7/9] enable fa --- tests/models/model_test.py | 15 +++++++++++++-- tests/models/test_llama3_2.py | 2 +- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/tests/models/model_test.py b/tests/models/model_test.py index 036e4bf4c..139b9a637 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -48,6 +48,11 @@ pass from transformers import AutoProcessor, AutoTokenizer # noqa: E402 +try: # noqa: E402 + from transformers.utils import is_flash_attn_2_available # noqa: E402 +except Exception: # pragma: no cover - availability check + def is_flash_attn_2_available(): # type: ignore + return False from gptqmodel import BACKEND, GPTQModel # noqa: E402 from gptqmodel.models.base import BaseQModel # noqa: E402 @@ -551,7 +556,10 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne args = kwargs if kwargs else {} if self.USE_FLASH_ATTN: - args["attn_implementation"] = "flash_attention_2" + if is_flash_attn_2_available(): + args["attn_implementation"] = "flash_attention_2" + else: + log.warn("flash-attn requested but not available; falling back to framework defaults") log.info(f"args: {args}") @@ -639,7 +647,10 @@ def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_pa load_kwargs = dict(args) if self.USE_FLASH_ATTN: - load_kwargs["attn_implementation"] = "flash_attention_2" + if is_flash_attn_2_available(): + load_kwargs["attn_implementation"] = "flash_attention_2" + else: + log.warn("flash-attn requested but not available; falling back to framework defaults") active_backend = backend if backend is not None else self.LOAD_BACKEND diff --git a/tests/models/test_llama3_2.py b/tests/models/test_llama3_2.py index b4e339bf5..20cda35a2 100644 --- a/tests/models/test_llama3_2.py +++ b/tests/models/test_llama3_2.py @@ -23,7 +23,7 @@ class TestLlama3_2(ModelTest): DATASET_SIZE = 1024 DATASET_SORT = "desc" QUANT_BATCH_SIZE = 4 - USE_FLASH_ATTN = False + USE_FLASH_ATTN = True # EORA = Lora( # # for quant, path is save path. for load, it is loading path # path="./eora_test", From 5635a068e3acd8ac0fe945067eab0195b5002ca2 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 15 Oct 2025 21:45:17 +0000 Subject: [PATCH 8/9] add gc crash test --- tests/test_gc.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 tests/test_gc.py diff --git a/tests/test_gc.py b/tests/test_gc.py new file mode 100644 index 000000000..a254f3e36 --- /dev/null +++ b/tests/test_gc.py @@ -0,0 +1,46 @@ +import gc +import threading +from queue import Queue + +import pytest + + +torch = pytest.importorskip("torch", reason="requires PyTorch") + + +_THREAD_COUNT = 16 +_ITERATIONS = 20_000 +_ALLOCATION_BYTES = 1024 * 1024 + + +def _worker(barrier: threading.Barrier, errors: Queue) -> None: + try: + barrier.wait() + for _ in range(_ITERATIONS): + tensor = torch.empty(_ALLOCATION_BYTES, dtype=torch.uint8) + del tensor + gc.collect() + except Exception as exc: # pragma: no cover - stress test safeguard + errors.put(exc) + + +@pytest.mark.gc_stress +@pytest.mark.timeout(300) +def test_multithreaded_gc_collect(): + """Stress test that repeated gc.collect calls do not crash under threading.""" + barrier = threading.Barrier(_THREAD_COUNT) + errors: Queue = Queue() + + threads = [ + threading.Thread(target=_worker, args=(barrier, errors), name=f"gc-worker-{i}") + for i in range(_THREAD_COUNT) + ] + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + if not errors.empty(): + exc = errors.get() + pytest.fail(f"GC stress worker raised: {exc}") From 7d2d551783b833463023adaac6fedd9d3b9ff40f Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 15 Oct 2025 21:48:30 +0000 Subject: [PATCH 9/9] wrap gc with safe --- gptqmodel/utils/safe.py | 4 +++- gptqmodel/utils/torch.py | 13 ++++--------- tests/models/model_test.py | 2 ++ 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/gptqmodel/utils/safe.py b/gptqmodel/utils/safe.py index 05e9551f1..dc448b297 100644 --- a/gptqmodel/utils/safe.py +++ b/gptqmodel/utils/safe.py @@ -8,6 +8,7 @@ from __future__ import annotations +import gc import threading from functools import wraps from types import ModuleType @@ -99,9 +100,10 @@ def __repr__(self): TORCH_LINALG = ThreadSafe(torch.linalg) THREADPOOLCTL = ThreadSafe(_threadpoolctl) - +GC = ThreadSafe(gc) __all__ = [ "ThreadSafe", "TORCH_LINALG", "THREADPOOLCTL", + "GC", ] diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py index 80959e215..cddca4d46 100644 --- a/gptqmodel/utils/torch.py +++ b/gptqmodel/utils/torch.py @@ -4,8 +4,6 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import contextlib -import gc as py_gc -import threading import time from contextlib import contextmanager from enum import Enum @@ -16,6 +14,7 @@ from torch.cpu import StreamContext from ..utils.logger import setup_logger +from ..utils.safe import GC from . import gte_python_3_13_3, gte_python_3_14, has_gil_disabled, log_gil_requirements_for @@ -46,20 +45,16 @@ class BalanceStrategy(str, Enum): log = setup_logger() -GC_LOCK = threading.RLock() - def timed_gc_collect() -> int: """Run ``gc.collect`` and log the elapsed time along with reclaimed object count.""" start = time.perf_counter() - with GC_LOCK: - # TODO FIXME..lock to gen0 release for GIL=0 safety - # Python 3.14 removed gen1 so there is only gen0 and gen2 - collected = py_gc.collect(0) + # Python 3.14 removed gen1 so there is only gen0 and gen2 + collected = GC.collect() duration = time.perf_counter() - start - log.info(f"gc.collect(0) reclaimed {collected} objects in {duration:.3f}s") + log.info(f"gc.collect() reclaimed {collected} objects in {duration:.3f}s") return collected # reset dynamo cache on each model load since during ci loop model inference may exhuast cache diff --git a/tests/models/model_test.py b/tests/models/model_test.py index 139b9a637..01885ee81 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -48,6 +48,8 @@ pass from transformers import AutoProcessor, AutoTokenizer # noqa: E402 + + try: # noqa: E402 from transformers.utils import is_flash_attn_2_available # noqa: E402 except Exception: # pragma: no cover - availability check