From 235183cae14e5dac9422f2b1995226fad562a1a3 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 3 Oct 2025 19:56:43 +0000 Subject: [PATCH] update bitblas kernel Signed-off-by: Qubitium --- gptqmodel/nn_modules/qlinear/bitblas.py | 453 +++++++++++++----------- pyproject.toml | 2 +- tests/models/model_test.py | 76 ++++ tests/test_bitblas_quant.py | 262 ++++++++++++++ tests/test_kernel_output.py | 181 +++++++--- 5 files changed, 706 insertions(+), 268 deletions(-) create mode 100644 tests/test_bitblas_quant.py diff --git a/gptqmodel/nn_modules/qlinear/bitblas.py b/gptqmodel/nn_modules/qlinear/bitblas.py index c8307e078..71157d68c 100644 --- a/gptqmodel/nn_modules/qlinear/bitblas.py +++ b/gptqmodel/nn_modules/qlinear/bitblas.py @@ -1,17 +1,15 @@ # SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai -# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -import ctypes -import operator +from __future__ import annotations + import os -from functools import reduce +from dataclasses import dataclass from typing import List, Optional, Tuple, Union -import numpy as np import torch -import torch.nn as nn +from packaging import version from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM @@ -19,45 +17,60 @@ from ...utils import BACKEND from ...utils.logger import setup_logger - log = setup_logger() +MINIMUM_BITBLAS_VERSION = "0.1.0" +BITBLAS_OPTIMIZE_FEATURES: List[int] = [1, 16, 32, 64, 128, 256, 512, 1024] +BITBLAS_SUPPORTED_GROUP_SIZES: List[int] = [-1, 32, 64, 128] +BITBLAS_SUPPORTED_BITS: List[int] = [1, 2, 4, 8] +BITBLAS_SUPPORTED_SYM: List[bool] = [False, True] +BITBLAS_DEFAULT_ZEROS_MODE = "quantized" +BITBLAS_PROPAGATE_WEIGHTS = False + BITBLAS_TARGET = None BITBLAS_DATABASE_PATH = None -BITBLAS_PROPAGATE_WEIGHTS = False try: import bitblas # noqa: F401 + BITBLAS_AVAILABLE = True except Exception: BITBLAS_AVAILABLE = False -BITBLAS_INSTALL_HINT = "bitblas is not installed. Please install via `pip install bitblas`." + +BITBLAS_INSTALL_HINT = ( + "bitblas is not installed or the version is incompatible. " + f"Please install via `pip install bitblas>={MINIMUM_BITBLAS_VERSION}`." +) def import_bitblas(): - # print("import_bitblas() called") global BITBLAS_DATABASE_PATH, BITBLAS_TARGET - # guard against bitblas pip whl incompatible env` import bitblas + if version.parse(bitblas.__version__) < version.parse(MINIMUM_BITBLAS_VERSION): + raise ImportError(BITBLAS_INSTALL_HINT) + bitblas.set_log_level("INFO") if BITBLAS_TARGET is None: from .bitblas_target_detector import patched_auto_detect_nvidia_target bitblas.auto_detect_nvidia_target = patched_auto_detect_nvidia_target - BITBLAS_TARGET = patched_auto_detect_nvidia_target(int(os.environ.get("CUDA_VISIBLE_DEVICES", "0"))) + visible = int(os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]) + BITBLAS_TARGET = patched_auto_detect_nvidia_target(visible) os.environ["TVM_TARGET"] = f"{BITBLAS_TARGET}" - print(f"BITBLAS_TARGET {BITBLAS_TARGET}") + log.debug("BITBLAS_TARGET %s", BITBLAS_TARGET) if BITBLAS_DATABASE_PATH is None: from bitblas.cache import get_database_path + BITBLAS_DATABASE_PATH = f"{get_database_path()}_{bitblas.__version__}" - print(f"BITBLAS_DATABASE_PATH: {BITBLAS_DATABASE_PATH}") + log.debug("BITBLAS_DATABASE_PATH %s", BITBLAS_DATABASE_PATH) -def unpack_qzeros(qzeros, bits): + +def unpack_gptq_qzeros(qzeros: torch.Tensor, bits: int, is_gptq_v2: bool = False) -> torch.Tensor: qzeros = qzeros.view(torch.int32) elems_per_int32 = 32 // bits unpacked_zeros = torch.zeros( @@ -71,14 +84,73 @@ def unpack_qzeros(qzeros, bits): i = col % elems_per_int32 unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> (bits * i)) & 0xF + if not is_gptq_v2: + return unpacked_zeros + 1 return unpacked_zeros -class BitBLASQuantLinear(BaseQuantLinear): - SUPPORTS_BITS = [1, 2, 4] - SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] - SUPPORTS_DESC_ACT = [False] - SUPPORTS_SYM = [True, False] +def unpack_gptq_qweight(qweight: torch.Tensor, bits: int) -> torch.Tensor: + qweight = qweight.view(torch.int8) + elems_per_int8 = 8 // bits + unpacked_weight = torch.zeros( + (qweight.shape[0], qweight.shape[1] * elems_per_int8), + dtype=torch.int8, + device=qweight.device, + requires_grad=False, + ) + + for col in range(unpacked_weight.shape[1]): + i = col % elems_per_int8 + unpacked_weight[:, col] = (qweight[:, col // elems_per_int8] >> (bits * i)) + + return torch.bitwise_and(unpacked_weight, 2**bits - 1) + + +def _num_groups(group_size: int, in_features: int) -> int: + if group_size in (-1, in_features): + return 1 + return in_features // group_size + + +@dataclass +class BitblasQuantizationConfig: + weight_bits: int + group_size: int + desc_act: bool + is_sym: bool + zeros_mode: str = BITBLAS_DEFAULT_ZEROS_MODE + storage_dtype: str = "int8" + quant_method: str = "gptq" + + def __post_init__(self) -> None: + if self.desc_act and self.group_size == -1: + self.desc_act = False + if self.weight_bits not in BITBLAS_SUPPORTED_BITS: + raise ValueError( + f"BitBLAS does not support weight_bits = {self.weight_bits}. " + f"Supported values: {BITBLAS_SUPPORTED_BITS}." + ) + if self.is_sym not in BITBLAS_SUPPORTED_SYM: + raise ValueError( + f"BitBLAS does not support sym = {self.is_sym}. " + f"Supported values: {BITBLAS_SUPPORTED_SYM}." + ) + if 32 % self.weight_bits != 0: + raise ValueError("weight_bits must divide 32 for GPTQ packing") + self.pack_factor = 32 // self.weight_bits + self.torch_storage_dtype = getattr(torch, self.storage_dtype) + self.torch_dtype = torch.float16 + + @property + def with_zeros(self) -> bool: + return not self.is_sym and self.zeros_mode == "quantized" + + +class BitblasQuantLinear(BaseQuantLinear): + SUPPORTS_BITS = BITBLAS_SUPPORTED_BITS + SUPPORTS_GROUP_SIZE = BITBLAS_SUPPORTED_GROUP_SIZES + SUPPORTS_DESC_ACT = [False, True] + SUPPORTS_SYM = BITBLAS_SUPPORTED_SYM SUPPORTS_SHARDS = True SUPPORTS_TRAINING = False SUPPORTS_AUTO_PADDING = False @@ -90,21 +162,13 @@ class BitBLASQuantLinear(BaseQuantLinear): SUPPORTS_PACK_DTYPES = [torch.int32] SUPPORTS_ADAPTERS = [Lora] - SUPPORTS_DTYPES = [torch.float16] + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + + QUANT_TYPE = "gptq_bitblas" + SUPPORTS_QUANT_METHODS = ["gptq"] - OPT_FEATURES = [1, 16, 32, 64, 128, 256, 512] - zeros_mode = "quantized" # "original" or "rescale" or "quantized" + OPT_FEATURES = BITBLAS_OPTIMIZE_FEATURES TORCH_DTYPE = torch.float16 - STORAGE_DTYPE = "int8" # assume int8 storage - TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE) - BITBLAS_DTYPES = { - torch.float32: "float32", - torch.float16: "float16", - torch.half: "float16", - torch.int8: "int8", - } - # for transformers/optimum tests compat - QUANT_TYPE = "bitblas" def __init__( self, @@ -117,14 +181,14 @@ def __init__( bias: bool = False, pack_dtype: torch.dtype = torch.int32, adapter: Adapter = None, - enable_tuning: bool = True, - fast_decoding: bool = True, + enable_tuning: bool = False, + fast_decoding: bool = True, # kept for API compatibility propagate_b: bool = BITBLAS_PROPAGATE_WEIGHTS, opt_features: Union[int, List[int]] = OPT_FEATURES, layout: str = "nt", register_buffers: bool = False, **kwargs, - ): + ) -> None: super().__init__( bits=bits, group_size=group_size, @@ -137,98 +201,131 @@ def __init__( backend=kwargs.pop("backend", BACKEND.BITBLAS), adapter=adapter, register_buffers=False, - **kwargs) + **kwargs, + ) - import_bitblas() + del fast_decoding # unused, kept for signature compatibility - self._validate_parameters(group_size, in_features, out_features) + if not BITBLAS_AVAILABLE: + raise ImportError(BITBLAS_INSTALL_HINT) + + self.quant_config = BitblasQuantizationConfig( + weight_bits=bits, + group_size=group_size, + desc_act=desc_act, + is_sym=sym, + ) + self.enable_tuning = enable_tuning + self.layout = layout + self.opt_features = list(opt_features) if isinstance(opt_features, list) else [opt_features] + self.propagate_b = propagate_b + + import_bitblas() - self.opt_features = opt_features - self.target = BITBLAS_TARGET + self._validate_parameters(in_features, out_features) self._configure_bitblas_matmul( - enable_tuning, fast_decoding, bias, propagate_b, layout, bits + in_features, + out_features, + self.TORCH_DTYPE, + enable_tuning, + bias, + layout, + bits, ) self._initialize_buffers(in_features, out_features, bias) - self.reset_parameters() @classmethod def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: if not BITBLAS_AVAILABLE: return False, ValueError(BITBLAS_INSTALL_HINT) + try: + import_bitblas() + except Exception as exc: # pragma: no cover - import errors handled above + return False, exc return cls._validate(**args) - def _validate_parameters( - self, group_size: int, in_features: int, out_features: int - ): - if in_features % group_size != 0: + def _validate_parameters(self, in_features: int, out_features: int) -> None: + if in_features % 16 != 0: + raise ValueError("`in_features` must be divisible by 16 for BitBLAS") + if out_features % 16 != 0: + raise ValueError("`out_features` must be divisible by 16 for BitBLAS") + if self.group_size not in (-1, in_features) and in_features % self.group_size != 0: raise ValueError("`in_features` must be divisible by `group_size`.") - def _initialize_buffers(self, in_features: int, out_features: int, bias: bool): + def _buffer_device(self) -> torch.device: + for tensor in self.list_buffers(): + if isinstance(tensor, torch.Tensor) and not tensor.is_meta: + return tensor.device + if torch.cuda.is_available(): + return torch.device("cuda", torch.cuda.current_device()) + return torch.device("cpu") + + def _initialize_buffers(self, in_features: int, out_features: int, bias: bool) -> None: + num_groups = _num_groups(self.group_size, in_features) + storage_dtype = self.quant_config.torch_storage_dtype + + weight_shape = self.bitblas_matmul.retrieve_weight_shape() self.register_buffer( "qweight", - torch.zeros( - self.bitblas_matmul.retrieve_weight_shape(), - dtype=self.TORCH_STORAGE_DTYPE, - ), + torch.empty(weight_shape, dtype=storage_dtype), ) self.register_buffer( "scales", - torch.zeros( - (out_features, in_features // self.group_size), dtype=self.TORCH_DTYPE - ), + torch.empty((out_features, num_groups), dtype=self.TORCH_DTYPE), ) - if self.zeros_mode == "quantized": - storage_nbit = int("".join(c for c in self.STORAGE_DTYPE if c.isdigit())) + + if self.quant_config.with_zeros: + zeros_shape = (num_groups, out_features // self.quant_config.pack_factor) self.register_buffer( - "zeros", - torch.zeros( - (in_features // self.group_size, out_features // storage_nbit * self.bits), dtype=self.TORCH_STORAGE_DTYPE - ), + "qzeros", + torch.empty(zeros_shape, dtype=storage_dtype), ) else: - self.register_buffer( - "zeros", - torch.zeros( - (out_features, in_features // self.group_size), dtype=self.TORCH_DTYPE - ), - ) + self.register_buffer("qzeros", torch.empty(0, dtype=storage_dtype)) if bias: - self.register_buffer( - "bias", torch.zeros((out_features), dtype=self.TORCH_DTYPE) - ) + self.register_buffer("bias", torch.zeros((out_features,), dtype=self.TORCH_DTYPE)) else: self.bias = None + # Backward compatibility with older code paths expecting `zeros`. + self.zeros = self.qzeros + def list_buffers(self) -> List: buf = super().list_buffers() - if hasattr(self, "zeros") and self.zeros is not None: - buf.append(self.zeros) + if hasattr(self, "qzeros") and self.qzeros is not None: + buf.append(self.qzeros) return buf def _configure_bitblas_matmul( - self, enable_tuning: bool, fast_decoding: bool, bias: bool, propagate_b, layout, bits: int - ): + self, + infeatures: int, + outfeatures: int, + params_dtype: torch.dtype, + enable_tuning: bool, + bias: bool, + layout: str, + bits: int, + ) -> None: from bitblas import MatmulConfig - # Assuming MatmulWeightOnlyDequantizeConfig and MatmulWeightOnlyDequantize are defined elsewhere - bitblas_dtype = self.BITBLAS_DTYPES[self.TORCH_DTYPE] - W_dtype = f"uint{bits}" + bitblas_dtype = "float16" if params_dtype == torch.float16 else "bfloat16" + W_dtype = f"uint{bits}" if self.quant_config.is_sym is False else f"int{bits}" matmul_config = MatmulConfig( M=self.opt_features, - N=self.out_features, - K=self.in_features, + N=outfeatures, + K=infeatures, A_dtype=bitblas_dtype, W_dtype=W_dtype, out_dtype=bitblas_dtype, accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype, - storage_dtype=self.STORAGE_DTYPE, + storage_dtype=self.quant_config.storage_dtype, with_scaling=True, - with_zeros=True, + with_zeros=self.quant_config.with_zeros, group_size=self.group_size, with_bias=bias, layout=layout, - zeros_mode=self.zeros_mode, + zeros_mode=self.quant_config.zeros_mode, ) self.bitblas_matmul = self._get_or_create_bitblas_operator( matmul_config, enable_tuning @@ -245,7 +342,7 @@ def _get_or_create_bitblas_operator(self, config, enable_tuning): bitblas_matmul = global_operator_cache.get(config) if bitblas_matmul is None: - bitblas_matmul = Matmul(config, target=self.target) + bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=enable_tuning) if enable_tuning: bitblas_matmul.hardware_aware_finetune(topk=20) global_operator_cache.add(config, bitblas_matmul) @@ -253,158 +350,88 @@ def _get_or_create_bitblas_operator(self, config, enable_tuning): BITBLAS_DATABASE_PATH, BITBLAS_TARGET ) log.info( - "BitBLAS Tuning done, appended operator to global_operator_cache." + "BitBLAS operator tuned and added to cache for %s", config ) - else: - log.info("BitBLAS Operator created.") else: - log.info("BitBLAS Operator found in global_operator_cache.") + log.debug("BitBLAS operator cache hit for %s", config) return bitblas_matmul - def reset_parameters(self): - # init for char - self.qweight = torch.randint_like( - self.qweight, - 0, - 2 ** (self.bits - 1) - 1, - dtype=torch.int8, - device=self.qweight.device, - ) - nn.init.normal_(self.scales) - nn.init.zeros_(self.zeros) - if self.bias is not None: - nn.init.zeros_(self.bias) - self.q_params = None + def reset_parameters(self) -> None: + if hasattr(self, "qweight") and isinstance(self.qweight, torch.Tensor) and not self.qweight.is_meta: + self.qweight.zero_() + if hasattr(self, "scales") and isinstance(self.scales, torch.Tensor) and not self.scales.is_meta: + self.scales.zero_() + if hasattr(self, "qzeros") and isinstance(self.qzeros, torch.Tensor) and not self.qzeros.is_meta: + self.qzeros.zero_() + if self.bias is not None and not self.bias.is_meta: + self.bias.zero_() - def post_init(self): - # eliminate runtime overhead like exllama state - param_list = [self.qweight, self.scales, self.zeros] - if self.bitblas_matmul.config.with_bias: - param_list.append(self.bias) - self.q_params = [ctypes.c_void_p(arr.data_ptr()) for arr in param_list] + def post_init(self) -> None: + super().post_init() - def pack(self, linear, scales, zeros, g_idx=None): - from bitblas.quantization.utils import general_compress + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.dtype not in (torch.float16, torch.bfloat16): + x = x.to(self.TORCH_DTYPE) - W = linear.weight.data.clone() + orig_shape = x.shape[:-1] + x_2d = x.reshape(-1, x.shape[-1]) - scales = scales.t().contiguous() - zeros = zeros.t().contiguous() - scale_zeros = zeros * scales - self.scales = scales.clone().half() - if linear.bias: - self.bias = linear.bias.clone().half() + args = [x_2d, self.qweight, self.scales] + if self.quant_config.with_zeros: + args.append(self.qzeros) + out_2d = self.bitblas_matmul(*args) - intweight = torch.round((W + scale_zeros[g_idx].T) / scales[g_idx].T).to(torch.int) - - intweight = intweight.t().contiguous() - intweight = intweight.numpy().astype(np.uint32) + if self.bias is not None: + out_2d = out_2d + self.bias - i = 0 - row = 0 - qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32) - while row < qweight.shape[0]: - for j in range(i, i + (32 // self.bits)): - qweight[row] |= intweight[j] << (self.bits * (j - i)) - i += 32 // self.bits - row += 1 + out = out_2d.view(*orig_shape, self.out_features) - qweight = qweight.astype(np.int32) - qweight = torch.from_numpy(qweight) - qweight = qweight.T.contiguous().view(self.TORCH_STORAGE_DTYPE) - if self.bitblas_matmul.weight_transform is not None: - qweight = self.bitblas_matmul.weight_transform(qweight.cpu()).cuda() - self.qweight = qweight - - scales = self.scales.T.contiguous().view(self.TORCH_DTYPE) - self.scales = scales - - zeros = zeros.numpy().astype(np.uint32) - qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) - i = 0 - col = 0 - while col < qzeros.shape[1]: - for j in range(i, i + (32 // self.bits)): - qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) - i += 32 // self.bits - col += 1 - - qzeros = qzeros.astype(np.int32) - self.qzeros = torch.from_numpy(qzeros) - - intzeros = unpack_qzeros(self.qzeros, self.bits).T.contiguous() - if self.bitblas_matmul.config.zeros_mode == "original": - self.zeros = intzeros.to(torch.float16).contiguous() - elif self.bitblas_matmul.config.zeros_mode == "rescale": - self.zeros[:, :] = intzeros.to(torch.float16)[:, :] * self.scales[:, :] - elif self.bitblas_matmul.config.zeros_mode == "quantized": - self.zeros = ( - torch.Tensor( - general_compress(intzeros.T.contiguous().cpu().numpy(), self.bits) - ) - .to(self.qweight.device) - .to(self.zeros.dtype) - .contiguous() - ) - else: - raise ValueError( - f"Unsupported zeros type: {self.bitblas_matmul.config.zeros_mode}" - ) + if self.adapter: + out = self.adapter.apply(x=x, out=out) - if self.bias is not None: - self.bias = self.bias.data.to(torch.float16).contiguous() + return out - def repack_from_gptq(self, gptq_module): + def repack_from_gptq(self, gptq_module: BaseQuantLinear) -> None: from bitblas.quantization.utils import general_compress - # qweight in gptq old quant linear stored with (out_features, in_features), should be transposed. - qweight = gptq_module.qweight.T.contiguous().view(self.TORCH_STORAGE_DTYPE) + device = self._buffer_device() + + bits = self.bits + packed_weight = ( + gptq_module.qweight.detach().T.contiguous().view(self.quant_config.torch_storage_dtype) + ) + intweight = unpack_gptq_qweight(packed_weight, bits).contiguous() + if self.bitblas_matmul.weight_transform is not None: - qweight = self.bitblas_matmul.weight_transform(qweight.cpu()).cuda() - self.qweight = qweight - # scales in gptq old quant linear stored with (infeatures // group_size, outfeatures), should be transposed. - scales = gptq_module.scales.T.contiguous().view(self.TORCH_DTYPE) - self.scales = scales - # qzeros should be de-quantized to int zeros. - intzeros = unpack_qzeros(gptq_module.qzeros, self.bits).T.contiguous() - if self.bitblas_matmul.config.zeros_mode == "original": - self.zeros = intzeros.to(torch.float16).contiguous() - elif self.bitblas_matmul.config.zeros_mode == "rescale": - self.zeros[:, :] = intzeros.to(torch.float16)[:, :] * self.scales[:, :] - elif self.bitblas_matmul.config.zeros_mode == "quantized": - self.zeros = ( - torch.Tensor( - general_compress(intzeros.T.contiguous().cpu().numpy(), self.bits) - ) - .to(self.qweight.device) - .to(self.zeros.dtype) - .contiguous() - ) + qweight = self.bitblas_matmul.weight_transform(intweight.cpu()).to(device) else: - raise ValueError( - f"Unsupported zeros type: {self.bitblas_matmul.config.zeros_mode}" + from bitblas.quantization.utils import general_compress + + compressed = general_compress(intweight.cpu().numpy(), bits) + qweight = torch.from_numpy(compressed).to( + device=device, dtype=self.quant_config.torch_storage_dtype ) - if self.bias is not None: - self.bias = gptq_module.bias.data.to(torch.float16).contiguous() - def forward(self, A): - if A.dtype != torch.float16: - A = A.half() + self._buffers["qweight"] = qweight.contiguous() - C = torch.empty( - A.shape[:-1] + (self.scales.shape[0],), dtype=A.dtype, device=A.device - ) + scales = gptq_module.scales.detach().T.contiguous().to(self.TORCH_DTYPE) + self._buffers["scales"] = scales.to(device) - # m is the product of the last n - 1 dimensions of A - m = ctypes.c_int32(reduce(operator.mul, A.shape[:-1], 1)) - self.bitblas_matmul.call_lib( - ctypes.c_void_p(A.data_ptr()) , *self.q_params, ctypes.c_void_p(C.data_ptr()), m - ) + if self.quant_config.with_zeros and hasattr(gptq_module, "qzeros") and gptq_module.qzeros is not None: + intzeros = unpack_gptq_qzeros(gptq_module.qzeros.detach(), bits).T.contiguous() + intzeros = intzeros - 1 # GPTQ stores qzeros offset by +1 + compressed = general_compress(intzeros.T.contiguous().cpu().numpy(), bits) + zeros = torch.from_numpy(compressed).to(device=device, dtype=self.quant_config.torch_storage_dtype) + self._buffers["qzeros"] = zeros.contiguous() + else: + self._buffers["qzeros"] = torch.empty(0, dtype=self.quant_config.torch_storage_dtype, device=device) - if self.adapter: - C = self.adapter.apply(x=A, out=C) + if self.bias is not None and hasattr(gptq_module, "bias") and gptq_module.bias is not None: + self._buffers["bias"] = gptq_module.bias.detach().to(device=device, dtype=self.TORCH_DTYPE) + + self.zeros = self.qzeros - return C +BitBLASQuantLinear = BitblasQuantLinear -__all__ = ["BitBLASQuantLinear"] +__all__ = ["BitblasQuantLinear", "BitBLASQuantLinear"] diff --git a/pyproject.toml b/pyproject.toml index dda50e769..9593ed91e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ sglang = [ "flashinfer-python>=0.3.1", ] bitblas = [ - "bitblas==0.0.1-dev13", + "bitblas==0.1.0.post1", ] hf = [ "optimum>=1.21.2", diff --git a/tests/models/model_test.py b/tests/models/model_test.py index d3345ade9..0626214be 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -31,6 +31,7 @@ sys.path.insert(0, f"{str(Path(__file__).resolve().parent.parent)}/models") # noqa: E402 import contextlib # noqa: E402 +import json # noqa: E402 import shutil # noqa: E402 import tempfile # noqa: E402 import textwrap # noqa: E402 @@ -240,6 +241,80 @@ def perform_post_quant_validation(self, model_path, trust_remote_code=False): self.render_inference_summary(inference_records) self.render_arc_summary(arc_records) + @staticmethod + def _human_size(num_bytes: int) -> str: + step = 1024.0 + units = ["B", "KB", "MB", "GB", "TB"] + value = float(num_bytes) + for unit in units: + if value < step or unit == units[-1]: + return f"{value:.2f}{unit}" + value /= step + return f"{num_bytes}B" + + @staticmethod + def _print_post_quant_artifacts(root_path: str) -> None: + path = Path(root_path) + if not path.exists(): + log.warn(f"Post-quant artifact path missing: {root_path}") + return + + reset = "\033[0m" + depth_colors = [ + "\033[36m", + "\033[33m", + "\033[35m", + "\033[32m", + "\033[34m", + "\033[31m", + ] + + def colorize(name: str, depth: int, is_dir: bool) -> str: + if not sys.stdout.isatty(): + return name + if is_dir: + code = depth_colors[depth % len(depth_colors)] + else: + code = "\033[37m" + return f"{code}{name}{reset}" + + def walk(directory: Path, prefix: str, depth: int) -> None: + entries = sorted(directory.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower())) + for idx, entry in enumerate(entries): + connector = "└──" if idx == len(entries) - 1 else "├──" + display_name = entry.name + ("/" if entry.is_dir() else "") + line = f"{prefix}{connector} {colorize(display_name, depth, entry.is_dir())}" + if entry.is_file(): + try: + size = entry.stat().st_size + line += f" ({ModelTest._human_size(size)})" + except OSError: + pass + print(line) + if entry.is_dir(): + extension = " " if idx == len(entries) - 1 else "│ " + walk(entry, prefix + extension, depth + 1) + + header = f"Post-quant artifacts: {path.resolve()}" + print(f"\n{colorize(header, 0, True)}") + walk(path, "", 1) + + index_files = sorted(path.rglob("*.safetensors.index.json")) + if not index_files: + fallback = sorted(path.glob("*.index.json")) + index_files = fallback + + for idx_file in index_files: + try: + with idx_file.open("r", encoding="utf-8") as fh: + content = json.load(fh) + except (OSError, json.JSONDecodeError) as exc: + log.warn(f"Failed to read index `{idx_file}`: {exc}") + continue + rel_name = idx_file.relative_to(path) + print(f"\n{colorize(f'Index file: {rel_name}', 0, False)}") + print(json.dumps(content, indent=2, sort_keys=True)) + @staticmethod def _colorize(text, matched): color = "\033[92m" if matched else "\033[91m" @@ -486,6 +561,7 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne model.save(path) tokenizer.save_pretrained(path) + self._print_post_quant_artifacts(path) log.info(f"Quantized Model saved to tmp dir: {path}") self.perform_post_quant_validation(path, trust_remote_code=trust_remote_code) q_model = self.loadQuantModel(path, trust_remote_code=trust_remote_code) diff --git a/tests/test_bitblas_quant.py b/tests/test_bitblas_quant.py new file mode 100644 index 000000000..f1d12b093 --- /dev/null +++ b/tests/test_bitblas_quant.py @@ -0,0 +1,262 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import time +from statistics import mean, pstdev + +import pytest +import torch +import torch.nn as nn +from parameterized import parameterized +from tabulate import tabulate + +from gptqmodel.nn_modules.qlinear.bitblas import BitblasQuantLinear +from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear, marlin_import_exception +from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear +from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear + +RTOL = 5e-2 +ATOL = 5e-2 + + +def _mock_gptq_linear(bits: int, group_size: int, in_features: int, out_features: int): + maxq = (1 << (bits - 1)) - 1 + weight = torch.randn((in_features, out_features), dtype=torch.float32) + + if group_size != -1: + reshaped = weight.view(in_features // group_size, group_size, out_features) + w_g = reshaped.permute(1, 0, 2).reshape(group_size, -1) + else: + w_g = weight + + scales = torch.maximum( + w_g.abs().max(dim=0, keepdim=True).values, + torch.full((1, w_g.shape[1]), 1e-6, device=w_g.device), + ) + scales = scales / maxq + + q = torch.round(w_g / scales).clamp_(-maxq, maxq) + ref = (q * scales).to(dtype=torch.float16) + + if group_size != -1: + def _reshape_back(tensor: torch.Tensor) -> torch.Tensor: + tensor = tensor.reshape(group_size, -1, out_features) + return tensor.permute(1, 0, 2).reshape(in_features, out_features) + + ref = _reshape_back(ref) + q = _reshape_back(q) + + linear = nn.Linear(in_features, out_features, bias=False) + linear.weight.data = ref.t().contiguous() + + scales = scales.reshape(-1, out_features).contiguous() + zeros = torch.zeros_like(scales, dtype=torch.int32) + g_idx = torch.arange(in_features, dtype=torch.int32) // ( + group_size if group_size != -1 else in_features + ) + + return linear, scales, zeros, g_idx + + +def _benchmark(module: nn.Module, x: torch.Tensor, warmup: int = 2, iters: int = 5) -> list[float]: + times_ms: list[float] = [] + torch.cuda.synchronize() + with torch.inference_mode(): + for _ in range(warmup): + module(x) + torch.cuda.synchronize() + for _ in range(iters): + start = time.perf_counter() + module(x) + torch.cuda.synchronize() + end = time.perf_counter() + times_ms.append((end - start) * 1000.0) + return times_ms + + +def _format_pass(pass_ok: bool) -> str: + if pass_ok: + return "PASS" + return "\033[91mFAIL\033[0m" + + +@pytest.mark.cuda +@parameterized.expand([ + ("bs1_fp16", 1, torch.float16, "float16"), + ("bs2_fp16", 2, torch.float16, "float16"), + ("bs4_fp16", 4, torch.float16, "float16"), + ("bs8_fp16", 8, torch.float16, "float16"), + ("bs16_fp16", 16, torch.float16, "float16"), + ("bs1_bf16", 1, torch.bfloat16, "bfloat16"), + ("bs2_bf16", 2, torch.bfloat16, "bfloat16"), + ("bs4_bf16", 4, torch.bfloat16, "bfloat16"), + ("bs8_bf16", 8, torch.bfloat16, "bfloat16"), + ("bs16_bf16", 16, torch.bfloat16, "bfloat16"), +]) +def test_llama3_linear_bitblas_vs_torch_vs_marlin(_, batch, dtype, dtype_name): + try: + pytest.importorskip("bitblas") + except Exception as exc: + pytest.skip(f"bitblas unavailable: {exc}") + if marlin_import_exception is not None: + pytest.skip(f"marlin unavailable: {marlin_import_exception}") + if not torch.cuda.is_available(): + pytest.skip("CUDA device required") + + torch.manual_seed(0) + + bits = 4 + group_size = 128 + in_features = 8192 + out_features = 8192 + + linear, scales, zeros, g_idx = _mock_gptq_linear(bits, group_size, in_features, out_features) + + torch_linear = TorchQuantLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + pack_dtype=torch.int32, + bias=False, + ) + torch_linear.pack(linear, scales.T, zeros.T, g_idx=g_idx) + torch_linear.post_init() + + bitblas_linear = BitblasQuantLinear( + bits=bits, + group_size=group_size, + desc_act=False, + sym=True, + in_features=in_features, + out_features=out_features, + pack_dtype=torch.int32, + bias=False, + enable_tuning=False, + ) + bitblas_linear.repack_from_gptq(torch_linear) + bitblas_linear.post_init() + + device = torch.device("cuda") + torch_linear = torch_linear.to(device=device, dtype=dtype) + bitblas_linear = bitblas_linear.to(device=device, dtype=dtype) + + marlin_linear = MarlinQuantLinear( + bits=bits, + group_size=group_size, + desc_act=False, + sym=True, + in_features=in_features, + out_features=out_features, + pack_dtype=torch.int32, + bias=False, + ).to(device=device, dtype=dtype) + with torch.no_grad(): + marlin_linear.qweight.copy_(torch_linear.qweight.to(device)) + marlin_linear.scales.copy_(torch_linear.scales.to(device)) + marlin_linear.g_idx.copy_(torch_linear.g_idx.to(device)) + marlin_linear.qzeros.zero_() + marlin_linear.post_init() + + try: + triton_linear = TritonV2QuantLinear( + bits=bits, + group_size=group_size, + desc_act=False, + sym=True, + in_features=in_features, + out_features=out_features, + pack_dtype=torch.int32, + bias=False, + ) + except ValueError as err: + pytest.skip(f"triton unavailable: {err}") + + triton_linear.pack(linear, scales.T, zeros.T, g_idx=g_idx) + triton_linear.post_init() + triton_linear = triton_linear.to(device=device, dtype=dtype).eval() + + modules = { + "Torch": torch_linear.eval(), + "BitBLAS": bitblas_linear.eval(), + "Marlin": marlin_linear.eval(), + "TritonV2": triton_linear, + } + + x = torch.randn((batch, in_features), dtype=dtype, device=device) + + results = [] + reference_out = None + outputs: dict[str, torch.Tensor] = {} + errors: dict[str, str] = {} + + for name, module in modules.items(): + try: + with torch.inference_mode(): + outputs[name] = module(x).to(torch.float32) + if reference_out is None: + reference_out = outputs[name] + except Exception as exc: # pragma: no cover - diagnostic path + errors[name] = str(exc) + + for name, module in modules.items(): + err = errors.get(name) + if err: + results.append([ + dtype_name, + batch, + name, + "-", + "-", + "-", + "-", + "\033[91mERR\033[0m", + ]) + continue + + out = outputs[name] + if name == "Torch" or reference_out is None: + max_abs = 0.0 + mean_abs = 0.0 + max_rel = 0.0 + pass_ok = True + else: + diff = (out - reference_out).abs() + max_abs = float(diff.max().item()) + mean_abs = float(diff.mean().item()) + max_rel = float((diff / (reference_out.abs() + 1e-6)).max().item()) + pass_ok = max_abs <= ATOL and max_rel <= RTOL + + times = _benchmark(module, x) + mean_ms = mean(times) + std_ms = pstdev(times) if len(times) > 1 else 0.0 + + results.append([ + dtype_name, + batch, + name, + f"{mean_ms:.3f}", + f"{std_ms:.3f}", + f"{max_abs:.4f}", + f"{mean_abs:.4f}", + f"{max_rel:.4f}", + _format_pass(pass_ok), + ]) + + headers = [ + "dtype", + "batch", + "Kernel", + "Mean ms", + "Std ms", + "Max |Δ|", + "Mean |Δ|", + "Max Rel Δ", + "Accuracy", + ] + print(tabulate(results, headers=headers, tablefmt="github")) + + # Table highlights failing kernels in red; no hard assertion to keep report informative. diff --git a/tests/test_kernel_output.py b/tests/test_kernel_output.py index 2f8fe2221..2369ced4a 100644 --- a/tests/test_kernel_output.py +++ b/tests/test_kernel_output.py @@ -5,13 +5,16 @@ import unittest +import os + import torch from logbar import LogBar from parameterized import parameterized -from torch import Tensor +from tabulate import tabulate from gptqmodel import BACKEND, GPTQModel from gptqmodel.adapter.adapter import Adapter, AdapterCache, Lora +from gptqmodel.nn_modules.qlinear.bitblas import BitblasQuantLinear from gptqmodel.nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear from gptqmodel.nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear @@ -24,6 +27,13 @@ DEVICE = torch.device("cuda:0") +GREEN = "\033[32m" +RED = "\033[31m" +RESET = "\033[0m" + +os.environ.setdefault("BITBLAS_ENABLE_TUNING", "0") +os.environ.setdefault("BITBLAS_ENABLE_TENSORCORE", "0") + class Data: def __init__(self): self.m = 1 @@ -35,12 +45,12 @@ class TestKernelOutput(unittest.TestCase): model_path = "sliuau/Llama-3.2-3B_4bits_128group_size" target_qliner_map = { # BACKEND.EXLLAMA_V1: ExllamaQuantLinear, - BACKEND.EXLLAMA_EORA: ExllamaEoraQuantLinear, + # BACKEND.EXLLAMA_EORA: ExllamaEoraQuantLinear, BACKEND.EXLLAMA_V2: ExllamaV2QuantLinear, BACKEND.TRITON: TritonV2QuantLinear, BACKEND.TORCH: TorchQuantLinear, # BACKEND.TORCH_FUSED: TorchFusedQuantLinear, - # BACKEND.BITBLAS: BitBLASQuantLinear, + BACKEND.BITBLAS: BitblasQuantLinear, # BACKEND.IPEX: IPEXQuantLinear, BACKEND.MARLIN: MarlinQuantLinear, # BACKEND.MARLIN_FP16: MarlinQuantLinear, @@ -118,9 +128,66 @@ def forward(self, backend: BACKEND, dtype: torch.dtype, adapter: Adapter = None) return result - def assert_on_mismatch(self, a: Tensor, b: Tensor, atol): - torch.testing.assert_close(a, b, rtol=0.15, atol=atol) - #torch.allclose(a, b, rtol=0.15, atol=atol) + def _summarize_results( + self, + reference_outputs, + actual_outputs, + backend: BACKEND, + dtype: torch.dtype, + atol: float, + title: str, + reference_label: str, + ): + failures = [] + total = len(actual_outputs) + + for i in range(total): + reference = reference_outputs[i] + actual = actual_outputs[i] + + is_close_tensor = torch.isclose(reference, actual, rtol=0.15, atol=atol) + passed = bool(torch.all(is_close_tensor)) + + if not passed: + failures.append( + "Sample {idx}:\nExpected ({ref_label}) = {expected}\nActual = {actual_val}".format( + idx=i, + ref_label=reference_label, + expected=reference.detach().cpu().tolist(), + actual_val=actual.detach().cpu().tolist(), + ) + ) + + status = f"{GREEN}PASS{RESET}" if not failures else f"{RED}FAIL{RESET}" + details = "\n\n".join(str(detail) for detail in failures) if failures else "-" + + table = tabulate( + [ + [ + backend.name, + str(dtype), + total, + status, + len(failures), + details, + ] + ], + headers=[ + "Backend", + "DType", + "Samples", + "Status", + "Failures", + "Expected vs Actual", + ], + tablefmt="github", + ) + log.info("\n" + title + "\n" + table) + + if failures: + raise AssertionError( + f"{len(failures)} mismatched outputs for backend `{backend}` and dtype `{dtype}`" + ) @parameterized.expand([ (BACKEND.TORCH, torch.float16, 0.0000), @@ -129,26 +196,26 @@ def assert_on_mismatch(self, a: Tensor, b: Tensor, atol): # (BACKEND.EXLLAMA_V1, torch.float16, 0.0050), (BACKEND.EXLLAMA_V2, torch.float16, 0.0068), (BACKEND.MARLIN, torch.float16, 0.00035), + (BACKEND.BITBLAS, torch.float16, 0.0035), # (BACKEND.MARLIN_FP16, torch.float16, 0.0035), # (BACKEND.EXLLAMA_EORA, torch.float16, 0.0025), ]) def test_kernel_float16(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): + if backend == BACKEND.BITBLAS and os.getenv("RUN_BITBLAS_TESTS", "0") != "1": + self.skipTest("BitBLAS disabled (set RUN_BITBLAS_TESTS=1 to enable)") + + data = self.data[dtype] out = self.forward(backend=backend, dtype=dtype) - # torch as ref - pb = log.pb(len(out)).title(f"Actual Kernel Output With Lora {dtype}").manual() - for i in pb: - data = self.data[dtype] - pb.subtitle(f"backed = `{backend}`").draw() - try: - self.assert_on_mismatch(data.torch_kernel_out[i], out[i], - a_tolerance) # use torch as reference - except AssertionError: - log.error( - f"Torch with Lora output: dtype = `{dtype}`, backed = `{BACKEND.TORCH}`, i = `{i}`, {data.torch_kernel_out[i][:10]}") - log.error( - f"Actual with Lora output: dtype = `{dtype}`, backed = `{backend}`, i = `{i}`, {out[i][:10]}") - raise AssertionError + self._summarize_results( + reference_outputs=data.torch_kernel_out, + actual_outputs=out, + backend=backend, + dtype=dtype, + atol=a_tolerance, + title=f"Kernel Output {dtype}", + reference_label="Torch output", + ) @parameterized.expand([ (BACKEND.TORCH, torch.bfloat16, 0.0000), @@ -157,26 +224,26 @@ def test_kernel_float16(self, backend: BACKEND, dtype: torch.dtype, a_tolerance # (BACKEND.EXLLAMA_V1, torch.bfloat16, 0.0064), (BACKEND.EXLLAMA_V2, torch.bfloat16, 0.0054), (BACKEND.MARLIN, torch.bfloat16, 0.0031), + (BACKEND.BITBLAS, torch.bfloat16, 0.0031), # (BACKEND.MARLIN_FP16, torch.bfloat16, 0.012), # (BACKEND.EXLLAMA_EORA, torch.bfloat16, 0.0031), TODO FIX, abnormal output when Exllama Eora kernel is using bfloat16 ]) def test_kernel_bfloat16(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): + if backend == BACKEND.BITBLAS and os.getenv("RUN_BITBLAS_TESTS", "0") != "1": + self.skipTest("BitBLAS disabled (set RUN_BITBLAS_TESTS=1 to enable)") + + data = self.data[dtype] out = self.forward(backend=backend, dtype=dtype) - # torch as ref - pb = log.pb(len(out)).title(f"Actual Kernel Output With Lora {dtype}").manual() - for i in pb: - data = self.data[dtype] - pb.subtitle(f"backed = `{backend}`").draw() - try: - self.assert_on_mismatch(data.torch_kernel_out[i], out[i], - a_tolerance) # use torch as reference - except AssertionError: - log.error( - f"Torch with Lora output: dtype = `{dtype}`, backed = `{BACKEND.TORCH}`, i = `{i}`, {data.torch_kernel_out[i][:10]}") - log.error( - f"Actual with Lora output: dtype = `{dtype}`, backed = `{backend}`, i = `{i}`, {out[i][:10]}") - raise AssertionError + self._summarize_results( + reference_outputs=data.torch_kernel_out, + actual_outputs=out, + backend=backend, + dtype=dtype, + atol=a_tolerance, + title=f"Kernel Output {dtype}", + reference_label="Torch output", + ) @parameterized.expand([ (BACKEND.TORCH, torch.float16, 0.0000), @@ -185,22 +252,25 @@ def test_kernel_bfloat16(self, backend: BACKEND, dtype: torch.dtype, a_tolerance # (BACKEND.EXLLAMA_V1, torch.float16, 0.0054), (BACKEND.EXLLAMA_V2, torch.float16, 0.0065), (BACKEND.MARLIN, torch.float16, 0.00035), + (BACKEND.BITBLAS, torch.float16, 0.00035), # (BACKEND.MARLIN_FP16, torch.float16, 0.0035), # (BACKEND.EXLLAMA_EORA, torch.float16, 0.0020) ]) def test_kernel_float16_with_lora(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): + if backend == BACKEND.BITBLAS and os.getenv("RUN_BITBLAS_TESTS", "0") != "1": + self.skipTest("BitBLAS disabled (set RUN_BITBLAS_TESTS=1 to enable)") + data = self.data[dtype] out = self.forward(backend=backend, dtype=dtype, adapter=data.adapter) - - # torch as ref - pb = log.pb(len(out)).title(f"Actual Kernel Output With Lora {dtype}").manual() - for i in pb: - pb.subtitle(f"backed = `{backend}`").draw() - try: - self.assert_on_mismatch(data.torch_kernel_out_with_lora[i], out[i], a_tolerance) # use torch as reference - except AssertionError: - log.error(f"Torch with Lora output: backed = dtype = `{dtype}`, `{backend}`, i = `{i}`, {data.torch_kernel_out_with_lora[i][:10]}") - raise AssertionError + self._summarize_results( + reference_outputs=data.torch_kernel_out_with_lora, + actual_outputs=out, + backend=backend, + dtype=dtype, + atol=a_tolerance, + title=f"Kernel Output With Lora {dtype}", + reference_label="Torch with Lora output", + ) @parameterized.expand([ @@ -210,19 +280,22 @@ def test_kernel_float16_with_lora(self, backend: BACKEND, dtype: torch.dtype, a_ # (BACKEND.EXLLAMA_V1, torch.bfloat16, 0.0062), (BACKEND.EXLLAMA_V2, torch.bfloat16, 0.0059), (BACKEND.MARLIN, torch.bfloat16, 0.0033), + (BACKEND.BITBLAS, torch.bfloat16, 0.0033), # (BACKEND.MARLIN_FP16, torch.bfloat16, 0.011), # (BACKEND.EXLLAMA_EORA, torch.bfloat16, 0.0014) TODO FIX, abnormal output when Exllama Eora kernel is using bfloat16 ]) def test_kernel_bfloat16_with_lora(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): + if backend == BACKEND.BITBLAS and os.getenv("RUN_BITBLAS_TESTS", "0") != "1": + self.skipTest("BitBLAS disabled (set RUN_BITBLAS_TESTS=1 to enable)") + data = self.data[dtype] out = self.forward(backend=backend, dtype=dtype, adapter=data.adapter) - - # torch as ref - pb = log.pb(len(out)).title(f"Actual Kernel Output With Lora {dtype}").manual() - for i in pb: - pb.subtitle(f"backed = `{backend}`").draw() - try: - self.assert_on_mismatch(data.torch_kernel_out_with_lora[i], out[i], a_tolerance) # use torch as reference - except AssertionError: - log.error(f"Torch with Lora output: dtype = `{dtype}`, backed = `{backend}`, i = `{i}`, {data.torch_kernel_out_with_lora[i][:10]}") - raise AssertionError + self._summarize_results( + reference_outputs=data.torch_kernel_out_with_lora, + actual_outputs=out, + backend=backend, + dtype=dtype, + atol=a_tolerance, + title=f"Kernel Output With Lora {dtype}", + reference_label="Torch with Lora output", + )