From add3bd630a040e1fe3c769e5191e09f0287b5932 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Tue, 26 Aug 2025 11:58:09 +0000 Subject: [PATCH 1/5] add _convert_weight_to_int4pack() Signed-off-by: ZX-ModelCloud --- gptqmodel/nn_modules/qlinear/torch_fused.py | 114 +++++++++++++------- 1 file changed, 76 insertions(+), 38 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused.py b/gptqmodel/nn_modules/qlinear/torch_fused.py index c2a9738a6..46d819214 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused.py @@ -29,21 +29,46 @@ log = setup_logger() # TODO: not yet working for cuda/cpu fused int4 ops -# def pack_scales_and_zeros(scales, zeros): -# assert scales.shape == zeros.shape -# # assert scales.dtype == torch.bfloat16 -# # assert zeros.dtype == torch.bfloat16 -# return ( -# torch.cat( -# [ -# scales.reshape(scales.size(0), scales.size(1), 1), -# zeros.reshape(zeros.size(0), zeros.size(1), 1), -# ], -# 2, -# ) -# .transpose(0, 1) -# .contiguous() -# ) +def pack_scales_and_zeros(scales, zeros): + print("scales", scales.shape, zeros.shape) + # assert scales.shape == zeros.shape + # assert scales.dtype == torch.bfloat16 + # assert zeros.dtype == torch.bfloat16 + return ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + .transpose(0, 1) + .contiguous() + ) + +def gptq_int32_to_uint8(qweight: torch.Tensor) -> torch.Tensor: + """ + Convert GPTQ qweight (int32, each element packs 8 int4 values) + into (uint8, each element packs 2 int4 values). + + Input: [n, k_int32] int32 + Output: [n, k_int32 * 4] uint8 # since each int32 becomes 4 uint8 + """ + assert qweight.dtype == torch.int32 + + # Unpack into 8 int4 values + q_unpack = torch.stack([ + (qweight >> (4 * i)) & 0xF for i in range(8) + ], dim=-1) # shape: [n, k_int32, 8] + + # Repack into uint8 (each uint8 holds two int4 values) + q_even = q_unpack[..., 0::2] # [n, k_int32, 4] + q_odd = q_unpack[..., 1::2] # [n, k_int32, 4] + q_uint8 = (q_even | (q_odd << 4)).to(torch.uint8) + + # Reshape to [n, k_uint8], where k_uint8 = k_int32 * 4 + q_uint8 = q_uint8.reshape(qweight.shape[0], -1) + return q_uint8 class TorchFusedQuantLinear(PackableQuantLinear): SUPPORTS_BITS = [4] @@ -57,7 +82,7 @@ class TorchFusedQuantLinear(PackableQuantLinear): SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] # optimized for XPU but should run on all - SUPPORTS_DEVICES = [DEVICE.XPU] # change this to XPU to limit to Intel XPU + SUPPORTS_DEVICES = [DEVICE.XPU, DEVICE.CUDA] # change this to XPU to limit to Intel XPU SUPPORTS_PLATFORM = [PLATFORM.ALL] SUPPORTS_PACK_DTYPES = [torch.int32] SUPPORTS_ADAPTERS = [Lora] @@ -146,13 +171,13 @@ def transform(self, dtype): ).to(self.dequant_dtype) zeros = torch.bitwise_and(zeros, self.maxq).reshape(zeros.shape[0], -1) # Unpack and reorder qweight - weight = torch.bitwise_and( - torch.bitwise_right_shift( - torch.unsqueeze(self.qweight, 1).expand(-1, self.pack_factor, -1), - self.wf_unsqueeze_neg_one # self.wf.unsqueeze(-1) - ).to(self.dequant_dtype), - self.maxq - ) + # weight = torch.bitwise_and( + # torch.bitwise_right_shift( + # torch.unsqueeze(self.qweight, 1).expand(-1, self.pack_factor, -1), + # self.wf_unsqueeze_neg_one # self.wf.unsqueeze(-1) + # ).to(self.dequant_dtype), + # self.maxq + # ) self.ret_idx = torch.zeros(self.g_idx.shape[0], dtype=torch.int32).to(self.g_idx.device) groups = self.g_idx.shape[0] // self.group_size remainder = self.g_idx.shape[0] % self.group_size @@ -163,15 +188,15 @@ def transform(self, dtype): for i in range(groups): g_idx_2[self.g_idx == i] += arange_tensor self.ret_idx[g_idx_2] = torch.arange(self.g_idx.shape[0]).to(self.ret_idx.device).to(self.ret_idx.dtype) - weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]).index_select(0, self.ret_idx).t() + # weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]).index_select(0, self.ret_idx).t() # Pack qweight - packed = torch.zeros(weight.shape[0], weight.shape[1] // self.pack_factor, dtype=torch.int32, device=weight.device) - for col in range(weight.shape[1] // self.pack_factor): - for i in range(self.pack_factor): - packed_col = weight[:, col * self.pack_factor + i].to(torch.int32) - packed[:, col] |= packed_col << (i * self.bits) + # packed = torch.zeros(weight.shape[0], weight.shape[1] // self.pack_factor, dtype=torch.int32, device=weight.device) + # for col in range(weight.shape[1] // self.pack_factor): + # for i in range(self.pack_factor): + # packed_col = weight[:, col * self.pack_factor + i].to(torch.int32) + # packed[:, col] |= packed_col << (i * self.bits) - self.qweight = packed.contiguous() + # self.qweight = packed.contiguous() self.qzeros = zeros.contiguous() def forward(self, x: torch.Tensor): @@ -185,23 +210,36 @@ def _forward(self, x, out_shape): if not self.training and not self.transformed and TORCH_HAS_XPU_FUSED_OPS: # one-time transform per module for xpu aten fused ops + print("ssss 1", self.qweight.shape, self.scales.shape, self.qzeros.shape) self.transform(x.dtype) + print("ssss 2", self.qweight.shape, self.scales.shape, self.qzeros.shape) + # raise Exception("Test") self.transformed = True if self.transformed: - x = x[:, self.ret_idx].contiguous() + # x = x[:, self.ret_idx].contiguous() # fused ops optimized for xpu using torch.ops # note _weight_int4pack_mm_with_scales_and_zeros is added by intel for xpu only - out = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( - x, self.qweight, self.group_size, self.scales, self.qzeros - ).reshape(out_shape) + # out = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( + # x, self.qweight, self.group_size, self.scales, self.qzeros + # ).reshape(out_shape) # TODO: torch.ops _weight_int4pack_mm has fused aten op for int4 matmul but we need to transform and align format # scales + zeros and pass as one tensor - # scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) - # out = torch.ops.aten._weight_int4pack_mm( - # x, self.qweight, self.group_size, scales_and_zeros - # ).reshape(out_shape) + scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) + q_uint8 = gptq_int32_to_uint8(self.qweight) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + q_uint8, 8 + ) + print("q_uint8", self.qweight.shape, q_uint8.shape, weight_int4pack.shape) + B_innerKTiles = weight_int4pack.size(3) * 2 + kKTileSize = 16 + k = x.size(1) + print("weight_int4pack",weight_int4pack.shape, weight_int4pack.size(1), k / (B_innerKTiles * kKTileSize)) + print("B_innerKTiles",k , B_innerKTiles, kKTileSize) + out = torch.ops.aten._weight_int4pack_mm( + x.to(torch.bfloat16), weight_int4pack, self.group_size, scales_and_zeros + ).reshape(out_shape) else: # make sure dequant dtype matches input x weights = self.dequantize_weight(num_itr=num_itr).to(x.dtype) From d8d7a0476e94da53c80f6e454b994919674affa1 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Wed, 27 Aug 2025 04:29:01 +0000 Subject: [PATCH 2/5] fix qweight int32 to uint8 Signed-off-by: ZX-ModelCloud --- gptqmodel/nn_modules/qlinear/torch_fused.py | 38 ++++++--------------- 1 file changed, 11 insertions(+), 27 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused.py b/gptqmodel/nn_modules/qlinear/torch_fused.py index 46d819214..6ec2de186 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused.py @@ -46,30 +46,6 @@ def pack_scales_and_zeros(scales, zeros): .contiguous() ) -def gptq_int32_to_uint8(qweight: torch.Tensor) -> torch.Tensor: - """ - Convert GPTQ qweight (int32, each element packs 8 int4 values) - into (uint8, each element packs 2 int4 values). - - Input: [n, k_int32] int32 - Output: [n, k_int32 * 4] uint8 # since each int32 becomes 4 uint8 - """ - assert qweight.dtype == torch.int32 - - # Unpack into 8 int4 values - q_unpack = torch.stack([ - (qweight >> (4 * i)) & 0xF for i in range(8) - ], dim=-1) # shape: [n, k_int32, 8] - - # Repack into uint8 (each uint8 holds two int4 values) - q_even = q_unpack[..., 0::2] # [n, k_int32, 4] - q_odd = q_unpack[..., 1::2] # [n, k_int32, 4] - q_uint8 = (q_even | (q_odd << 4)).to(torch.uint8) - - # Reshape to [n, k_uint8], where k_uint8 = k_int32 * 4 - q_uint8 = q_uint8.reshape(qweight.shape[0], -1) - return q_uint8 - class TorchFusedQuantLinear(PackableQuantLinear): SUPPORTS_BITS = [4] SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] @@ -227,16 +203,24 @@ def _forward(self, x, out_shape): # TODO: torch.ops _weight_int4pack_mm has fused aten op for int4 matmul but we need to transform and align format # scales + zeros and pass as one tensor scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) - q_uint8 = gptq_int32_to_uint8(self.qweight) + + inner_ktiles = 2 + # convert to uint8 + # see https://github.com/huggingface/optimum-quanto/blob/main/optimum/quanto/tensor/weights/tinygemm/packed.py#L61 + t = self.qweight + q_uint8 = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) + + print("qweight", self.qweight.shape, self.qweight.dtype) weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - q_uint8, 8 + q_uint8, inner_ktiles ) print("q_uint8", self.qweight.shape, q_uint8.shape, weight_int4pack.shape) B_innerKTiles = weight_int4pack.size(3) * 2 kKTileSize = 16 k = x.size(1) - print("weight_int4pack",weight_int4pack.shape, weight_int4pack.size(1), k / (B_innerKTiles * kKTileSize)) + print("weight_int4pack",weight_int4pack.shape,weight_int4pack.dtype, weight_int4pack.size(1), k / (B_innerKTiles * kKTileSize)) print("B_innerKTiles",k , B_innerKTiles, kKTileSize) + print("x",x.shape) out = torch.ops.aten._weight_int4pack_mm( x.to(torch.bfloat16), weight_int4pack, self.group_size, scales_and_zeros ).reshape(out_shape) From 6074bc7d356211ae34a3b1e42810e1e2da1b169e Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Wed, 27 Aug 2025 08:34:36 +0000 Subject: [PATCH 3/5] add compress_scales() Signed-off-by: ZX-ModelCloud --- gptqmodel/nn_modules/qlinear/torch_fused.py | 65 +++++++++++++++++++-- 1 file changed, 60 insertions(+), 5 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused.py b/gptqmodel/nn_modules/qlinear/torch_fused.py index 6ec2de186..f3b9aa882 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused.py @@ -28,6 +28,28 @@ log = setup_logger() + +def compress_scales(scales: torch.Tensor, pack_dtype_bits: int, bits: int) -> torch.Tensor: + """ + Compress per-channel scales to match qzeros shape. + + Args: + scales: [num_groups, out_features] + pack_dtype_bits: int, e.g., 32 + bits: int, e.g., 4 + + Returns: + compressed_scales: [num_groups, out_features // (pack_dtype_bits // bits)] + """ + num_groups, out_features = scales.shape + pack_ratio = pack_dtype_bits // bits + assert out_features % pack_ratio == 0, "out_features must be divisible by pack ratio" + + # reshape [num_groups, out_features // pack_ratio, pack_ratio] + reshaped = scales.reshape(num_groups, out_features // pack_ratio, pack_ratio) + compressed = reshaped[..., 0] # take first in each pack + return compressed + # TODO: not yet working for cuda/cpu fused int4 ops def pack_scales_and_zeros(scales, zeros): print("scales", scales.shape, zeros.shape) @@ -46,6 +68,22 @@ def pack_scales_and_zeros(scales, zeros): .contiguous() ) +def pack_tinygemm_scales_and_zeros(scales, zeros, dtype=torch.bfloat16): + # _guard_dtype_size(scales, "scales", dtype=dtype, size=zeros.size()) + # _guard_dtype_size(zeros, "zeros", dtype=dtype) + dim = scales.dim() + return ( + torch.cat( + [ + scales.unsqueeze(-1), + zeros.unsqueeze(-1), + ], + dim, + ) + .transpose(-3, -2) + .contiguous() + ) + class TorchFusedQuantLinear(PackableQuantLinear): SUPPORTS_BITS = [4] SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] @@ -187,7 +225,7 @@ def _forward(self, x, out_shape): if not self.training and not self.transformed and TORCH_HAS_XPU_FUSED_OPS: # one-time transform per module for xpu aten fused ops print("ssss 1", self.qweight.shape, self.scales.shape, self.qzeros.shape) - self.transform(x.dtype) + # self.transform(x.dtype) print("ssss 2", self.qweight.shape, self.scales.shape, self.qzeros.shape) # raise Exception("Test") self.transformed = True @@ -202,7 +240,10 @@ def _forward(self, x, out_shape): # TODO: torch.ops _weight_int4pack_mm has fused aten op for int4 matmul but we need to transform and align format # scales + zeros and pass as one tensor - scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) + scales_compressed = compress_scales(self.scales, self.pack_dtype_bits, self.bits) + # scales_and_zeros = pack_scales_and_zeros(scales_compressed, self.qzeros) + scales_and_zeros = pack_tinygemm_scales_and_zeros(scales_compressed.transpose(0, 1).contiguous(), self.qzeros.transpose(0, 1).contiguous()) + inner_ktiles = 2 # convert to uint8 @@ -221,9 +262,23 @@ def _forward(self, x, out_shape): print("weight_int4pack",weight_int4pack.shape,weight_int4pack.dtype, weight_int4pack.size(1), k / (B_innerKTiles * kKTileSize)) print("B_innerKTiles",k , B_innerKTiles, kKTileSize) print("x",x.shape) - out = torch.ops.aten._weight_int4pack_mm( - x.to(torch.bfloat16), weight_int4pack, self.group_size, scales_and_zeros - ).reshape(out_shape) + + kMTileSize = 16 + m = x.size(0) + # like c++ divUp + def div_up(a: int, b: int) -> int: + return (a + b - 1) // b + mTiles = div_up(m, kMTileSize) + kNTileSize = 8 + kNTileSizeTensor = 8 + nTileScaleFactor = (kNTileSize / kNTileSizeTensor) + nTiles = (weight_int4pack.size(0) / nTileScaleFactor) + n = nTiles * kNTileSize + print("qScaleAndZeros", scales_and_zeros.shape, n) + print(x.shape, weight_int4pack.shape, self.group_size, scales_and_zeros.shape) + out = torch.ops.aten._weight_int4pack_mm(x.to(torch.bfloat16), weight_int4pack, self.group_size, scales_and_zeros) + print("out", out.shape) + out = out.reshape(out_shape) else: # make sure dequant dtype matches input x weights = self.dequantize_weight(num_itr=num_itr).to(x.dtype) From e4e0bcdf7209529c1a771c7d133e117f6503b75f Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Thu, 28 Aug 2025 05:53:24 +0000 Subject: [PATCH 4/5] fix _weight_int4pack_mm output shape Signed-off-by: ZX-ModelCloud --- gptqmodel/nn_modules/qlinear/torch_fused.py | 108 +++++--------------- 1 file changed, 23 insertions(+), 85 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused.py b/gptqmodel/nn_modules/qlinear/torch_fused.py index f3b9aa882..31bbda41a 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused.py @@ -28,61 +28,27 @@ log = setup_logger() - -def compress_scales(scales: torch.Tensor, pack_dtype_bits: int, bits: int) -> torch.Tensor: +def gptq_qweight_to_uint8(qweight_int32: torch.Tensor, in_features: int) -> torch.Tensor: + """ + Convert GPTQ qweight (int32, [in//8, out]) into uint8 [out, in]. + Each int32 stores 8x 4-bit weights (nibbles). + Each uint8 will store 2x 4-bit weights: low nibble + high nibble. """ - Compress per-channel scales to match qzeros shape. + packed_k, out_features = qweight_int32.shape + assert packed_k * 8 == in_features, "qweight shape mismatch with in_features" - Args: - scales: [num_groups, out_features] - pack_dtype_bits: int, e.g., 32 - bits: int, e.g., 4 + # Unpack int32 -> [out, in] + qweight_int32 = qweight_int32.permute(1, 0).contiguous() # [out, in//8] + qweight_int32 = qweight_int32.unsqueeze(-1) # [out, in//8, 1] + unpacked = torch.empty((out_features, in_features), dtype=torch.uint8, device=qweight_int32.device) - Returns: - compressed_scales: [num_groups, out_features // (pack_dtype_bits // bits)] - """ - num_groups, out_features = scales.shape - pack_ratio = pack_dtype_bits // bits - assert out_features % pack_ratio == 0, "out_features must be divisible by pack ratio" - - # reshape [num_groups, out_features // pack_ratio, pack_ratio] - reshaped = scales.reshape(num_groups, out_features // pack_ratio, pack_ratio) - compressed = reshaped[..., 0] # take first in each pack - return compressed - -# TODO: not yet working for cuda/cpu fused int4 ops -def pack_scales_and_zeros(scales, zeros): - print("scales", scales.shape, zeros.shape) - # assert scales.shape == zeros.shape - # assert scales.dtype == torch.bfloat16 - # assert zeros.dtype == torch.bfloat16 - return ( - torch.cat( - [ - scales.reshape(scales.size(0), scales.size(1), 1), - zeros.reshape(zeros.size(0), zeros.size(1), 1), - ], - 2, - ) - .transpose(0, 1) - .contiguous() - ) - -def pack_tinygemm_scales_and_zeros(scales, zeros, dtype=torch.bfloat16): - # _guard_dtype_size(scales, "scales", dtype=dtype, size=zeros.size()) - # _guard_dtype_size(zeros, "zeros", dtype=dtype) - dim = scales.dim() - return ( - torch.cat( - [ - scales.unsqueeze(-1), - zeros.unsqueeze(-1), - ], - dim, - ) - .transpose(-3, -2) - .contiguous() - ) + for i in range(8): + nibble = (qweight_int32 >> (i * 4)) & 0xF + unpacked[:, i::8] = nibble.squeeze(-1).to(torch.uint8) + + # Pack 2 nibbles into one uint8: [out, in//2] + packed_uint8 = (unpacked[:, 0::2] & 0xF) | ((unpacked[:, 1::2] & 0xF) << 4) + return packed_uint8.contiguous() class TorchFusedQuantLinear(PackableQuantLinear): SUPPORTS_BITS = [4] @@ -225,7 +191,7 @@ def _forward(self, x, out_shape): if not self.training and not self.transformed and TORCH_HAS_XPU_FUSED_OPS: # one-time transform per module for xpu aten fused ops print("ssss 1", self.qweight.shape, self.scales.shape, self.qzeros.shape) - # self.transform(x.dtype) + self.transform(x.dtype) print("ssss 2", self.qweight.shape, self.scales.shape, self.qzeros.shape) # raise Exception("Test") self.transformed = True @@ -240,41 +206,13 @@ def _forward(self, x, out_shape): # TODO: torch.ops _weight_int4pack_mm has fused aten op for int4 matmul but we need to transform and align format # scales + zeros and pass as one tensor - scales_compressed = compress_scales(self.scales, self.pack_dtype_bits, self.bits) - # scales_and_zeros = pack_scales_and_zeros(scales_compressed, self.qzeros) - scales_and_zeros = pack_tinygemm_scales_and_zeros(scales_compressed.transpose(0, 1).contiguous(), self.qzeros.transpose(0, 1).contiguous()) + scales_and_zeros = torch.stack([self.scales, self.qzeros], dim=-1).contiguous() - - inner_ktiles = 2 - # convert to uint8 - # see https://github.com/huggingface/optimum-quanto/blob/main/optimum/quanto/tensor/weights/tinygemm/packed.py#L61 - t = self.qweight - q_uint8 = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) - - print("qweight", self.qweight.shape, self.qweight.dtype) + q_uint8 = gptq_qweight_to_uint8(self.qweight, self.in_features) weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - q_uint8, inner_ktiles + q_uint8, 8 ) - print("q_uint8", self.qweight.shape, q_uint8.shape, weight_int4pack.shape) - B_innerKTiles = weight_int4pack.size(3) * 2 - kKTileSize = 16 - k = x.size(1) - print("weight_int4pack",weight_int4pack.shape,weight_int4pack.dtype, weight_int4pack.size(1), k / (B_innerKTiles * kKTileSize)) - print("B_innerKTiles",k , B_innerKTiles, kKTileSize) - print("x",x.shape) - - kMTileSize = 16 - m = x.size(0) - # like c++ divUp - def div_up(a: int, b: int) -> int: - return (a + b - 1) // b - mTiles = div_up(m, kMTileSize) - kNTileSize = 8 - kNTileSizeTensor = 8 - nTileScaleFactor = (kNTileSize / kNTileSizeTensor) - nTiles = (weight_int4pack.size(0) / nTileScaleFactor) - n = nTiles * kNTileSize - print("qScaleAndZeros", scales_and_zeros.shape, n) + print(x.shape, weight_int4pack.shape, self.group_size, scales_and_zeros.shape) out = torch.ops.aten._weight_int4pack_mm(x.to(torch.bfloat16), weight_int4pack, self.group_size, scales_and_zeros) print("out", out.shape) From aae28e7e689685b356b15265de929d2623fb6837 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Fri, 29 Aug 2025 03:42:27 +0000 Subject: [PATCH 5/5] cleanup Signed-off-by: ZX-ModelCloud --- gptqmodel/nn_modules/qlinear/torch_fused.py | 87 +++++++++------------ 1 file changed, 39 insertions(+), 48 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused.py b/gptqmodel/nn_modules/qlinear/torch_fused.py index 31bbda41a..195e8b334 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused.py @@ -28,28 +28,6 @@ log = setup_logger() -def gptq_qweight_to_uint8(qweight_int32: torch.Tensor, in_features: int) -> torch.Tensor: - """ - Convert GPTQ qweight (int32, [in//8, out]) into uint8 [out, in]. - Each int32 stores 8x 4-bit weights (nibbles). - Each uint8 will store 2x 4-bit weights: low nibble + high nibble. - """ - packed_k, out_features = qweight_int32.shape - assert packed_k * 8 == in_features, "qweight shape mismatch with in_features" - - # Unpack int32 -> [out, in] - qweight_int32 = qweight_int32.permute(1, 0).contiguous() # [out, in//8] - qweight_int32 = qweight_int32.unsqueeze(-1) # [out, in//8, 1] - unpacked = torch.empty((out_features, in_features), dtype=torch.uint8, device=qweight_int32.device) - - for i in range(8): - nibble = (qweight_int32 >> (i * 4)) & 0xF - unpacked[:, i::8] = nibble.squeeze(-1).to(torch.uint8) - - # Pack 2 nibbles into one uint8: [out, in//2] - packed_uint8 = (unpacked[:, 0::2] & 0xF) | ((unpacked[:, 1::2] & 0xF) << 4) - return packed_uint8.contiguous() - class TorchFusedQuantLinear(PackableQuantLinear): SUPPORTS_BITS = [4] SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] @@ -62,7 +40,7 @@ class TorchFusedQuantLinear(PackableQuantLinear): SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] # optimized for XPU but should run on all - SUPPORTS_DEVICES = [DEVICE.XPU, DEVICE.CUDA] # change this to XPU to limit to Intel XPU + SUPPORTS_DEVICES = [DEVICE.CUDA] # change this to XPU to limit to Intel XPU SUPPORTS_PLATFORM = [PLATFORM.ALL] SUPPORTS_PACK_DTYPES = [torch.int32] SUPPORTS_ADAPTERS = [Lora] @@ -143,6 +121,7 @@ def train(self, mode: bool = True): return super().train(mode=mode) def transform(self, dtype): + print("www 0", self.qweight.shape, self.qweight) self.scales = self.scales.clone().to(dtype).contiguous() # Unpack qzeros zeros = torch.bitwise_right_shift( @@ -151,13 +130,14 @@ def transform(self, dtype): ).to(self.dequant_dtype) zeros = torch.bitwise_and(zeros, self.maxq).reshape(zeros.shape[0], -1) # Unpack and reorder qweight - # weight = torch.bitwise_and( - # torch.bitwise_right_shift( - # torch.unsqueeze(self.qweight, 1).expand(-1, self.pack_factor, -1), - # self.wf_unsqueeze_neg_one # self.wf.unsqueeze(-1) - # ).to(self.dequant_dtype), - # self.maxq - # ) + weight = torch.bitwise_and( + torch.bitwise_right_shift( + torch.unsqueeze(self.qweight, 1).expand(-1, self.pack_factor, -1), + self.wf_unsqueeze_neg_one # self.wf.unsqueeze(-1) + ).to(self.dequant_dtype), + self.maxq + ) + print("www 1", weight.shape, weight) self.ret_idx = torch.zeros(self.g_idx.shape[0], dtype=torch.int32).to(self.g_idx.device) groups = self.g_idx.shape[0] // self.group_size remainder = self.g_idx.shape[0] % self.group_size @@ -168,7 +148,9 @@ def transform(self, dtype): for i in range(groups): g_idx_2[self.g_idx == i] += arange_tensor self.ret_idx[g_idx_2] = torch.arange(self.g_idx.shape[0]).to(self.ret_idx.device).to(self.ret_idx.dtype) - # weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]).index_select(0, self.ret_idx).t() + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]).index_select(0, self.ret_idx).t() + print("www 2",weight.shape, weight) + # Pack qweight # packed = torch.zeros(weight.shape[0], weight.shape[1] // self.pack_factor, dtype=torch.int32, device=weight.device) # for col in range(weight.shape[1] // self.pack_factor): @@ -176,9 +158,27 @@ def transform(self, dtype): # packed_col = weight[:, col * self.pack_factor + i].to(torch.int32) # packed[:, col] |= packed_col << (i * self.bits) - # self.qweight = packed.contiguous() + out_features, in_features = weight.shape + assert in_features % 2 == 0, "in_features 必须是偶数才能两两pack" + + + low = weight[:, 0::2].to(torch.uint8) # 低 4bit + high = weight[:, 1::2].to(torch.uint8) # 高 4bit + + # [out_features, in_features // 2], dtype=uint8 + packed_uint8 = low | (high << 4) # 每两个int4打包成一个uint8 + + # input is [n][k / 2](uint8 dtype) + # output is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2](int32dtype) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + packed_uint8.contiguous(), 2 + ) + + self.qweight = weight_int4pack.contiguous() self.qzeros = zeros.contiguous() + self.scales_and_zeros = torch.stack([self.scales, self.qzeros], dim=-1).contiguous() + def forward(self, x: torch.Tensor): out_shape = x.shape[:-1] + (self.out_features,) x = x.reshape(-1, x.shape[-1]) @@ -190,14 +190,11 @@ def _forward(self, x, out_shape): if not self.training and not self.transformed and TORCH_HAS_XPU_FUSED_OPS: # one-time transform per module for xpu aten fused ops - print("ssss 1", self.qweight.shape, self.scales.shape, self.qzeros.shape) self.transform(x.dtype) - print("ssss 2", self.qweight.shape, self.scales.shape, self.qzeros.shape) - # raise Exception("Test") self.transformed = True if self.transformed: - # x = x[:, self.ret_idx].contiguous() + x = x[:, self.ret_idx].contiguous() # fused ops optimized for xpu using torch.ops # note _weight_int4pack_mm_with_scales_and_zeros is added by intel for xpu only # out = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( @@ -206,17 +203,11 @@ def _forward(self, x, out_shape): # TODO: torch.ops _weight_int4pack_mm has fused aten op for int4 matmul but we need to transform and align format # scales + zeros and pass as one tensor - scales_and_zeros = torch.stack([self.scales, self.qzeros], dim=-1).contiguous() - - q_uint8 = gptq_qweight_to_uint8(self.qweight, self.in_features) - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - q_uint8, 8 - ) - - print(x.shape, weight_int4pack.shape, self.group_size, scales_and_zeros.shape) - out = torch.ops.aten._weight_int4pack_mm(x.to(torch.bfloat16), weight_int4pack, self.group_size, scales_and_zeros) - print("out", out.shape) - out = out.reshape(out_shape) + # scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) + print("xxx", self.scales_and_zeros.shape) + out = torch.ops.aten._weight_int4pack_mm( + x, self.qweight, self.group_size, self.scales_and_zeros + ).reshape(out_shape) else: # make sure dequant dtype matches input x weights = self.dequantize_weight(num_itr=num_itr).to(x.dtype) @@ -265,4 +256,4 @@ def dequantize_model(model: PreTrainedModel): return model -__all__ = ["TorchFusedQuantLinear", "dequantize_model"] +__all__ = ["TorchFusedQuantLinear", "dequantize_model"] \ No newline at end of file