Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fix for CPU Inference #385

Merged
merged 1 commit into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions auto_gptq/modeling/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ def make_quant(
out_features = tmp.weight.shape[1]
if (not(desc_act) or group_size == -1) and not use_triton and not use_qigen:
new_layer = QuantLinear(
bits, group_size, in_features, out_features, True, use_cuda_fp16=use_cuda_fp16, trainable=trainable
bits, group_size, in_features, out_features, True, use_cuda_fp16=use_cuda_fp16, trainable=trainable, weight_dtype=tmp.weight.dtype
)
else:
new_layer = QuantLinear(bits, group_size, in_features, out_features, True, trainable=trainable)
new_layer = QuantLinear(bits, group_size, in_features, out_features, True, trainable=trainable, weight_dtype=tmp.weight.dtype)
new_layer.device = ori_layer_device
setattr(module, attr, new_layer.to(ori_layer_device))
for name1, child in module.named_children():
Expand Down
17 changes: 9 additions & 8 deletions auto_gptq/nn_modules/qlinear/qlinear_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def __init__(
outfeatures,
bias,
kernel_switch_threshold=128,
trainable=False
trainable=False,
weight_dtype=torch.float16,
):
super().__init__()
global _autogptq_cuda_available
Expand All @@ -55,14 +56,14 @@ def __init__(
)
self.register_buffer(
'scales',
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=weight_dtype)
)
self.register_buffer(
'g_idx',
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
)
if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
self.register_buffer('bias', torch.zeros((outfeatures), dtype=weight_dtype))
else:
self.bias = None

Expand Down Expand Up @@ -105,9 +106,9 @@ def pack(self, linear, scales, zeros, g_idx=None):
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
self.scales = scales.clone().to(dtype=linear.weight.dtype)
if linear.bias is not None:
self.bias = linear.bias.clone().half()
self.bias = linear.bias.clone().to(dtype=linear.weight.dtype)

intweight = []
for idx in range(self.infeatures):
Expand Down Expand Up @@ -267,10 +268,10 @@ def forward(self, x: torch.Tensor):
g_idx_i = self.g_idx[i*num_dim:(i+1)*num_dim]
weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()]))
weights = torch.cat(weights,dim=1)
out = torch.matmul(x.to(weights.dtype), weights)
out = out.half().reshape(out_shape)
out = torch.matmul(x, weights)
out = out.to(dtype=weights.dtype).reshape(out_shape)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems broken in master at the moment:

File "/opt/miniconda3/envs/text-gen-gptq/lib/python3.10/site-packages/auto_gptq/nn_modules/qlinear/qlinear_cuda.py", line 272, in forward
out = out.to(dtype=weights.dtype).reshape(out_shape)
UnboundLocalError: local variable 'weights' referenced before assignment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @chuyqa, thanks for pointing out this issue. I have added a fix here: #390.

out = out + self.bias if self.bias is not None else out
return out.to(x.dtype)
return out


__all__ = ["QuantLinear"]
17 changes: 9 additions & 8 deletions auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def __init__(
bias,
use_cuda_fp16=True,
kernel_switch_threshold=128,
trainable=False
trainable=False,
weight_dtype=torch.float16,
):
super().__init__()
global _autogptq_cuda_available
Expand All @@ -54,15 +55,15 @@ def __init__(
)
self.register_buffer(
'scales',
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=weight_dtype)
)
self.register_buffer(
'g_idx',
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
)

if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
self.register_buffer('bias', torch.zeros((outfeatures), dtype=weight_dtype))
else:
self.bias = None
self.half_indim = self.infeatures // 2
Expand Down Expand Up @@ -105,9 +106,9 @@ def pack(self, linear, scales, zeros, g_idx):
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
self.scales = scales.clone().to(dtype=linear.weight.dtype)
if linear.bias is not None:
self.bias = linear.bias.clone().half()
self.bias = linear.bias.clone().to(dtype=linear.weight.dtype)

intweight = []
for idx in range(self.infeatures):
Expand Down Expand Up @@ -267,10 +268,10 @@ def forward(self, x):
weight = (scales * (weight - zeros))
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

out = torch.matmul(x.to(weight.dtype), weight)
out = out.half().reshape(out_shape)
out = torch.matmul(x, weight)
out = out.to(dtype=weight.dtype).reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out.to(x_dtype)
return out


__all__ = ["QuantLinear"]