From 8bd7bf05b28d2c82f07e8fcb60a98b905b50a00e Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 10 Oct 2025 12:24:41 +0000 Subject: [PATCH 1/3] enable cpu torch fused op Signed-off-by: jiqing-feng --- gptqmodel/nn_modules/qlinear/torch_fused.py | 97 ++++++++++++++------- gptqmodel/utils/torch.py | 2 +- 2 files changed, 65 insertions(+), 34 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused.py b/gptqmodel/nn_modules/qlinear/torch_fused.py index b2932fcfd..ba8b6b440 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused.py @@ -13,27 +13,27 @@ from ...nn_modules.qlinear import BaseQuantLinear, PackableQuantLinear from ...utils.backend import BACKEND from ...utils.logger import setup_logger -from ...utils.torch import TORCH_HAS_XPU_FUSED_OPS +from ...utils.torch import TORCH_HAS_FUSED_OPS log = setup_logger() -# TODO: not yet working for cuda/cpu fused int4 ops -# def pack_scales_and_zeros(scales, zeros): -# assert scales.shape == zeros.shape -# # assert scales.dtype == torch.bfloat16 -# # assert zeros.dtype == torch.bfloat16 -# return ( -# torch.cat( -# [ -# scales.reshape(scales.size(0), scales.size(1), 1), -# zeros.reshape(zeros.size(0), zeros.size(1), 1), -# ], -# 2, -# ) -# .transpose(0, 1) -# .contiguous() -# ) +# TODO: CPU works, not yet working for cuda fused int4 ops +def pack_scales_and_zeros(scales, zeros): + assert scales.shape == zeros.shape + # assert scales.dtype == torch.bfloat16 + # assert zeros.dtype == torch.bfloat16 + return ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + .transpose(0, 1) + .contiguous() + ) class TorchFusedQuantLinear(PackableQuantLinear): SUPPORTS_BITS = [4] @@ -45,9 +45,7 @@ class TorchFusedQuantLinear(PackableQuantLinear): SUPPORTS_AUTO_PADDING = True SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1] SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] - - # optimized for XPU but should run on all - SUPPORTS_DEVICES = [DEVICE.CPU, DEVICE.XPU] # change this to XPU to limit to Intel XPU + SUPPORTS_DEVICES = [DEVICE.CPU, DEVICE.XPU] SUPPORTS_PLATFORM = [PLATFORM.ALL] SUPPORTS_PACK_DTYPES = [torch.int32] SUPPORTS_ADAPTERS = [Lora] @@ -129,7 +127,7 @@ def train(self, mode: bool = True): return super().train(mode=mode) - def transform(self, dtype): + def transform_xpu(self, dtype): self.scales = self.scales.clone().to(dtype).contiguous() # Unpack qzeros zeros = torch.bitwise_right_shift( @@ -166,6 +164,39 @@ def transform(self, dtype): self.qweight = packed.contiguous() self.qzeros = zeros.contiguous() + def transform_cpu(self, dtype): + self.scales = self.scales.clone().to(dtype).contiguous() + # Unpack and reorder qweight + weight = torch.bitwise_and( + torch.bitwise_right_shift( + torch.unsqueeze(self.qweight, 1).expand(-1, self.pack_factor, -1), + self.wf_unsqueeze_neg_one # self.wf.unsqueeze(-1) + ).to(self.dequant_dtype), + self.maxq + ) + self.ret_idx = torch.zeros(self.g_idx.shape[0], dtype=torch.int32).to(self.g_idx.device) + groups = self.g_idx.shape[0] // self.group_size + remainder = self.g_idx.shape[0] % self.group_size + g_idx_2 = self.g_idx * self.group_size + if remainder > 0: + g_idx_2[self.g_idx == groups] += torch.arange(remainder).to(self.g_idx_2.device).to(self.g_idx_2.dtype) + arange_tensor = torch.arange(self.group_size).to(self.g_idx.device).to(self.g_idx.dtype) + for i in range(groups): + g_idx_2[self.g_idx == i] += arange_tensor + self.ret_idx[g_idx_2] = torch.arange(self.g_idx.shape[0]).to(self.ret_idx.device).to(self.ret_idx.dtype) + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]).index_select(0, self.ret_idx).t() + self.qweight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(weight.int(), 1).contiguous() + self.qzeros = torch.zeros_like(self.scales).contiguous() + self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) + + def transform(self, dtype, device): + if device == "xpu": + self.transform_xpu(dtype) + elif: device == "cpu": + self.transform_cpu(dtype) + else: + raise NotImplementedError + def forward(self, x: torch.Tensor): out_shape = x.shape[:-1] + (self.out_features,) x = x.reshape(-1, x.shape[-1]) @@ -175,25 +206,25 @@ def forward(self, x: torch.Tensor): def _forward(self, x, out_shape): num_itr = self.g_idx.shape[0] // x.shape[-1] - if not self.training and not self.transformed and TORCH_HAS_XPU_FUSED_OPS and "xpu" == x.device.type: + if not self.training and not self.transformed and TORCH_HAS_FUSED_OPS: # one-time transform per module for xpu aten fused ops - self.transform(x.dtype) + self.transform(x.dtype, x.device.type) self.transformed = True if self.transformed: x = x[:, self.ret_idx].contiguous() # fused ops optimized for xpu using torch.ops # note _weight_int4pack_mm_with_scales_and_zeros is added by intel for xpu only - out = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( - x, self.qweight, self.group_size, self.scales, self.qzeros - ).reshape(out_shape) - - # TODO: torch.ops _weight_int4pack_mm has fused aten op for int4 matmul but we need to transform and align format - # scales + zeros and pass as one tensor - # scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) - # out = torch.ops.aten._weight_int4pack_mm( - # x, self.qweight, self.group_size, scales_and_zeros - # ).reshape(out_shape) + if x.device.type == "xpu": + out = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( + x, self.qweight, self.group_size, self.scales, self.qzeros + ).reshape(out_shape) + elif x.device.type == "cpu": + out = torch.ops.aten._weight_int4pack_mm_for_cpu( + x, self.qweight, self.group_size, self.scales_and_zeros + ).reshape(out_shape) + else: + raise NotImplementedError else: # make sure dequant dtype matches input x weights = self.dequantize_weight(num_itr=num_itr).to(x.dtype) diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py index 6a931c82a..d8c323c93 100644 --- a/gptqmodel/utils/torch.py +++ b/gptqmodel/utils/torch.py @@ -21,7 +21,7 @@ # pytorch 2.6.0 fixes many compilation errors TORCH_HAS_COMPILE = version.parse(torch.__version__).release >= version.Version('2.6').release TORCH_GTE_28 = version.parse(torch.__version__).release >= version.Version('2.8').release -TORCH_HAS_XPU_FUSED_OPS = version.parse(torch.__version__).release >= version.Version('2.8').release +TORCH_HAS_FUSED_OPS = version.parse(torch.__version__).release >= version.Version('2.8').release HAS_CUDA = False HAS_XPU = False From b9187c9fb56b7039faa71bcaef1c757586c5d9a9 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 10 Oct 2025 12:27:26 +0000 Subject: [PATCH 2/3] fix format Signed-off-by: jiqing-feng --- gptqmodel/nn_modules/qlinear/torch_fused.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused.py b/gptqmodel/nn_modules/qlinear/torch_fused.py index ba8b6b440..e89491e88 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused.py @@ -192,7 +192,7 @@ def transform_cpu(self, dtype): def transform(self, dtype, device): if device == "xpu": self.transform_xpu(dtype) - elif: device == "cpu": + elif device == "cpu": self.transform_cpu(dtype) else: raise NotImplementedError From 69b64366e1c8c4fbddef0a1306d160a8a48d3a8c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 10 Oct 2025 12:29:07 +0000 Subject: [PATCH 3/3] fix pack scales and zeros Signed-off-by: jiqing-feng --- gptqmodel/nn_modules/qlinear/torch_fused.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused.py b/gptqmodel/nn_modules/qlinear/torch_fused.py index e89491e88..f36e47e07 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused.py @@ -31,7 +31,6 @@ def pack_scales_and_zeros(scales, zeros): ], 2, ) - .transpose(0, 1) .contiguous() )