From d8eae1fc7477416c3415cc60f571dbda5acd3e1e Mon Sep 17 00:00:00 2001 From: Qubitium Date: Thu, 23 Oct 2025 14:38:14 +0000 Subject: [PATCH] bitblas has strange compat issues with pip nvidia libs --- gptqmodel/nn_modules/qlinear/bitblas.py | 103 ++++++++++++++++++++++-- tests/test_bitblas.py | 42 ++++++++++ 2 files changed, 138 insertions(+), 7 deletions(-) create mode 100644 tests/test_bitblas.py diff --git a/gptqmodel/nn_modules/qlinear/bitblas.py b/gptqmodel/nn_modules/qlinear/bitblas.py index 2838c7deb..db9adadae 100644 --- a/gptqmodel/nn_modules/qlinear/bitblas.py +++ b/gptqmodel/nn_modules/qlinear/bitblas.py @@ -4,8 +4,11 @@ from __future__ import annotations +import ctypes import os from dataclasses import dataclass +from pathlib import Path +import sys from typing import List, Optional, Tuple, Union import torch @@ -20,7 +23,7 @@ log = setup_logger() -MINIMUM_BITBLAS_VERSION = "0.1.0" +MINIMUM_BITBLAS_VERSION = "0.1.0.post1" BITBLAS_OPTIMIZE_FEATURES: List[int] = [1, 16, 32, 64, 128, 256, 512, 1024] BITBLAS_SUPPORTED_GROUP_SIZES: List[int] = [-1, 32, 64, 128] BITBLAS_SUPPORTED_BITS: List[int] = [1, 2, 4, 8] @@ -31,12 +34,96 @@ BITBLAS_TARGET = None BITBLAS_DATABASE_PATH = None -try: - import bitblas # noqa: F401 +# TODO FIXME. ugly hack to bypass nv lib loadig for bitlbas +def _load_cuda_libraries() -> bool: + loaded_any = False + candidate_dirs = [] + + env_dirs = [] + for var in ("LD_LIBRARY_PATH", "LIBRARY_PATH"): + paths = os.environ.get(var, "") + if paths: + env_dirs.extend(Path(p) for p in paths.split(":") if p) + candidate_dirs.extend(env_dirs) + + candidate_dirs.extend( + [ + Path("/usr/local/cuda/lib64"), + Path("/usr/local/cuda/lib"), + Path("/usr/lib/x86_64-linux-gnu"), + ] + ) + + try: + import nvidia # noqa: F401 + except Exception: # pragma: no cover - optional dependency + nvidia_paths = [] + else: + nvidia_paths = [Path(p) for p in getattr(nvidia, "__path__", [])] + + for base in nvidia_paths: + candidate_dirs.extend( + [ + base / "cuda_runtime" / "lib", + base / "cuda_nvrtc" / "lib", + ] + ) + candidate_dirs.extend(path for path in base.glob("cu*/lib")) + + site_packages = Path(sys.prefix) / "lib" / f"python{sys.version_info.major}.{sys.version_info.minor}" / "site-packages" + candidate_dirs.append(site_packages) + + seen_dirs = set() + for directory in candidate_dirs: + if not directory or not directory.is_dir(): + continue + resolved = directory.resolve() + if resolved in seen_dirs: + continue + seen_dirs.add(resolved) + + for pattern in ("libcudart.so*", "libnvrtc.so*"): + for candidate in sorted(directory.glob(pattern)): + if not candidate.is_file(): + continue + try: + ctypes.CDLL(str(candidate), mode=ctypes.RTLD_GLOBAL) + loaded_any = True + except OSError: + continue + + return loaded_any + + +def _is_bitblas_available() -> bool: + try: + import bitblas + except Exception as exc: + error_text = str(exc) + if "libcu" not in error_text: + log.debug("BitBLAS import failed: %s", exc) + return False + if not _load_cuda_libraries(): + log.debug("CUDA libraries missing, BitBLAS import failed: %s", exc) + return False + try: + import bitblas + except Exception as retry_exc: + log.debug("BitBLAS import retry failed: %s", retry_exc) + return False + parsed_version = version.parse(bitblas.__version__) + minimum_version = version.parse(MINIMUM_BITBLAS_VERSION) + if parsed_version < minimum_version: + log.debug( + "BitBLAS version %s below minimum required %s", + bitblas.__version__, + MINIMUM_BITBLAS_VERSION, + ) + return False + return True + - BITBLAS_AVAILABLE = True -except Exception: - BITBLAS_AVAILABLE = False +BITBLAS_AVAILABLE = _is_bitblas_available() BITBLAS_INSTALL_HINT = ( @@ -50,7 +137,9 @@ def import_bitblas(): import bitblas - if version.parse(bitblas.__version__) < version.parse(MINIMUM_BITBLAS_VERSION): + parsed_version = version.parse(bitblas.__version__) + minimum_version = version.parse(MINIMUM_BITBLAS_VERSION) + if parsed_version < minimum_version: raise ImportError(BITBLAS_INSTALL_HINT) bitblas.set_log_level("INFO") diff --git a/tests/test_bitblas.py b/tests/test_bitblas.py new file mode 100644 index 000000000..78facbe39 --- /dev/null +++ b/tests/test_bitblas.py @@ -0,0 +1,42 @@ +import os + +import pytest +import torch + +from gptqmodel.nn_modules.qlinear.bitblas import ( + BITBLAS_AVAILABLE, + BitblasQuantLinear, + import_bitblas, +) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for BitBLAS") +@pytest.mark.skipif(not BITBLAS_AVAILABLE, reason="BitBLAS backend is not available") +def test_bitblas_forward_pass(): + import_bitblas() + + device_index = int(os.environ.get("BITBLAS_TEST_DEVICE", 0)) + device = torch.device("cuda", device_index) + torch.cuda.set_device(device_index) + + layer = BitblasQuantLinear( + bits=4, + group_size=32, + desc_act=False, + sym=True, + in_features=32, + out_features=32, + bias=False, + ).to(device) + + with torch.no_grad(): + layer.qweight.zero_() + layer.scales.zero_() + if layer.quant_config.with_zeros: + layer.qzeros.zero_() + + x = torch.randn(2, 32, device=device, dtype=layer.TORCH_DTYPE) + y = layer(x) + + assert y.shape == (2, 32) + assert torch.allclose(y, torch.zeros_like(y), atol=1e-4, rtol=1e-4)