From 8b71d483c6caf4319bf53f06305f7dea98096f8c Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 02:47:44 +0000 Subject: [PATCH 01/26] use torch.ops.aten fused ops for awq --- gptqmodel/nn_modules/qlinear/torch_fused.py | 58 ++-- .../nn_modules/qlinear/torch_fused_awq.py | 296 ++++++++++++++++++ gptqmodel/utils/backend.py | 1 + gptqmodel/utils/importer.py | 12 +- tests/test_torch_fused_awq.py | 100 ++++++ 5 files changed, 444 insertions(+), 23 deletions(-) create mode 100644 gptqmodel/nn_modules/qlinear/torch_fused_awq.py create mode 100644 tests/test_torch_fused_awq.py diff --git a/gptqmodel/nn_modules/qlinear/torch_fused.py b/gptqmodel/nn_modules/qlinear/torch_fused.py index c17ef2c73..085ef9a5c 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused.py @@ -111,6 +111,38 @@ def optimize(self): super().optimize() + def _build_ret_idx(self) -> torch.Tensor: + existing = getattr(self, "ret_idx", None) + total = self.g_idx.shape[0] + if isinstance(existing, torch.Tensor) and existing.numel() == total: + return existing + + device = self.g_idx.device + ret_idx = torch.zeros(total, dtype=torch.int32, device=device) + group_size = max(int(self.group_size), 1) + groups = total // group_size + remainder = total % group_size + g_idx = self.g_idx.to(torch.int32) + g_idx_2 = g_idx * group_size + + if remainder > 0: + mask = g_idx == groups + if mask.any(): + g_idx_2[mask] += torch.arange(remainder, device=device, dtype=torch.int32) + + if groups > 0: + base = torch.arange(group_size, device=device, dtype=torch.int32) + for i in range(groups): + mask = g_idx == i + if not mask.any(): + continue + count = int(mask.sum().item()) + g_idx_2[mask] += base[:count] + + ret_idx[g_idx_2] = torch.arange(total, device=device, dtype=torch.int32) + self.ret_idx = ret_idx + return ret_idx + def train(self, mode: bool = True): old_train = self.training if mode == old_train: @@ -156,17 +188,8 @@ def transform_xpu(self, dtype): ).to(self.dequant_dtype), self.maxq ) - self.ret_idx = torch.zeros(self.g_idx.shape[0], dtype=torch.int32).to(self.g_idx.device) - groups = self.g_idx.shape[0] // self.group_size - remainder = self.g_idx.shape[0] % self.group_size - g_idx_2 = self.g_idx * self.group_size - if remainder > 0: - g_idx_2[self.g_idx == groups] += torch.arange(remainder).to(self.g_idx_2.device).to(self.g_idx_2.dtype) - arange_tensor = torch.arange(self.group_size).to(self.g_idx.device).to(self.g_idx.dtype) - for i in range(groups): - g_idx_2[self.g_idx == i] += arange_tensor - self.ret_idx[g_idx_2] = torch.arange(self.g_idx.shape[0]).to(self.ret_idx.device).to(self.ret_idx.dtype) - weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]).index_select(0, self.ret_idx).t() + ret_idx = self._build_ret_idx() + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]).index_select(0, ret_idx).t() # Pack qweight packed = torch.zeros(weight.shape[0], weight.shape[1] // self.pack_factor, dtype=torch.int32, device=weight.device) for col in range(weight.shape[1] // self.pack_factor): @@ -187,17 +210,8 @@ def transform_cpu(self, dtype): ).to(self.dequant_dtype), self.maxq ) - self.ret_idx = torch.zeros(self.g_idx.shape[0], dtype=torch.int32).to(self.g_idx.device) - groups = self.g_idx.shape[0] // self.group_size - remainder = self.g_idx.shape[0] % self.group_size - g_idx_2 = self.g_idx * self.group_size - if remainder > 0: - g_idx_2[self.g_idx == groups] += torch.arange(remainder).to(self.g_idx_2.device).to(self.g_idx_2.dtype) - arange_tensor = torch.arange(self.group_size).to(self.g_idx.device).to(self.g_idx.dtype) - for i in range(groups): - g_idx_2[self.g_idx == i] += arange_tensor - self.ret_idx[g_idx_2] = torch.arange(self.g_idx.shape[0]).to(self.ret_idx.device).to(self.ret_idx.dtype) - weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]).index_select(0, self.ret_idx).t() + ret_idx = self._build_ret_idx() + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]).index_select(0, ret_idx).t() self.qweight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(weight.int(), 1).contiguous() self.qzeros = torch.zeros_like(self.scales).contiguous() self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py new file mode 100644 index 000000000..2601e8665 --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -0,0 +1,296 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import torch + +from ...adapter.adapter import Adapter +from ...quantization.awq.utils.packing_utils import ( + dequantize_gemm, + reverse_awq_order, + unpack_awq, +) +from ...utils.backend import BACKEND +from ...utils.logger import setup_logger +from ...utils.torch import TORCH_HAS_FUSED_OPS +from .torch_fused import Int4PackedOp, TorchFusedQuantLinear, pack_scales_and_zeros + + +log = setup_logger() + + +class TorchFusedAwqQuantLinear(TorchFusedQuantLinear): + """Torch fused AWQ variant that reuses the GPTQ fused kernels via CPU int4 packing.""" + + QUANT_TYPE = "torch_fused_awq" + SUPPORTS_BITS = TorchFusedQuantLinear.SUPPORTS_BITS + SUPPORTS_GROUP_SIZE = TorchFusedQuantLinear.SUPPORTS_GROUP_SIZE + SUPPORTS_DESC_ACT = TorchFusedQuantLinear.SUPPORTS_DESC_ACT + SUPPORTS_SYM = TorchFusedQuantLinear.SUPPORTS_SYM + SUPPORTS_SHARDS = TorchFusedQuantLinear.SUPPORTS_SHARDS + SUPPORTS_TRAINING = TorchFusedQuantLinear.SUPPORTS_TRAINING + SUPPORTS_AUTO_PADDING = TorchFusedQuantLinear.SUPPORTS_AUTO_PADDING + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = TorchFusedQuantLinear.SUPPORTS_IN_FEATURES_DIVISIBLE_BY + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = TorchFusedQuantLinear.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY + SUPPORTS_DEVICES = TorchFusedQuantLinear.SUPPORTS_DEVICES + SUPPORTS_PLATFORM = TorchFusedQuantLinear.SUPPORTS_PLATFORM + SUPPORTS_PACK_DTYPES = TorchFusedQuantLinear.SUPPORTS_PACK_DTYPES + SUPPORTS_ADAPTERS = TorchFusedQuantLinear.SUPPORTS_ADAPTERS + + SUPPORTS_DTYPES = [torch.float16] + REQUIRES_FORMAT_V2 = False + + def __init__( + self, + bits: int, + group_size: int, + sym: bool, + desc_act: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, + register_buffers: bool = True, + **kwargs, + ): + # AWQ bookkeeping must exist before the base class registers buffers. + self._awq_layout = False + self._awq_dense_weight = None + self._awq_qweight = None + self._awq_qzeros = None + self._awq_scales = None + self._awq_qweight_ptr = None + self._awq_qzeros_ptr = None + self._awq_scales_ptr = None + + kwargs.setdefault("backend", BACKEND.TORCH_FUSED_AWQ) + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + adapter=adapter, + register_buffers=register_buffers, + **kwargs, + ) + + def post_init(self): + super().post_init() + self._maybe_capture_awq_state(force=True) + + def register_buffer(self, name, tensor, persistent=True): + super().register_buffer(name, tensor, persistent=persistent) + if name in {"qweight", "qzeros", "scales"}: + self._maybe_capture_awq_state() + + def optimize(self): + if self.optimized: + return + super().optimize() + self._uses_awq_layout() + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + qweight_key = prefix + "qweight" + awq_tensor = None + if qweight_key in state_dict: + candidate = state_dict[qweight_key] + if self._is_awq_qweight_tensor(candidate): + awq_tensor = candidate.to(self.pack_dtype).clone() + self._awq_qweight = awq_tensor.clone() + placeholder = getattr(self, "qweight", None) + if isinstance(placeholder, torch.Tensor) and placeholder.numel() == awq_tensor.numel(): + state_dict[qweight_key] = torch.zeros_like(placeholder) + else: + rows = max(1, self.in_features // self.pack_factor) + cols = self.out_features + state_dict[qweight_key] = torch.zeros( + (rows, cols), + dtype=self.pack_dtype, + device=awq_tensor.device, + ) + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + if awq_tensor is not None: + state_dict[qweight_key] = awq_tensor + device = getattr(self, "qweight", awq_tensor).device + self.register_buffer( + "qweight", + awq_tensor.to(device=device, dtype=self.pack_dtype).contiguous(), + persistent=True, + ) + self._awq_layout = True + + def _awq_qweight_shape(self): + pack_cols = self.out_features // self.pack_factor + return self.in_features, pack_cols + + def _is_awq_qweight_tensor(self, tensor: torch.Tensor) -> bool: + if tensor is None or not torch.is_tensor(tensor) or tensor.dim() != 2: + return False + rows, cols = tensor.shape + exp_rows, exp_cols = self._awq_qweight_shape() + return rows == exp_rows and cols == exp_cols + + def _maybe_capture_awq_state(self, force: bool = False): + qweight = getattr(self, "qweight", None) + qzeros = getattr(self, "qzeros", None) + scales = getattr(self, "scales", None) + if ( + qweight is None + or qzeros is None + or scales is None + or not torch.is_tensor(qweight) + or not torch.is_tensor(qzeros) + or not torch.is_tensor(scales) + or not self._is_awq_qweight_tensor(qweight) + ): + return + + qweight_ptr = qweight.data_ptr() + qzeros_ptr = qzeros.data_ptr() + scales_ptr = scales.data_ptr() + if ( + not force + and self._awq_qweight is not None + and self._awq_qzeros is not None + and self._awq_scales is not None + and self._awq_qweight_ptr == qweight_ptr + and self._awq_qzeros_ptr == qzeros_ptr + and self._awq_scales_ptr == scales_ptr + ): + return + + self._awq_qweight = qweight.clone() + self._awq_qzeros = qzeros.clone() + scale_clone = scales.clone() + if scale_clone.dtype != torch.float16: + scale_clone = scale_clone.to(torch.float16) + self._awq_scales = scale_clone + self._awq_dense_weight = None + self._awq_qweight_ptr = qweight_ptr + self._awq_qzeros_ptr = qzeros_ptr + self._awq_scales_ptr = scales_ptr + self._awq_layout = True + + def _uses_awq_layout(self) -> bool: + if self._awq_layout: + return True + self._maybe_capture_awq_state() + return self._awq_layout + + def _transform_cpu_awq(self, dtype): + self.scales = self.scales.clone().to(dtype).contiguous() + scale_fp32 = self.scales.to(torch.float32) + iweight, izeros = unpack_awq(self.qweight, self.qzeros, self.bits) + iweight, izeros = reverse_awq_order(iweight, izeros, self.bits) + max_val = (1 << self.bits) - 1 + iweight = torch.bitwise_and(iweight, max_val) + if izeros is not None: + izeros = torch.bitwise_and(izeros, max_val) + ret_idx = self._build_ret_idx() + weight = iweight.index_select(0, ret_idx).t().contiguous() + self.qweight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(weight.int(), 1).contiguous() + + if izeros is None: + zeros = torch.zeros_like(scale_fp32) + else: + zero_offset = 1 << (self.bits - 1) + zeros = (zero_offset - izeros.reshape_as(scale_fp32)).to(dtype=scale_fp32.dtype) + zeros = zeros * scale_fp32 + self.scales = scale_fp32.to(dtype=dtype) + self.qzeros = zeros.to(dtype=dtype) + self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) + self._awq_layout = True + + def _awq_weight_dense(self, device, dtype): + if self._awq_dense_weight is None: + assert self._awq_qweight is not None + assert self._awq_qzeros is not None + assert self._awq_scales is not None + dense = dequantize_gemm( + self._awq_qweight, + self._awq_qzeros, + self._awq_scales, + self.bits, + self.group_size, + ).to(device=device, dtype=torch.float32) + self._awq_dense_weight = dense + return self._awq_dense_weight.to(device=device, dtype=dtype) + + def transform_cpu(self, dtype): + if self._uses_awq_layout(): + self._transform_cpu_awq(dtype) + return + super().transform_cpu(dtype) + + def transform(self, dtype, device): + if device == "xpu" and self._uses_awq_layout(): + raise NotImplementedError("TorchFusedAwqQuantLinear AWQ layout is currently supported on CPU only.") + super().transform(dtype, device) + + def forward(self, x: torch.Tensor): + out_shape = x.shape[:-1] + (self.out_features,) + x_flat = x.reshape(-1, x.shape[-1]) + if not self.training and not self.transformed and TORCH_HAS_FUSED_OPS: + self.transform(x_flat.dtype, x_flat.device.type) + self.transformed = True + if x_flat.device.type == "cpu": + self.torch_fused_op = Int4PackedOp( + self.qweight, self.scales_and_zeros, self.group_size + ).eval() + import torch._inductor.config as config + config.freezing = True + config.max_autotune = True + + if self.transformed: + out = self._fused_op_forward(x_flat) + else: + if self._uses_awq_layout(): + weight = self._awq_weight_dense(device=x_flat.device, dtype=x_flat.dtype) + out = torch.matmul(x_flat, weight) + else: + num_itr = self.g_idx.shape[0] // x_flat.shape[-1] + weights = self.dequantize_weight(num_itr=num_itr).to(x_flat.dtype) + out = torch.matmul(x_flat, weights) + + if self.bias is not None: + out.add_(self.bias) + if self.adapter: + out = self.adapter.apply(x=x_flat, out=out) + + return out.reshape(out_shape) + + @torch.no_grad + def _fused_op_forward(self, x): + awq_active = self._uses_awq_layout() + use_awq_fallback = awq_active and x.device.type == "cpu" + if use_awq_fallback: + compute_dtype = torch.float16 if x.dtype == torch.bfloat16 else x.dtype + weight = self._awq_weight_dense(device=x.device, dtype=compute_dtype) + return torch.matmul(x.to(compute_dtype), weight).to(x.dtype) + return super()._fused_op_forward(x) + + +__all__ = ["TorchFusedAwqQuantLinear"] diff --git a/gptqmodel/utils/backend.py b/gptqmodel/utils/backend.py index 0744c0511..2d2746bee 100644 --- a/gptqmodel/utils/backend.py +++ b/gptqmodel/utils/backend.py @@ -12,6 +12,7 @@ class BACKEND(str, Enum): # gptq TORCH_FUSED = "torch_fused" # optimized for Intel XPU + TORCH_FUSED_AWQ = "torch_fused_awq" # AWQ variant of torch fused kernel TORCH = "torch" # GOOD: about 80% of triton TRITON = "triton" # VERY GOOD: all-around kernel EXLLAMA_V1 = "exllama_v1" # FAST: optimized for batching == 1 diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index 0f72d4f92..03805504f 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -30,6 +30,7 @@ from ..nn_modules.qlinear.qqq import QQQQuantLinear from ..nn_modules.qlinear.torch import TorchQuantLinear from ..nn_modules.qlinear.torch_fused import TorchFusedQuantLinear +from ..nn_modules.qlinear.torch_fused_awq import TorchFusedAwqQuantLinear from ..nn_modules.qlinear.tritonv2 import TRITON_AVAILABLE, TRITON_INSTALL_HINT, TritonV2QuantLinear from ..quantization import FORMAT, METHOD from ..utils.logger import setup_logger @@ -70,6 +71,7 @@ BACKEND.GEMM: AwqGEMMQuantLinear, BACKEND.GEMV: AwqGEMVQuantLinear, BACKEND.GEMV_FAST: AwqGEMVFastQuantLinear, + BACKEND.TORCH_FUSED_AWQ: TorchFusedAwqQuantLinear, BACKEND.TORCH_AWQ: AwqTorchQuantLinear, }), } @@ -85,7 +87,15 @@ FORMAT.QQQ: [BACKEND.QQQ], }, METHOD.AWQ: { - FORMAT.GEMM: [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.GEMM, BACKEND.TORCH_AWQ], + FORMAT.GEMM: [ + BACKEND.MACHETE, + BACKEND.MARLIN, + BACKEND.EXLLAMA_V2, + BACKEND.EXLLAMA_V1, + BACKEND.GEMM, + BACKEND.TORCH_FUSED_AWQ, + BACKEND.TORCH_AWQ, + ], FORMAT.GEMV: [BACKEND.GEMV], FORMAT.GEMV_FAST: [BACKEND.GEMV_FAST], FORMAT.MARLIN: [BACKEND.MACHETE, BACKEND.MARLIN], diff --git a/tests/test_torch_fused_awq.py b/tests/test_torch_fused_awq.py new file mode 100644 index 000000000..999fbcafc --- /dev/null +++ b/tests/test_torch_fused_awq.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import pytest +import torch + +from gptqmodel.nn_modules.qlinear.awq_torch import AwqTorchQuantLinear +from gptqmodel.nn_modules.qlinear.torch_fused_awq import TorchFusedAwqQuantLinear +from gptqmodel.utils.torch import TORCH_HAS_FUSED_OPS + + +def pack_awq(unpacked: torch.Tensor, bits: int) -> torch.Tensor: + pack_factor = 32 // bits + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + assert unpacked.shape[1] % pack_factor == 0 + packed = torch.zeros( + (unpacked.shape[0], unpacked.shape[1] // pack_factor), + dtype=torch.int32, + ) + for col in range(unpacked.shape[1] // pack_factor): + for i, order in enumerate(order_map): + value = unpacked[:, col * pack_factor + order].to(torch.int32) + packed[:, col] |= value << (i * bits) + return packed + + +@pytest.mark.skipif(not TORCH_HAS_FUSED_OPS, reason="Torch fused ops require PyTorch>=2.8") +@pytest.mark.parametrize("dtype", [torch.float16], ids=["float16"]) +def test_torch_fused_awq_matches_baseline_torch_kernel(dtype): + torch.manual_seed(0) + + bits = 4 + in_features = 64 + out_features = 128 + group_size = 32 + batch = 4 + + groups = in_features // group_size + + int_weight = torch.randint(0, 2**bits, size=(in_features, out_features), dtype=torch.int32) + zero_points = torch.randint(0, 2**bits, size=(groups, out_features), dtype=torch.int32) + scales = (torch.rand(groups, out_features, dtype=torch.float16) * 1.5) + 0.25 + bias = torch.randn(out_features, dtype=torch.float16) + + qweight = pack_awq(int_weight, bits) + qzeros = pack_awq(zero_points, bits) + + awq_module = AwqTorchQuantLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + bias=True, + register_buffers=True, + ) + awq_module.qweight.copy_(qweight) + awq_module.qzeros.copy_(qzeros) + awq_module.scales.copy_(scales) + awq_module.bias.copy_(bias) + awq_module.post_init() + awq_module.eval() + + fused_module = TorchFusedAwqQuantLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + bias=True, + register_buffers=True, + ) + fused_module.register_buffer("qweight", qweight.clone(), persistent=True) + fused_module.qzeros.copy_(qzeros) + fused_module.scales.copy_(scales) + fused_module.bias.copy_(bias) + fused_module.post_init() + fused_module.eval() + + x = torch.randn(batch, in_features, dtype=dtype) + baseline = awq_module(x.to(torch.float16)).to(dtype) + fused_out = fused_module(x) + + tol_map = { + torch.float16: 5e-3, + torch.bfloat16: 1.1, + } + tol = tol_map[dtype] + abs_diff = (fused_out - baseline).abs() + rel_diff = abs_diff / baseline.abs().clamp_min(1e-6) + + torch.testing.assert_close(fused_out, baseline, rtol=tol, atol=tol) + + header = f"{'dtype':<10} {'rtol':<10} {'atol':<10} {'abs_max':<12} {'rel_max':<12}" + row = f"{str(dtype):<10} {tol:<10.4g} {tol:<10.4g} {abs_diff.max().item():<12.4e} {rel_diff.max().item():<12.4e}" + print(f"{header}\n{row}") From 25e64f9e4be78c06517664c02ce36e5a8989bd6b Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 02:57:48 +0000 Subject: [PATCH 02/26] cleanup --- .../nn_modules/qlinear/torch_fused_awq.py | 100 +++++++----------- tests/test_torch_fused_awq.py | 2 +- 2 files changed, 42 insertions(+), 60 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index 2601e8665..2023c2f3b 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -55,15 +55,7 @@ def __init__( register_buffers: bool = True, **kwargs, ): - # AWQ bookkeeping must exist before the base class registers buffers. - self._awq_layout = False - self._awq_dense_weight = None - self._awq_qweight = None - self._awq_qzeros = None - self._awq_scales = None - self._awq_qweight_ptr = None - self._awq_qzeros_ptr = None - self._awq_scales_ptr = None + self._awq_buffers_ready = False kwargs.setdefault("backend", BACKEND.TORCH_FUSED_AWQ) super().__init__( @@ -79,15 +71,20 @@ def __init__( register_buffers=register_buffers, **kwargs, ) + self.register_buffer("awq_qweight_src", None, persistent=False) + self.register_buffer("awq_qzeros_src", None, persistent=False) + self.register_buffer("awq_scales_src", None, persistent=False) + self._awq_buffers_ready = True + self._update_awq_buffers() def post_init(self): super().post_init() - self._maybe_capture_awq_state(force=True) + self._update_awq_buffers() def register_buffer(self, name, tensor, persistent=True): super().register_buffer(name, tensor, persistent=persistent) - if name in {"qweight", "qzeros", "scales"}: - self._maybe_capture_awq_state() + if getattr(self, "_awq_buffers_ready", False) and name in {"qweight", "qzeros", "scales"}: + self._update_awq_buffers() def optimize(self): if self.optimized: @@ -111,7 +108,7 @@ def _load_from_state_dict( candidate = state_dict[qweight_key] if self._is_awq_qweight_tensor(candidate): awq_tensor = candidate.to(self.pack_dtype).clone() - self._awq_qweight = awq_tensor.clone() + self.awq_qweight_src = awq_tensor.clone() placeholder = getattr(self, "qweight", None) if isinstance(placeholder, torch.Tensor) and placeholder.numel() == awq_tensor.numel(): state_dict[qweight_key] = torch.zeros_like(placeholder) @@ -140,7 +137,7 @@ def _load_from_state_dict( awq_tensor.to(device=device, dtype=self.pack_dtype).contiguous(), persistent=True, ) - self._awq_layout = True + self._update_awq_buffers() def _awq_qweight_shape(self): pack_cols = self.out_features // self.pack_factor @@ -153,7 +150,9 @@ def _is_awq_qweight_tensor(self, tensor: torch.Tensor) -> bool: exp_rows, exp_cols = self._awq_qweight_shape() return rows == exp_rows and cols == exp_cols - def _maybe_capture_awq_state(self, force: bool = False): + def _update_awq_buffers(self): + if not getattr(self, "_awq_buffers_ready", False): + return qweight = getattr(self, "qweight", None) qzeros = getattr(self, "qzeros", None) scales = getattr(self, "scales", None) @@ -167,43 +166,26 @@ def _maybe_capture_awq_state(self, force: bool = False): or not self._is_awq_qweight_tensor(qweight) ): return - - qweight_ptr = qweight.data_ptr() - qzeros_ptr = qzeros.data_ptr() - scales_ptr = scales.data_ptr() - if ( - not force - and self._awq_qweight is not None - and self._awq_qzeros is not None - and self._awq_scales is not None - and self._awq_qweight_ptr == qweight_ptr - and self._awq_qzeros_ptr == qzeros_ptr - and self._awq_scales_ptr == scales_ptr - ): - return - - self._awq_qweight = qweight.clone() - self._awq_qzeros = qzeros.clone() + self.awq_qweight_src = qweight.clone() + self.awq_qzeros_src = qzeros.clone() scale_clone = scales.clone() if scale_clone.dtype != torch.float16: scale_clone = scale_clone.to(torch.float16) - self._awq_scales = scale_clone - self._awq_dense_weight = None - self._awq_qweight_ptr = qweight_ptr - self._awq_qzeros_ptr = qzeros_ptr - self._awq_scales_ptr = scales_ptr - self._awq_layout = True + self.awq_scales_src = scale_clone def _uses_awq_layout(self) -> bool: - if self._awq_layout: - return True - self._maybe_capture_awq_state() - return self._awq_layout + return self.awq_qweight_src is not None def _transform_cpu_awq(self, dtype): - self.scales = self.scales.clone().to(dtype).contiguous() - scale_fp32 = self.scales.to(torch.float32) - iweight, izeros = unpack_awq(self.qweight, self.qzeros, self.bits) + if ( + self.awq_qweight_src is None + or self.awq_qzeros_src is None + or self.awq_scales_src is None + ): + raise RuntimeError("AWQ state unavailable for CPU transform.") + self.scales = self.awq_scales_src.clone().to(dtype).contiguous() + scale_fp32 = self.awq_scales_src.to(torch.float32) + iweight, izeros = unpack_awq(self.awq_qweight_src, self.awq_qzeros_src, self.bits) iweight, izeros = reverse_awq_order(iweight, izeros, self.bits) max_val = (1 << self.bits) - 1 iweight = torch.bitwise_and(iweight, max_val) @@ -222,22 +204,22 @@ def _transform_cpu_awq(self, dtype): self.scales = scale_fp32.to(dtype=dtype) self.qzeros = zeros.to(dtype=dtype) self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) - self._awq_layout = True def _awq_weight_dense(self, device, dtype): - if self._awq_dense_weight is None: - assert self._awq_qweight is not None - assert self._awq_qzeros is not None - assert self._awq_scales is not None - dense = dequantize_gemm( - self._awq_qweight, - self._awq_qzeros, - self._awq_scales, - self.bits, - self.group_size, - ).to(device=device, dtype=torch.float32) - self._awq_dense_weight = dense - return self._awq_dense_weight.to(device=device, dtype=dtype) + if ( + self.awq_qweight_src is None + or self.awq_qzeros_src is None + or self.awq_scales_src is None + ): + raise RuntimeError("AWQ dense weight requested without cached tensors.") + dense = dequantize_gemm( + self.awq_qweight_src, + self.awq_qzeros_src, + self.awq_scales_src, + self.bits, + self.group_size, + ).to(device=device, dtype=torch.float32) + return dense.to(device=device, dtype=dtype) def transform_cpu(self, dtype): if self._uses_awq_layout(): diff --git a/tests/test_torch_fused_awq.py b/tests/test_torch_fused_awq.py index 999fbcafc..edc4d6d7e 100644 --- a/tests/test_torch_fused_awq.py +++ b/tests/test_torch_fused_awq.py @@ -27,7 +27,7 @@ def pack_awq(unpacked: torch.Tensor, bits: int) -> torch.Tensor: @pytest.mark.skipif(not TORCH_HAS_FUSED_OPS, reason="Torch fused ops require PyTorch>=2.8") -@pytest.mark.parametrize("dtype", [torch.float16], ids=["float16"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["float16", "bfloat16"]) def test_torch_fused_awq_matches_baseline_torch_kernel(dtype): torch.manual_seed(0) From 784181b5b61b6471461235b392b4e4f097ec6108 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 03:09:43 +0000 Subject: [PATCH 03/26] cleanup2 --- .../nn_modules/qlinear/torch_fused_awq.py | 64 ++++++++++--------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index 2023c2f3b..d7608a449 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -55,7 +55,9 @@ def __init__( register_buffers: bool = True, **kwargs, ): - self._awq_buffers_ready = False + self._awq_qweight = None + self._awq_qzeros = None + self._awq_scales = None kwargs.setdefault("backend", BACKEND.TORCH_FUSED_AWQ) super().__init__( @@ -71,20 +73,16 @@ def __init__( register_buffers=register_buffers, **kwargs, ) - self.register_buffer("awq_qweight_src", None, persistent=False) - self.register_buffer("awq_qzeros_src", None, persistent=False) - self.register_buffer("awq_scales_src", None, persistent=False) - self._awq_buffers_ready = True - self._update_awq_buffers() + self._refresh_awq_cache() def post_init(self): super().post_init() - self._update_awq_buffers() + self._refresh_awq_cache() def register_buffer(self, name, tensor, persistent=True): super().register_buffer(name, tensor, persistent=persistent) - if getattr(self, "_awq_buffers_ready", False) and name in {"qweight", "qzeros", "scales"}: - self._update_awq_buffers() + if name in {"qweight", "qzeros", "scales"}: + self._refresh_awq_cache() def optimize(self): if self.optimized: @@ -108,7 +106,7 @@ def _load_from_state_dict( candidate = state_dict[qweight_key] if self._is_awq_qweight_tensor(candidate): awq_tensor = candidate.to(self.pack_dtype).clone() - self.awq_qweight_src = awq_tensor.clone() + self._awq_qweight = awq_tensor.clone() placeholder = getattr(self, "qweight", None) if isinstance(placeholder, torch.Tensor) and placeholder.numel() == awq_tensor.numel(): state_dict[qweight_key] = torch.zeros_like(placeholder) @@ -137,7 +135,7 @@ def _load_from_state_dict( awq_tensor.to(device=device, dtype=self.pack_dtype).contiguous(), persistent=True, ) - self._update_awq_buffers() + self._refresh_awq_cache() def _awq_qweight_shape(self): pack_cols = self.out_features // self.pack_factor @@ -150,12 +148,15 @@ def _is_awq_qweight_tensor(self, tensor: torch.Tensor) -> bool: exp_rows, exp_cols = self._awq_qweight_shape() return rows == exp_rows and cols == exp_cols - def _update_awq_buffers(self): - if not getattr(self, "_awq_buffers_ready", False): - return + def _refresh_awq_cache(self): qweight = getattr(self, "qweight", None) qzeros = getattr(self, "qzeros", None) scales = getattr(self, "scales", None) + if ( + self._awq_qweight is not None + and (not torch.is_tensor(qweight) or not self._is_awq_qweight_tensor(qweight)) + ): + return if ( qweight is None or qzeros is None @@ -165,27 +166,30 @@ def _update_awq_buffers(self): or not torch.is_tensor(scales) or not self._is_awq_qweight_tensor(qweight) ): + self._awq_qweight = None + self._awq_qzeros = None + self._awq_scales = None return - self.awq_qweight_src = qweight.clone() - self.awq_qzeros_src = qzeros.clone() + self._awq_qweight = qweight.clone() + self._awq_qzeros = qzeros.clone() scale_clone = scales.clone() if scale_clone.dtype != torch.float16: scale_clone = scale_clone.to(torch.float16) - self.awq_scales_src = scale_clone + self._awq_scales = scale_clone def _uses_awq_layout(self) -> bool: - return self.awq_qweight_src is not None + return self._awq_qweight is not None def _transform_cpu_awq(self, dtype): if ( - self.awq_qweight_src is None - or self.awq_qzeros_src is None - or self.awq_scales_src is None + self._awq_qweight is None + or self._awq_qzeros is None + or self._awq_scales is None ): raise RuntimeError("AWQ state unavailable for CPU transform.") - self.scales = self.awq_scales_src.clone().to(dtype).contiguous() - scale_fp32 = self.awq_scales_src.to(torch.float32) - iweight, izeros = unpack_awq(self.awq_qweight_src, self.awq_qzeros_src, self.bits) + self.scales = self._awq_scales.clone().to(dtype).contiguous() + scale_fp32 = self._awq_scales.to(torch.float32) + iweight, izeros = unpack_awq(self._awq_qweight, self._awq_qzeros, self.bits) iweight, izeros = reverse_awq_order(iweight, izeros, self.bits) max_val = (1 << self.bits) - 1 iweight = torch.bitwise_and(iweight, max_val) @@ -207,15 +211,15 @@ def _transform_cpu_awq(self, dtype): def _awq_weight_dense(self, device, dtype): if ( - self.awq_qweight_src is None - or self.awq_qzeros_src is None - or self.awq_scales_src is None + self._awq_qweight is None + or self._awq_qzeros is None + or self._awq_scales is None ): raise RuntimeError("AWQ dense weight requested without cached tensors.") dense = dequantize_gemm( - self.awq_qweight_src, - self.awq_qzeros_src, - self.awq_scales_src, + self._awq_qweight, + self._awq_qzeros, + self._awq_scales, self.bits, self.group_size, ).to(device=device, dtype=torch.float32) From f7a5cb5f45ccde51fd798d1a514d3b2d1eaf7fba Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 03:45:06 +0000 Subject: [PATCH 04/26] float16 only --- gptqmodel/nn_modules/qlinear/torch_fused_awq.py | 13 ++++++++++--- tests/test_torch_fused_awq.py | 12 ++++-------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index d7608a449..463ffde51 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -239,6 +239,7 @@ def transform(self, dtype, device): def forward(self, x: torch.Tensor): out_shape = x.shape[:-1] + (self.out_features,) x_flat = x.reshape(-1, x.shape[-1]) + self._assert_supported_dtype(x_flat.dtype) if not self.training and not self.transformed and TORCH_HAS_FUSED_OPS: self.transform(x_flat.dtype, x_flat.device.type) self.transformed = True @@ -273,10 +274,16 @@ def _fused_op_forward(self, x): awq_active = self._uses_awq_layout() use_awq_fallback = awq_active and x.device.type == "cpu" if use_awq_fallback: - compute_dtype = torch.float16 if x.dtype == torch.bfloat16 else x.dtype - weight = self._awq_weight_dense(device=x.device, dtype=compute_dtype) - return torch.matmul(x.to(compute_dtype), weight).to(x.dtype) + weight = self._awq_weight_dense(device=x.device, dtype=x.dtype) + return torch.matmul(x, weight) return super()._fused_op_forward(x) + def _assert_supported_dtype(self, dtype: torch.dtype): + if dtype not in self.SUPPORTS_DTYPES: + supported = ", ".join(str(d) for d in self.SUPPORTS_DTYPES) + raise TypeError( + f"{self.__class__.__name__} only supports input dtypes [{supported}], but received {dtype}." + ) + __all__ = ["TorchFusedAwqQuantLinear"] diff --git a/tests/test_torch_fused_awq.py b/tests/test_torch_fused_awq.py index edc4d6d7e..1244026c7 100644 --- a/tests/test_torch_fused_awq.py +++ b/tests/test_torch_fused_awq.py @@ -27,9 +27,9 @@ def pack_awq(unpacked: torch.Tensor, bits: int) -> torch.Tensor: @pytest.mark.skipif(not TORCH_HAS_FUSED_OPS, reason="Torch fused ops require PyTorch>=2.8") -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["float16", "bfloat16"]) -def test_torch_fused_awq_matches_baseline_torch_kernel(dtype): +def test_torch_fused_awq_matches_baseline_torch_kernel(): torch.manual_seed(0) + dtype = torch.float16 bits = 4 in_features = 64 @@ -82,14 +82,10 @@ def test_torch_fused_awq_matches_baseline_torch_kernel(dtype): fused_module.eval() x = torch.randn(batch, in_features, dtype=dtype) - baseline = awq_module(x.to(torch.float16)).to(dtype) + baseline = awq_module(x) fused_out = fused_module(x) - tol_map = { - torch.float16: 5e-3, - torch.bfloat16: 1.1, - } - tol = tol_map[dtype] + tol = 5e-3 abs_diff = (fused_out - baseline).abs() rel_diff = abs_diff / baseline.abs().clamp_min(1e-6) From f100c519194a63cd5acb6aaa1cb08dc9d8132f31 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 04:14:48 +0000 Subject: [PATCH 05/26] log --- .../nn_modules/qlinear/torch_fused_awq.py | 94 +++++-------------- 1 file changed, 25 insertions(+), 69 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index 463ffde51..bf4372976 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -55,10 +55,6 @@ def __init__( register_buffers: bool = True, **kwargs, ): - self._awq_qweight = None - self._awq_qzeros = None - self._awq_scales = None - kwargs.setdefault("backend", BACKEND.TORCH_FUSED_AWQ) super().__init__( bits=bits, @@ -73,22 +69,6 @@ def __init__( register_buffers=register_buffers, **kwargs, ) - self._refresh_awq_cache() - - def post_init(self): - super().post_init() - self._refresh_awq_cache() - - def register_buffer(self, name, tensor, persistent=True): - super().register_buffer(name, tensor, persistent=persistent) - if name in {"qweight", "qzeros", "scales"}: - self._refresh_awq_cache() - - def optimize(self): - if self.optimized: - return - super().optimize() - self._uses_awq_layout() def _load_from_state_dict( self, @@ -106,7 +86,6 @@ def _load_from_state_dict( candidate = state_dict[qweight_key] if self._is_awq_qweight_tensor(candidate): awq_tensor = candidate.to(self.pack_dtype).clone() - self._awq_qweight = awq_tensor.clone() placeholder = getattr(self, "qweight", None) if isinstance(placeholder, torch.Tensor) and placeholder.numel() == awq_tensor.numel(): state_dict[qweight_key] = torch.zeros_like(placeholder) @@ -135,7 +114,6 @@ def _load_from_state_dict( awq_tensor.to(device=device, dtype=self.pack_dtype).contiguous(), persistent=True, ) - self._refresh_awq_cache() def _awq_qweight_shape(self): pack_cols = self.out_features // self.pack_factor @@ -148,48 +126,20 @@ def _is_awq_qweight_tensor(self, tensor: torch.Tensor) -> bool: exp_rows, exp_cols = self._awq_qweight_shape() return rows == exp_rows and cols == exp_cols - def _refresh_awq_cache(self): - qweight = getattr(self, "qweight", None) - qzeros = getattr(self, "qzeros", None) - scales = getattr(self, "scales", None) - if ( - self._awq_qweight is not None - and (not torch.is_tensor(qweight) or not self._is_awq_qweight_tensor(qweight)) - ): - return - if ( - qweight is None - or qzeros is None - or scales is None - or not torch.is_tensor(qweight) - or not torch.is_tensor(qzeros) - or not torch.is_tensor(scales) - or not self._is_awq_qweight_tensor(qweight) - ): - self._awq_qweight = None - self._awq_qzeros = None - self._awq_scales = None - return - self._awq_qweight = qweight.clone() - self._awq_qzeros = qzeros.clone() - scale_clone = scales.clone() - if scale_clone.dtype != torch.float16: - scale_clone = scale_clone.to(torch.float16) - self._awq_scales = scale_clone - def _uses_awq_layout(self) -> bool: - return self._awq_qweight is not None + qweight = getattr(self, "qweight", None) + return torch.is_tensor(qweight) and self._is_awq_qweight_tensor(qweight) def _transform_cpu_awq(self, dtype): - if ( - self._awq_qweight is None - or self._awq_qzeros is None - or self._awq_scales is None - ): + if not self._uses_awq_layout(): raise RuntimeError("AWQ state unavailable for CPU transform.") - self.scales = self._awq_scales.clone().to(dtype).contiguous() - scale_fp32 = self._awq_scales.to(torch.float32) - iweight, izeros = unpack_awq(self._awq_qweight, self._awq_qzeros, self.bits) + src_scales = self.scales + if src_scales.dtype != torch.float16: + src_scales = src_scales.to(torch.float16) + src_scales = src_scales.contiguous() + self.scales = src_scales.clone().to(dtype).contiguous() + scale_fp32 = src_scales.to(torch.float32) + iweight, izeros = unpack_awq(self.qweight, self.qzeros, self.bits) iweight, izeros = reverse_awq_order(iweight, izeros, self.bits) max_val = (1 << self.bits) - 1 iweight = torch.bitwise_and(iweight, max_val) @@ -210,16 +160,12 @@ def _transform_cpu_awq(self, dtype): self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) def _awq_weight_dense(self, device, dtype): - if ( - self._awq_qweight is None - or self._awq_qzeros is None - or self._awq_scales is None - ): + if not self._uses_awq_layout(): raise RuntimeError("AWQ dense weight requested without cached tensors.") dense = dequantize_gemm( - self._awq_qweight, - self._awq_qzeros, - self._awq_scales, + self.qweight, + self.qzeros, + self.scales, self.bits, self.group_size, ).to(device=device, dtype=torch.float32) @@ -240,7 +186,12 @@ def forward(self, x: torch.Tensor): out_shape = x.shape[:-1] + (self.out_features,) x_flat = x.reshape(-1, x.shape[-1]) self._assert_supported_dtype(x_flat.dtype) - if not self.training and not self.transformed and TORCH_HAS_FUSED_OPS: + if ( + not self.training + and not self.transformed + and TORCH_HAS_FUSED_OPS + and not self._uses_awq_layout() + ): self.transform(x_flat.dtype, x_flat.device.type) self.transformed = True if x_flat.device.type == "cpu": @@ -252,8 +203,10 @@ def forward(self, x: torch.Tensor): config.max_autotune = True if self.transformed: + log.debug("awq calling fused op") out = self._fused_op_forward(x_flat) else: + log.debug("awq dense path") if self._uses_awq_layout(): weight = self._awq_weight_dense(device=x_flat.device, dtype=x_flat.dtype) out = torch.matmul(x_flat, weight) @@ -274,8 +227,11 @@ def _fused_op_forward(self, x): awq_active = self._uses_awq_layout() use_awq_fallback = awq_active and x.device.type == "cpu" if use_awq_fallback: + log.debug("awq unfused fallback") weight = self._awq_weight_dense(device=x.device, dtype=x.dtype) return torch.matmul(x, weight) + else: + log.debug("awq fused") return super()._fused_op_forward(x) def _assert_supported_dtype(self, dtype: torch.dtype): From 28ee4973a70400c6aa873c76eb15ee6c64ddb887 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 04:19:49 +0000 Subject: [PATCH 06/26] fused path --- gptqmodel/nn_modules/qlinear/torch_fused_awq.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index bf4372976..a2ce67fc6 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -186,12 +186,7 @@ def forward(self, x: torch.Tensor): out_shape = x.shape[:-1] + (self.out_features,) x_flat = x.reshape(-1, x.shape[-1]) self._assert_supported_dtype(x_flat.dtype) - if ( - not self.training - and not self.transformed - and TORCH_HAS_FUSED_OPS - and not self._uses_awq_layout() - ): + if not self.training and not self.transformed and TORCH_HAS_FUSED_OPS: self.transform(x_flat.dtype, x_flat.device.type) self.transformed = True if x_flat.device.type == "cpu": From 82e4d6225c791c669a77520f9af0ad70c16e5834 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 05:50:21 +0000 Subject: [PATCH 07/26] add gptq torch fused doc on layout --- docs/torch_fused_int4_transformations.md | 238 +++++++++++++++++++++++ 1 file changed, 238 insertions(+) create mode 100644 docs/torch_fused_int4_transformations.md diff --git a/docs/torch_fused_int4_transformations.md b/docs/torch_fused_int4_transformations.md new file mode 100644 index 000000000..82ac5b665 --- /dev/null +++ b/docs/torch_fused_int4_transformations.md @@ -0,0 +1,238 @@ +# Torch Fused INT4 Transformations + +This note explains what `TorchFusedQuantLinear.transform_xpu` and `transform_cpu` +do to GPTQ-format tensors before calling the fused `torch.ops.aten` kernels. +The goal is to document the exact tensor shapes, the axis permutations, and the +bit packing order expected by `aten._weight_int4pack_mm_*` so you do not need to +reverse engineer the loops in `gptqmodel/nn_modules/qlinear/torch_fused.py:175-219`. + +## Terminology and starting layout + +Let: + +* `I` – number of input features. +* `O` – number of output features. +* `B` – quantization bits (always 4 here). +* `W` – number of bits stored per lane in `pack_dtype` (`W = 32` by default). +* `pack_factor = W / B` – how many quantized values share one lane (8 when `B=4`). +* `group_size` – number of input channels that share one `(scale, zero)` pair. +* `G = ceil(I / group_size)` – number of groups (and rows in `scales`/`qzeros`). + +Immediately after loading a GPTQ v2 checkpoint: + +``` +qweight : [I / pack_factor, O] dtype = pack_dtype (int32) +qzeros : [G, O / pack_factor] dtype = pack_dtype (int32) +scales : [G, O] dtype = fp16 +g_idx : [I] dtype = int32 (maps input channel -> group id) +``` + +Each entry of `qweight`/`qzeros` is a 32-bit lane that packs `pack_factor` +4-bit nibbles. Conceptually, a single column of `qweight` (one output channel) +looks like this before unpacking: + +``` +raw lane bits (int32) → [in_{k+7}] [in_{k+6}] … [in_{k+1}] [in_{k}] +bit positions → 31..28 27..24 7..4 3..0 +``` + +## `transform_xpu(dtype)` + +The XPU path needs tensors that match +`aten._weight_int4pack_mm_with_scales_and_zeros`. The routine performs five +steps: + +1. **Scales cast** – `self.scales = self.scales.clone().to(dtype)`. No layout changes. +2. **Unpack `qzeros`** – expand each 32-bit lane into `pack_factor` nibbles, mask + with `0xF`, then reshape to `[G, O]`. + + ``` + Before unpack (per group g): + qzeros[g] = [ lane_0, lane_1, … ] (each lane holds 8 outputs) + After unpack: + zeros[g] = [ z_{0}, z_{1}, …, z_{O-1} ] + + lane layout + ┌──────────── 32 bits ────────────┐ + | z_{b+7} | … | z_{b+1} | z_{b} | + └────────────────────────────────┘ ← reshaped into consecutive columns + ``` + +3. **Unpack and reorder `qweight`** – identical nibble extraction produces a + tensor shaped `[I, O]`. It is then re-indexed with `ret_idx` so that input + rows follow the `g_idx` schedule used during quantization, and finally + transposed to `[O, I]`. At this point every row corresponds to one output + channel and every column corresponds to an *unpacked* input channel. + + ``` + weight_full (after transpose): + input columns → + ┌───────────────────────────────────────────┐ + out0│ w00 w01 w02 w03 w04 w05 w06 w07 … w0(I-1) │ + out1│ w10 w11 w12 w13 w14 w15 w16 w17 … w1(I-1) │ + │ ⋮ │ + ``` + +4. **Pack rows into XPU layout** – the double `for` loop rebuilds `int32` + lanes, but now the rows are `O` (output channels) instead of packed input + clusters. The resulting tensor has shape `[O, I / pack_factor]`. + + ``` + packed[row=j, col=k] stores inputs (8 values) = + weight_full[j, 8k + i] for i = 0..7 + + 31..28 27..24 23..20 19..16 15..12 11..8 7..4 3..0 + [in+7] [in+6] [in+5] [in+4] [in+3] [in+2] [in+1] [in+0] + ``` + +5. **Finalize buffers** – `self.qweight = packed.contiguous()` (int32) and + `self.qzeros = zeros.contiguous()` (float, `[G, O]`). These, together with + `self.scales`, match the signature of + `aten._weight_int4pack_mm_with_scales_and_zeros(x, qweight, group_size, scales, qzeros)`. + +For XPU execution, `_fused_op_forward` also permutes activations before the +matmul: + +``` +x = x[:, ret_idx] +``` + +This applies the inverse of the group-wise reordering performed in step 3, +ensuring that column `i` of `qweight` always multiplies the same logical input +channel the calibration used. + +### Visual summary (XPU) + +``` + ┌─────────────┐ unpack+permute ┌─────────────┐ +raw qweight →│ I/8 × O │ ───────────────────────→ │ O × I │ + └─────────────┘ └─────────────┘ + pack rows ↓ + ┌─────────────┐ + │ O × (I/8) │ int32 lanes + └─────────────┘ + +raw qzeros → [G × O/8] lanes ──unpack──► zeros [G × O] +scales → [G × O] (cast to `dtype`) +``` + +## `transform_cpu(dtype)` + +The CPU path shares the unpack/reorder logic but delegates the final packing to +PyTorch’s helper so the layout matches +`aten._weight_int4pack_mm_for_cpu`. Steps: + +1. **Scales cast** – identical to the XPU path. +2. **Unpack + reorder `qweight`** – same as step 3 above, yielding + `weight_full = [O, I]` with 4-bit integers. +3. **Convert to int4pack** – `torch.ops.aten._convert_weight_to_int4pack_for_cpu` + repacks that matrix into `torch.uint8` tiles of shape `[O, I * B / 8]` + (i.e., `I/2` columns when `B=4`). Each byte stores two adjacent inputs. + + ``` + byte layout (per output row j): + bits 7..4 → weight_full[j, 2k+1] + bits 3..0 → weight_full[j, 2k] + ``` + + The helper currently requires both `O` and `I` to be multiples of 16; the op + raises `_convert_weight_to_int4pack_cpu : expect N to be dividable by 16` + otherwise. + +4. **Merge scales and zeros** – The fused CPU kernel expects scale and zero + offsets in a single tensor, so `pack_scales_and_zeros` stacks them along the + last dimension: + + ``` + scales_and_zeros[g, o] = [ scale[g, o], zero[g, o] ] shape = [G, O, 2] + + group g + ┌──────── out dimension ────────┐ + │ [ s, z ] [ s, z ] … [ s, z ] │ + └─────────────────────────────────┘ + ``` + + The current GPTQ fused path only uses symmetric int4, so `self.qzeros` is + zeroed before packing (`zero[g, o] = 0`). Non-symmetric per-group offsets + would require extending this block. + +5. **Buffers used at runtime** – `self.qweight` is now the `uint8` + int4pack tensor, `self.scales_and_zeros` stores the merged metadata, and + `_fused_op_forward` calls + `aten._weight_int4pack_mm_for_cpu(x, qweight_uint8, group_size, scales_and_zeros)`. + +### Visual summary (CPU) + +``` +weight_full (O × I, ints) ──_convert_weight_to_int4pack_for_cpu──► +┌──────────────┐ ┌──────────────┐ +│ O × I │ │ O × (I/2) │ uint8 +└──────────────┘ └──────────────┘ + ↑ ↑ + └───────── unpack & transpose from raw qweight ───────────┘ + +scales (G × O, dtype `dtype`) +qzeros (G × O, zeroed) ──► scales_and_zeros (G × O × 2) +``` + +## Activation permutation and fused matmul + +Both device paths rely on the same activation permutation: + +1. `ret_idx` is built once from `g_idx` so that unpacked rows can be restored to + the calibration order. +2. Before calling any fused matmul, `_fused_op_forward` applies `x = x[:, ret_idx]`. +3. The matmul then multiplies `x` with the packed `qweight`: + + * XPU: `aten._weight_int4pack_mm_with_scales_and_zeros` + consumes `qweight[int32][O, I/8]`, `scales[G, O]`, and `qzeros[G, O]`. + * CPU: `aten._weight_int4pack_mm_for_cpu` + consumes `qweight[uint8][O, I/2]` and `scales_and_zeros[G, O, 2]`. + +Because the same `ret_idx` is used for both the unpacked weight (during packing) +and the activation tensor (during inference), every nibble in the packed matrix +aligns with the correct logical input column. + +## Comparing XPU vs CPU transformations + +Although both device paths share the same unpack → reorder → transpose steps, +they diverge in how the packed tensors are laid out and what the fused matmul +expects afterward. The table below highlights the key differences for quick +debugging. + +| Aspect | XPU (`transform_xpu`) | CPU (`transform_cpu`) | +|----------------------------|---------------------------------------------------------------|-------------------------------------------------------------------| +| Packed `qweight` shape | `[O, I / 8]`, dtype `int32` | `[O, I / 2]`, dtype `uint8` | +| Bits per storage lane | 32-bit lane packs 8 inputs; nibble order `[in+7 … in+0]` | 8-bit lane packs 2 inputs; high nibble = odd, low nibble = even | +| Packing direction | Manual double-loop packs along **columns** of `weight_full` | `_convert_weight_to_int4pack_for_cpu` packs along **columns** into bytes | +| Per-group zeros | Unpacked to full `[G, O]` tensor and passed separately | Forced to zero and merged with scales via `pack_scales_and_zeros` | +| Scale format | One tensor per group (`scales[G, O]`) | Concatenated `[..., 0] = scale`, `[..., 1] = zero` (`float`) | +| Fused kernel call | `_weight_int4pack_mm_with_scales_and_zeros(x, qW, gsz, s, z)` | `_weight_int4pack_mm_for_cpu(x, qW, gsz, scales_and_zeros)` | +| Alignment requirements | Determined by manual pack loop (only needs `I % 8 == 0`) | Kernel enforces `I % 16 == 0` and `O % 16 == 0` | +| Activation permutation | `x = x[:, ret_idx]` prior to matmul (same code path) | Same permutation reuse | + +Visually, you can think of the difference as *row-major lane packing* (XPU) +versus *byte-tiling* (CPU): + +``` +XPU: | int32 lane | = [w7][w6][w5][w4][w3][w2][w1][w0] +CPU: | uint8 lane | = [w1][w0] +``` + +Both forms originate from the same `[O, I]` intermediate; the divergence is only +in the final storage type, accompanying metadata, and fused operator ABI. + +## Quick reference + +| Stage | Shape / dtype (int4) | Notes | +|--------------------------------|-----------------------------------------------------------|------------------------------------------------| +| Raw `qweight` | `[I / 8, O]`, `int32` | 8 nibbles per lane | +| After unpack + transpose | `[O, I]`, `int8` (values in `[0, 15]`) | Used by both device paths | +| Packed XPU `qweight` | `[O, I / 8]`, `int32` | Bits `[3:0]` hold the lowest-numbered channel | +| Packed CPU `qweight` | `[O, I / 2]`, `uint8` | High nibble = odd input, low nibble = even | +| `qzeros` (post-XPU transform) | `[G, O]`, matches `scales` | Passed separately to the XPU fused op | +| `scales_and_zeros` (CPU only) | `[G, O, 2]`, float | `[..., 0] = scale`, `[..., 1] = zero` | + +These details mirror the expectations of the Intel XPU and CPU fused matmul +kernels, and the ASCII layouts above describe how rows/columns line up inside +every packed tensor before the fused matmul executes. From a8b2ca65e4e7f817536302b5f78f637d2ef9ac43 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 06:36:09 +0000 Subject: [PATCH 08/26] fix awq transformation --- .../nn_modules/qlinear/torch_fused_awq.py | 49 ++++-- tests/test_torch_fused_awq.py | 144 ++++++++++++------ 2 files changed, 132 insertions(+), 61 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index a2ce67fc6..fa1cddee6 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -137,26 +137,49 @@ def _transform_cpu_awq(self, dtype): if src_scales.dtype != torch.float16: src_scales = src_scales.to(torch.float16) src_scales = src_scales.contiguous() - self.scales = src_scales.clone().to(dtype).contiguous() - scale_fp32 = src_scales.to(torch.float32) + + # Cache unpacked AWQ tensors iweight, izeros = unpack_awq(self.qweight, self.qzeros, self.bits) iweight, izeros = reverse_awq_order(iweight, izeros, self.bits) max_val = (1 << self.bits) - 1 iweight = torch.bitwise_and(iweight, max_val) if izeros is not None: izeros = torch.bitwise_and(izeros, max_val) - ret_idx = self._build_ret_idx() - weight = iweight.index_select(0, ret_idx).t().contiguous() - self.qweight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(weight.int(), 1).contiguous() - if izeros is None: - zeros = torch.zeros_like(scale_fp32) - else: - zero_offset = 1 << (self.bits - 1) - zeros = (zero_offset - izeros.reshape_as(scale_fp32)).to(dtype=scale_fp32.dtype) - zeros = zeros * scale_fp32 - self.scales = scale_fp32.to(dtype=dtype) - self.qzeros = zeros.to(dtype=dtype) + # Precompute the per-group zero offsets (kept in float16 for parity with AWQ reference) + scale_fp16 = src_scales.clone() + scale_fp32 = scale_fp16.to(torch.float32) + zero_offset = 1 << (self.bits - 1) + zeros_fp16 = (zero_offset - izeros.reshape_as(scale_fp32)).to(dtype=scale_fp32.dtype) + zeros_fp16 = (zeros_fp16 * scale_fp32).to(torch.float16) + + # Repack AWQ-per-output rows into GPTQ-style per-input packs so the base + # TorchFusedQuantLinear path can handle the conversion to int4pack. + in_features, out_features = iweight.shape + pack_factor = int(self.pack_factor) + if in_features % pack_factor != 0: + raise ValueError( + f"AWQ in_features={in_features} must be divisible by pack_factor={pack_factor}." + ) + + rows = iweight.view(in_features // pack_factor, pack_factor, out_features) + gptq_qweight = torch.zeros( + (rows.shape[0], out_features), + dtype=self.pack_dtype, + device=iweight.device, + ) + bit_shifts = list(range(0, pack_factor * self.bits, self.bits)) + for lane, shift in enumerate(bit_shifts): + gptq_qweight |= rows[:, lane, :].to(torch.int32) << shift + self.qweight = gptq_qweight.contiguous() + + # Reuse the GPTQ CPU transformation to convert into int4pack layout. + self.scales = scale_fp16.clone() + super().transform_cpu(dtype) + + # Restore AWQ-specific scale/zero metadata for the fused op. + self.scales = scale_fp16.to(dtype=dtype) + self.qzeros = zeros_fp16.to(dtype=dtype) self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) def _awq_weight_dense(self, device, dtype): diff --git a/tests/test_torch_fused_awq.py b/tests/test_torch_fused_awq.py index 1244026c7..4786d0a43 100644 --- a/tests/test_torch_fused_awq.py +++ b/tests/test_torch_fused_awq.py @@ -3,49 +3,99 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import json +import os +from functools import lru_cache +from pathlib import Path + import pytest import torch +from safetensors import safe_open from gptqmodel.nn_modules.qlinear.awq_torch import AwqTorchQuantLinear from gptqmodel.nn_modules.qlinear.torch_fused_awq import TorchFusedAwqQuantLinear from gptqmodel.utils.torch import TORCH_HAS_FUSED_OPS -def pack_awq(unpacked: torch.Tensor, bits: int) -> torch.Tensor: - pack_factor = 32 // bits - order_map = [0, 2, 4, 6, 1, 3, 5, 7] - assert unpacked.shape[1] % pack_factor == 0 - packed = torch.zeros( - (unpacked.shape[0], unpacked.shape[1] // pack_factor), - dtype=torch.int32, - ) - for col in range(unpacked.shape[1] // pack_factor): - for i, order in enumerate(order_map): - value = unpacked[:, col * pack_factor + order].to(torch.int32) - packed[:, col] |= value << (i * bits) - return packed - - -@pytest.mark.skipif(not TORCH_HAS_FUSED_OPS, reason="Torch fused ops require PyTorch>=2.8") -def test_torch_fused_awq_matches_baseline_torch_kernel(): - torch.manual_seed(0) - dtype = torch.float16 +CHECKPOINT_DIR = Path("/monster/data/model/deepseek-r1-distill-qwen-7b-awq") +CHECKPOINT_MODULE = os.environ.get( + "GPTQMODEL_AWQ_TEST_MODULE", "model.layers.0.mlp.up_proj" +) + + +@lru_cache(maxsize=1) +def _load_awq_checkpoint_module(): + if not CHECKPOINT_DIR.exists(): + pytest.skip(f"AWQ checkpoint not available at {CHECKPOINT_DIR}") + + index_path = CHECKPOINT_DIR / "model.safetensors.index.json" + if not index_path.exists(): + pytest.skip(f"Missing model index at {index_path}") + + with index_path.open("r", encoding="utf-8") as fh: + index_data = json.load(fh) + weight_map = index_data["weight_map"] + + config_path = CHECKPOINT_DIR / "config.json" + with config_path.open("r", encoding="utf-8") as fh: + config = json.load(fh) + quant_cfg = config.get("quantization_config", {}) + bits = int(quant_cfg.get("bits", 4)) + group_size = int(quant_cfg.get("group_size", 128)) + + suffixes = ["qweight", "qzeros", "scales", "bias"] + tensors = {} + file_to_keys = {} + for suffix in suffixes: + full_key = f"{CHECKPOINT_MODULE}.{suffix}" + filename = weight_map.get(full_key) + if filename is None: + if suffix == "bias": + continue + raise KeyError(f"Missing tensor '{full_key}' in checkpoint index.") + file_to_keys.setdefault(filename, []).append(full_key) + + for filename, keys in file_to_keys.items(): + tensor_path = CHECKPOINT_DIR / filename + with safe_open(tensor_path, framework="pt", device="cpu") as handle: + for key in keys: + tensors[key] = handle.get_tensor(key).clone() + + qweight = tensors[f"{CHECKPOINT_MODULE}.qweight"].to(torch.int32).contiguous() + qzeros = tensors[f"{CHECKPOINT_MODULE}.qzeros"].to(torch.int32).contiguous() + scales = tensors[f"{CHECKPOINT_MODULE}.scales"].to(torch.float16).contiguous() + bias_key = f"{CHECKPOINT_MODULE}.bias" + bias = tensors.get(bias_key) + if bias is not None: + bias = bias.to(torch.float16).contiguous() - bits = 4 - in_features = 64 - out_features = 128 - group_size = 32 - batch = 4 + pack_factor = 32 // bits + in_features = qweight.shape[0] + out_features = qweight.shape[1] * pack_factor - groups = in_features // group_size + return { + "bits": bits, + "group_size": group_size, + "in_features": in_features, + "out_features": out_features, + "qweight": qweight, + "qzeros": qzeros, + "scales": scales, + "bias": bias, + } - int_weight = torch.randint(0, 2**bits, size=(in_features, out_features), dtype=torch.int32) - zero_points = torch.randint(0, 2**bits, size=(groups, out_features), dtype=torch.int32) - scales = (torch.rand(groups, out_features, dtype=torch.float16) * 1.5) + 0.25 - bias = torch.randn(out_features, dtype=torch.float16) - qweight = pack_awq(int_weight, bits) - qzeros = pack_awq(zero_points, bits) +@pytest.mark.skipif(not TORCH_HAS_FUSED_OPS, reason="Torch fused ops require PyTorch>=2.8") +def test_torch_fused_awq_matches_checkpoint_module(): + module_data = _load_awq_checkpoint_module() + bits = module_data["bits"] + group_size = module_data["group_size"] + in_features = module_data["in_features"] + out_features = module_data["out_features"] + qweight = module_data["qweight"] + qzeros = module_data["qzeros"] + scales = module_data["scales"] + bias = module_data["bias"] awq_module = AwqTorchQuantLinear( bits=bits, @@ -54,16 +104,9 @@ def test_torch_fused_awq_matches_baseline_torch_kernel(): desc_act=False, in_features=in_features, out_features=out_features, - bias=True, + bias=bias is not None, register_buffers=True, ) - awq_module.qweight.copy_(qweight) - awq_module.qzeros.copy_(qzeros) - awq_module.scales.copy_(scales) - awq_module.bias.copy_(bias) - awq_module.post_init() - awq_module.eval() - fused_module = TorchFusedAwqQuantLinear( bits=bits, group_size=group_size, @@ -71,26 +114,31 @@ def test_torch_fused_awq_matches_baseline_torch_kernel(): desc_act=False, in_features=in_features, out_features=out_features, - bias=True, + bias=bias is not None, register_buffers=True, ) + + awq_module.qweight.copy_(qweight) + awq_module.qzeros.copy_(qzeros) + awq_module.scales.copy_(scales) + if bias is not None: + awq_module.bias.copy_(bias) + awq_module.post_init() + awq_module.eval() + fused_module.register_buffer("qweight", qweight.clone(), persistent=True) fused_module.qzeros.copy_(qzeros) fused_module.scales.copy_(scales) - fused_module.bias.copy_(bias) + if bias is not None: + fused_module.bias.copy_(bias) fused_module.post_init() fused_module.eval() + dtype = torch.float16 + batch = 4 x = torch.randn(batch, in_features, dtype=dtype) baseline = awq_module(x) fused_out = fused_module(x) tol = 5e-3 - abs_diff = (fused_out - baseline).abs() - rel_diff = abs_diff / baseline.abs().clamp_min(1e-6) - torch.testing.assert_close(fused_out, baseline, rtol=tol, atol=tol) - - header = f"{'dtype':<10} {'rtol':<10} {'atol':<10} {'abs_max':<12} {'rel_max':<12}" - row = f"{str(dtype):<10} {tol:<10.4g} {tol:<10.4g} {abs_diff.max().item():<12.4e} {rel_diff.max().item():<12.4e}" - print(f"{header}\n{row}") From 22cf785ed0ab651ea5c0e8577a4376100ec23cd5 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 06:42:44 +0000 Subject: [PATCH 09/26] log rtol/atol --- tests/test_torch_fused_awq.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_torch_fused_awq.py b/tests/test_torch_fused_awq.py index 4786d0a43..be5c869e7 100644 --- a/tests/test_torch_fused_awq.py +++ b/tests/test_torch_fused_awq.py @@ -11,6 +11,7 @@ import pytest import torch from safetensors import safe_open +from tabulate import tabulate from gptqmodel.nn_modules.qlinear.awq_torch import AwqTorchQuantLinear from gptqmodel.nn_modules.qlinear.torch_fused_awq import TorchFusedAwqQuantLinear @@ -141,4 +142,19 @@ def test_torch_fused_awq_matches_checkpoint_module(): fused_out = fused_module(x) tol = 5e-3 + abs_diff = (fused_out - baseline).abs() + rel_diff = abs_diff / baseline.abs().clamp_min(1e-6) + summary = tabulate( + [ + [ + str(dtype), + f"{tol:.4g}", + f"{tol:.4g}", + f"{abs_diff.max().item():.4e}", + f"{rel_diff.max().item():.4e}", + ] + ], + headers=["dtype", "rtol", "atol", "abs_max", "rel_max"], + ) + print(summary) torch.testing.assert_close(fused_out, baseline, rtol=tol, atol=tol) From fa451ae07c53dc91108849afcb6e3c669f82d3a7 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 06:57:05 +0000 Subject: [PATCH 10/26] cleanup --- docs/torch_fused_int4_transformations.md | 43 +++++++++++++++++++ .../nn_modules/qlinear/torch_fused_awq.py | 38 ++++++++-------- 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/docs/torch_fused_int4_transformations.md b/docs/torch_fused_int4_transformations.md index 82ac5b665..2fd039e01 100644 --- a/docs/torch_fused_int4_transformations.md +++ b/docs/torch_fused_int4_transformations.md @@ -222,6 +222,46 @@ CPU: | uint8 lane | = [w1][w0] Both forms originate from the same `[O, I]` intermediate; the divergence is only in the final storage type, accompanying metadata, and fused operator ABI. +## AWQ compatibility (`torch_fused_awq.py`) + +`TorchFusedAwqQuantLinear` (`gptqmodel/nn_modules/qlinear/torch_fused_awq.py`) +reuses the CPU fused kernel while accepting checkpoints emitted by the AWQ +tooling. An "AWQ layout" is detected whenever `qweight` has shape +`[in_features, out_features / pack_factor]` (i.e., rows are raw input channels +instead of packed groups). When that layout is present, `_transform_cpu_awq` +performs an extra shim before the standard CPU packing runs: + +1. **Unpack AWQ rows** – `unpack_awq` expands each column lane into eight + outputs, yielding `iweight[int8][I, O]` and `izeros[int8][G, O]`. Both + tensors are then permuted with `reverse_awq_order` (the inverse of + `quantization.awq.utils.packing_utils.AWQ_ORDER`) so the columns match the + logical transformer layout expected by GPTQ. +2. **Normalize zero codes** – AWQ stores integer zero points per output channel. + `_transform_cpu_awq` converts them into floating offsets compatible with the + fused kernel using + `zeros_fp16 = (2^{bits-1} - izeros) * scales_fp32`, keeping the result in + `float16` so the metadata matches the original AWQ calibration statistics. +3. **Repack into GPTQ lanes** – The unpacked `iweight` matrix is reshaped to + `[I / pack_factor, pack_factor, O]` and re-packed along the `pack_factor` + dimension so each row once again represents eight inputs inside a 32-bit + lane. After this step `self.qweight` is indistinguishable from a GPTQ v2 + tensor, which means the regular `transform_cpu` logic can run unchanged. +4. **Delegate to the base CPU transform** – Calling `super().transform_cpu` + converts the temporary GPTQ-formatted `qweight` into the `[O, I/2]` `uint8` + int4pack layout and produces `scales_and_zeros` from the (temporarily zeroed) + metadata. +5. **Restore AWQ metadata** – Immediately afterward, the AWQ shim reinstates + the real `float16` scales and the converted zero offsets, then rebuilds + `scales_and_zeros = pack_scales_and_zeros(scales, zeros_fp16)`. This ensures + `_weight_int4pack_mm_for_cpu` receives the same affine parameters the AWQ + calibration solved for. + +Because the shim runs entirely on the CPU path, `TorchFusedAwqQuantLinear` +currently raises `NotImplementedError` when asked to transform an AWQ layout on +`xpu` devices. Inference still benefits from the fused CPU kernel; if the module +cannot be transformed (e.g., due to dtype mismatch or missing fused ops) it +falls back to the dense AWQ matmul defined in `_awq_weight_dense`. + ## Quick reference | Stage | Shape / dtype (int4) | Notes | @@ -232,6 +272,9 @@ in the final storage type, accompanying metadata, and fused operator ABI. | Packed CPU `qweight` | `[O, I / 2]`, `uint8` | High nibble = odd input, low nibble = even | | `qzeros` (post-XPU transform) | `[G, O]`, matches `scales` | Passed separately to the XPU fused op | | `scales_and_zeros` (CPU only) | `[G, O, 2]`, float | `[..., 0] = scale`, `[..., 1] = zero` | +| Raw AWQ `qweight` | `[I, O / 8]`, `int32` | Rows are single inputs packed across outputs | +| Unpacked AWQ weights/zeros | `iweight[I, O]`, `izeros[G, O]`, `int8` | Produced by `unpack_awq` + `reverse_awq_order` | +| AWQ zero offsets (final) | `[G, O]`, `float16` | `(2^{bits-1} - izeros) * scales`; merged via `pack_scales_and_zeros` | These details mirror the expectations of the Intel XPU and CPU fused matmul kernels, and the ASCII layouts above describe how rows/columns line up inside diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index fa1cddee6..39c356ea9 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -84,7 +84,7 @@ def _load_from_state_dict( awq_tensor = None if qweight_key in state_dict: candidate = state_dict[qweight_key] - if self._is_awq_qweight_tensor(candidate): + if self.is_awq_qweight_tensor(candidate): awq_tensor = candidate.to(self.pack_dtype).clone() placeholder = getattr(self, "qweight", None) if isinstance(placeholder, torch.Tensor) and placeholder.numel() == awq_tensor.numel(): @@ -115,23 +115,23 @@ def _load_from_state_dict( persistent=True, ) - def _awq_qweight_shape(self): + def awq_qweight_shape(self): pack_cols = self.out_features // self.pack_factor return self.in_features, pack_cols - def _is_awq_qweight_tensor(self, tensor: torch.Tensor) -> bool: + def is_awq_qweight_tensor(self, tensor: torch.Tensor) -> bool: if tensor is None or not torch.is_tensor(tensor) or tensor.dim() != 2: return False rows, cols = tensor.shape - exp_rows, exp_cols = self._awq_qweight_shape() + exp_rows, exp_cols = self.awq_qweight_shape() return rows == exp_rows and cols == exp_cols - def _uses_awq_layout(self) -> bool: + def uses_awq_layout(self) -> bool: qweight = getattr(self, "qweight", None) - return torch.is_tensor(qweight) and self._is_awq_qweight_tensor(qweight) + return torch.is_tensor(qweight) and self.is_awq_qweight_tensor(qweight) - def _transform_cpu_awq(self, dtype): - if not self._uses_awq_layout(): + def transform_cpu_awq(self, dtype): + if not self.uses_awq_layout(): raise RuntimeError("AWQ state unavailable for CPU transform.") src_scales = self.scales if src_scales.dtype != torch.float16: @@ -182,8 +182,8 @@ def _transform_cpu_awq(self, dtype): self.qzeros = zeros_fp16.to(dtype=dtype) self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) - def _awq_weight_dense(self, device, dtype): - if not self._uses_awq_layout(): + def awq_weight_dequantize(self, device, dtype): + if not self.uses_awq_layout(): raise RuntimeError("AWQ dense weight requested without cached tensors.") dense = dequantize_gemm( self.qweight, @@ -195,20 +195,20 @@ def _awq_weight_dense(self, device, dtype): return dense.to(device=device, dtype=dtype) def transform_cpu(self, dtype): - if self._uses_awq_layout(): - self._transform_cpu_awq(dtype) + if self.uses_awq_layout(): + self.transform_cpu_awq(dtype) return super().transform_cpu(dtype) def transform(self, dtype, device): - if device == "xpu" and self._uses_awq_layout(): + if device == "xpu" and self.uses_awq_layout(): raise NotImplementedError("TorchFusedAwqQuantLinear AWQ layout is currently supported on CPU only.") super().transform(dtype, device) def forward(self, x: torch.Tensor): out_shape = x.shape[:-1] + (self.out_features,) x_flat = x.reshape(-1, x.shape[-1]) - self._assert_supported_dtype(x_flat.dtype) + self.assert_supported_dtype(x_flat.dtype) if not self.training and not self.transformed and TORCH_HAS_FUSED_OPS: self.transform(x_flat.dtype, x_flat.device.type) self.transformed = True @@ -225,8 +225,8 @@ def forward(self, x: torch.Tensor): out = self._fused_op_forward(x_flat) else: log.debug("awq dense path") - if self._uses_awq_layout(): - weight = self._awq_weight_dense(device=x_flat.device, dtype=x_flat.dtype) + if self.uses_awq_layout(): + weight = self.awq_weight_dequantize(device=x_flat.device, dtype=x_flat.dtype) out = torch.matmul(x_flat, weight) else: num_itr = self.g_idx.shape[0] // x_flat.shape[-1] @@ -242,17 +242,17 @@ def forward(self, x: torch.Tensor): @torch.no_grad def _fused_op_forward(self, x): - awq_active = self._uses_awq_layout() + awq_active = self.uses_awq_layout() use_awq_fallback = awq_active and x.device.type == "cpu" if use_awq_fallback: log.debug("awq unfused fallback") - weight = self._awq_weight_dense(device=x.device, dtype=x.dtype) + weight = self.awq_weight_dequantize(device=x.device, dtype=x.dtype) return torch.matmul(x, weight) else: log.debug("awq fused") return super()._fused_op_forward(x) - def _assert_supported_dtype(self, dtype: torch.dtype): + def assert_supported_dtype(self, dtype: torch.dtype): if dtype not in self.SUPPORTS_DTYPES: supported = ", ".join(str(d) for d in self.SUPPORTS_DTYPES) raise TypeError( From 3ecbdbefea8560cd74e27901c90ac64d84563366 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 07:10:48 +0000 Subject: [PATCH 11/26] cleanup2 --- docs/torch_fused_int4_transformations.md | 18 ++-- .../nn_modules/qlinear/torch_fused_awq.py | 87 +++++++------------ 2 files changed, 38 insertions(+), 67 deletions(-) diff --git a/docs/torch_fused_int4_transformations.md b/docs/torch_fused_int4_transformations.md index 2fd039e01..5d82f19d2 100644 --- a/docs/torch_fused_int4_transformations.md +++ b/docs/torch_fused_int4_transformations.md @@ -226,10 +226,10 @@ in the final storage type, accompanying metadata, and fused operator ABI. `TorchFusedAwqQuantLinear` (`gptqmodel/nn_modules/qlinear/torch_fused_awq.py`) reuses the CPU fused kernel while accepting checkpoints emitted by the AWQ -tooling. An "AWQ layout" is detected whenever `qweight` has shape -`[in_features, out_features / pack_factor]` (i.e., rows are raw input channels -instead of packed groups). When that layout is present, `_transform_cpu_awq` -performs an extra shim before the standard CPU packing runs: +tooling. The module always expects `qweight` to be stored in the AWQ layout +`[in_features, out_features / pack_factor]`, meaning each row corresponds to a +single logical input channel. `transform_cpu_awq` performs a fixed shim before +the standard CPU packing runs: 1. **Unpack AWQ rows** – `unpack_awq` expands each column lane into eight outputs, yielding `iweight[int8][I, O]` and `izeros[int8][G, O]`. Both @@ -237,7 +237,7 @@ performs an extra shim before the standard CPU packing runs: `quantization.awq.utils.packing_utils.AWQ_ORDER`) so the columns match the logical transformer layout expected by GPTQ. 2. **Normalize zero codes** – AWQ stores integer zero points per output channel. - `_transform_cpu_awq` converts them into floating offsets compatible with the + `transform_cpu_awq` converts them into floating offsets compatible with the fused kernel using `zeros_fp16 = (2^{bits-1} - izeros) * scales_fp32`, keeping the result in `float16` so the metadata matches the original AWQ calibration statistics. @@ -257,10 +257,10 @@ performs an extra shim before the standard CPU packing runs: calibration solved for. Because the shim runs entirely on the CPU path, `TorchFusedAwqQuantLinear` -currently raises `NotImplementedError` when asked to transform an AWQ layout on -`xpu` devices. Inference still benefits from the fused CPU kernel; if the module -cannot be transformed (e.g., due to dtype mismatch or missing fused ops) it -falls back to the dense AWQ matmul defined in `_awq_weight_dense`. +currently raises `NotImplementedError` when asked to run the fused transform on +`xpu` devices. If the module has not been transformed yet (or fused ops are +unavailable), inference falls back to the dense AWQ matmul computed by +`awq_weight_dequantize`, which simply dequantizes the cached AWQ tensors on the fly. ## Quick reference diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index 39c356ea9..de1f5df05 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -84,19 +84,27 @@ def _load_from_state_dict( awq_tensor = None if qweight_key in state_dict: candidate = state_dict[qweight_key] - if self.is_awq_qweight_tensor(candidate): - awq_tensor = candidate.to(self.pack_dtype).clone() - placeholder = getattr(self, "qweight", None) - if isinstance(placeholder, torch.Tensor) and placeholder.numel() == awq_tensor.numel(): - state_dict[qweight_key] = torch.zeros_like(placeholder) - else: - rows = max(1, self.in_features // self.pack_factor) - cols = self.out_features - state_dict[qweight_key] = torch.zeros( - (rows, cols), - dtype=self.pack_dtype, - device=awq_tensor.device, - ) + if not torch.is_tensor(candidate): + raise TypeError(f"{qweight_key} must be a tensor to load AWQ weights.") + awq_tensor = candidate.to(self.pack_dtype).clone() + expected_rows = self.in_features + expected_cols = max(1, self.out_features // self.pack_factor) + if awq_tensor.shape != (expected_rows, expected_cols): + raise ValueError( + f"{self.__class__.__name__} expects AWQ qweight shape " + f"{(expected_rows, expected_cols)}, but received {tuple(awq_tensor.shape)}." + ) + placeholder = getattr(self, "qweight", None) + if isinstance(placeholder, torch.Tensor) and placeholder.numel() == awq_tensor.numel(): + state_dict[qweight_key] = torch.zeros_like(placeholder) + else: + rows = max(1, self.in_features // self.pack_factor) + cols = self.out_features + state_dict[qweight_key] = torch.zeros( + (rows, cols), + dtype=self.pack_dtype, + device=awq_tensor.device, + ) super()._load_from_state_dict( state_dict, prefix, @@ -115,24 +123,7 @@ def _load_from_state_dict( persistent=True, ) - def awq_qweight_shape(self): - pack_cols = self.out_features // self.pack_factor - return self.in_features, pack_cols - - def is_awq_qweight_tensor(self, tensor: torch.Tensor) -> bool: - if tensor is None or not torch.is_tensor(tensor) or tensor.dim() != 2: - return False - rows, cols = tensor.shape - exp_rows, exp_cols = self.awq_qweight_shape() - return rows == exp_rows and cols == exp_cols - - def uses_awq_layout(self) -> bool: - qweight = getattr(self, "qweight", None) - return torch.is_tensor(qweight) and self.is_awq_qweight_tensor(qweight) - def transform_cpu_awq(self, dtype): - if not self.uses_awq_layout(): - raise RuntimeError("AWQ state unavailable for CPU transform.") src_scales = self.scales if src_scales.dtype != torch.float16: src_scales = src_scales.to(torch.float16) @@ -183,8 +174,6 @@ def transform_cpu_awq(self, dtype): self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) def awq_weight_dequantize(self, device, dtype): - if not self.uses_awq_layout(): - raise RuntimeError("AWQ dense weight requested without cached tensors.") dense = dequantize_gemm( self.qweight, self.qzeros, @@ -195,15 +184,14 @@ def awq_weight_dequantize(self, device, dtype): return dense.to(device=device, dtype=dtype) def transform_cpu(self, dtype): - if self.uses_awq_layout(): - self.transform_cpu_awq(dtype) - return - super().transform_cpu(dtype) + self.transform_cpu_awq(dtype) def transform(self, dtype, device): - if device == "xpu" and self.uses_awq_layout(): - raise NotImplementedError("TorchFusedAwqQuantLinear AWQ layout is currently supported on CPU only.") - super().transform(dtype, device) + if device != "cpu": + raise NotImplementedError( + "TorchFusedAwqQuantLinear only supports fused transforms on CPU devices." + ) + self.transform_cpu(dtype) def forward(self, x: torch.Tensor): out_shape = x.shape[:-1] + (self.out_features,) @@ -225,13 +213,8 @@ def forward(self, x: torch.Tensor): out = self._fused_op_forward(x_flat) else: log.debug("awq dense path") - if self.uses_awq_layout(): - weight = self.awq_weight_dequantize(device=x_flat.device, dtype=x_flat.dtype) - out = torch.matmul(x_flat, weight) - else: - num_itr = self.g_idx.shape[0] // x_flat.shape[-1] - weights = self.dequantize_weight(num_itr=num_itr).to(x_flat.dtype) - out = torch.matmul(x_flat, weights) + weight = self.awq_weight_dequantize(device=x_flat.device, dtype=x_flat.dtype) + out = torch.matmul(x_flat, weight) if self.bias is not None: out.add_(self.bias) @@ -240,18 +223,6 @@ def forward(self, x: torch.Tensor): return out.reshape(out_shape) - @torch.no_grad - def _fused_op_forward(self, x): - awq_active = self.uses_awq_layout() - use_awq_fallback = awq_active and x.device.type == "cpu" - if use_awq_fallback: - log.debug("awq unfused fallback") - weight = self.awq_weight_dequantize(device=x.device, dtype=x.dtype) - return torch.matmul(x, weight) - else: - log.debug("awq fused") - return super()._fused_op_forward(x) - def assert_supported_dtype(self, dtype: torch.dtype): if dtype not in self.SUPPORTS_DTYPES: supported = ", ".join(str(d) for d in self.SUPPORTS_DTYPES) From 8cc736452434e57e6182f2723d90ce4d79c09c0b Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 07:32:52 +0000 Subject: [PATCH 12/26] cleanup 3 --- .../nn_modules/qlinear/torch_fused_awq.py | 114 ++++++++++-------- tests/test_kernel_output_awq.py | 64 +++++++++- 2 files changed, 125 insertions(+), 53 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index de1f5df05..e4b6d5483 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import math + import torch from ...adapter.adapter import Adapter @@ -21,7 +23,7 @@ class TorchFusedAwqQuantLinear(TorchFusedQuantLinear): - """Torch fused AWQ variant that reuses the GPTQ fused kernels via CPU int4 packing.""" + """Torch fused AWQ variant based on GPTQ fused kernels via CPU int4 packing.""" QUANT_TYPE = "torch_fused_awq" SUPPORTS_BITS = TorchFusedQuantLinear.SUPPORTS_BITS @@ -66,62 +68,72 @@ def __init__( bias=bias, pack_dtype=pack_dtype, adapter=adapter, - register_buffers=register_buffers, + register_buffers=False, **kwargs, ) + if register_buffers: + qweight_shape = self._awq_qweight_shape() + group_size = max(int(self.group_size), 1) + group_rows = self._awq_group_count() + pack_cols = qweight_shape[1] - def _load_from_state_dict( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ): - qweight_key = prefix + "qweight" - awq_tensor = None - if qweight_key in state_dict: - candidate = state_dict[qweight_key] - if not torch.is_tensor(candidate): - raise TypeError(f"{qweight_key} must be a tensor to load AWQ weights.") - awq_tensor = candidate.to(self.pack_dtype).clone() - expected_rows = self.in_features - expected_cols = max(1, self.out_features // self.pack_factor) - if awq_tensor.shape != (expected_rows, expected_cols): - raise ValueError( - f"{self.__class__.__name__} expects AWQ qweight shape " - f"{(expected_rows, expected_cols)}, but received {tuple(awq_tensor.shape)}." - ) - placeholder = getattr(self, "qweight", None) - if isinstance(placeholder, torch.Tensor) and placeholder.numel() == awq_tensor.numel(): - state_dict[qweight_key] = torch.zeros_like(placeholder) - else: - rows = max(1, self.in_features // self.pack_factor) - cols = self.out_features - state_dict[qweight_key] = torch.zeros( - (rows, cols), - dtype=self.pack_dtype, - device=awq_tensor.device, - ) - super()._load_from_state_dict( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ) - if awq_tensor is not None: - state_dict[qweight_key] = awq_tensor - device = getattr(self, "qweight", awq_tensor).device self.register_buffer( "qweight", - awq_tensor.to(device=device, dtype=self.pack_dtype).contiguous(), - persistent=True, + torch.zeros(qweight_shape, dtype=self.pack_dtype), ) + self.register_buffer( + "qzeros", + torch.zeros((group_rows, pack_cols), dtype=self.pack_dtype), + ) + self.register_buffer( + "scales", + torch.zeros((group_rows, self.out_features), dtype=torch.float16), + ) + g_idx = torch.arange(self.in_features, dtype=torch.int32) // group_size + self.register_buffer("g_idx", g_idx) + if bias: + self.register_buffer("bias", torch.zeros(self.out_features, dtype=torch.float16)) + else: + self.bias = None + + def _awq_qweight_shape(self): + pack_cols = max(1, self.out_features // self.pack_factor) + return self.in_features, pack_cols + + def _awq_group_count(self): + group_size = max(int(self.group_size), 1) + return max(1, math.ceil(self.in_features / group_size)) + + # def _load_from_state_dict( + # self, + # state_dict, + # prefix, + # local_metadata, + # strict, + # missing_keys, + # unexpected_keys, + # error_msgs, + # ): + # self.register_awq_buffers() + # super()._load_from_state_dict( + # state_dict, + # prefix, + # local_metadata, + # strict, + # missing_keys, + # unexpected_keys, + # error_msgs, + # ) + # qweight = getattr(self, "qweight", None) + # if torch.is_tensor(qweight): + # expected_shape = self._awq_qweight_shape() + # if tuple(qweight.shape) != expected_shape: + # raise ValueError( + # f"{self.__class__.__name__} only loads AWQ-formatted qweight tensors with " + # f"shape {expected_shape}, but received {tuple(qweight.shape)}." + # ) + # if qweight.dtype != self.pack_dtype: + # self.qweight = qweight.to(dtype=self.pack_dtype).contiguous() def transform_cpu_awq(self, dtype): src_scales = self.scales diff --git a/tests/test_kernel_output_awq.py b/tests/test_kernel_output_awq.py index 9e446beeb..eecc8f8d7 100644 --- a/tests/test_kernel_output_awq.py +++ b/tests/test_kernel_output_awq.py @@ -22,6 +22,7 @@ marlin_import_exception, ) from gptqmodel.nn_modules.qlinear.awq_torch import AwqTorchQuantLinear +from gptqmodel.nn_modules.qlinear.torch_fused_awq import TorchFusedAwqQuantLinear from gptqmodel.utils.marlin import marlin_make_workspace_new @@ -30,6 +31,7 @@ log = LogBar.shared() DEVICE = torch.device("cuda:0") +CPU_DEVICE = torch.device("cpu") GREEN = "\033[32m" RED = "\033[31m" @@ -50,6 +52,7 @@ class TestAwqKernelOutput(unittest.TestCase): (BACKEND.GEMM, torch.float16, 0.004), # (BACKEND.GEMM, torch.bfloat16, 0.05), (BACKEND.MARLIN, torch.float16, 0.006), + (BACKEND.TORCH_FUSED_AWQ, torch.float16, 0.004), # (BACKEND.MARLIN, torch.bfloat16, 0.05), ] @@ -92,6 +95,16 @@ def setUpClass(cls) -> None: qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu ) + try: + cls.modules[BACKEND.TORCH_FUSED_AWQ] = cls._build_torch_fused_awq_module( + qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu + ) + except Exception as exc: + cls.backend_skip_reason[BACKEND.TORCH_FUSED_AWQ] = ( + f"Torch fused AWQ kernel unavailable: {exc}" + ) + cls.modules[BACKEND.TORCH_FUSED_AWQ] = None + base_inputs = cls._generate_inputs() cls.inputs: Dict[torch.dtype, List[torch.Tensor]] = {} cls.reference_outputs: Dict[torch.dtype, List[torch.Tensor]] = {} @@ -247,6 +260,35 @@ def _build_torch_awq_module( module.post_init() return module + @classmethod + def _build_torch_fused_awq_module( + cls, + qweight_cpu: torch.Tensor, + qzeros_cpu: torch.Tensor, + scales_cpu: torch.Tensor, + bias_cpu: torch.Tensor, + ) -> TorchFusedAwqQuantLinear: + module = TorchFusedAwqQuantLinear( + bits=cls.BITS, + group_size=cls.GROUP_SIZE, + sym=True, + desc_act=False, + in_features=cls.in_features, + out_features=cls.out_features, + bias=True, + adapter=None, + register_buffers=True, + ).to(CPU_DEVICE) + + module.qweight.copy_(qweight_cpu.to(CPU_DEVICE)) + module.qzeros.copy_(qzeros_cpu.to(CPU_DEVICE)) + module.scales.copy_(scales_cpu.to(torch.float16).to(CPU_DEVICE)) + module.bias.copy_(bias_cpu.to(torch.float16).to(CPU_DEVICE)) + + module.eval() + module.post_init() + return module + @classmethod def _generate_inputs(cls) -> List[torch.Tensor]: large_shapes = [(4, 32), (2, 64), (1, 96)] @@ -288,19 +330,37 @@ def _forward( *, compute_dtype: Optional[torch.dtype] = None, output_dtype: Optional[torch.dtype] = None, + target_device: Optional[torch.device] = None, ) -> List[torch.Tensor]: + if target_device is None: + target_device = cls._infer_module_device(module) outputs: List[torch.Tensor] = [] with torch.inference_mode(): for tensor in inputs: local_tensor = tensor - if compute_dtype is not None and tensor.dtype != compute_dtype: - local_tensor = tensor.to(dtype=compute_dtype) + if local_tensor.device != target_device: + local_tensor = local_tensor.to(device=target_device) + if compute_dtype is not None and local_tensor.dtype != compute_dtype: + local_tensor = local_tensor.to(dtype=compute_dtype) result = module(local_tensor) if output_dtype is not None and result.dtype != output_dtype: result = result.to(dtype=output_dtype) outputs.append(result.detach().cpu()) return outputs + @staticmethod + def _infer_module_device(module: torch.nn.Module) -> torch.device: + try: + tensor = next(module.parameters()) + return tensor.device + except StopIteration: + pass + try: + tensor = next(module.buffers()) + return tensor.device + except StopIteration: + return torch.device("cpu") + def _maybe_skip_backend(self, backend: BACKEND) -> None: reason = self.backend_skip_reason.get(backend) if reason: From a4f2a694456a8043ae006736d8f6ae40ee1b5352 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 07:34:54 +0000 Subject: [PATCH 13/26] remove unused --- .../nn_modules/qlinear/torch_fused_awq.py | 41 +++---------------- 1 file changed, 5 insertions(+), 36 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index e4b6d5483..3ce0cd4af 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -72,9 +72,9 @@ def __init__( **kwargs, ) if register_buffers: - qweight_shape = self._awq_qweight_shape() + qweight_shape = self.awq_qweight_shape() group_size = max(int(self.group_size), 1) - group_rows = self._awq_group_count() + group_rows = self.awq_group_count() pack_cols = qweight_shape[1] self.register_buffer( @@ -96,52 +96,21 @@ def __init__( else: self.bias = None - def _awq_qweight_shape(self): + def awq_qweight_shape(self): pack_cols = max(1, self.out_features // self.pack_factor) return self.in_features, pack_cols - def _awq_group_count(self): + def awq_group_count(self): group_size = max(int(self.group_size), 1) return max(1, math.ceil(self.in_features / group_size)) - # def _load_from_state_dict( - # self, - # state_dict, - # prefix, - # local_metadata, - # strict, - # missing_keys, - # unexpected_keys, - # error_msgs, - # ): - # self.register_awq_buffers() - # super()._load_from_state_dict( - # state_dict, - # prefix, - # local_metadata, - # strict, - # missing_keys, - # unexpected_keys, - # error_msgs, - # ) - # qweight = getattr(self, "qweight", None) - # if torch.is_tensor(qweight): - # expected_shape = self._awq_qweight_shape() - # if tuple(qweight.shape) != expected_shape: - # raise ValueError( - # f"{self.__class__.__name__} only loads AWQ-formatted qweight tensors with " - # f"shape {expected_shape}, but received {tuple(qweight.shape)}." - # ) - # if qweight.dtype != self.pack_dtype: - # self.qweight = qweight.to(dtype=self.pack_dtype).contiguous() - def transform_cpu_awq(self, dtype): src_scales = self.scales if src_scales.dtype != torch.float16: src_scales = src_scales.to(torch.float16) src_scales = src_scales.contiguous() - # Cache unpacked AWQ tensors + # Unpack AWQ tensors iweight, izeros = unpack_awq(self.qweight, self.qzeros, self.bits) iweight, izeros = reverse_awq_order(iweight, izeros, self.bits) max_val = (1 << self.bits) - 1 From 70e4b0455965526e023846db44036b63a9bba316 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 07:38:32 +0000 Subject: [PATCH 14/26] inline methods --- gptqmodel/nn_modules/qlinear/torch_fused_awq.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index 3ce0cd4af..e4534ac3b 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -72,10 +72,12 @@ def __init__( **kwargs, ) if register_buffers: - qweight_shape = self.awq_qweight_shape() + # AWQ packs each input row into pack_factor-wide columns for int4 lanes. + pack_cols = max(1, self.out_features // self.pack_factor) + qweight_shape = (self.in_features, pack_cols) group_size = max(int(self.group_size), 1) - group_rows = self.awq_group_count() - pack_cols = qweight_shape[1] + # Each group holds group_size input rows; ceil ensures trailing rows are captured. + group_rows = max(1, math.ceil(self.in_features / group_size)) self.register_buffer( "qweight", @@ -96,14 +98,6 @@ def __init__( else: self.bias = None - def awq_qweight_shape(self): - pack_cols = max(1, self.out_features // self.pack_factor) - return self.in_features, pack_cols - - def awq_group_count(self): - group_size = max(int(self.group_size), 1) - return max(1, math.ceil(self.in_features / group_size)) - def transform_cpu_awq(self, dtype): src_scales = self.scales if src_scales.dtype != torch.float16: From 158002f0f52c0160610000a11c019d9b381c7024 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 07:40:27 +0000 Subject: [PATCH 15/26] remove debug logs --- gptqmodel/nn_modules/qlinear/torch_fused_awq.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index e4534ac3b..fe883a390 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -184,10 +184,10 @@ def forward(self, x: torch.Tensor): config.max_autotune = True if self.transformed: - log.debug("awq calling fused op") + # log.debug("awq calling fused op") out = self._fused_op_forward(x_flat) else: - log.debug("awq dense path") + # log.debug("awq dense path") weight = self.awq_weight_dequantize(device=x_flat.device, dtype=x_flat.dtype) out = torch.matmul(x_flat, weight) From 3b02a37304ed84efa377666dbc09f4da35ceef0d Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 10:44:59 +0000 Subject: [PATCH 16/26] avoid clone --- gptqmodel/nn_modules/qlinear/torch_fused.py | 4 ++-- gptqmodel/nn_modules/qlinear/torch_fused_awq.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused.py b/gptqmodel/nn_modules/qlinear/torch_fused.py index 085ef9a5c..ab31ea47b 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused.py @@ -173,7 +173,7 @@ def train(self, mode: bool = True): return super().train(mode=mode) def transform_xpu(self, dtype): - self.scales = self.scales.clone().to(dtype).contiguous() + self.scales = self.scales.to(dtype).contiguous() # Unpack qzeros zeros = torch.bitwise_right_shift( torch.unsqueeze(self.qzeros, 2).expand(-1, -1, self.pack_factor), @@ -201,7 +201,7 @@ def transform_xpu(self, dtype): self.qzeros = zeros.contiguous() def transform_cpu(self, dtype): - self.scales = self.scales.clone().to(dtype).contiguous() + self.scales = self.scales.to(dtype).contiguous() # Unpack and reorder qweight weight = torch.bitwise_and( torch.bitwise_right_shift( diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index fe883a390..89e7ec722 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -76,7 +76,7 @@ def __init__( pack_cols = max(1, self.out_features // self.pack_factor) qweight_shape = (self.in_features, pack_cols) group_size = max(int(self.group_size), 1) - # Each group holds group_size input rows; ceil ensures trailing rows are captured. + # Each group holds group_size input rows; ceil ensures remaining rows are included. group_rows = max(1, math.ceil(self.in_features / group_size)) self.register_buffer( @@ -112,7 +112,7 @@ def transform_cpu_awq(self, dtype): if izeros is not None: izeros = torch.bitwise_and(izeros, max_val) - # Precompute the per-group zero offsets (kept in float16 for parity with AWQ reference) + # Precompute the per-group zero offsets scale_fp16 = src_scales.clone() scale_fp32 = scale_fp16.to(torch.float32) zero_offset = 1 << (self.bits - 1) @@ -140,7 +140,6 @@ def transform_cpu_awq(self, dtype): self.qweight = gptq_qweight.contiguous() # Reuse the GPTQ CPU transformation to convert into int4pack layout. - self.scales = scale_fp16.clone() super().transform_cpu(dtype) # Restore AWQ-specific scale/zero metadata for the fused op. From 1ccab24bdcfc6bb74b468bb1951eff7d8b726442 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 11:00:44 +0000 Subject: [PATCH 17/26] merge code with gptq torch fused --- gptqmodel/nn_modules/qlinear/torch_fused.py | 8 +++++--- gptqmodel/nn_modules/qlinear/torch_fused_awq.py | 8 ++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused.py b/gptqmodel/nn_modules/qlinear/torch_fused.py index ab31ea47b..160076411 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused.py @@ -200,7 +200,7 @@ def transform_xpu(self, dtype): self.qweight = packed.contiguous() self.qzeros = zeros.contiguous() - def transform_cpu(self, dtype): + def transform_cpu(self, dtype, do_scales_and_zeros: bool = True): self.scales = self.scales.to(dtype).contiguous() # Unpack and reorder qweight weight = torch.bitwise_and( @@ -213,8 +213,10 @@ def transform_cpu(self, dtype): ret_idx = self._build_ret_idx() weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]).index_select(0, ret_idx).t() self.qweight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(weight.int(), 1).contiguous() - self.qzeros = torch.zeros_like(self.scales).contiguous() - self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) + + if do_scales_and_zeros: + self.qzeros = torch.zeros_like(self.scales).contiguous() + self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) def transform(self, dtype, device): if device == "xpu": diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index 89e7ec722..6ccb0c22a 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -140,11 +140,11 @@ def transform_cpu_awq(self, dtype): self.qweight = gptq_qweight.contiguous() # Reuse the GPTQ CPU transformation to convert into int4pack layout. - super().transform_cpu(dtype) + super().transform_cpu(dtype, do_scales_and_zeros=False) - # Restore AWQ-specific scale/zero metadata for the fused op. - self.scales = scale_fp16.to(dtype=dtype) - self.qzeros = zeros_fp16.to(dtype=dtype) + # AWQ-specific scale/zero metadata for the fused op. + self.scales = scale_fp16.to(dtype=dtype).contiguous() + self.qzeros = zeros_fp16.to(dtype=dtype).contiguous() self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) def awq_weight_dequantize(self, device, dtype): From 0ab12fdd7234302df980ddc66fce654054300bf0 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 11:12:34 +0000 Subject: [PATCH 18/26] cleanup, add XPU todo --- gptqmodel/nn_modules/qlinear/torch_fused_awq.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index 6ccb0c22a..78094b945 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -112,8 +112,8 @@ def transform_cpu_awq(self, dtype): if izeros is not None: izeros = torch.bitwise_and(izeros, max_val) - # Precompute the per-group zero offsets - scale_fp16 = src_scales.clone() + # Precompute the per-group zero offsets; reuse the contiguous fp16 copy to avoid extra clones. + scale_fp16 = src_scales scale_fp32 = scale_fp16.to(torch.float32) zero_offset = 1 << (self.bits - 1) zeros_fp16 = (zero_offset - izeros.reshape_as(scale_fp32)).to(dtype=scale_fp32.dtype) @@ -148,18 +148,18 @@ def transform_cpu_awq(self, dtype): self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) def awq_weight_dequantize(self, device, dtype): - dense = dequantize_gemm( + return dequantize_gemm( self.qweight, self.qzeros, self.scales, self.bits, self.group_size, - ).to(device=device, dtype=torch.float32) - return dense.to(device=device, dtype=dtype) + ).to(device=device, dtype=dtype) def transform_cpu(self, dtype): self.transform_cpu_awq(dtype) + # TODO: add XPU def transform(self, dtype, device): if device != "cpu": raise NotImplementedError( @@ -167,6 +167,7 @@ def transform(self, dtype, device): ) self.transform_cpu(dtype) + # TODO: add XPU def forward(self, x: torch.Tensor): out_shape = x.shape[:-1] + (self.out_features,) x_flat = x.reshape(-1, x.shape[-1]) From 507abd219174d7b351b257a98cd4b90aaead4279 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 11:42:20 +0000 Subject: [PATCH 19/26] make sure to test both xpu and cpu --- tests/test_kernel_output_torch_fused.py | 79 +++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/tests/test_kernel_output_torch_fused.py b/tests/test_kernel_output_torch_fused.py index f299046db..3a7c858be 100644 --- a/tests/test_kernel_output_torch_fused.py +++ b/tests/test_kernel_output_torch_fused.py @@ -19,6 +19,10 @@ log = LogBar.shared() +def _xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + class TestKernelOutput(unittest.TestCase): model_path = "sliuau/llama3.2-1b-4bit-group128" # hf "sliuau/llama3.2-1b-4bit-group128" target_qliner_map = { @@ -88,3 +92,78 @@ class TestKernelOutputXPU(TestKernelOutput): class TestKernelOutputXPUBFloat16(TestKernelOutputXPU): dtype = torch.bfloat16 + + +class TestTorchFusedKernelDevices(unittest.TestCase): + model_path = TestKernelOutput.model_path + target_qliner_map = TestKernelOutput.target_qliner_map + target = TestKernelOutput.target + dtype = torch.float16 + m = [1, 16, 64, 256] + k = 2048 + input_samples_each_size = 5 + r_tolerance = 0.0 + a_tolerance = 0.01 + reference_backend = BACKEND.TORCH + reference_device = "cpu" + + @classmethod + def setUpClass(cls): + torch.manual_seed(0) + cls.inputs = [] + for dim_0 in cls.m: + for _ in range(cls.input_samples_each_size): + cls.inputs.append(torch.rand((dim_0, cls.k), dtype=cls.dtype)) + + reference_model = GPTQModel.load( + cls.model_path, + backend=cls.reference_backend, + device=cls.reference_device, + dtype=cls.dtype, + ) + cls.reference_outputs = [ + cls.forward(reference_model, sample, cls.reference_backend) + for sample in cls.inputs + ] + del reference_model + + @classmethod + def forward(cls, model, x, backend: BACKEND): + target_qlinear_cls = cls.target_qliner_map[backend] + modules = find_modules(model.model, layers=[target_qlinear_cls]) + result = None + for name, module in modules.items(): + if name == cls.target: + result = module(x.to(model.device)) + break + + assert result is not None + return result + + def assert_on_mismatch(self, a: Tensor, b: Tensor, rtol=0.00005, atol=0.00005): + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + + @parameterized.expand([ + ("cpu", "cpu"), + ("xpu", "xpu:0"), + ]) + def test_torch_fused_matches_cpu_reference(self, _name: str, device: str): + if device.startswith("xpu") and not _xpu_available(): + self.skipTest("Test requires XPU") + + model = GPTQModel.load( + self.model_path, + backend=BACKEND.TORCH_FUSED, + device=device, + dtype=self.dtype, + ) + for idx, sample in enumerate(self.inputs): + model_input = sample.to(model.device) + fused_out = self.forward(model, model_input, BACKEND.TORCH_FUSED) + reference = self.reference_outputs[idx] + self.assert_on_mismatch( + reference.to("cpu"), + fused_out.to("cpu"), + self.r_tolerance, + self.a_tolerance, + ) From 573b51b9ff142458a6678615776b168b05a89443 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 11:58:38 +0000 Subject: [PATCH 20/26] xpu tests --- .../nn_modules/qlinear/torch_fused_awq.py | 82 +++++++++++----- tests/test_kernel_output_awq.py | 98 +++++++++++++++---- 2 files changed, 133 insertions(+), 47 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index 78094b945..030e1cdf9 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -98,55 +98,87 @@ def __init__( else: self.bias = None - def transform_cpu_awq(self, dtype): + def _prepare_awq_fused_tensors(self): src_scales = self.scales if src_scales.dtype != torch.float16: src_scales = src_scales.to(torch.float16) src_scales = src_scales.contiguous() - # Unpack AWQ tensors iweight, izeros = unpack_awq(self.qweight, self.qzeros, self.bits) iweight, izeros = reverse_awq_order(iweight, izeros, self.bits) max_val = (1 << self.bits) - 1 iweight = torch.bitwise_and(iweight, max_val) - if izeros is not None: - izeros = torch.bitwise_and(izeros, max_val) + if izeros is None: + raise RuntimeError("AWQ fused kernel requires zero points.") + izeros = torch.bitwise_and(izeros, max_val) - # Precompute the per-group zero offsets; reuse the contiguous fp16 copy to avoid extra clones. scale_fp16 = src_scales scale_fp32 = scale_fp16.to(torch.float32) zero_offset = 1 << (self.bits - 1) zeros_fp16 = (zero_offset - izeros.reshape_as(scale_fp32)).to(dtype=scale_fp32.dtype) zeros_fp16 = (zeros_fp16 * scale_fp32).to(torch.float16) - # Repack AWQ-per-output rows into GPTQ-style per-input packs so the base - # TorchFusedQuantLinear path can handle the conversion to int4pack. + gptq_qweight = self._pack_awq_qweight(iweight) + gptq_qzeros = self._pack_awq_qzeros(izeros) + return gptq_qweight, gptq_qzeros, scale_fp16, zeros_fp16 + + def _pack_awq_qweight(self, iweight: torch.Tensor) -> torch.Tensor: in_features, out_features = iweight.shape pack_factor = int(self.pack_factor) if in_features % pack_factor != 0: raise ValueError( f"AWQ in_features={in_features} must be divisible by pack_factor={pack_factor}." ) - rows = iweight.view(in_features // pack_factor, pack_factor, out_features) - gptq_qweight = torch.zeros( + packed = torch.zeros( (rows.shape[0], out_features), dtype=self.pack_dtype, device=iweight.device, ) - bit_shifts = list(range(0, pack_factor * self.bits, self.bits)) - for lane, shift in enumerate(bit_shifts): - gptq_qweight |= rows[:, lane, :].to(torch.int32) << shift - self.qweight = gptq_qweight.contiguous() + shifts = range(0, pack_factor * self.bits, self.bits) + for lane, shift in enumerate(shifts): + packed |= rows[:, lane, :].to(torch.int32) << shift + return packed.contiguous() - # Reuse the GPTQ CPU transformation to convert into int4pack layout. - super().transform_cpu(dtype, do_scales_and_zeros=False) + def _pack_awq_qzeros(self, izeros: torch.Tensor) -> torch.Tensor: + pack_factor = int(self.pack_factor) + if izeros.shape[1] % pack_factor != 0: + raise ValueError( + f"AWQ qzeros dimension {izeros.shape[1]} must be divisible by pack_factor={pack_factor}." + ) + cols = izeros.view(izeros.shape[0], izeros.shape[1] // pack_factor, pack_factor) + packed = torch.zeros( + (cols.shape[0], cols.shape[1]), + dtype=self.pack_dtype, + device=izeros.device, + ) + shifts = range(0, pack_factor * self.bits, self.bits) + for lane, shift in enumerate(shifts): + packed |= cols[:, :, lane].to(torch.int32) << shift + return packed.contiguous() - # AWQ-specific scale/zero metadata for the fused op. - self.scales = scale_fp16.to(dtype=dtype).contiguous() - self.qzeros = zeros_fp16.to(dtype=dtype).contiguous() + def transform_cpu_awq(self, dtype): + gptq_qweight, gptq_qzeros, scale_fp16, zeros_fp16 = self._prepare_awq_fused_tensors() + self.qweight = gptq_qweight + self.qzeros = gptq_qzeros + super().transform_cpu(dtype, do_scales_and_zeros=False) + device = self.qweight.device + self.scales = scale_fp16.to(device=device, dtype=dtype).contiguous() + self.qzeros = zeros_fp16.to(device=device, dtype=dtype).contiguous() self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) + def transform_xpu_awq(self, dtype): + gptq_qweight, gptq_qzeros, scale_fp16, zeros_fp16 = self._prepare_awq_fused_tensors() + self.qweight = gptq_qweight + self.qzeros = gptq_qzeros + super().transform_xpu(dtype) + device = self.qweight.device + self.scales = scale_fp16.to(device=device, dtype=dtype).contiguous() + self.qzeros = zeros_fp16.to(device=device, dtype=dtype).contiguous() + + def transform_cpu(self, dtype): + self.transform_cpu_awq(dtype) + def awq_weight_dequantize(self, device, dtype): return dequantize_gemm( self.qweight, @@ -156,18 +188,16 @@ def awq_weight_dequantize(self, device, dtype): self.group_size, ).to(device=device, dtype=dtype) - def transform_cpu(self, dtype): - self.transform_cpu_awq(dtype) - - # TODO: add XPU def transform(self, dtype, device): - if device != "cpu": + if device == "cpu": + self.transform_cpu(dtype) + elif device == "xpu": + self.transform_xpu_awq(dtype) + else: raise NotImplementedError( - "TorchFusedAwqQuantLinear only supports fused transforms on CPU devices." + "TorchFusedAwqQuantLinear only supports fused transforms on CPU or XPU devices." ) - self.transform_cpu(dtype) - # TODO: add XPU def forward(self, x: torch.Tensor): out_shape = x.shape[:-1] + (self.out_features,) x_flat = x.reshape(-1, x.shape[-1]) diff --git a/tests/test_kernel_output_awq.py b/tests/test_kernel_output_awq.py index eecc8f8d7..18ce04e2d 100644 --- a/tests/test_kernel_output_awq.py +++ b/tests/test_kernel_output_awq.py @@ -38,6 +38,10 @@ RESET = "\033[0m" +def _xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + class TestAwqKernelOutput(unittest.TestCase): MODEL_PATH = Path("/monster/data/model/deepseek-r1-distill-qwen-7b-awq") TARGET = "model.layers.20.self_attn.v_proj" @@ -58,13 +62,14 @@ class TestAwqKernelOutput(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - if not torch.cuda.is_available(): - raise unittest.SkipTest("CUDA is required for AWQ kernel output checks.") - - cls.device = DEVICE + cls.cuda_available = torch.cuda.is_available() + cls.device = DEVICE if cls.cuda_available else CPU_DEVICE cls.log = log cls._weight_map = cls._load_weight_map() cls.backend_skip_reason: Dict[BACKEND, str] = {} + if not cls.cuda_available: + cls.backend_skip_reason[BACKEND.GEMM] = "CUDA is required for GEMM backend." + cls.backend_skip_reason[BACKEND.MARLIN] = "CUDA is required for AWQ Marlin kernel." try: tensors = cls._load_awq_tensors(cls.TARGET) @@ -77,6 +82,10 @@ def setUpClass(cls) -> None: scales_cpu, bias_cpu, ) = tensors + cls.qweight_cpu = qweight_cpu + cls.qzeros_cpu = qzeros_cpu + cls.scales_cpu = scales_cpu + cls.bias_cpu = bias_cpu cls.in_features = qweight_cpu.shape[0] cls.out_features = qweight_cpu.shape[1] * (32 // cls.BITS) @@ -87,12 +96,16 @@ def setUpClass(cls) -> None: qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu ) - cls.modules[BACKEND.GEMM] = cls._build_gemm_module( - qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu + cls.modules[BACKEND.GEMM] = ( + cls._build_gemm_module(qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu) + if cls.cuda_available + else None ) - cls.modules[BACKEND.MARLIN] = cls._build_marlin_module( - qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu + cls.modules[BACKEND.MARLIN] = ( + cls._build_marlin_module(qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu) + if cls.cuda_available + else None ) try: @@ -136,7 +149,8 @@ def tearDownClass(cls) -> None: for module in getattr(cls, "modules", {}).values(): if module is not None: del module - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() @classmethod def _load_weight_map(cls) -> Dict[str, str]: @@ -267,6 +281,8 @@ def _build_torch_fused_awq_module( qzeros_cpu: torch.Tensor, scales_cpu: torch.Tensor, bias_cpu: torch.Tensor, + *, + device: torch.device = CPU_DEVICE, ) -> TorchFusedAwqQuantLinear: module = TorchFusedAwqQuantLinear( bits=cls.BITS, @@ -278,12 +294,12 @@ def _build_torch_fused_awq_module( bias=True, adapter=None, register_buffers=True, - ).to(CPU_DEVICE) + ).to(device) - module.qweight.copy_(qweight_cpu.to(CPU_DEVICE)) - module.qzeros.copy_(qzeros_cpu.to(CPU_DEVICE)) - module.scales.copy_(scales_cpu.to(torch.float16).to(CPU_DEVICE)) - module.bias.copy_(bias_cpu.to(torch.float16).to(CPU_DEVICE)) + module.qweight.copy_(qweight_cpu.to(device)) + module.qzeros.copy_(qzeros_cpu.to(device)) + module.scales.copy_(scales_cpu.to(torch.float16).to(device)) + module.bias.copy_(bias_cpu.to(torch.float16).to(device)) module.eval() module.post_init() @@ -295,13 +311,15 @@ def _generate_inputs(cls) -> List[torch.Tensor]: medium_shapes = [(2, 32), (1, 48), (1, 32)] small_shapes = [(1, 32), (1, 24), (1, 16)] - try: - total_mem_gb = ( - torch.cuda.get_device_properties(cls.device).total_memory - / (1024 ** 3) - ) - except Exception: # pragma: no cover - total_mem_gb = 0.0 + total_mem_gb = 0.0 + if cls.device.type == "cuda": + try: + total_mem_gb = ( + torch.cuda.get_device_properties(cls.device).total_memory + / (1024 ** 3) + ) + except Exception: # pragma: no cover + total_mem_gb = 0.0 if os.getenv("GPTQMODEL_FAST_TESTS", "0") == "1": shapes = small_shapes @@ -457,3 +475,41 @@ def test_awq_kernel_outputs(self, backend: BACKEND, dtype: torch.dtype, atol: fl title=f"AWQ Kernel Output {dtype}", reference_label="Torch AWQ output", ) + + @parameterized.expand( + [ + ("cpu", "cpu"), + ("xpu", "xpu:0"), + ] + ) + def test_torch_fused_awq_devices(self, _label: str, device_str: str) -> None: + self._maybe_skip_backend(BACKEND.TORCH_FUSED_AWQ) + if device_str.startswith("xpu") and not _xpu_available(): + self.skipTest("Torch fused AWQ XPU test requires Intel XPU runtime.") + + device = torch.device(device_str) + module = self._build_torch_fused_awq_module( + self.qweight_cpu, + self.qzeros_cpu, + self.scales_cpu, + self.bias_cpu, + device=device, + ) + + try: + actual_outputs = self._forward( + module, + self.inputs[torch.float16], + target_device=device, + ) + self._summarize_results( + reference_outputs=self.reference_outputs[torch.float16], + actual_outputs=actual_outputs, + backend=BACKEND.TORCH_FUSED_AWQ, + dtype=torch.float16, + atol=0.004, + title=f"Torch Fused AWQ Device {device_str}", + reference_label="Torch AWQ output", + ) + finally: + del module From 1c4f5a656fcc2d1e6e426f07a041c1d5e2c8f1de Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 12:06:28 +0000 Subject: [PATCH 21/26] fix xpu transform --- gptqmodel/nn_modules/qlinear/torch_fused_awq.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index 030e1cdf9..11d128254 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -168,13 +168,12 @@ def transform_cpu_awq(self, dtype): self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) def transform_xpu_awq(self, dtype): - gptq_qweight, gptq_qzeros, scale_fp16, zeros_fp16 = self._prepare_awq_fused_tensors() + gptq_qweight, gptq_qzeros, scale_fp16, _zeros_fp16 = self._prepare_awq_fused_tensors() self.qweight = gptq_qweight self.qzeros = gptq_qzeros super().transform_xpu(dtype) device = self.qweight.device self.scales = scale_fp16.to(device=device, dtype=dtype).contiguous() - self.qzeros = zeros_fp16.to(device=device, dtype=dtype).contiguous() def transform_cpu(self, dtype): self.transform_cpu_awq(dtype) From 85b5bb21c8c639df340cfc6f5b13cf258bad6788 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 12:11:20 +0000 Subject: [PATCH 22/26] cleanup --- gptqmodel/nn_modules/qlinear/torch_fused_awq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index 11d128254..aa8dbabf4 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -168,7 +168,7 @@ def transform_cpu_awq(self, dtype): self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) def transform_xpu_awq(self, dtype): - gptq_qweight, gptq_qzeros, scale_fp16, _zeros_fp16 = self._prepare_awq_fused_tensors() + gptq_qweight, gptq_qzeros, scale_fp16, _ = self._prepare_awq_fused_tensors() self.qweight = gptq_qweight self.qzeros = gptq_qzeros super().transform_xpu(dtype) From 1b20353e39506efbeff22c11ca95c84864a7f88f Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 8 Nov 2025 12:15:38 +0000 Subject: [PATCH 23/26] cleanup2 --- .../nn_modules/qlinear/torch_fused_awq.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index aa8dbabf4..d48371474 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -98,7 +98,7 @@ def __init__( else: self.bias = None - def _prepare_awq_fused_tensors(self): + def prepare_awq_fused_tensors(self, need_zeros_fp16: bool = True): src_scales = self.scales if src_scales.dtype != torch.float16: src_scales = src_scales.to(torch.float16) @@ -114,15 +114,17 @@ def _prepare_awq_fused_tensors(self): scale_fp16 = src_scales scale_fp32 = scale_fp16.to(torch.float32) - zero_offset = 1 << (self.bits - 1) - zeros_fp16 = (zero_offset - izeros.reshape_as(scale_fp32)).to(dtype=scale_fp32.dtype) - zeros_fp16 = (zeros_fp16 * scale_fp32).to(torch.float16) - gptq_qweight = self._pack_awq_qweight(iweight) - gptq_qzeros = self._pack_awq_qzeros(izeros) - return gptq_qweight, gptq_qzeros, scale_fp16, zeros_fp16 + if need_zeros_fp16: + zero_offset = 1 << (self.bits - 1) + zeros_fp16 = (zero_offset - izeros.reshape_as(scale_fp32)).to(dtype=scale_fp32.dtype) + zeros_fp16 = (zeros_fp16 * scale_fp32).to(torch.float16) - def _pack_awq_qweight(self, iweight: torch.Tensor) -> torch.Tensor: + gptq_qweight = self.pack_awq_qweight(iweight) + gptq_qzeros = self.pack_awq_qzeros(izeros) + return gptq_qweight, gptq_qzeros, scale_fp16, zeros_fp16 if need_zeros_fp16 else None + + def pack_awq_qweight(self, iweight: torch.Tensor) -> torch.Tensor: in_features, out_features = iweight.shape pack_factor = int(self.pack_factor) if in_features % pack_factor != 0: @@ -140,7 +142,7 @@ def _pack_awq_qweight(self, iweight: torch.Tensor) -> torch.Tensor: packed |= rows[:, lane, :].to(torch.int32) << shift return packed.contiguous() - def _pack_awq_qzeros(self, izeros: torch.Tensor) -> torch.Tensor: + def pack_awq_qzeros(self, izeros: torch.Tensor) -> torch.Tensor: pack_factor = int(self.pack_factor) if izeros.shape[1] % pack_factor != 0: raise ValueError( @@ -158,7 +160,7 @@ def _pack_awq_qzeros(self, izeros: torch.Tensor) -> torch.Tensor: return packed.contiguous() def transform_cpu_awq(self, dtype): - gptq_qweight, gptq_qzeros, scale_fp16, zeros_fp16 = self._prepare_awq_fused_tensors() + gptq_qweight, gptq_qzeros, scale_fp16, zeros_fp16 = self.prepare_awq_fused_tensors() self.qweight = gptq_qweight self.qzeros = gptq_qzeros super().transform_cpu(dtype, do_scales_and_zeros=False) @@ -168,7 +170,7 @@ def transform_cpu_awq(self, dtype): self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) def transform_xpu_awq(self, dtype): - gptq_qweight, gptq_qzeros, scale_fp16, _ = self._prepare_awq_fused_tensors() + gptq_qweight, gptq_qzeros, scale_fp16, _ = self.prepare_awq_fused_tensors(need_zeros_fp16=False) self.qweight = gptq_qweight self.qzeros = gptq_qzeros super().transform_xpu(dtype) From 52243d69c59d2314a7a4b7899c2e575a3aab23a6 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 9 Nov 2025 00:57:01 +0000 Subject: [PATCH 24/26] cleanup 3 plus test logs --- docs/torch_fused_int4_transformations.md | 7 +++ .../nn_modules/qlinear/torch_fused_awq.py | 52 +++++++++---------- tests/test_kernel_output_awq.py | 5 ++ tests/test_torch_fused_awq.py | 39 +++++++++++--- 4 files changed, 70 insertions(+), 33 deletions(-) diff --git a/docs/torch_fused_int4_transformations.md b/docs/torch_fused_int4_transformations.md index 5d82f19d2..13636ff5d 100644 --- a/docs/torch_fused_int4_transformations.md +++ b/docs/torch_fused_int4_transformations.md @@ -1,3 +1,10 @@ +``` +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium +``` + # Torch Fused INT4 Transformations This note explains what `TorchFusedQuantLinear.transform_xpu` and `transform_cpu` diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index d48371474..aaa075fd2 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -26,6 +26,8 @@ class TorchFusedAwqQuantLinear(TorchFusedQuantLinear): """Torch fused AWQ variant based on GPTQ fused kernels via CPU int4 packing.""" QUANT_TYPE = "torch_fused_awq" + + # inherit from torch fused SUPPORTS_BITS = TorchFusedQuantLinear.SUPPORTS_BITS SUPPORTS_GROUP_SIZE = TorchFusedQuantLinear.SUPPORTS_GROUP_SIZE SUPPORTS_DESC_ACT = TorchFusedQuantLinear.SUPPORTS_DESC_ACT @@ -39,9 +41,10 @@ class TorchFusedAwqQuantLinear(TorchFusedQuantLinear): SUPPORTS_PLATFORM = TorchFusedQuantLinear.SUPPORTS_PLATFORM SUPPORTS_PACK_DTYPES = TorchFusedQuantLinear.SUPPORTS_PACK_DTYPES SUPPORTS_ADAPTERS = TorchFusedQuantLinear.SUPPORTS_ADAPTERS + REQUIRES_FORMAT_V2 = TorchFusedQuantLinear.REQUIRES_FORMAT_V2 + # AWQ kernels are only accuracy validate for float16 for now SUPPORTS_DTYPES = [torch.float16] - REQUIRES_FORMAT_V2 = False def __init__( self, @@ -68,9 +71,12 @@ def __init__( bias=bias, pack_dtype=pack_dtype, adapter=adapter, + # Skip base buffer init, we need to manually init buffers for awq register_buffers=False, **kwargs, ) + + # Create awq buffers if register_buffers: # AWQ packs each input row into pack_factor-wide columns for int4 lanes. pack_cols = max(1, self.out_features // self.pack_factor) @@ -83,26 +89,26 @@ def __init__( "qweight", torch.zeros(qweight_shape, dtype=self.pack_dtype), ) + self.register_buffer( "qzeros", torch.zeros((group_rows, pack_cols), dtype=self.pack_dtype), ) + self.register_buffer( "scales", torch.zeros((group_rows, self.out_features), dtype=torch.float16), ) - g_idx = torch.arange(self.in_features, dtype=torch.int32) // group_size - self.register_buffer("g_idx", g_idx) + + self.register_buffer("g_idx", torch.arange(self.in_features, dtype=torch.int32) // group_size) + if bias: self.register_buffer("bias", torch.zeros(self.out_features, dtype=torch.float16)) else: self.bias = None - def prepare_awq_fused_tensors(self, need_zeros_fp16: bool = True): - src_scales = self.scales - if src_scales.dtype != torch.float16: - src_scales = src_scales.to(torch.float16) - src_scales = src_scales.contiguous() + def prepare_awq_fused_tensors(self, need_zeros: bool = True): + self.scales.to(torch.float16).contiguous() iweight, izeros = unpack_awq(self.qweight, self.qzeros, self.bits) iweight, izeros = reverse_awq_order(iweight, izeros, self.bits) @@ -112,17 +118,13 @@ def prepare_awq_fused_tensors(self, need_zeros_fp16: bool = True): raise RuntimeError("AWQ fused kernel requires zero points.") izeros = torch.bitwise_and(izeros, max_val) - scale_fp16 = src_scales - scale_fp32 = scale_fp16.to(torch.float32) - - if need_zeros_fp16: + if need_zeros: zero_offset = 1 << (self.bits - 1) - zeros_fp16 = (zero_offset - izeros.reshape_as(scale_fp32)).to(dtype=scale_fp32.dtype) - zeros_fp16 = (zeros_fp16 * scale_fp32).to(torch.float16) + zeros = (zero_offset - izeros.reshape_as(self.scales)) * self.scales gptq_qweight = self.pack_awq_qweight(iweight) gptq_qzeros = self.pack_awq_qzeros(izeros) - return gptq_qweight, gptq_qzeros, scale_fp16, zeros_fp16 if need_zeros_fp16 else None + return gptq_qweight, gptq_qzeros, self.scales, zeros if need_zeros else None def pack_awq_qweight(self, iweight: torch.Tensor) -> torch.Tensor: in_features, out_features = iweight.shape @@ -160,22 +162,20 @@ def pack_awq_qzeros(self, izeros: torch.Tensor) -> torch.Tensor: return packed.contiguous() def transform_cpu_awq(self, dtype): - gptq_qweight, gptq_qzeros, scale_fp16, zeros_fp16 = self.prepare_awq_fused_tensors() - self.qweight = gptq_qweight - self.qzeros = gptq_qzeros + self.qweight, self.qzeros, scales, zeros = self.prepare_awq_fused_tensors() + super().transform_cpu(dtype, do_scales_and_zeros=False) - device = self.qweight.device - self.scales = scale_fp16.to(device=device, dtype=dtype).contiguous() - self.qzeros = zeros_fp16.to(device=device, dtype=dtype).contiguous() + + self.scales = scales.to(device=self.qweight.device, dtype=dtype).contiguous() + self.qzeros = zeros.to(device=self.qweight.device, dtype=dtype).contiguous() self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) def transform_xpu_awq(self, dtype): - gptq_qweight, gptq_qzeros, scale_fp16, _ = self.prepare_awq_fused_tensors(need_zeros_fp16=False) - self.qweight = gptq_qweight - self.qzeros = gptq_qzeros + self.qweight, self.qzeros, scales, _ = self.prepare_awq_fused_tensors(need_zeros=False) + super().transform_xpu(dtype) - device = self.qweight.device - self.scales = scale_fp16.to(device=device, dtype=dtype).contiguous() + + self.scales = scales.to(device=self.qweight.device, dtype=dtype).contiguous() def transform_cpu(self, dtype): self.transform_cpu_awq(dtype) diff --git a/tests/test_kernel_output_awq.py b/tests/test_kernel_output_awq.py index 18ce04e2d..22186f6d4 100644 --- a/tests/test_kernel_output_awq.py +++ b/tests/test_kernel_output_awq.py @@ -393,6 +393,7 @@ def _summarize_results( atol: float, title: str, reference_label: str, + device: Optional[torch.device] = None, ) -> None: failures = [] total = len(actual_outputs) @@ -419,12 +420,14 @@ def _summarize_results( status = f"{GREEN}PASS{RESET}" if not failures else f"{RED}FAIL{RESET}" avg_abs_diff = mean_abs_diff / total if total else 0.0 details = "\n\n".join(str(detail) for detail in failures) if failures else "-" + device_label = str(device) if device is not None else "-" table = tabulate( [ [ backend.name, str(dtype), + device_label, total, f"{max_abs_diff:.6f}", f"{avg_abs_diff:.6f}", @@ -436,6 +439,7 @@ def _summarize_results( headers=[ "Backend", "DType", + "Device", "Samples", "MaxAbsDiff", "MeanAbsDiff", @@ -510,6 +514,7 @@ def test_torch_fused_awq_devices(self, _label: str, device_str: str) -> None: atol=0.004, title=f"Torch Fused AWQ Device {device_str}", reference_label="Torch AWQ output", + device=device, ) finally: del module diff --git a/tests/test_torch_fused_awq.py b/tests/test_torch_fused_awq.py index be5c869e7..35f102838 100644 --- a/tests/test_torch_fused_awq.py +++ b/tests/test_torch_fused_awq.py @@ -24,6 +24,10 @@ ) +def _xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + @lru_cache(maxsize=1) def _load_awq_checkpoint_module(): if not CHECKPOINT_DIR.exists(): @@ -87,7 +91,20 @@ def _load_awq_checkpoint_module(): @pytest.mark.skipif(not TORCH_HAS_FUSED_OPS, reason="Torch fused ops require PyTorch>=2.8") -def test_torch_fused_awq_matches_checkpoint_module(): +@pytest.mark.parametrize( + "device_str", + [ + pytest.param("cpu", id="cpu"), + pytest.param( + "xpu:0", + id="xpu", + marks=pytest.mark.skipif( + not _xpu_available(), reason="Torch fused AWQ XPU test requires Intel XPU runtime." + ), + ), + ], +) +def test_torch_fused_awq_matches_checkpoint_module(device_str: str): module_data = _load_awq_checkpoint_module() bits = module_data["bits"] group_size = module_data["group_size"] @@ -98,6 +115,8 @@ def test_torch_fused_awq_matches_checkpoint_module(): scales = module_data["scales"] bias = module_data["bias"] + device = torch.device(device_str) + awq_module = AwqTorchQuantLinear( bits=bits, group_size=group_size, @@ -135,26 +154,32 @@ def test_torch_fused_awq_matches_checkpoint_module(): fused_module.post_init() fused_module.eval() + awq_module.to(device) + fused_module.to(device) + dtype = torch.float16 batch = 4 - x = torch.randn(batch, in_features, dtype=dtype) + x = torch.randn(batch, in_features, dtype=dtype, device=device) baseline = awq_module(x) fused_out = fused_module(x) - tol = 5e-3 + rtol = 5e-3 + atol = 5e-3 abs_diff = (fused_out - baseline).abs() rel_diff = abs_diff / baseline.abs().clamp_min(1e-6) summary = tabulate( [ [ + device_str, str(dtype), - f"{tol:.4g}", - f"{tol:.4g}", + f"{rtol:.4g}", + f"{atol:.4g}", f"{abs_diff.max().item():.4e}", f"{rel_diff.max().item():.4e}", ] ], - headers=["dtype", "rtol", "atol", "abs_max", "rel_max"], + headers=["Device", "DType", "RTol", "ATol", "AbsMaxDiff", "RelMaxDiff"], + tablefmt="github", ) print(summary) - torch.testing.assert_close(fused_out, baseline, rtol=tol, atol=tol) + torch.testing.assert_close(fused_out, baseline, rtol=rtol, atol=atol) From 0736ed4ab17baba46f4e5985d3356ac8c2eefeec Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 9 Nov 2025 01:24:31 +0000 Subject: [PATCH 25/26] tabulate logs --- tests/test_kernel_output_torch_fused.py | 50 ++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 6 deletions(-) diff --git a/tests/test_kernel_output_torch_fused.py b/tests/test_kernel_output_torch_fused.py index 3a7c858be..aa1ddff02 100644 --- a/tests/test_kernel_output_torch_fused.py +++ b/tests/test_kernel_output_torch_fused.py @@ -9,6 +9,7 @@ from logbar import LogBar from parameterized import parameterized from torch import Tensor +from tabulate import tabulate from gptqmodel import BACKEND, GPTQModel from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear @@ -157,13 +158,50 @@ def test_torch_fused_matches_cpu_reference(self, _name: str, device: str): device=device, dtype=self.dtype, ) + failures = [] for idx, sample in enumerate(self.inputs): model_input = sample.to(model.device) fused_out = self.forward(model, model_input, BACKEND.TORCH_FUSED) reference = self.reference_outputs[idx] - self.assert_on_mismatch( - reference.to("cpu"), - fused_out.to("cpu"), - self.r_tolerance, - self.a_tolerance, - ) + try: + self.assert_on_mismatch( + reference.to("cpu"), + fused_out.to("cpu"), + self.r_tolerance, + self.a_tolerance, + ) + except AssertionError as exc: + failures.append(f"Sample {idx}: {str(exc).splitlines()[0]}") + + status = "PASS" if not failures else "FAIL" + table = tabulate( + [ + [ + BACKEND.TORCH_FUSED.name, + str(self.dtype), + device, + len(self.inputs), + f"{self.r_tolerance:.2e}", + f"{self.a_tolerance:.2e}", + status, + len(failures), + "\n\n".join(failures) if failures else "-", + ] + ], + headers=[ + "Backend", + "DType", + "Device", + "Samples", + "RTol", + "ATol", + "Status", + "Failures", + "Details", + ], + tablefmt="github", + ) + log.info("\nTorch Fused vs CPU Reference\n" + table) + + if failures: + raise AssertionError(f"{len(failures)} mismatched samples on device {device}") From e9d1e6e7768612b87d2895500b160349069684b7 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 9 Nov 2025 01:33:55 +0000 Subject: [PATCH 26/26] prepare for v5.4.0 release --- README.md | 5 +++-- gptqmodel/version.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index d6e8cce3a..6453ac01c 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@

