Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 37 additions & 31 deletions gptqmodel/nn_modules/qlinear/torch_fused.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,6 @@

log = setup_logger()

# TODO: not yet working for cuda/cpu fused int4 ops
# def pack_scales_and_zeros(scales, zeros):
# assert scales.shape == zeros.shape
# # assert scales.dtype == torch.bfloat16
# # assert zeros.dtype == torch.bfloat16
# return (
# torch.cat(
# [
# scales.reshape(scales.size(0), scales.size(1), 1),
# zeros.reshape(zeros.size(0), zeros.size(1), 1),
# ],
# 2,
# )
# .transpose(0, 1)
# .contiguous()
# )

class TorchFusedQuantLinear(PackableQuantLinear):
SUPPORTS_BITS = [4]
SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128]
Expand All @@ -57,7 +40,7 @@ class TorchFusedQuantLinear(PackableQuantLinear):
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1]

# optimized for XPU but should run on all
SUPPORTS_DEVICES = [DEVICE.XPU] # change this to XPU to limit to Intel XPU
SUPPORTS_DEVICES = [DEVICE.CUDA] # change this to XPU to limit to Intel XPU
SUPPORTS_PLATFORM = [PLATFORM.ALL]
SUPPORTS_PACK_DTYPES = [torch.int32]
SUPPORTS_ADAPTERS = [Lora]
Expand Down Expand Up @@ -138,6 +121,7 @@ def train(self, mode: bool = True):
return super().train(mode=mode)

def transform(self, dtype):
print("www 0", self.qweight.shape, self.qweight)
self.scales = self.scales.clone().to(dtype).contiguous()
# Unpack qzeros
zeros = torch.bitwise_right_shift(
Expand All @@ -153,6 +137,7 @@ def transform(self, dtype):
).to(self.dequant_dtype),
self.maxq
)
print("www 1", weight.shape, weight)
self.ret_idx = torch.zeros(self.g_idx.shape[0], dtype=torch.int32).to(self.g_idx.device)
groups = self.g_idx.shape[0] // self.group_size
remainder = self.g_idx.shape[0] % self.group_size
Expand All @@ -164,16 +149,36 @@ def transform(self, dtype):
g_idx_2[self.g_idx == i] += arange_tensor
self.ret_idx[g_idx_2] = torch.arange(self.g_idx.shape[0]).to(self.ret_idx.device).to(self.ret_idx.dtype)
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]).index_select(0, self.ret_idx).t()
print("www 2",weight.shape, weight)

# Pack qweight
packed = torch.zeros(weight.shape[0], weight.shape[1] // self.pack_factor, dtype=torch.int32, device=weight.device)
for col in range(weight.shape[1] // self.pack_factor):
for i in range(self.pack_factor):
packed_col = weight[:, col * self.pack_factor + i].to(torch.int32)
packed[:, col] |= packed_col << (i * self.bits)
# packed = torch.zeros(weight.shape[0], weight.shape[1] // self.pack_factor, dtype=torch.int32, device=weight.device)
# for col in range(weight.shape[1] // self.pack_factor):
# for i in range(self.pack_factor):
# packed_col = weight[:, col * self.pack_factor + i].to(torch.int32)
# packed[:, col] |= packed_col << (i * self.bits)

out_features, in_features = weight.shape
assert in_features % 2 == 0, "in_features 必须是偶数才能两两pack"


self.qweight = packed.contiguous()
low = weight[:, 0::2].to(torch.uint8) # 低 4bit
high = weight[:, 1::2].to(torch.uint8) # 高 4bit

# [out_features, in_features // 2], dtype=uint8
packed_uint8 = low | (high << 4) # 每两个int4打包成一个uint8

# input is [n][k / 2](uint8 dtype)
# output is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2](int32dtype)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
packed_uint8.contiguous(), 2
)

self.qweight = weight_int4pack.contiguous()
self.qzeros = zeros.contiguous()

self.scales_and_zeros = torch.stack([self.scales, self.qzeros], dim=-1).contiguous()

def forward(self, x: torch.Tensor):
out_shape = x.shape[:-1] + (self.out_features,)
x = x.reshape(-1, x.shape[-1])
Expand All @@ -192,16 +197,17 @@ def _forward(self, x, out_shape):
x = x[:, self.ret_idx].contiguous()
# fused ops optimized for xpu using torch.ops
# note _weight_int4pack_mm_with_scales_and_zeros is added by intel for xpu only
out = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros(
x, self.qweight, self.group_size, self.scales, self.qzeros
).reshape(out_shape)
# out = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros(
# x, self.qweight, self.group_size, self.scales, self.qzeros
# ).reshape(out_shape)

# TODO: torch.ops _weight_int4pack_mm has fused aten op for int4 matmul but we need to transform and align format
# scales + zeros and pass as one tensor
# scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros)
# out = torch.ops.aten._weight_int4pack_mm(
# x, self.qweight, self.group_size, scales_and_zeros
# ).reshape(out_shape)
print("xxx", self.scales_and_zeros.shape)
out = torch.ops.aten._weight_int4pack_mm(
x, self.qweight, self.group_size, self.scales_and_zeros
).reshape(out_shape)
else:
# make sure dequant dtype matches input x
weights = self.dequantize_weight(num_itr=num_itr).to(x.dtype)
Expand Down Expand Up @@ -250,4 +256,4 @@ def dequantize_model(model: PreTrainedModel):
return model


__all__ = ["TorchFusedQuantLinear", "dequantize_model"]
__all__ = ["TorchFusedQuantLinear", "dequantize_model"]