diff --git a/gptqmodel/models/definitions/ovis.py b/gptqmodel/models/definitions/ovis.py index a9144ed19..7dbae76c9 100644 --- a/gptqmodel/models/definitions/ovis.py +++ b/gptqmodel/models/definitions/ovis.py @@ -12,9 +12,9 @@ from ...utils.calibration import batched from ...utils.image import fetch_image from ...utils.model import MODALITY, move_to +from ...utils.offload import offload_to_disk from .._const import CPU from ..base import BaseQModel -from ...utils.offload import offload_to_disk class OvisQModel(BaseQModel): diff --git a/gptqmodel/models/definitions/qwen3_moe.py b/gptqmodel/models/definitions/qwen3_moe.py index c68ee7afa..6e07fab60 100644 --- a/gptqmodel/models/definitions/qwen3_moe.py +++ b/gptqmodel/models/definitions/qwen3_moe.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from ...quantization import METHOD from ..base import BaseQModel diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index 20b2d6892..442e925be 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -1032,12 +1032,10 @@ def pack_original(self, linear: nn.Module, scales: t.Tensor, zeros: t.Tensor, g_ class AWQuantLinear(BaseQuantLinear): def __init__(self, bias: bool = False, - use_bf16: bool = False, register_buffers: bool = False, **kwargs): super().__init__(bias=bias, register_buffers=False, **kwargs) - self.use_bf16 = use_bf16 in_features = self.in_features out_features = self.out_features @@ -1058,12 +1056,12 @@ def __init__(self, "scales", t.zeros( (in_features // self.group_size, out_features), - dtype=t.bfloat16 if self.use_bf16 else t.float32, + dtype=t.float16, ), ) if bias: - self.register_buffer("bias", t.zeros(out_features, dtype=t.bfloat16 if self.use_bf16 else t.float32,)) + self.register_buffer("bias", t.zeros(out_features, dtype=t.float16)) else: self.bias = None diff --git a/gptqmodel/nn_modules/qlinear/awq_gemm.py b/gptqmodel/nn_modules/qlinear/awq_gemm.py index 4f041e794..e2b95100e 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemm.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemm.py @@ -31,7 +31,7 @@ class AwqGEMMQuantLinear(AWQuantLinear): SUPPORTS_PACK_DTYPES = [torch.int32] SUPPORTS_ADAPTERS = [Lora] - SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + SUPPORTS_DTYPES = [torch.float16] REQUIRES_FORMAT_V2 = False diff --git a/gptqmodel/nn_modules/qlinear/awq_gemv.py b/gptqmodel/nn_modules/qlinear/awq_gemv.py index 50add314f..150ad3b69 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemv.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemv.py @@ -34,7 +34,7 @@ class AwqGEMVQuantLinear(AWQuantLinear): SUPPORTS_PACK_DTYPES = [torch.int32] SUPPORTS_ADAPTERS = [Lora] - SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + SUPPORTS_DTYPES = [torch.float16] # for transformers/optimum tests compat QUANT_TYPE = "awq_gemv" diff --git a/gptqmodel/nn_modules/qlinear/awq_marlin.py b/gptqmodel/nn_modules/qlinear/awq_marlin.py index e7d576159..5b41a9d6f 100644 --- a/gptqmodel/nn_modules/qlinear/awq_marlin.py +++ b/gptqmodel/nn_modules/qlinear/awq_marlin.py @@ -53,7 +53,7 @@ class AwqMarlinQuantLinear(AWQuantLinear): SUPPORTS_PACK_DTYPES = [torch.int32] SUPPORTS_ADAPTERS = [Lora] - SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + SUPPORTS_DTYPES = [torch.float16] REQUIRES_FORMAT_V2 = False diff --git a/gptqmodel/nn_modules/qlinear/awq_torch.py b/gptqmodel/nn_modules/qlinear/awq_torch.py new file mode 100644 index 000000000..5e298d7dc --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/awq_torch.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import torch + +from ...adapter.adapter import Adapter, Lora +from ...models._const import DEVICE, PLATFORM +from ...quantization.awq.utils.packing_utils import dequantize_gemm +from ...utils.backend import BACKEND +from ...utils.logger import setup_logger +from . import AWQuantLinear + + +log = setup_logger() + + +class AwqTorchQuantLinear(AWQuantLinear): + SUPPORTS_BITS = [4] + SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] + SUPPORTS_DESC_ACT = [True, False] + SUPPORTS_SYM = [True, False] + SUPPORTS_SHARDS = True + SUPPORTS_TRAINING = True + SUPPORTS_AUTO_PADDING = False + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] + + SUPPORTS_DEVICES = [DEVICE.ALL] + SUPPORTS_PLATFORM = [PLATFORM.ALL] + SUPPORTS_PACK_DTYPES = [torch.int32] + SUPPORTS_ADAPTERS = [Lora] + + SUPPORTS_DTYPES = [torch.float16] + + REQUIRES_FORMAT_V2 = False + + QUANT_TYPE = "awq_torch" + + def __init__( + self, + bits: int, + group_size: int, + sym: bool, + desc_act: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, + register_buffers: bool = False, + **kwargs, + ): + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + backend=kwargs.pop("backend", BACKEND.TORCH_AWQ), + adapter=adapter, + register_buffers=register_buffers, + **kwargs, + ) + + def post_init(self): + super().post_init() + + def extra_repr(self) -> str: + return ( + f"in_features={self.in_features}, out_features={self.out_features}, " + f"bias={self.bias is not None}, bits={self.bits}, group_size={self.group_size}" + ) + + def forward(self, x: torch.Tensor): + original_shape = x.shape[:-1] + (self.out_features,) + device = x.device + x_flat = x.reshape(-1, x.shape[-1]) + + weight = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.bits, self.group_size) + assert weight.dtype == torch.float16, f"weight {weight.dtype} is not float16" + if weight.dtype != x_flat.dtype or weight.device != device: + weight = weight.to(device=device, dtype=x_flat.dtype) + + output = torch.matmul(x_flat, weight) + + if self.bias is not None: + output = output + self.bias + + if self.adapter: + output = self.adapter.apply(x=x_flat, out=output) + + output = output.reshape(original_shape) + + return output + +__all__ = ["AwqTorchQuantLinear"] diff --git a/gptqmodel/utils/backend.py b/gptqmodel/utils/backend.py index 664e0e40c..0744c0511 100644 --- a/gptqmodel/utils/backend.py +++ b/gptqmodel/utils/backend.py @@ -29,6 +29,7 @@ class BACKEND(str, Enum): GEMM = "gemm" GEMV = "gemv" GEMV_FAST = "gemv_fast" + TORCH_AWQ = "torch_awq" # external VLLM = "vllm" # External inference engine: CUDA + ROCm + IPEX diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index 3bb46f107..0f72d4f92 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -20,6 +20,7 @@ from ..nn_modules.qlinear.awq_gemv_fast import AwqGEMVFastQuantLinear from ..nn_modules.qlinear.awq_machete import AwqMacheteQuantLinear from ..nn_modules.qlinear.awq_marlin import AwqMarlinQuantLinear +from ..nn_modules.qlinear.awq_torch import AwqTorchQuantLinear from ..nn_modules.qlinear.bitblas import BitBLASQuantLinear from ..nn_modules.qlinear.exllama import ExllamaQuantLinear from ..nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear @@ -69,6 +70,7 @@ BACKEND.GEMM: AwqGEMMQuantLinear, BACKEND.GEMV: AwqGEMVQuantLinear, BACKEND.GEMV_FAST: AwqGEMVFastQuantLinear, + BACKEND.TORCH_AWQ: AwqTorchQuantLinear, }), } @@ -83,7 +85,7 @@ FORMAT.QQQ: [BACKEND.QQQ], }, METHOD.AWQ: { - FORMAT.GEMM: [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.GEMM], + FORMAT.GEMM: [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.GEMM, BACKEND.TORCH_AWQ], FORMAT.GEMV: [BACKEND.GEMV], FORMAT.GEMV_FAST: [BACKEND.GEMV_FAST], FORMAT.MARLIN: [BACKEND.MACHETE, BACKEND.MARLIN], @@ -314,6 +316,8 @@ def select_quant_linear( qlinear = AwqGEMVQuantLinear elif backend == BACKEND.GEMV_FAST: qlinear = AwqGEMVFastQuantLinear + elif backend == BACKEND.TORCH_AWQ: + qlinear = AwqTorchQuantLinear elif backend == BACKEND.TORCH: qlinear = TorchQuantLinear elif backend == BACKEND.TORCH_FUSED: diff --git a/tests/test_awq_torch_kernel.py b/tests/test_awq_torch_kernel.py new file mode 100644 index 000000000..a91dd88ef --- /dev/null +++ b/tests/test_awq_torch_kernel.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import pytest +import torch + +from gptqmodel.nn_modules.qlinear.awq_torch import AwqTorchQuantLinear +from gptqmodel.quantization import FORMAT, METHOD +from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm +from gptqmodel.utils.backend import BACKEND +from gptqmodel.utils.importer import select_quant_linear + + +def _pack_awq_tensor(unpacked: torch.Tensor, bits: int) -> torch.Tensor: + pack_factor = 32 // bits + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + + assert unpacked.shape[1] % pack_factor == 0 + packed = torch.zeros( + (unpacked.shape[0], unpacked.shape[1] // pack_factor), + dtype=torch.int32, + ) + for col in range(unpacked.shape[1] // pack_factor): + for i, order in enumerate(order_map): + value = unpacked[:, col * pack_factor + order].to(torch.int32) + packed[:, col] |= value << (i * bits) + return packed + + +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_awq_torch_matches_manual_dequant(dtype): + if dtype not in AwqTorchQuantLinear.SUPPORTS_DTYPES: + pytest.skip(f"dtype {dtype} not supported by AwqTorchQuantLinear") + torch.manual_seed(0) + + bits = 4 + in_features = 32 + out_features = 64 + group_size = 16 + + assert out_features % (32 // bits) == 0 + assert in_features % group_size == 0 + + groups = in_features // group_size + pack_cols = out_features + + int_weight = torch.randint(0, 2**bits, size=(in_features, out_features), dtype=torch.int32) + zero_points = torch.randint(0, 2**bits, size=(groups, pack_cols), dtype=torch.int32) + scales = (torch.rand(groups, pack_cols, dtype=torch.float16) * 2.0) + 0.25 + bias = torch.randn(out_features, dtype=torch.float16) + + qweight = _pack_awq_tensor(int_weight, bits) + qzeros = _pack_awq_tensor(zero_points, bits) + + module = AwqTorchQuantLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + bias=True, + register_buffers=True, + ) + + module.qweight.copy_(qweight) + module.qzeros.copy_(qzeros) + module.scales = module.scales.to(dtype=torch.float16) + module.scales.copy_(scales.to(torch.float16)) + module.bias.copy_(bias) + module.post_init() + module.eval() + + batch = 4 + x = torch.randn(batch, in_features, dtype=dtype) + + bias_expected = module.bias + + dequant_weight = dequantize_gemm( + qweight=module.qweight, + qzeros=module.qzeros, + scales=module.scales, + bits=bits, + group_size=group_size, + ).to(dtype=dtype) + + expected = torch.matmul(x.to(dtype), dequant_weight) + expected = expected + bias_expected + + output_first = module(x) + output_second = module(x) + + atol = 1e-4 if dtype == torch.float32 else 5e-3 + rtol = 1e-4 if dtype == torch.float32 else 5e-3 + torch.testing.assert_close(output_first, expected, atol=atol, rtol=rtol) + torch.testing.assert_close(output_second, expected, atol=atol, rtol=rtol) + + +def test_awq_torch_backend_selection(): + qlinear_cls = select_quant_linear( + bits=4, + group_size=128, + desc_act=False, + sym=True, + device=None, + backend=BACKEND.TORCH_AWQ, + format=FORMAT.GEMM, + quant_method=METHOD.AWQ, + pack_dtype=torch.int32, + ) + assert qlinear_cls is AwqTorchQuantLinear diff --git a/tests/test_kernel_output_awq.py b/tests/test_kernel_output_awq.py index 95c8bbcda..7b3c0a5bf 100644 --- a/tests/test_kernel_output_awq.py +++ b/tests/test_kernel_output_awq.py @@ -21,6 +21,7 @@ AwqMarlinQuantLinear, marlin_import_exception, ) +from gptqmodel.nn_modules.qlinear.awq_torch import AwqTorchQuantLinear from gptqmodel.utils.marlin import marlin_make_workspace_new @@ -40,13 +41,16 @@ class TestAwqKernelOutput(unittest.TestCase): TARGET = "model.layers.20.self_attn.v_proj" BITS = 4 GROUP_SIZE = 128 - SUPPORTED_DTYPES = (torch.float16, torch.bfloat16) + SUPPORTED_DTYPES = (torch.float16,) + baseline_backend = BACKEND.TORCH_AWQ backend_cases = [ - (BACKEND.GEMM, torch.float16, 0.0), - (BACKEND.GEMM, torch.bfloat16, 0.0005), + (baseline_backend, torch.float16, 0.0), + # (baseline_backend, torch.bfloat16, 0.0), + (BACKEND.GEMM, torch.float16, 0.001), + # (BACKEND.GEMM, torch.bfloat16, 0.05), (BACKEND.MARLIN, torch.float16, 0.01), - (BACKEND.MARLIN, torch.bfloat16, 0.015), + # (BACKEND.MARLIN, torch.bfloat16, 0.05), ] @classmethod @@ -76,6 +80,10 @@ def setUpClass(cls) -> None: cls.modules: Dict[BACKEND, Optional[torch.nn.Module]] = {} + cls.modules[cls.baseline_backend] = cls._build_torch_awq_module( + qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu + ) + cls.modules[BACKEND.GEMM] = cls._build_gemm_module( qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu ) @@ -94,9 +102,20 @@ def setUpClass(cls) -> None: for tensor in base_inputs ] cls.inputs[dtype] = converted_inputs + torch_module = cls.modules.get(cls.baseline_backend) + if torch_module is None: + raise unittest.SkipTest("Torch AWQ kernel unavailable for baseline.") + + forward_kwargs = {} + if dtype == torch.bfloat16: + forward_kwargs = { + "compute_dtype": torch.float16, + "output_dtype": dtype, + } cls.reference_outputs[dtype] = cls._forward( - cls.modules[BACKEND.GEMM], + torch_module, converted_inputs, + **forward_kwargs, ) @classmethod @@ -151,8 +170,8 @@ def _build_gemm_module( module.qweight.copy_(qweight_cpu.to(cls.device)) module.qzeros.copy_(qzeros_cpu.to(cls.device)) - module.scales.copy_(scales_cpu.to(torch.float32).to(cls.device)) - module.bias.copy_(bias_cpu.to(torch.float32).to(cls.device)) + module.scales.copy_(scales_cpu.to(cls.device)) + module.bias.copy_(bias_cpu.to(cls.device)) module.eval() module.post_init() @@ -199,6 +218,35 @@ def _build_marlin_module( module.post_init() return module + @classmethod + def _build_torch_awq_module( + cls, + qweight_cpu: torch.Tensor, + qzeros_cpu: torch.Tensor, + scales_cpu: torch.Tensor, + bias_cpu: torch.Tensor, + ) -> AwqTorchQuantLinear: + module = AwqTorchQuantLinear( + bits=cls.BITS, + group_size=cls.GROUP_SIZE, + sym=True, + desc_act=False, + in_features=cls.in_features, + out_features=cls.out_features, + bias=True, + adapter=None, + register_buffers=True, + ).to(cls.device) + + module.qweight.copy_(qweight_cpu.to(cls.device)) + module.qzeros.copy_(qzeros_cpu.to(cls.device)) + module.scales.copy_(scales_cpu.to(cls.device)) + module.bias.copy_(bias_cpu.to(cls.device)) + + module.eval() + module.post_init() + return module + @classmethod def _generate_inputs(cls) -> List[torch.Tensor]: large_shapes = [(4, 32), (2, 64), (1, 96)] @@ -237,12 +285,20 @@ def _forward( cls, module: torch.nn.Module, inputs: Iterable[torch.Tensor], + *, + compute_dtype: Optional[torch.dtype] = None, + output_dtype: Optional[torch.dtype] = None, ) -> List[torch.Tensor]: outputs: List[torch.Tensor] = [] with torch.inference_mode(): for tensor in inputs: - result = module(tensor) - outputs.append(result.detach().to(torch.float32).cpu()) + local_tensor = tensor + if compute_dtype is not None and tensor.dtype != compute_dtype: + local_tensor = tensor.to(dtype=compute_dtype) + result = module(local_tensor) + if output_dtype is not None and result.dtype != output_dtype: + result = result.to(dtype=output_dtype) + outputs.append(result.detach().cpu()) return outputs def _maybe_skip_backend(self, backend: BACKEND) -> None: @@ -262,10 +318,15 @@ def _summarize_results( ) -> None: failures = [] total = len(actual_outputs) + max_abs_diff = 0.0 + mean_abs_diff = 0.0 for idx, (reference, actual) in enumerate(zip(reference_outputs, actual_outputs)): reference_fp32 = reference.to(torch.float32) actual_fp32 = actual.to(torch.float32) + diff = torch.abs(reference_fp32 - actual_fp32) + max_abs_diff = max(max_abs_diff, float(diff.max().item())) + mean_abs_diff += float(diff.mean().item()) is_close_tensor = torch.isclose(reference_fp32, actual_fp32, rtol=0.15, atol=atol) if not bool(torch.all(is_close_tensor)): failures.append( @@ -278,6 +339,7 @@ def _summarize_results( ) status = f"{GREEN}PASS{RESET}" if not failures else f"{RED}FAIL{RESET}" + avg_abs_diff = mean_abs_diff / total if total else 0.0 details = "\n\n".join(str(detail) for detail in failures) if failures else "-" table = tabulate( @@ -286,6 +348,8 @@ def _summarize_results( backend.name, str(dtype), total, + f"{max_abs_diff:.6f}", + f"{avg_abs_diff:.6f}", status, len(failures), details, @@ -295,6 +359,8 @@ def _summarize_results( "Backend", "DType", "Samples", + "MaxAbsDiff", + "MeanAbsDiff", "Status", "Failures", "Expected vs Actual", @@ -318,7 +384,10 @@ def test_awq_kernel_outputs(self, backend: BACKEND, dtype: torch.dtype, atol: fl inputs = self.inputs[dtype] reference_outputs = self.reference_outputs[dtype] - actual_outputs = self._forward(module, inputs) + if backend == self.baseline_backend: + actual_outputs = reference_outputs + else: + actual_outputs = self._forward(module, inputs) self._summarize_results( reference_outputs=reference_outputs, actual_outputs=actual_outputs, @@ -326,5 +395,5 @@ def test_awq_kernel_outputs(self, backend: BACKEND, dtype: torch.dtype, atol: fl dtype=dtype, atol=atol, title=f"AWQ Kernel Output {dtype}", - reference_label="AWQ GEMM output", + reference_label="Torch AWQ output", )