Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions gptqmodel/eora/eora.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from ..utils.logger import setup_logger
from ..utils.rocm import IS_ROCM
from ..utils.safe import TORCH_LINALG

log = setup_logger()

Expand Down Expand Up @@ -58,7 +59,7 @@ def eora_compute_lora(
original_backend = torch.backends.cuda.preferred_linalg_library()
torch.backends.cuda.preferred_linalg_library(backend="magma")

L, Q = torch.linalg.eigh(raw_scaling_diag_matrix)
L, Q = TORCH_LINALG.eigh(raw_scaling_diag_matrix)

if (L < 0).any():
## When expanding the calibration data size for EoRA, I suggest maintaining the balance by allocating 50% to general input (C4) and the remaining 50% to downstream task data.
Expand All @@ -76,7 +77,7 @@ def eora_compute_lora(

delta_scale = torch.matmul(w_wq_delta, scaling_diag_matrix)

U, S, V = torch.linalg.svd(delta_scale, full_matrices=False)
U, S, V = TORCH_LINALG.svd(delta_scale, full_matrices=False)
lowrank_r = rank
truc_s = S[:lowrank_r]
truc_u = U[:, :lowrank_r]
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/nn_modules/qlinear/bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ...utils import BACKEND
from ...utils.logger import setup_logger


log = setup_logger()

MINIMUM_BITBLAS_VERSION = "0.1.0"
Expand Down
13 changes: 3 additions & 10 deletions gptqmodel/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ..quantization import QuantizeConfig
from ..utils.device import get_device
from ..utils.logger import setup_logger
from ..utils.torch import HAS_CUDA, HAS_XPU
from ..utils.safe import TORCH_LINALG
from .gar import compose_final_perm, compute_global_perm, compute_local_perms, invert_perm
from .quantizer import HF_OPTIMUM, Quantizer

Expand All @@ -31,13 +31,6 @@

lock = threading.Lock()

# TODO: is there a buffer init threading init bug in torch.linalg?
# bypass strange threading bug by warming up torch.linalg.cholesky to setup internal setup calls
if HAS_CUDA or HAS_XPU:
tmp_eye = torch.eye(64, dtype=torch.float32, device="cuda" if HAS_CUDA else "xpu")
torch.linalg.cholesky(tmp_eye)
del tmp_eye


def get_number_of_rows_and_cols(layer: nn.Module):
# return layer.weight.shape[0], np.prod(layer.weight.shape[1:])
Expand Down Expand Up @@ -254,8 +247,8 @@ def hessian_inverse(self, H: torch.Tensor):
H2 = H.clone()
H2[diag, diag] += damp * mean
# TODO call to torch.linalg is not threadsafe? Porque no? Esta muy mal.
H2 = torch.linalg.cholesky(H2)
Hinv = torch.linalg.cholesky(torch.cholesky_inverse(H2), upper=True)
H2 = TORCH_LINALG.cholesky(H2)
Hinv = TORCH_LINALG.cholesky(torch.cholesky_inverse(H2), upper=True)
del H, H2
break
except torch._C._LinAlgError as e:
Expand Down
5 changes: 3 additions & 2 deletions gptqmodel/quantization/qqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..looper.named_module import NamedModule
from ..quantization.quantizer import HF_OPTIMUM
from ..utils import setup_logger
from ..utils.safe import TORCH_LINALG
from .gptq import get_number_of_rows_and_cols


Expand Down Expand Up @@ -316,9 +317,9 @@ def quantize(
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = TORCH_LINALG.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
H = TORCH_LINALG.cholesky(H, upper=True)
Hinv = H

for i1 in range(0, self.columns, blocksize):
Expand Down
3 changes: 2 additions & 1 deletion gptqmodel/quantization/rotation/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from ...utils.logger import setup_logger
from ...utils.model import get_module_by_name_prefix
from ...utils.safe import TORCH_LINALG
from ...utils.torch import torch_empty_cache
from .hadamard_utils import apply_exact_had_to_linear, random_hadamard_matrix

Expand Down Expand Up @@ -90,7 +91,7 @@ def random_orthogonal_matrix(size, device):
"""
torch.cuda.empty_cache()
random_matrix = torch.randn(size, size, dtype=torch.float64).to(device)
q, r = torch.linalg.qr(random_matrix)
q, r = TORCH_LINALG.qr(random_matrix)
q *= torch.sign(torch.diag(r)).unsqueeze(0)
return q

Expand Down
1 change: 1 addition & 0 deletions gptqmodel/utils/ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from contextlib import AbstractContextManager, ExitStack, contextmanager
from typing import Any, Iterator


ContextArg = AbstractContextManager[Any] | None


Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from .ctx import ctx
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type, Union

Expand Down Expand Up @@ -55,6 +54,7 @@
from ..quantization.config import FORMAT_FIELD_CHECKPOINT, METHOD, dynamic_get
from . import has_gil_disabled
from .backend import BACKEND
from .ctx import ctx
from .device import get_device
from .importer import select_quant_linear
from .logger import setup_logger
Expand Down
100 changes: 100 additions & 0 deletions gptqmodel/utils/safe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium


"""Thread-safe utilities and module wrappers used across Transformers."""

from __future__ import annotations

import threading
from functools import wraps
from types import ModuleType

import torch


class ThreadSafe(ModuleType):
"""Generic proxy that exposes a module through a shared (non-reentrant) lock."""

def __init__(self, module: ModuleType):
super().__init__(module.__name__)
self._module = module
self._lock = threading.Lock()
self._callable_cache: dict[str, object] = {}
# Keep core module metadata available so tools relying on attributes
# like __doc__ or __spec__ see the original values.
self.__dict__.update(
{
"__doc__": module.__doc__,
"__package__": module.__package__,
"__file__": getattr(module, "__file__", None),
"__spec__": getattr(module, "__spec__", None),
}
)

def __getattr__(self, name: str):
attr = getattr(self._module, name)
if callable(attr):
cached = self._callable_cache.get(name)
if cached is not None and getattr(cached, "__wrapped__", None) is attr:
return cached

@wraps(attr)
def locked(*args, **kwargs):
with self._lock:
return attr(*args, **kwargs)

locked.__wrapped__ = attr
self._callable_cache[name] = locked
return locked
return attr

def __dir__(self):
return sorted(set(super().__dir__()) | set(dir(self._module)))


class _ThreadSafeProxy:
"""Lightweight proxy that serializes access to an object with a shared lock."""

def __init__(self, value, lock):
object.__setattr__(self, "_value", value)
object.__setattr__(self, "_lock", lock)
object.__setattr__(self, "_callable_cache", {})
object.__setattr__(self, "__wrapped__", value)

def __getattr__(self, name: str):
attr = getattr(self._value, name)
if callable(attr):
cached = self._callable_cache.get(name)
if cached is not None and getattr(cached, "__wrapped__", None) is attr:
return cached

@wraps(attr)
def locked(*args, **kwargs):
with self._lock:
return attr(*args, **kwargs)

locked.__wrapped__ = attr
self._callable_cache[name] = locked
return locked
return attr

def __setattr__(self, name, value):
setattr(self._value, name, value)

def __dir__(self):
return dir(self._value)

def __repr__(self):
return repr(self._value)



TORCH_LINALG = ThreadSafe(torch.linalg)

__all__ = [
"ThreadSafe",
"TORCH_LINALG",
]
2 changes: 1 addition & 1 deletion gptqmodel/utils/threadx.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import torch

from .. import DEBUG_ON
from ..utils.logger import setup_logger
from ..utils.ctx import ctx
from ..utils.logger import setup_logger


log = setup_logger()
Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_llama3_2_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from model_test import ModelTest
from gptqmodel.quantization import METHOD, FORMAT

from gptqmodel.quantization import FORMAT, METHOD


# a100:0
Expand Down
3 changes: 2 additions & 1 deletion tests/test_bitblas_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@

from gptqmodel.nn_modules.qlinear.bitblas import BitblasQuantLinear
from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear, marlin_import_exception
from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear
from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear
from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear


RTOL = 5e-2
ATOL = 5e-2
Expand Down
4 changes: 1 addition & 3 deletions tests/test_kernel_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import unittest

import os
import unittest

import torch
from logbar import LogBar
Expand All @@ -15,7 +14,6 @@
from gptqmodel import BACKEND, GPTQModel
from gptqmodel.adapter.adapter import Adapter, AdapterCache, Lora
from gptqmodel.nn_modules.qlinear.bitblas import BitblasQuantLinear
from gptqmodel.nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear
from gptqmodel.nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear
from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear
from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear
Expand Down
6 changes: 4 additions & 2 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from logbar import LogBar # noqa: E402
from parameterized import parameterized # noqa: E402

from gptqmodel.utils.safe import TORCH_LINALG # noqa: E402


log = LogBar.shared()

Expand All @@ -33,7 +35,7 @@ class Test(unittest.TestCase):
)
def test_linalg_eigh(self, dtype: torch.dtype, size: int):
matrix = torch.randn([size, size], device=ROCM, dtype=dtype)
torch.linalg.eigh(matrix)
TORCH_LINALG.eigh(matrix)

@parameterized.expand(
[
Expand All @@ -49,6 +51,6 @@ def test_linalg_eigh_magma(self, dtype: torch.dtype, size: int):
torch.backends.cuda.preferred_linalg_library(backend="magma")

matrix = torch.randn([size, size], device=ROCM, dtype=dtype)
torch.linalg.eigh(matrix)
TORCH_LINALG.eigh(matrix)

torch.backends.cuda.preferred_linalg_library(backend=original_backend)