-
Notifications
You must be signed in to change notification settings - Fork 417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix result dtype conversion in QuantLinear.forward() #390
Conversation
Fixes: AutoGPTQ#385 (comment) Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
@fxmarty, can you please review this PR? |
@@ -268,8 +268,8 @@ def forward(self, x: torch.Tensor): | |||
g_idx_i = self.g_idx[i*num_dim:(i+1)*num_dim] | |||
weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()])) | |||
weights = torch.cat(weights,dim=1) | |||
out = torch.matmul(x, weights) | |||
out = out.to(dtype=weights.dtype).reshape(out_shape) | |||
out = torch.matmul(x, weights).to(dtype=weights.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure to remember why these casts are needed in the first place. Shouldn't the activation & weight be of same dtype (either both fp32, either both fp16)?
EDIT: The error is happening because of this change: a7d61ca#diff-c4c2bf0dd8440248a29510131f06affa3c2ab00d1bd7ca507dc0b7125a04f825R20 @fxmarty, I'm getting the following error:
For the following code:
Is this happening because of this commit: bcd1406 |
Hi thank you - superseded by #393 Note this bug in accelerate: huggingface/accelerate#2116 |
Fixes: #385 (comment)
Signed-Off By: Vivek Khandelwal vivek@nod-labs.com