diff --git a/gptqmodel/nn_modules/qlinear/torch_fused.py b/gptqmodel/nn_modules/qlinear/torch_fused.py index c2a9738a6..195e8b334 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused.py @@ -28,23 +28,6 @@ log = setup_logger() -# TODO: not yet working for cuda/cpu fused int4 ops -# def pack_scales_and_zeros(scales, zeros): -# assert scales.shape == zeros.shape -# # assert scales.dtype == torch.bfloat16 -# # assert zeros.dtype == torch.bfloat16 -# return ( -# torch.cat( -# [ -# scales.reshape(scales.size(0), scales.size(1), 1), -# zeros.reshape(zeros.size(0), zeros.size(1), 1), -# ], -# 2, -# ) -# .transpose(0, 1) -# .contiguous() -# ) - class TorchFusedQuantLinear(PackableQuantLinear): SUPPORTS_BITS = [4] SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] @@ -57,7 +40,7 @@ class TorchFusedQuantLinear(PackableQuantLinear): SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] # optimized for XPU but should run on all - SUPPORTS_DEVICES = [DEVICE.XPU] # change this to XPU to limit to Intel XPU + SUPPORTS_DEVICES = [DEVICE.CUDA] # change this to XPU to limit to Intel XPU SUPPORTS_PLATFORM = [PLATFORM.ALL] SUPPORTS_PACK_DTYPES = [torch.int32] SUPPORTS_ADAPTERS = [Lora] @@ -138,6 +121,7 @@ def train(self, mode: bool = True): return super().train(mode=mode) def transform(self, dtype): + print("www 0", self.qweight.shape, self.qweight) self.scales = self.scales.clone().to(dtype).contiguous() # Unpack qzeros zeros = torch.bitwise_right_shift( @@ -153,6 +137,7 @@ def transform(self, dtype): ).to(self.dequant_dtype), self.maxq ) + print("www 1", weight.shape, weight) self.ret_idx = torch.zeros(self.g_idx.shape[0], dtype=torch.int32).to(self.g_idx.device) groups = self.g_idx.shape[0] // self.group_size remainder = self.g_idx.shape[0] % self.group_size @@ -164,16 +149,36 @@ def transform(self, dtype): g_idx_2[self.g_idx == i] += arange_tensor self.ret_idx[g_idx_2] = torch.arange(self.g_idx.shape[0]).to(self.ret_idx.device).to(self.ret_idx.dtype) weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]).index_select(0, self.ret_idx).t() + print("www 2",weight.shape, weight) + # Pack qweight - packed = torch.zeros(weight.shape[0], weight.shape[1] // self.pack_factor, dtype=torch.int32, device=weight.device) - for col in range(weight.shape[1] // self.pack_factor): - for i in range(self.pack_factor): - packed_col = weight[:, col * self.pack_factor + i].to(torch.int32) - packed[:, col] |= packed_col << (i * self.bits) + # packed = torch.zeros(weight.shape[0], weight.shape[1] // self.pack_factor, dtype=torch.int32, device=weight.device) + # for col in range(weight.shape[1] // self.pack_factor): + # for i in range(self.pack_factor): + # packed_col = weight[:, col * self.pack_factor + i].to(torch.int32) + # packed[:, col] |= packed_col << (i * self.bits) + + out_features, in_features = weight.shape + assert in_features % 2 == 0, "in_features 必须是偶数才能两两pack" + - self.qweight = packed.contiguous() + low = weight[:, 0::2].to(torch.uint8) # 低 4bit + high = weight[:, 1::2].to(torch.uint8) # 高 4bit + + # [out_features, in_features // 2], dtype=uint8 + packed_uint8 = low | (high << 4) # 每两个int4打包成一个uint8 + + # input is [n][k / 2](uint8 dtype) + # output is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2](int32dtype) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + packed_uint8.contiguous(), 2 + ) + + self.qweight = weight_int4pack.contiguous() self.qzeros = zeros.contiguous() + self.scales_and_zeros = torch.stack([self.scales, self.qzeros], dim=-1).contiguous() + def forward(self, x: torch.Tensor): out_shape = x.shape[:-1] + (self.out_features,) x = x.reshape(-1, x.shape[-1]) @@ -192,16 +197,17 @@ def _forward(self, x, out_shape): x = x[:, self.ret_idx].contiguous() # fused ops optimized for xpu using torch.ops # note _weight_int4pack_mm_with_scales_and_zeros is added by intel for xpu only - out = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( - x, self.qweight, self.group_size, self.scales, self.qzeros - ).reshape(out_shape) + # out = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( + # x, self.qweight, self.group_size, self.scales, self.qzeros + # ).reshape(out_shape) # TODO: torch.ops _weight_int4pack_mm has fused aten op for int4 matmul but we need to transform and align format # scales + zeros and pass as one tensor # scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) - # out = torch.ops.aten._weight_int4pack_mm( - # x, self.qweight, self.group_size, scales_and_zeros - # ).reshape(out_shape) + print("xxx", self.scales_and_zeros.shape) + out = torch.ops.aten._weight_int4pack_mm( + x, self.qweight, self.group_size, self.scales_and_zeros + ).reshape(out_shape) else: # make sure dequant dtype matches input x weights = self.dequantize_weight(num_itr=num_itr).to(x.dtype) @@ -250,4 +256,4 @@ def dequantize_model(model: PreTrainedModel): return model -__all__ = ["TorchFusedQuantLinear", "dequantize_model"] +__all__ = ["TorchFusedQuantLinear", "dequantize_model"] \ No newline at end of file