## Latest News +* 11/9/2025 [5.4.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.4.0): ✨New Intel CPU and XPU hw optimized AWQ `TorchFusedAWQ` kernel. Torch Fused kernels now compatible with `torch.compile`. Fixed AWQ MoE model compatibility and reduced vram usage. * 11/3/2025 [5.2.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.2.0): 🎉Minimax M2 support with [ModelCloud BF16 M2 Model](https://huggingface.co/ModelCloud/MiniMax-M2-BF16). New `VramStrategy.Balanced` quantization property for reduced memory usage for large MoE on multi-3090 (24GB) devices. ✨Marin model. New AWQ Torch reference kernel. Fix AWQ Marlin kernel for bf16. Fix GLM 4.5/4.6 MoE missing `mtp` layers on model save (HF bug). Modular refractor. 🎉AWQ support out of beta with full feature support in including multi-gpu quant and MoE vram saving. ✨Brumby (attention free) model support. ✨Brumby (attention free) model support. ✨IBM Granite Nano support. New `calibration_concat_separator` config option. * 10/24/2025 [5.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.0.0): 🎉 Data-parallel quant support for `MoE` models on multi-gpu using `nogil` Python. `offload_to_disk` support enabled by default to massively reduce `cpu` ram usage. New `Intel` and `AMD` cpu hw accelerated `TorchFused` kernel. Packing stage is now 4x faster and now inlined with quantization. `Vram` pressure for large models reduced during quantization. @@ -202,8 +203,8 @@ GPT-QModel is validated for Linux, MacOS, and Windows 11: |-----------------|---------------| --- | -------------- |-----------------------------------------------| | 🐧 Linux | Nvidia GPU | ✅ | `Ampere+` | Marlin, Exllama V2, Exallma V1, Triton, Torch | | 🐧 Linux | AMD GPU | ✅ | `7900XT+`, `ROCm 6.2+` | Exllama V2, Exallma V1, Torch | -| 🐧 Linux | Intel XPU | ✅ | `Arc`, `Datacenter Max` | Torch Fused (Python 2.8+), Torch | -| 🐧 Linux | Intel/AMD CPU | ✅ | `avx`, `amx`, `xmx` | Torch Fused (Python 2.8+), Torch | +| 🐧 Linux | Intel XPU | ✅ | `Arc`, `Datacenter Max` | TorchFused, TorchFusedAWQ, Torch | +| 🐧 Linux | Intel/AMD CPU | ✅ | `avx`, `amx`, `xmx` | TorchFused, TorchFusedAWQ, Torch | | 🍎 MacOS | GPU (Metal) / CPU | ✅ | `Apple Silicon`, `M1+` | Torch, MLX via conversion | | 🪟 Windows | GPU (Nvidia) / CPU | ✅ | `Nvidia` | Torch | diff --git a/gptqmodel/version.py b/gptqmodel/version.py index b62cbd7ea..1051a72aa 100644 --- a/gptqmodel/version.py +++ b/gptqmodel/version.py @@ -7,4 +7,4 @@ # even minor versions are release # 5.2.0 => release, 5.1.0 => devel # micro version (5.2.x) denotes patch fix, i.e. 5.2.1 is a patch fix release -__version__ = "5.3.0" +__version__ = "5.4.0"