-
Notifications
You must be signed in to change notification settings - Fork 429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core
/ QLinear
] Support CPU inference
#376
Conversation
out = torch.matmul(x.to(weights.dtype), weights) | ||
|
||
# To support CPU inference | ||
if weight.dtype == torch.float16 and weight.device.type == "cpu": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When does this case arise?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you run:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
checkpoint = "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ"
device = "cpu" # for GPU usage or "cpu" for CPU usage
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
model = AutoModelForCausalLM.from_pretrained(checkpoint, low_cpu_mem_usage=True, quantization_config=quantization_config, torch_dtype=torch.float32).to(device)
inputs = tokenizer.encode("Hello how are you?", return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=4, do_sample=False)
print(tokenizer.decode(outputs[0]))
with this patch huggingface/transformers#26719 being applied on transformers
@@ -266,7 +266,12 @@ def forward(self, x): | |||
weight = (scales * (weight - zeros)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To add more context, the weights are in fp16 because the scales are in fp16.
I'm not keen on merging this - scales should be in fp32 in the first place. |
Hi @fxmarty, can we get this in with any changes? |
@vivekkhandelwal1 Yet happy to have it in the next release, but what is proposed in this PR is not the correct solution. What is needed is to dispatch correctly the module parameters / buffers (remove hard-coded |
Yeah, you're correct. @younesbelkada, can you make the changes accordingly, otherwise I can do that. |
hi @vivekkhandelwal1 , thanks for offering help, it would be great if you can quickly do that if possible 🙏 |
Hi @younesbelkada, I don't have push access to your repo, can you please provide me that? |
@vivekkhandelwal1 @PanQiWei is the owner and would need to give you that. In the meantime if you open a PR I can review it and merge. |
|
Closing as superseded by #385 |
On par with: huggingface/transformers#26719
This PR simply proposes to temporary upcast the weights and hidden states in fp32 before performing matmul in case users are on CPU. I can confirm that together with huggingface/transformers#26719 and this PR the script below:
Runs fine on CPU (but is slow)