From 769411d48453006cbc48dc4e50336bc6f463a43c Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 29 Oct 2025 10:01:52 +0000 Subject: [PATCH 1/9] torch kernel for awq --- gptqmodel/nn_modules/qlinear/awq_torch.py | 141 ++++++++++++++++++++++ gptqmodel/utils/backend.py | 1 + gptqmodel/utils/importer.py | 6 +- tests/test_awq_torch_kernel.py | 108 +++++++++++++++++ 4 files changed, 255 insertions(+), 1 deletion(-) create mode 100644 gptqmodel/nn_modules/qlinear/awq_torch.py create mode 100644 tests/test_awq_torch_kernel.py diff --git a/gptqmodel/nn_modules/qlinear/awq_torch.py b/gptqmodel/nn_modules/qlinear/awq_torch.py new file mode 100644 index 000000000..0c197554d --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/awq_torch.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +import torch + +from ...adapter.adapter import Adapter, Lora +from ...models._const import DEVICE, PLATFORM +from ...quantization.awq.utils.packing_utils import dequantize_gemm +from ...utils.backend import BACKEND +from ...utils.logger import setup_logger +from . import AWQuantLinear + + +log = setup_logger() + + +class AwqTorchQuantLinear(AWQuantLinear): + SUPPORTS_BITS = [4] + SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] + SUPPORTS_DESC_ACT = [True, False] + SUPPORTS_SYM = [True, False] + SUPPORTS_SHARDS = True + SUPPORTS_TRAINING = True + SUPPORTS_AUTO_PADDING = False + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] + + SUPPORTS_DEVICES = [DEVICE.ALL] + SUPPORTS_PLATFORM = [PLATFORM.ALL] + SUPPORTS_PACK_DTYPES = [torch.int32] + SUPPORTS_ADAPTERS = [Lora] + + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + + REQUIRES_FORMAT_V2 = False + + QUANT_TYPE = "awq_torch" + + def __init__( + self, + bits: int, + group_size: int, + sym: bool, + desc_act: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, + register_buffers: bool = False, + **kwargs, + ): + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + backend=kwargs.pop("backend", BACKEND.TORCH_AWQ), + adapter=adapter, + register_buffers=register_buffers, + **kwargs, + ) + + self._cached_weights: dict[tuple[torch.device, torch.dtype], torch.Tensor] = {} + + def _invalidate_cache(self) -> None: + self._cached_weights.clear() + + def post_init(self): + self._invalidate_cache() + super().post_init() + + def extra_repr(self) -> str: + return ( + f"in_features={self.in_features}, out_features={self.out_features}, " + f"bias={self.bias is not None}, bits={self.bits}, group_size={self.group_size}" + ) + + def _materialize_weight(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + qweight = self.qweight.to(device=device, non_blocking=True) + qzeros = self.qzeros.to(device=device, non_blocking=True) if self.qzeros is not None else None + scales = self.scales.to(device=device, non_blocking=True) + + weight = dequantize_gemm(qweight, qzeros, scales, self.bits, self.group_size) + return weight.to(dtype=dtype) + + def _get_dequantized_weight(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + key = (device, dtype) + cached = self._cached_weights.get(key) + if cached is None: + cached = self._materialize_weight(device=device, dtype=dtype) + self._cached_weights[key] = cached + return cached + + def _get_bias(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor | None: + if self.bias is None: + return None + return self.bias.to(device=device, dtype=dtype, non_blocking=True) + + def forward(self, x: torch.Tensor): + original_shape = x.shape[:-1] + (self.out_features,) + original_dtype = x.dtype + device = x.device + + target_dtype = x.dtype + x_flat = x.reshape(-1, x.shape[-1]).to(dtype=target_dtype) + + weight = self._get_dequantized_weight(device=device, dtype=target_dtype) + output = torch.matmul(x_flat, weight) + + bias = self._get_bias(device=device, dtype=output.dtype) + if bias is not None: + output = output + bias + + if self.adapter: + output = self.adapter.apply(x=x_flat, out=output) + + output = output.reshape(original_shape) + + return output + + def load_state_dict(self, state_dict, strict=True): + result = super().load_state_dict(state_dict, strict=strict) + self._invalidate_cache() + return result + + def to(self, *args, **kwargs): + module = super().to(*args, **kwargs) + self._invalidate_cache() + return module + + +__all__ = ["AwqTorchQuantLinear"] diff --git a/gptqmodel/utils/backend.py b/gptqmodel/utils/backend.py index 664e0e40c..0744c0511 100644 --- a/gptqmodel/utils/backend.py +++ b/gptqmodel/utils/backend.py @@ -29,6 +29,7 @@ class BACKEND(str, Enum): GEMM = "gemm" GEMV = "gemv" GEMV_FAST = "gemv_fast" + TORCH_AWQ = "torch_awq" # external VLLM = "vllm" # External inference engine: CUDA + ROCm + IPEX diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index 3bb46f107..0f72d4f92 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -20,6 +20,7 @@ from ..nn_modules.qlinear.awq_gemv_fast import AwqGEMVFastQuantLinear from ..nn_modules.qlinear.awq_machete import AwqMacheteQuantLinear from ..nn_modules.qlinear.awq_marlin import AwqMarlinQuantLinear +from ..nn_modules.qlinear.awq_torch import AwqTorchQuantLinear from ..nn_modules.qlinear.bitblas import BitBLASQuantLinear from ..nn_modules.qlinear.exllama import ExllamaQuantLinear from ..nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear @@ -69,6 +70,7 @@ BACKEND.GEMM: AwqGEMMQuantLinear, BACKEND.GEMV: AwqGEMVQuantLinear, BACKEND.GEMV_FAST: AwqGEMVFastQuantLinear, + BACKEND.TORCH_AWQ: AwqTorchQuantLinear, }), } @@ -83,7 +85,7 @@ FORMAT.QQQ: [BACKEND.QQQ], }, METHOD.AWQ: { - FORMAT.GEMM: [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.GEMM], + FORMAT.GEMM: [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.GEMM, BACKEND.TORCH_AWQ], FORMAT.GEMV: [BACKEND.GEMV], FORMAT.GEMV_FAST: [BACKEND.GEMV_FAST], FORMAT.MARLIN: [BACKEND.MACHETE, BACKEND.MARLIN], @@ -314,6 +316,8 @@ def select_quant_linear( qlinear = AwqGEMVQuantLinear elif backend == BACKEND.GEMV_FAST: qlinear = AwqGEMVFastQuantLinear + elif backend == BACKEND.TORCH_AWQ: + qlinear = AwqTorchQuantLinear elif backend == BACKEND.TORCH: qlinear = TorchQuantLinear elif backend == BACKEND.TORCH_FUSED: diff --git a/tests/test_awq_torch_kernel.py b/tests/test_awq_torch_kernel.py new file mode 100644 index 000000000..590240c0a --- /dev/null +++ b/tests/test_awq_torch_kernel.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import pytest +import torch + +from gptqmodel.nn_modules.qlinear.awq_torch import AwqTorchQuantLinear +from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm +from gptqmodel.quantization import FORMAT, METHOD +from gptqmodel.utils.backend import BACKEND +from gptqmodel.utils.importer import select_quant_linear + + +def _pack_awq_tensor(unpacked: torch.Tensor, bits: int) -> torch.Tensor: + pack_factor = 32 // bits + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + + assert unpacked.shape[1] % pack_factor == 0 + packed = torch.zeros( + (unpacked.shape[0], unpacked.shape[1] // pack_factor), + dtype=torch.int32, + ) + for col in range(unpacked.shape[1] // pack_factor): + for i, order in enumerate(order_map): + value = unpacked[:, col * pack_factor + order].to(torch.int32) + packed[:, col] |= value << (i * bits) + return packed + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_awq_torch_matches_manual_dequant(dtype): + torch.manual_seed(0) + + bits = 4 + in_features = 32 + out_features = 64 + group_size = 16 + + assert out_features % (32 // bits) == 0 + assert in_features % group_size == 0 + + groups = in_features // group_size + pack_cols = out_features + + int_weight = torch.randint(0, 2**bits, size=(in_features, out_features), dtype=torch.int32) + zero_points = torch.randint(0, 2**bits, size=(groups, pack_cols), dtype=torch.int32) + scales = (torch.rand(groups, pack_cols, dtype=torch.float32) * 2.0) + 0.25 + bias = torch.randn(out_features, dtype=torch.float32) + + qweight = _pack_awq_tensor(int_weight, bits) + qzeros = _pack_awq_tensor(zero_points, bits) + + module = AwqTorchQuantLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + bias=True, + register_buffers=True, + ) + + module.qweight.copy_(qweight) + module.qzeros.copy_(qzeros) + module.scales.copy_(scales) + module.bias.copy_(bias) + module.post_init() + module.eval() + + batch = 4 + x = torch.randn(batch, in_features, dtype=dtype) + + dequant_weight = dequantize_gemm( + qweight=qweight.to(torch.int32), + qzeros=qzeros.to(torch.int32), + scales=scales, + bits=bits, + group_size=group_size, + ).to(dtype=dtype) + + expected = torch.matmul(x.to(dtype), dequant_weight) + expected = expected + bias.to(dtype) + + output_first = module(x) + output_second = module(x) + + atol = 1e-4 if dtype == torch.float32 else 5e-3 + rtol = 1e-4 if dtype == torch.float32 else 5e-3 + torch.testing.assert_close(output_first, expected, atol=atol, rtol=rtol) + torch.testing.assert_close(output_second, expected, atol=atol, rtol=rtol) + + +def test_awq_torch_backend_selection(): + qlinear_cls = select_quant_linear( + bits=4, + group_size=128, + desc_act=False, + sym=True, + device=None, + backend=BACKEND.TORCH_AWQ, + format=FORMAT.GEMM, + quant_method=METHOD.AWQ, + pack_dtype=torch.int32, + ) + assert qlinear_cls is AwqTorchQuantLinear From c27d5ef32e9273e58bf156f716c954ebf67b81fa Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 29 Oct 2025 10:12:03 +0000 Subject: [PATCH 2/9] cleanup --- gptqmodel/nn_modules/qlinear/awq_torch.py | 54 ++++++----------------- tests/test_awq_torch_kernel.py | 7 ++- 2 files changed, 18 insertions(+), 43 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/awq_torch.py b/gptqmodel/nn_modules/qlinear/awq_torch.py index 0c197554d..b6429dfa2 100644 --- a/gptqmodel/nn_modules/qlinear/awq_torch.py +++ b/gptqmodel/nn_modules/qlinear/awq_torch.py @@ -3,8 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from __future__ import annotations - import torch from ...adapter.adapter import Adapter, Lora @@ -69,13 +67,7 @@ def __init__( **kwargs, ) - self._cached_weights: dict[tuple[torch.device, torch.dtype], torch.Tensor] = {} - - def _invalidate_cache(self) -> None: - self._cached_weights.clear() - def post_init(self): - self._invalidate_cache() super().post_init() def extra_repr(self) -> str: @@ -84,39 +76,23 @@ def extra_repr(self) -> str: f"bias={self.bias is not None}, bits={self.bits}, group_size={self.group_size}" ) - def _materialize_weight(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: - qweight = self.qweight.to(device=device, non_blocking=True) - qzeros = self.qzeros.to(device=device, non_blocking=True) if self.qzeros is not None else None - scales = self.scales.to(device=device, non_blocking=True) - - weight = dequantize_gemm(qweight, qzeros, scales, self.bits, self.group_size) - return weight.to(dtype=dtype) - - def _get_dequantized_weight(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: - key = (device, dtype) - cached = self._cached_weights.get(key) - if cached is None: - cached = self._materialize_weight(device=device, dtype=dtype) - self._cached_weights[key] = cached - return cached - - def _get_bias(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor | None: - if self.bias is None: - return None - return self.bias.to(device=device, dtype=dtype, non_blocking=True) - def forward(self, x: torch.Tensor): original_shape = x.shape[:-1] + (self.out_features,) original_dtype = x.dtype device = x.device - target_dtype = x.dtype + self.ensure_buffer_dtype(original_dtype) + + target_dtype = original_dtype x_flat = x.reshape(-1, x.shape[-1]).to(dtype=target_dtype) - weight = self._get_dequantized_weight(device=device, dtype=target_dtype) + weight = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.bits, self.group_size).to(dtype=target_dtype) + output = torch.matmul(x_flat, weight) - bias = self._get_bias(device=device, dtype=output.dtype) + bias = None + if self.bias is not None: + bias = self.bias.to(device=device, dtype=target_dtype, non_blocking=True) if bias is not None: output = output + bias @@ -127,15 +103,11 @@ def forward(self, x: torch.Tensor): return output - def load_state_dict(self, state_dict, strict=True): - result = super().load_state_dict(state_dict, strict=strict) - self._invalidate_cache() - return result - - def to(self, *args, **kwargs): - module = super().to(*args, **kwargs) - self._invalidate_cache() - return module + def ensure_buffer_dtype(self, dtype: torch.dtype) -> None: + if self.scales.dtype != dtype: + self.scales = self.scales.to(dtype=dtype) + if self.bias is not None and self.bias.dtype != dtype: + self.bias = self.bias.to(dtype=dtype) __all__ = ["AwqTorchQuantLinear"] diff --git a/tests/test_awq_torch_kernel.py b/tests/test_awq_torch_kernel.py index 590240c0a..671e931af 100644 --- a/tests/test_awq_torch_kernel.py +++ b/tests/test_awq_torch_kernel.py @@ -73,16 +73,19 @@ def test_awq_torch_matches_manual_dequant(dtype): batch = 4 x = torch.randn(batch, in_features, dtype=dtype) + scales_expected = scales.to(dtype=dtype) + bias_expected = bias.to(dtype=dtype) + dequant_weight = dequantize_gemm( qweight=qweight.to(torch.int32), qzeros=qzeros.to(torch.int32), - scales=scales, + scales=scales_expected, bits=bits, group_size=group_size, ).to(dtype=dtype) expected = torch.matmul(x.to(dtype), dequant_weight) - expected = expected + bias.to(dtype) + expected = expected + bias_expected output_first = module(x) output_second = module(x) From 639bde0df4fe56ec0b8db0fd5380c1ed32cab152 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 29 Oct 2025 10:17:35 +0000 Subject: [PATCH 3/9] add torch awq to kernel outpt test --- tests/test_kernel_output_awq.py | 35 +++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/test_kernel_output_awq.py b/tests/test_kernel_output_awq.py index 95c8bbcda..fcd666dc5 100644 --- a/tests/test_kernel_output_awq.py +++ b/tests/test_kernel_output_awq.py @@ -17,6 +17,7 @@ from gptqmodel import BACKEND from gptqmodel.nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear +from gptqmodel.nn_modules.qlinear.awq_torch import AwqTorchQuantLinear from gptqmodel.nn_modules.qlinear.awq_marlin import ( AwqMarlinQuantLinear, marlin_import_exception, @@ -47,6 +48,8 @@ class TestAwqKernelOutput(unittest.TestCase): (BACKEND.GEMM, torch.bfloat16, 0.0005), (BACKEND.MARLIN, torch.float16, 0.01), (BACKEND.MARLIN, torch.bfloat16, 0.015), + (BACKEND.TORCH_AWQ, torch.float16, 0.001), + (BACKEND.TORCH_AWQ, torch.bfloat16, 0.05), ] @classmethod @@ -83,6 +86,9 @@ def setUpClass(cls) -> None: cls.modules[BACKEND.MARLIN] = cls._build_marlin_module( qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu ) + cls.modules[BACKEND.TORCH_AWQ] = cls._build_torch_awq_module( + qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu + ) base_inputs = cls._generate_inputs() cls.inputs: Dict[torch.dtype, List[torch.Tensor]] = {} @@ -199,6 +205,35 @@ def _build_marlin_module( module.post_init() return module + @classmethod + def _build_torch_awq_module( + cls, + qweight_cpu: torch.Tensor, + qzeros_cpu: torch.Tensor, + scales_cpu: torch.Tensor, + bias_cpu: torch.Tensor, + ) -> AwqTorchQuantLinear: + module = AwqTorchQuantLinear( + bits=cls.BITS, + group_size=cls.GROUP_SIZE, + sym=True, + desc_act=False, + in_features=cls.in_features, + out_features=cls.out_features, + bias=True, + adapter=None, + register_buffers=True, + ).to(cls.device) + + module.qweight.copy_(qweight_cpu.to(cls.device)) + module.qzeros.copy_(qzeros_cpu.to(cls.device)) + module.scales.copy_(scales_cpu.to(torch.float32).to(cls.device)) + module.bias.copy_(bias_cpu.to(torch.float32).to(cls.device)) + + module.eval() + module.post_init() + return module + @classmethod def _generate_inputs(cls) -> List[torch.Tensor]: large_shapes = [(4, 32), (2, 64), (1, 96)] From 8d6e905315c2b2330537581e4e9dce3d4f2b7fa4 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 29 Oct 2025 10:24:50 +0000 Subject: [PATCH 4/9] torch should be the baseline --- tests/test_kernel_output_awq.py | 41 ++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/tests/test_kernel_output_awq.py b/tests/test_kernel_output_awq.py index fcd666dc5..c5abfe3ca 100644 --- a/tests/test_kernel_output_awq.py +++ b/tests/test_kernel_output_awq.py @@ -43,13 +43,14 @@ class TestAwqKernelOutput(unittest.TestCase): GROUP_SIZE = 128 SUPPORTED_DTYPES = (torch.float16, torch.bfloat16) + baseline_backend = BACKEND.TORCH_AWQ backend_cases = [ - (BACKEND.GEMM, torch.float16, 0.0), - (BACKEND.GEMM, torch.bfloat16, 0.0005), + (baseline_backend, torch.float16, 0.0), + (baseline_backend, torch.bfloat16, 0.0), + (BACKEND.GEMM, torch.float16, 0.001), + (BACKEND.GEMM, torch.bfloat16, 0.05), (BACKEND.MARLIN, torch.float16, 0.01), - (BACKEND.MARLIN, torch.bfloat16, 0.015), - (BACKEND.TORCH_AWQ, torch.float16, 0.001), - (BACKEND.TORCH_AWQ, torch.bfloat16, 0.05), + (BACKEND.MARLIN, torch.bfloat16, 0.05), ] @classmethod @@ -79,14 +80,15 @@ def setUpClass(cls) -> None: cls.modules: Dict[BACKEND, Optional[torch.nn.Module]] = {} - cls.modules[BACKEND.GEMM] = cls._build_gemm_module( + cls.modules[cls.baseline_backend] = cls._build_torch_awq_module( qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu ) - cls.modules[BACKEND.MARLIN] = cls._build_marlin_module( + cls.modules[BACKEND.GEMM] = cls._build_gemm_module( qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu ) - cls.modules[BACKEND.TORCH_AWQ] = cls._build_torch_awq_module( + + cls.modules[BACKEND.MARLIN] = cls._build_marlin_module( qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu ) @@ -100,8 +102,12 @@ def setUpClass(cls) -> None: for tensor in base_inputs ] cls.inputs[dtype] = converted_inputs + torch_module = cls.modules.get(cls.baseline_backend) + if torch_module is None: + raise unittest.SkipTest("Torch AWQ kernel unavailable for baseline.") + cls.reference_outputs[dtype] = cls._forward( - cls.modules[BACKEND.GEMM], + torch_module, converted_inputs, ) @@ -297,10 +303,15 @@ def _summarize_results( ) -> None: failures = [] total = len(actual_outputs) + max_abs_diff = 0.0 + mean_abs_diff = 0.0 for idx, (reference, actual) in enumerate(zip(reference_outputs, actual_outputs)): reference_fp32 = reference.to(torch.float32) actual_fp32 = actual.to(torch.float32) + diff = torch.abs(reference_fp32 - actual_fp32) + max_abs_diff = max(max_abs_diff, float(diff.max().item())) + mean_abs_diff += float(diff.mean().item()) is_close_tensor = torch.isclose(reference_fp32, actual_fp32, rtol=0.15, atol=atol) if not bool(torch.all(is_close_tensor)): failures.append( @@ -313,6 +324,7 @@ def _summarize_results( ) status = f"{GREEN}PASS{RESET}" if not failures else f"{RED}FAIL{RESET}" + avg_abs_diff = mean_abs_diff / total if total else 0.0 details = "\n\n".join(str(detail) for detail in failures) if failures else "-" table = tabulate( @@ -321,6 +333,8 @@ def _summarize_results( backend.name, str(dtype), total, + f"{max_abs_diff:.6f}", + f"{avg_abs_diff:.6f}", status, len(failures), details, @@ -330,6 +344,8 @@ def _summarize_results( "Backend", "DType", "Samples", + "MaxAbsDiff", + "MeanAbsDiff", "Status", "Failures", "Expected vs Actual", @@ -353,7 +369,10 @@ def test_awq_kernel_outputs(self, backend: BACKEND, dtype: torch.dtype, atol: fl inputs = self.inputs[dtype] reference_outputs = self.reference_outputs[dtype] - actual_outputs = self._forward(module, inputs) + if backend == self.baseline_backend: + actual_outputs = reference_outputs + else: + actual_outputs = self._forward(module, inputs) self._summarize_results( reference_outputs=reference_outputs, actual_outputs=actual_outputs, @@ -361,5 +380,5 @@ def test_awq_kernel_outputs(self, backend: BACKEND, dtype: torch.dtype, atol: fl dtype=dtype, atol=atol, title=f"AWQ Kernel Output {dtype}", - reference_label="AWQ GEMM output", + reference_label="Torch AWQ output", ) From c4ceb5c2609188f014bd8a45f8c1c90005b2179a Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 29 Oct 2025 10:33:45 +0000 Subject: [PATCH 5/9] awq is float16 only..bfloat16 has high error loss --- gptqmodel/nn_modules/qlinear/awq_gemm.py | 2 +- gptqmodel/nn_modules/qlinear/awq_gemv.py | 2 +- gptqmodel/nn_modules/qlinear/awq_marlin.py | 2 +- gptqmodel/nn_modules/qlinear/awq_torch.py | 18 ++++------------ tests/test_kernel_output_awq.py | 25 +++++++++++++++++----- 5 files changed, 27 insertions(+), 22 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/awq_gemm.py b/gptqmodel/nn_modules/qlinear/awq_gemm.py index 4f041e794..e2b95100e 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemm.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemm.py @@ -31,7 +31,7 @@ class AwqGEMMQuantLinear(AWQuantLinear): SUPPORTS_PACK_DTYPES = [torch.int32] SUPPORTS_ADAPTERS = [Lora] - SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + SUPPORTS_DTYPES = [torch.float16] REQUIRES_FORMAT_V2 = False diff --git a/gptqmodel/nn_modules/qlinear/awq_gemv.py b/gptqmodel/nn_modules/qlinear/awq_gemv.py index 50add314f..150ad3b69 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemv.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemv.py @@ -34,7 +34,7 @@ class AwqGEMVQuantLinear(AWQuantLinear): SUPPORTS_PACK_DTYPES = [torch.int32] SUPPORTS_ADAPTERS = [Lora] - SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + SUPPORTS_DTYPES = [torch.float16] # for transformers/optimum tests compat QUANT_TYPE = "awq_gemv" diff --git a/gptqmodel/nn_modules/qlinear/awq_marlin.py b/gptqmodel/nn_modules/qlinear/awq_marlin.py index e7d576159..5b41a9d6f 100644 --- a/gptqmodel/nn_modules/qlinear/awq_marlin.py +++ b/gptqmodel/nn_modules/qlinear/awq_marlin.py @@ -53,7 +53,7 @@ class AwqMarlinQuantLinear(AWQuantLinear): SUPPORTS_PACK_DTYPES = [torch.int32] SUPPORTS_ADAPTERS = [Lora] - SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + SUPPORTS_DTYPES = [torch.float16] REQUIRES_FORMAT_V2 = False diff --git a/gptqmodel/nn_modules/qlinear/awq_torch.py b/gptqmodel/nn_modules/qlinear/awq_torch.py index b6429dfa2..35ce77f60 100644 --- a/gptqmodel/nn_modules/qlinear/awq_torch.py +++ b/gptqmodel/nn_modules/qlinear/awq_torch.py @@ -32,7 +32,7 @@ class AwqTorchQuantLinear(AWQuantLinear): SUPPORTS_PACK_DTYPES = [torch.int32] SUPPORTS_ADAPTERS = [Lora] - SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + SUPPORTS_DTYPES = [torch.float16] REQUIRES_FORMAT_V2 = False @@ -78,21 +78,18 @@ def extra_repr(self) -> str: def forward(self, x: torch.Tensor): original_shape = x.shape[:-1] + (self.out_features,) - original_dtype = x.dtype device = x.device - self.ensure_buffer_dtype(original_dtype) - target_dtype = original_dtype - x_flat = x.reshape(-1, x.shape[-1]).to(dtype=target_dtype) + x_flat = x.reshape(-1, x.shape[-1]) - weight = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.bits, self.group_size).to(dtype=target_dtype) + weight = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.bits, self.group_size) output = torch.matmul(x_flat, weight) bias = None if self.bias is not None: - bias = self.bias.to(device=device, dtype=target_dtype, non_blocking=True) + bias = self.bias.to(device=device) if bias is not None: output = output + bias @@ -103,11 +100,4 @@ def forward(self, x: torch.Tensor): return output - def ensure_buffer_dtype(self, dtype: torch.dtype) -> None: - if self.scales.dtype != dtype: - self.scales = self.scales.to(dtype=dtype) - if self.bias is not None and self.bias.dtype != dtype: - self.bias = self.bias.to(dtype=dtype) - - __all__ = ["AwqTorchQuantLinear"] diff --git a/tests/test_kernel_output_awq.py b/tests/test_kernel_output_awq.py index c5abfe3ca..6cd10538b 100644 --- a/tests/test_kernel_output_awq.py +++ b/tests/test_kernel_output_awq.py @@ -41,16 +41,16 @@ class TestAwqKernelOutput(unittest.TestCase): TARGET = "model.layers.20.self_attn.v_proj" BITS = 4 GROUP_SIZE = 128 - SUPPORTED_DTYPES = (torch.float16, torch.bfloat16) + SUPPORTED_DTYPES = (torch.float16,) baseline_backend = BACKEND.TORCH_AWQ backend_cases = [ (baseline_backend, torch.float16, 0.0), - (baseline_backend, torch.bfloat16, 0.0), + # (baseline_backend, torch.bfloat16, 0.0), (BACKEND.GEMM, torch.float16, 0.001), - (BACKEND.GEMM, torch.bfloat16, 0.05), + # (BACKEND.GEMM, torch.bfloat16, 0.05), (BACKEND.MARLIN, torch.float16, 0.01), - (BACKEND.MARLIN, torch.bfloat16, 0.05), + # (BACKEND.MARLIN, torch.bfloat16, 0.05), ] @classmethod @@ -106,9 +106,16 @@ def setUpClass(cls) -> None: if torch_module is None: raise unittest.SkipTest("Torch AWQ kernel unavailable for baseline.") + forward_kwargs = {} + if dtype == torch.bfloat16: + forward_kwargs = { + "compute_dtype": torch.float16, + "output_dtype": dtype, + } cls.reference_outputs[dtype] = cls._forward( torch_module, converted_inputs, + **forward_kwargs, ) @classmethod @@ -278,11 +285,19 @@ def _forward( cls, module: torch.nn.Module, inputs: Iterable[torch.Tensor], + *, + compute_dtype: Optional[torch.dtype] = None, + output_dtype: Optional[torch.dtype] = None, ) -> List[torch.Tensor]: outputs: List[torch.Tensor] = [] with torch.inference_mode(): for tensor in inputs: - result = module(tensor) + local_tensor = tensor + if compute_dtype is not None and tensor.dtype != compute_dtype: + local_tensor = tensor.to(dtype=compute_dtype) + result = module(local_tensor) + if output_dtype is not None and result.dtype != output_dtype: + result = result.to(dtype=output_dtype) outputs.append(result.detach().to(torch.float32).cpu()) return outputs From d73b16178a115a2774e7005a04a31f23d4c45ba9 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 29 Oct 2025 10:38:04 +0000 Subject: [PATCH 6/9] awq is float16 only..bfloat16 has high error loss --- gptqmodel/nn_modules/qlinear/awq_torch.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/awq_torch.py b/gptqmodel/nn_modules/qlinear/awq_torch.py index 35ce77f60..6fd6deda2 100644 --- a/gptqmodel/nn_modules/qlinear/awq_torch.py +++ b/gptqmodel/nn_modules/qlinear/awq_torch.py @@ -79,19 +79,17 @@ def extra_repr(self) -> str: def forward(self, x: torch.Tensor): original_shape = x.shape[:-1] + (self.out_features,) device = x.device - - x_flat = x.reshape(-1, x.shape[-1]) weight = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.bits, self.group_size) + assert weight.dtype == torch.float16 + if weight.dtype != x_flat.dtype or weight.device != device: + weight = weight.to(device=device, dtype=x_flat.dtype) output = torch.matmul(x_flat, weight) - bias = None if self.bias is not None: - bias = self.bias.to(device=device) - if bias is not None: - output = output + bias + output = output + self.bias if self.adapter: output = self.adapter.apply(x=x_flat, out=output) From d305428fd373ddf24947c1a9b47f4fb1a0c424d4 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 29 Oct 2025 11:02:56 +0000 Subject: [PATCH 7/9] cleanup --- gptqmodel/nn_modules/qlinear/awq_torch.py | 2 +- tests/test_awq_torch_kernel.py | 20 +++++++++++--------- tests/test_kernel_output_awq.py | 10 +++++----- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/awq_torch.py b/gptqmodel/nn_modules/qlinear/awq_torch.py index 6fd6deda2..5e298d7dc 100644 --- a/gptqmodel/nn_modules/qlinear/awq_torch.py +++ b/gptqmodel/nn_modules/qlinear/awq_torch.py @@ -82,7 +82,7 @@ def forward(self, x: torch.Tensor): x_flat = x.reshape(-1, x.shape[-1]) weight = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.bits, self.group_size) - assert weight.dtype == torch.float16 + assert weight.dtype == torch.float16, f"weight {weight.dtype} is not float16" if weight.dtype != x_flat.dtype or weight.device != device: weight = weight.to(device=device, dtype=x_flat.dtype) diff --git a/tests/test_awq_torch_kernel.py b/tests/test_awq_torch_kernel.py index 671e931af..03096104e 100644 --- a/tests/test_awq_torch_kernel.py +++ b/tests/test_awq_torch_kernel.py @@ -29,8 +29,10 @@ def _pack_awq_tensor(unpacked: torch.Tensor, bits: int) -> torch.Tensor: return packed -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float16]) def test_awq_torch_matches_manual_dequant(dtype): + if dtype not in AwqTorchQuantLinear.SUPPORTS_DTYPES: + pytest.skip(f"dtype {dtype} not supported by AwqTorchQuantLinear") torch.manual_seed(0) bits = 4 @@ -46,8 +48,8 @@ def test_awq_torch_matches_manual_dequant(dtype): int_weight = torch.randint(0, 2**bits, size=(in_features, out_features), dtype=torch.int32) zero_points = torch.randint(0, 2**bits, size=(groups, pack_cols), dtype=torch.int32) - scales = (torch.rand(groups, pack_cols, dtype=torch.float32) * 2.0) + 0.25 - bias = torch.randn(out_features, dtype=torch.float32) + scales = (torch.rand(groups, pack_cols, dtype=torch.float16) * 2.0) + 0.25 + bias = torch.randn(out_features, dtype=torch.float16) qweight = _pack_awq_tensor(int_weight, bits) qzeros = _pack_awq_tensor(zero_points, bits) @@ -65,7 +67,8 @@ def test_awq_torch_matches_manual_dequant(dtype): module.qweight.copy_(qweight) module.qzeros.copy_(qzeros) - module.scales.copy_(scales) + module.scales = module.scales.to(dtype=torch.float16) + module.scales.copy_(scales.to(torch.float16)) module.bias.copy_(bias) module.post_init() module.eval() @@ -73,13 +76,12 @@ def test_awq_torch_matches_manual_dequant(dtype): batch = 4 x = torch.randn(batch, in_features, dtype=dtype) - scales_expected = scales.to(dtype=dtype) - bias_expected = bias.to(dtype=dtype) + bias_expected = module.bias dequant_weight = dequantize_gemm( - qweight=qweight.to(torch.int32), - qzeros=qzeros.to(torch.int32), - scales=scales_expected, + qweight=module.qweight, + qzeros=module.qzeros, + scales=module.scales, bits=bits, group_size=group_size, ).to(dtype=dtype) diff --git a/tests/test_kernel_output_awq.py b/tests/test_kernel_output_awq.py index 6cd10538b..5fa1081ed 100644 --- a/tests/test_kernel_output_awq.py +++ b/tests/test_kernel_output_awq.py @@ -170,8 +170,8 @@ def _build_gemm_module( module.qweight.copy_(qweight_cpu.to(cls.device)) module.qzeros.copy_(qzeros_cpu.to(cls.device)) - module.scales.copy_(scales_cpu.to(torch.float32).to(cls.device)) - module.bias.copy_(bias_cpu.to(torch.float32).to(cls.device)) + module.scales.copy_(scales_cpu.to(cls.device)) + module.bias.copy_(bias_cpu.to(cls.device)) module.eval() module.post_init() @@ -240,8 +240,8 @@ def _build_torch_awq_module( module.qweight.copy_(qweight_cpu.to(cls.device)) module.qzeros.copy_(qzeros_cpu.to(cls.device)) - module.scales.copy_(scales_cpu.to(torch.float32).to(cls.device)) - module.bias.copy_(bias_cpu.to(torch.float32).to(cls.device)) + module.scales.copy_(scales_cpu.to(cls.device)) + module.bias.copy_(bias_cpu.to(cls.device)) module.eval() module.post_init() @@ -298,7 +298,7 @@ def _forward( result = module(local_tensor) if output_dtype is not None and result.dtype != output_dtype: result = result.to(dtype=output_dtype) - outputs.append(result.detach().to(torch.float32).cpu()) + outputs.append(result.detach().cpu()) return outputs def _maybe_skip_backend(self, backend: BACKEND) -> None: From 6708218b8b43753d98dd9df5d87aaa32cbc599bb Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 29 Oct 2025 11:09:43 +0000 Subject: [PATCH 8/9] fix base awq quant linear has wrong float32 dtype for scales/bis! --- gptqmodel/nn_modules/qlinear/__init__.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index 20b2d6892..442e925be 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -1032,12 +1032,10 @@ def pack_original(self, linear: nn.Module, scales: t.Tensor, zeros: t.Tensor, g_ class AWQuantLinear(BaseQuantLinear): def __init__(self, bias: bool = False, - use_bf16: bool = False, register_buffers: bool = False, **kwargs): super().__init__(bias=bias, register_buffers=False, **kwargs) - self.use_bf16 = use_bf16 in_features = self.in_features out_features = self.out_features @@ -1058,12 +1056,12 @@ def __init__(self, "scales", t.zeros( (in_features // self.group_size, out_features), - dtype=t.bfloat16 if self.use_bf16 else t.float32, + dtype=t.float16, ), ) if bias: - self.register_buffer("bias", t.zeros(out_features, dtype=t.bfloat16 if self.use_bf16 else t.float32,)) + self.register_buffer("bias", t.zeros(out_features, dtype=t.float16)) else: self.bias = None From 6fc22c93fd92f6cb676c33cba7f652de834d255e Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 29 Oct 2025 11:13:46 +0000 Subject: [PATCH 9/9] ruff --- gptqmodel/models/definitions/ovis.py | 2 +- gptqmodel/models/definitions/qwen3_moe.py | 1 - tests/test_awq_torch_kernel.py | 2 +- tests/test_kernel_output_awq.py | 2 +- 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/gptqmodel/models/definitions/ovis.py b/gptqmodel/models/definitions/ovis.py index a9144ed19..7dbae76c9 100644 --- a/gptqmodel/models/definitions/ovis.py +++ b/gptqmodel/models/definitions/ovis.py @@ -12,9 +12,9 @@ from ...utils.calibration import batched from ...utils.image import fetch_image from ...utils.model import MODALITY, move_to +from ...utils.offload import offload_to_disk from .._const import CPU from ..base import BaseQModel -from ...utils.offload import offload_to_disk class OvisQModel(BaseQModel): diff --git a/gptqmodel/models/definitions/qwen3_moe.py b/gptqmodel/models/definitions/qwen3_moe.py index c68ee7afa..6e07fab60 100644 --- a/gptqmodel/models/definitions/qwen3_moe.py +++ b/gptqmodel/models/definitions/qwen3_moe.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from ...quantization import METHOD from ..base import BaseQModel diff --git a/tests/test_awq_torch_kernel.py b/tests/test_awq_torch_kernel.py index 03096104e..a91dd88ef 100644 --- a/tests/test_awq_torch_kernel.py +++ b/tests/test_awq_torch_kernel.py @@ -7,8 +7,8 @@ import torch from gptqmodel.nn_modules.qlinear.awq_torch import AwqTorchQuantLinear -from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm from gptqmodel.quantization import FORMAT, METHOD +from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm from gptqmodel.utils.backend import BACKEND from gptqmodel.utils.importer import select_quant_linear diff --git a/tests/test_kernel_output_awq.py b/tests/test_kernel_output_awq.py index 5fa1081ed..7b3c0a5bf 100644 --- a/tests/test_kernel_output_awq.py +++ b/tests/test_kernel_output_awq.py @@ -17,11 +17,11 @@ from gptqmodel import BACKEND from gptqmodel.nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear -from gptqmodel.nn_modules.qlinear.awq_torch import AwqTorchQuantLinear from gptqmodel.nn_modules.qlinear.awq_marlin import ( AwqMarlinQuantLinear, marlin_import_exception, ) +from gptqmodel.nn_modules.qlinear.awq_torch import AwqTorchQuantLinear from gptqmodel.utils.marlin import marlin_make_workspace_new