Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 96 additions & 7 deletions gptqmodel/nn_modules/qlinear/bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

from __future__ import annotations

import ctypes
import os
from dataclasses import dataclass
from pathlib import Path
import sys
from typing import List, Optional, Tuple, Union

import torch
Expand All @@ -20,7 +23,7 @@

log = setup_logger()

MINIMUM_BITBLAS_VERSION = "0.1.0"
MINIMUM_BITBLAS_VERSION = "0.1.0.post1"
BITBLAS_OPTIMIZE_FEATURES: List[int] = [1, 16, 32, 64, 128, 256, 512, 1024]
BITBLAS_SUPPORTED_GROUP_SIZES: List[int] = [-1, 32, 64, 128]
BITBLAS_SUPPORTED_BITS: List[int] = [1, 2, 4, 8]
Expand All @@ -31,12 +34,96 @@
BITBLAS_TARGET = None
BITBLAS_DATABASE_PATH = None

try:
import bitblas # noqa: F401
# TODO FIXME. ugly hack to bypass nv lib loadig for bitlbas
def _load_cuda_libraries() -> bool:
loaded_any = False
candidate_dirs = []

env_dirs = []
for var in ("LD_LIBRARY_PATH", "LIBRARY_PATH"):
paths = os.environ.get(var, "")
if paths:
env_dirs.extend(Path(p) for p in paths.split(":") if p)
candidate_dirs.extend(env_dirs)

candidate_dirs.extend(
[
Path("/usr/local/cuda/lib64"),
Path("/usr/local/cuda/lib"),
Path("/usr/lib/x86_64-linux-gnu"),
]
)

try:
import nvidia # noqa: F401
except Exception: # pragma: no cover - optional dependency
nvidia_paths = []
else:
nvidia_paths = [Path(p) for p in getattr(nvidia, "__path__", [])]

for base in nvidia_paths:
candidate_dirs.extend(
[
base / "cuda_runtime" / "lib",
base / "cuda_nvrtc" / "lib",
]
)
candidate_dirs.extend(path for path in base.glob("cu*/lib"))

site_packages = Path(sys.prefix) / "lib" / f"python{sys.version_info.major}.{sys.version_info.minor}" / "site-packages"
candidate_dirs.append(site_packages)

seen_dirs = set()
for directory in candidate_dirs:
if not directory or not directory.is_dir():
continue
resolved = directory.resolve()
if resolved in seen_dirs:
continue
seen_dirs.add(resolved)

for pattern in ("libcudart.so*", "libnvrtc.so*"):
for candidate in sorted(directory.glob(pattern)):
if not candidate.is_file():
continue
try:
ctypes.CDLL(str(candidate), mode=ctypes.RTLD_GLOBAL)
loaded_any = True
except OSError:
continue

return loaded_any


def _is_bitblas_available() -> bool:
try:
import bitblas
except Exception as exc:
error_text = str(exc)
if "libcu" not in error_text:
log.debug("BitBLAS import failed: %s", exc)
return False
if not _load_cuda_libraries():
log.debug("CUDA libraries missing, BitBLAS import failed: %s", exc)
return False
try:
import bitblas
except Exception as retry_exc:
log.debug("BitBLAS import retry failed: %s", retry_exc)
return False
parsed_version = version.parse(bitblas.__version__)
minimum_version = version.parse(MINIMUM_BITBLAS_VERSION)
if parsed_version < minimum_version:
log.debug(
"BitBLAS version %s below minimum required %s",
bitblas.__version__,
MINIMUM_BITBLAS_VERSION,
)
return False
return True


BITBLAS_AVAILABLE = True
except Exception:
BITBLAS_AVAILABLE = False
BITBLAS_AVAILABLE = _is_bitblas_available()


BITBLAS_INSTALL_HINT = (
Expand All @@ -50,7 +137,9 @@ def import_bitblas():

import bitblas

if version.parse(bitblas.__version__) < version.parse(MINIMUM_BITBLAS_VERSION):
parsed_version = version.parse(bitblas.__version__)
minimum_version = version.parse(MINIMUM_BITBLAS_VERSION)
if parsed_version < minimum_version:
raise ImportError(BITBLAS_INSTALL_HINT)

bitblas.set_log_level("INFO")
Expand Down
42 changes: 42 additions & 0 deletions tests/test_bitblas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os

import pytest
import torch

from gptqmodel.nn_modules.qlinear.bitblas import (
BITBLAS_AVAILABLE,
BitblasQuantLinear,
import_bitblas,
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for BitBLAS")
@pytest.mark.skipif(not BITBLAS_AVAILABLE, reason="BitBLAS backend is not available")
def test_bitblas_forward_pass():
import_bitblas()

device_index = int(os.environ.get("BITBLAS_TEST_DEVICE", 0))
device = torch.device("cuda", device_index)
torch.cuda.set_device(device_index)

layer = BitblasQuantLinear(
bits=4,
group_size=32,
desc_act=False,
sym=True,
in_features=32,
out_features=32,
bias=False,
).to(device)

with torch.no_grad():
layer.qweight.zero_()
layer.scales.zero_()
if layer.quant_config.with_zeros:
layer.qzeros.zero_()

x = torch.randn(2, 32, device=device, dtype=layer.TORCH_DTYPE)
y = layer(x)

assert y.shape == (2, 32)
assert torch.allclose(y, torch.zeros_like(y), atol=1e-4, rtol=1e-4)