Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 63 additions & 33 deletions gptqmodel/nn_modules/qlinear/torch_fused.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,26 @@
from ...nn_modules.qlinear import BaseQuantLinear, PackableQuantLinear
from ...utils.backend import BACKEND
from ...utils.logger import setup_logger
from ...utils.torch import TORCH_HAS_XPU_FUSED_OPS
from ...utils.torch import TORCH_HAS_FUSED_OPS


log = setup_logger()

# TODO: not yet working for cuda/cpu fused int4 ops
# def pack_scales_and_zeros(scales, zeros):
# assert scales.shape == zeros.shape
# # assert scales.dtype == torch.bfloat16
# # assert zeros.dtype == torch.bfloat16
# return (
# torch.cat(
# [
# scales.reshape(scales.size(0), scales.size(1), 1),
# zeros.reshape(zeros.size(0), zeros.size(1), 1),
# ],
# 2,
# )
# .transpose(0, 1)
# .contiguous()
# )
# TODO: CPU works, not yet working for cuda fused int4 ops
def pack_scales_and_zeros(scales, zeros):
assert scales.shape == zeros.shape
# assert scales.dtype == torch.bfloat16
# assert zeros.dtype == torch.bfloat16
return (
torch.cat(
[
scales.reshape(scales.size(0), scales.size(1), 1),
zeros.reshape(zeros.size(0), zeros.size(1), 1),
],
2,
)
.contiguous()
)

class TorchFusedQuantLinear(PackableQuantLinear):
SUPPORTS_BITS = [4]
Expand All @@ -45,9 +44,7 @@ class TorchFusedQuantLinear(PackableQuantLinear):
SUPPORTS_AUTO_PADDING = True
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1]

# optimized for XPU but should run on all
SUPPORTS_DEVICES = [DEVICE.CPU, DEVICE.XPU] # change this to XPU to limit to Intel XPU
SUPPORTS_DEVICES = [DEVICE.CPU, DEVICE.XPU]
SUPPORTS_PLATFORM = [PLATFORM.ALL]
SUPPORTS_PACK_DTYPES = [torch.int32]
SUPPORTS_ADAPTERS = [Lora]
Expand Down Expand Up @@ -129,7 +126,7 @@ def train(self, mode: bool = True):

return super().train(mode=mode)

def transform(self, dtype):
def transform_xpu(self, dtype):
self.scales = self.scales.clone().to(dtype).contiguous()
# Unpack qzeros
zeros = torch.bitwise_right_shift(
Expand Down Expand Up @@ -166,6 +163,39 @@ def transform(self, dtype):
self.qweight = packed.contiguous()
self.qzeros = zeros.contiguous()

def transform_cpu(self, dtype):
self.scales = self.scales.clone().to(dtype).contiguous()
# Unpack and reorder qweight
weight = torch.bitwise_and(
torch.bitwise_right_shift(
torch.unsqueeze(self.qweight, 1).expand(-1, self.pack_factor, -1),
self.wf_unsqueeze_neg_one # self.wf.unsqueeze(-1)
).to(self.dequant_dtype),
self.maxq
)
self.ret_idx = torch.zeros(self.g_idx.shape[0], dtype=torch.int32).to(self.g_idx.device)
groups = self.g_idx.shape[0] // self.group_size
remainder = self.g_idx.shape[0] % self.group_size
g_idx_2 = self.g_idx * self.group_size
if remainder > 0:
g_idx_2[self.g_idx == groups] += torch.arange(remainder).to(self.g_idx_2.device).to(self.g_idx_2.dtype)
arange_tensor = torch.arange(self.group_size).to(self.g_idx.device).to(self.g_idx.dtype)
for i in range(groups):
g_idx_2[self.g_idx == i] += arange_tensor
self.ret_idx[g_idx_2] = torch.arange(self.g_idx.shape[0]).to(self.ret_idx.device).to(self.ret_idx.dtype)
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]).index_select(0, self.ret_idx).t()
self.qweight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(weight.int(), 1).contiguous()
self.qzeros = torch.zeros_like(self.scales).contiguous()
self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros)

def transform(self, dtype, device):
if device == "xpu":
self.transform_xpu(dtype)
elif device == "cpu":
self.transform_cpu(dtype)
else:
raise NotImplementedError

def forward(self, x: torch.Tensor):
out_shape = x.shape[:-1] + (self.out_features,)
x = x.reshape(-1, x.shape[-1])
Expand All @@ -175,25 +205,25 @@ def forward(self, x: torch.Tensor):
def _forward(self, x, out_shape):
num_itr = self.g_idx.shape[0] // x.shape[-1]

if not self.training and not self.transformed and TORCH_HAS_XPU_FUSED_OPS and "xpu" == x.device.type:
if not self.training and not self.transformed and TORCH_HAS_FUSED_OPS:
# one-time transform per module for xpu aten fused ops
self.transform(x.dtype)
self.transform(x.dtype, x.device.type)
self.transformed = True

if self.transformed:
x = x[:, self.ret_idx].contiguous()
# fused ops optimized for xpu using torch.ops
# note _weight_int4pack_mm_with_scales_and_zeros is added by intel for xpu only
out = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros(
x, self.qweight, self.group_size, self.scales, self.qzeros
).reshape(out_shape)

# TODO: torch.ops _weight_int4pack_mm has fused aten op for int4 matmul but we need to transform and align format
# scales + zeros and pass as one tensor
# scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros)
# out = torch.ops.aten._weight_int4pack_mm(
# x, self.qweight, self.group_size, scales_and_zeros
# ).reshape(out_shape)
if x.device.type == "xpu":
out = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros(
x, self.qweight, self.group_size, self.scales, self.qzeros
).reshape(out_shape)
elif x.device.type == "cpu":
out = torch.ops.aten._weight_int4pack_mm_for_cpu(
x, self.qweight, self.group_size, self.scales_and_zeros
).reshape(out_shape)
else:
raise NotImplementedError
else:
# make sure dequant dtype matches input x
weights = self.dequantize_weight(num_itr=num_itr).to(x.dtype)
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# pytorch 2.6.0 fixes many compilation errors
TORCH_HAS_COMPILE = version.parse(torch.__version__).release >= version.Version('2.6').release
TORCH_GTE_28 = version.parse(torch.__version__).release >= version.Version('2.8').release
TORCH_HAS_XPU_FUSED_OPS = version.parse(torch.__version__).release >= version.Version('2.8').release
TORCH_HAS_FUSED_OPS = version.parse(torch.__version__).release >= version.Version('2.8').release

HAS_CUDA = False
HAS_XPU = False
Expand Down