Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core/ QLinear] Support CPU inference #376

Closed
wants to merge 2 commits into from

Conversation

younesbelkada
Copy link

On par with: huggingface/transformers#26719

This PR simply proposes to temporary upcast the weights and hidden states in fp32 before performing matmul in case users are on CPU. I can confirm that together with huggingface/transformers#26719 and this PR the script below:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

checkpoint = "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ"
device = "cpu" # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

quantization_config = GPTQConfig(bits=4, disable_exllama=True)

model = AutoModelForCausalLM.from_pretrained(checkpoint, low_cpu_mem_usage=True, quantization_config=quantization_config, torch_dtype=torch.float32).to(device)

inputs = tokenizer.encode("Hello how are you?", return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=4, do_sample=False)
print(tokenizer.decode(outputs[0]))

Runs fine on CPU (but is slow)

@younesbelkada
Copy link
Author

cc @PanQiWei @fxmarty @TheBloke

out = torch.matmul(x.to(weights.dtype), weights)

# To support CPU inference
if weight.dtype == torch.float16 and weight.device.type == "cpu":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When does this case arise?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you run:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

checkpoint = "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ"
device = "cpu" # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

quantization_config = GPTQConfig(bits=4, disable_exllama=True)

model = AutoModelForCausalLM.from_pretrained(checkpoint, low_cpu_mem_usage=True, quantization_config=quantization_config, torch_dtype=torch.float32).to(device)

inputs = tokenizer.encode("Hello how are you?", return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=4, do_sample=False)
print(tokenizer.decode(outputs[0]))

with this patch huggingface/transformers#26719 being applied on transformers

@@ -266,7 +266,12 @@ def forward(self, x):
weight = (scales * (weight - zeros))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add more context, the weights are in fp16 because the scales are in fp16.

@fxmarty
Copy link
Collaborator

fxmarty commented Oct 24, 2023

I'm not keen on merging this - scales should be in fp32 in the first place.

@vivekkhandelwal1
Copy link
Contributor

vivekkhandelwal1 commented Oct 26, 2023

I'm not keen on merging this - scales should be in fp32 in the first place.

Hi @fxmarty, can we get this in with any changes?

@fxmarty
Copy link
Collaborator

fxmarty commented Oct 26, 2023

@vivekkhandelwal1 Yet happy to have it in the next release, but what is proposed in this PR is not the correct solution. What is needed is to dispatch correctly the module parameters / buffers (remove hard-coded .to(torch.float16) in the init).

@vivekkhandelwal1
Copy link
Contributor

@vivekkhandelwal1 Yet happy to have it in the next release, but what is proposed in this PR is not the correct solution. What is needed is to dispatch correctly the module parameters / buffers (remove hard-coded .to(torch.float16) in the init).

Yeah, you're correct. @younesbelkada, can you make the changes accordingly, otherwise I can do that.

@younesbelkada
Copy link
Author

hi @vivekkhandelwal1 , thanks for offering help, it would be great if you can quickly do that if possible 🙏

@vivekkhandelwal1
Copy link
Contributor

Hi @younesbelkada, I don't have push access to your repo, can you please provide me that?

@fxmarty
Copy link
Collaborator

fxmarty commented Oct 27, 2023

@vivekkhandelwal1 @PanQiWei is the owner and would need to give you that. In the meantime if you open a PR I can review it and merge.

@vivekkhandelwal1
Copy link
Contributor

@vivekkhandelwal1 @PanQiWei is the owner and would need to give you that. In the meantime if you open a PR I can review it and merge.

@fxmarty Please review this PR: #385

@fxmarty
Copy link
Collaborator

fxmarty commented Oct 31, 2023

Closing as superseded by #385

@fxmarty fxmarty closed this Oct 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants