diff --git a/gptqmodel/eora/eora.py b/gptqmodel/eora/eora.py index a9d56c206..ba53e2f10 100644 --- a/gptqmodel/eora/eora.py +++ b/gptqmodel/eora/eora.py @@ -21,6 +21,7 @@ from ..utils.logger import setup_logger from ..utils.rocm import IS_ROCM +from ..utils.safe import TORCH_LINALG log = setup_logger() @@ -58,7 +59,7 @@ def eora_compute_lora( original_backend = torch.backends.cuda.preferred_linalg_library() torch.backends.cuda.preferred_linalg_library(backend="magma") - L, Q = torch.linalg.eigh(raw_scaling_diag_matrix) + L, Q = TORCH_LINALG.eigh(raw_scaling_diag_matrix) if (L < 0).any(): ## When expanding the calibration data size for EoRA, I suggest maintaining the balance by allocating 50% to general input (C4) and the remaining 50% to downstream task data. @@ -76,7 +77,7 @@ def eora_compute_lora( delta_scale = torch.matmul(w_wq_delta, scaling_diag_matrix) - U, S, V = torch.linalg.svd(delta_scale, full_matrices=False) + U, S, V = TORCH_LINALG.svd(delta_scale, full_matrices=False) lowrank_r = rank truc_s = S[:lowrank_r] truc_u = U[:, :lowrank_r] diff --git a/gptqmodel/nn_modules/qlinear/bitblas.py b/gptqmodel/nn_modules/qlinear/bitblas.py index 71157d68c..31fde7924 100644 --- a/gptqmodel/nn_modules/qlinear/bitblas.py +++ b/gptqmodel/nn_modules/qlinear/bitblas.py @@ -17,6 +17,7 @@ from ...utils import BACKEND from ...utils.logger import setup_logger + log = setup_logger() MINIMUM_BITBLAS_VERSION = "0.1.0" diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 2689ee2f7..e7692cd20 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -22,7 +22,7 @@ from ..quantization import QuantizeConfig from ..utils.device import get_device from ..utils.logger import setup_logger -from ..utils.torch import HAS_CUDA, HAS_XPU +from ..utils.safe import TORCH_LINALG from .gar import compose_final_perm, compute_global_perm, compute_local_perms, invert_perm from .quantizer import HF_OPTIMUM, Quantizer @@ -31,13 +31,6 @@ lock = threading.Lock() -# TODO: is there a buffer init threading init bug in torch.linalg? -# bypass strange threading bug by warming up torch.linalg.cholesky to setup internal setup calls -if HAS_CUDA or HAS_XPU: - tmp_eye = torch.eye(64, dtype=torch.float32, device="cuda" if HAS_CUDA else "xpu") - torch.linalg.cholesky(tmp_eye) - del tmp_eye - def get_number_of_rows_and_cols(layer: nn.Module): # return layer.weight.shape[0], np.prod(layer.weight.shape[1:]) @@ -254,8 +247,8 @@ def hessian_inverse(self, H: torch.Tensor): H2 = H.clone() H2[diag, diag] += damp * mean # TODO call to torch.linalg is not threadsafe? Porque no? Esta muy mal. - H2 = torch.linalg.cholesky(H2) - Hinv = torch.linalg.cholesky(torch.cholesky_inverse(H2), upper=True) + H2 = TORCH_LINALG.cholesky(H2) + Hinv = TORCH_LINALG.cholesky(torch.cholesky_inverse(H2), upper=True) del H, H2 break except torch._C._LinAlgError as e: diff --git a/gptqmodel/quantization/qqq.py b/gptqmodel/quantization/qqq.py index b91dbc717..05021b1b7 100644 --- a/gptqmodel/quantization/qqq.py +++ b/gptqmodel/quantization/qqq.py @@ -15,6 +15,7 @@ from ..looper.named_module import NamedModule from ..quantization.quantizer import HF_OPTIMUM from ..utils import setup_logger +from ..utils.safe import TORCH_LINALG from .gptq import get_number_of_rows_and_cols @@ -316,9 +317,9 @@ def quantize( damp = percdamp * torch.mean(torch.diag(H)) diag = torch.arange(self.columns, device=self.dev) H[diag, diag] += damp - H = torch.linalg.cholesky(H) + H = TORCH_LINALG.cholesky(H) H = torch.cholesky_inverse(H) - H = torch.linalg.cholesky(H, upper=True) + H = TORCH_LINALG.cholesky(H, upper=True) Hinv = H for i1 in range(0, self.columns, blocksize): diff --git a/gptqmodel/quantization/rotation/rotation.py b/gptqmodel/quantization/rotation/rotation.py index 110e45110..7b74b6f38 100644 --- a/gptqmodel/quantization/rotation/rotation.py +++ b/gptqmodel/quantization/rotation/rotation.py @@ -11,6 +11,7 @@ from ...utils.logger import setup_logger from ...utils.model import get_module_by_name_prefix +from ...utils.safe import TORCH_LINALG from ...utils.torch import torch_empty_cache from .hadamard_utils import apply_exact_had_to_linear, random_hadamard_matrix @@ -90,7 +91,7 @@ def random_orthogonal_matrix(size, device): """ torch.cuda.empty_cache() random_matrix = torch.randn(size, size, dtype=torch.float64).to(device) - q, r = torch.linalg.qr(random_matrix) + q, r = TORCH_LINALG.qr(random_matrix) q *= torch.sign(torch.diag(r)).unsqueeze(0) return q diff --git a/gptqmodel/utils/ctx.py b/gptqmodel/utils/ctx.py index b06abf9ef..745febc3e 100644 --- a/gptqmodel/utils/ctx.py +++ b/gptqmodel/utils/ctx.py @@ -8,6 +8,7 @@ from contextlib import AbstractContextManager, ExitStack, contextmanager from typing import Any, Iterator + ContextArg = AbstractContextManager[Any] | None diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 5a31382c7..5a847c8cb 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -18,7 +18,6 @@ import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from .ctx import ctx from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Type, Union @@ -55,6 +54,7 @@ from ..quantization.config import FORMAT_FIELD_CHECKPOINT, METHOD, dynamic_get from . import has_gil_disabled from .backend import BACKEND +from .ctx import ctx from .device import get_device from .importer import select_quant_linear from .logger import setup_logger diff --git a/gptqmodel/utils/safe.py b/gptqmodel/utils/safe.py new file mode 100644 index 000000000..3effcfa3d --- /dev/null +++ b/gptqmodel/utils/safe.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + + +"""Thread-safe utilities and module wrappers used across Transformers.""" + +from __future__ import annotations + +import threading +from functools import wraps +from types import ModuleType + +import torch + + +class ThreadSafe(ModuleType): + """Generic proxy that exposes a module through a shared (non-reentrant) lock.""" + + def __init__(self, module: ModuleType): + super().__init__(module.__name__) + self._module = module + self._lock = threading.Lock() + self._callable_cache: dict[str, object] = {} + # Keep core module metadata available so tools relying on attributes + # like __doc__ or __spec__ see the original values. + self.__dict__.update( + { + "__doc__": module.__doc__, + "__package__": module.__package__, + "__file__": getattr(module, "__file__", None), + "__spec__": getattr(module, "__spec__", None), + } + ) + + def __getattr__(self, name: str): + attr = getattr(self._module, name) + if callable(attr): + cached = self._callable_cache.get(name) + if cached is not None and getattr(cached, "__wrapped__", None) is attr: + return cached + + @wraps(attr) + def locked(*args, **kwargs): + with self._lock: + return attr(*args, **kwargs) + + locked.__wrapped__ = attr + self._callable_cache[name] = locked + return locked + return attr + + def __dir__(self): + return sorted(set(super().__dir__()) | set(dir(self._module))) + + +class _ThreadSafeProxy: + """Lightweight proxy that serializes access to an object with a shared lock.""" + + def __init__(self, value, lock): + object.__setattr__(self, "_value", value) + object.__setattr__(self, "_lock", lock) + object.__setattr__(self, "_callable_cache", {}) + object.__setattr__(self, "__wrapped__", value) + + def __getattr__(self, name: str): + attr = getattr(self._value, name) + if callable(attr): + cached = self._callable_cache.get(name) + if cached is not None and getattr(cached, "__wrapped__", None) is attr: + return cached + + @wraps(attr) + def locked(*args, **kwargs): + with self._lock: + return attr(*args, **kwargs) + + locked.__wrapped__ = attr + self._callable_cache[name] = locked + return locked + return attr + + def __setattr__(self, name, value): + setattr(self._value, name, value) + + def __dir__(self): + return dir(self._value) + + def __repr__(self): + return repr(self._value) + + + +TORCH_LINALG = ThreadSafe(torch.linalg) + +__all__ = [ + "ThreadSafe", + "TORCH_LINALG", +] diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index dc3757938..c0bdb6e60 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -15,8 +15,8 @@ import torch from .. import DEBUG_ON -from ..utils.logger import setup_logger from ..utils.ctx import ctx +from ..utils.logger import setup_logger log = setup_logger() diff --git a/tests/models/test_llama3_2_awq.py b/tests/models/test_llama3_2_awq.py index 35a07728a..e903c4bb9 100644 --- a/tests/models/test_llama3_2_awq.py +++ b/tests/models/test_llama3_2_awq.py @@ -4,7 +4,8 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from model_test import ModelTest -from gptqmodel.quantization import METHOD, FORMAT + +from gptqmodel.quantization import FORMAT, METHOD # a100:0 diff --git a/tests/test_bitblas_quant.py b/tests/test_bitblas_quant.py index f1d12b093..232351d49 100644 --- a/tests/test_bitblas_quant.py +++ b/tests/test_bitblas_quant.py @@ -13,8 +13,9 @@ from gptqmodel.nn_modules.qlinear.bitblas import BitblasQuantLinear from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear, marlin_import_exception -from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear +from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear + RTOL = 5e-2 ATOL = 5e-2 diff --git a/tests/test_kernel_output.py b/tests/test_kernel_output.py index 2369ced4a..e33bef026 100644 --- a/tests/test_kernel_output.py +++ b/tests/test_kernel_output.py @@ -3,9 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -import unittest - import os +import unittest import torch from logbar import LogBar @@ -15,7 +14,6 @@ from gptqmodel import BACKEND, GPTQModel from gptqmodel.adapter.adapter import Adapter, AdapterCache, Lora from gptqmodel.nn_modules.qlinear.bitblas import BitblasQuantLinear -from gptqmodel.nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear from gptqmodel.nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 31fec6585..10908c076 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -15,6 +15,8 @@ from logbar import LogBar # noqa: E402 from parameterized import parameterized # noqa: E402 +from gptqmodel.utils.safe import TORCH_LINALG # noqa: E402 + log = LogBar.shared() @@ -33,7 +35,7 @@ class Test(unittest.TestCase): ) def test_linalg_eigh(self, dtype: torch.dtype, size: int): matrix = torch.randn([size, size], device=ROCM, dtype=dtype) - torch.linalg.eigh(matrix) + TORCH_LINALG.eigh(matrix) @parameterized.expand( [ @@ -49,6 +51,6 @@ def test_linalg_eigh_magma(self, dtype: torch.dtype, size: int): torch.backends.cuda.preferred_linalg_library(backend="magma") matrix = torch.randn([size, size], device=ROCM, dtype=dtype) - torch.linalg.eigh(matrix) + TORCH_LINALG.eigh(matrix) torch.backends.cuda.preferred_linalg_library(backend=original_backend)