-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Falcon support #111
Falcon support #111
Conversation
quant code import os
import numpy as np
import random
import torch
from transformers import AutoTokenizer, TextGenerationPipeline
from datasets import load_dataset
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
pretrained_model_dir = "tiiuae/falcon-7b"
quantized_model_dir = "falcon-7b-4bit-128g"
# os.makedirs(quantized_model_dir, exist_ok=True)
def get_wikitext2(nsamples, seed, seqlen, tokenizer):
# set seed
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
# load dataset and preprocess
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')
testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
traindataset = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
attention_mask = torch.ones_like(inp)
traindataset.append({'input_ids':inp,'attention_mask': attention_mask})
return traindataset, testenc
def main():
from transformers import AutoTokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=False)
except:
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
# load un-quantized model, the model will always be force loaded into cpu
quantize_config = BaseQuantizeConfig(
bits=4, # quantize model to 4-bit
group_size=64, # it is recommended to set the value to 128
desc_act=False, # desc_act and groupsize only works on triton
)
# get model maximum sequence length
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config, trust_remote_code=True, torch_dtype=torch.float32)
model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
if any([k in model_config for k in seq_len_keys]):
for key in seq_len_keys:
if key in model_config:
model.seqlen = model_config[key]
break
else:
model.seqlen = 2048
# load train dataset for quantize
traindataset, testenc = get_wikitext2(128, 0, model.seqlen, tokenizer)
# quantize model, the examples should be list of dict whose keys contains "input_ids" and "attention_mask"
# with value under torch.LongTensor type.
model.quantize(traindataset, use_triton=False)
# save quantized model
model.save_quantized(quantized_model_dir)
# save quantized model using safetensors
model.save_quantized(quantized_model_dir, use_safetensors=True)
# load quantized model, currently only support cpu or single gpu
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0", use_triton=False, torch_dtype=torch.float32, trust_remote_code=True)
token = tokenizer("test is", return_tensors="pt").to("cuda:0")
del token['token_type_ids']
print(tokenizer.decode(model.generate(**token).cpu().tolist()[0]))
import logging
logging.basicConfig(format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S")
main() |
Thank you very much for such fast implement to support falcon! 🥳 Here is my question, does this pr only focus on falcon-7b, or is falcon-40b will also be considered in this pr? |
I think both will work. But currently only tested on 7b. |
I see the model_type in config.json of 40b and 7b models are different, in 40b it's RefinedWeb while in 7b it's RefinedWebModel, so maybe both need to be added in auto-gptq's relevant code in order to support both two models. |
This is the current draft. |
Hi, I just convert this pr to draft mode based this information, you can convert to ready for review mode once everything is done. |
Confirmed that 7B works. Although 40B results in OOM. But it seems to work. |
Amazing @qwopqwop200 thank you! I am trying it now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will merge, thank you very much!
The falcon datatype is torch.bfloat16, why did you use float32 to load the model? It's a huge dataset (2.8T), but we probably only need the testing part |
This is just code to make sure it works. |
I can confirm that the model loads fine with I have made the 7B model no problem, and it works well : https://huggingface.co/TheBloke/falcon-7b-instruct-GPTQ I am having a bit of trouble making 40B. As @qwopqwop200 found, it uses a lot of VRAM - it peaks at around 32GB, so a 24GB card is no good. So I made it on an A6000 with 48GB. It took over 2.5 hours to quantise all the layers.. then this happened!!!
out of RAM! Argh! :) I had 167GB RAM so that was a hell of a lot of RAM it needed. I am trying again now on a server with L40 40GB GPU and 250GB RAM. I hope it will be enough! I am going to try low_cpu_mem_usage:
to see if this helps? What do you think @qwopqwop200 @PanQiWei ? Will that reduce RAM requirements for quantizing do you think? Any other ideas? |
How do you properly set up the device map for hugging face from_pretrained method? Which set maximum vram for each device and maximum CPU ram |
If there’s a large enough disk for it on the machine you can make a big swapfile for it. I got it running on 64GB of RAM that way (then ran into the VRAM needs you mentioned). Also, this was using a modified version of GPTQ-for-LLaMa so it might not be relevant, but since Falcon’s model file uses their own version of nn.Linear, if you add their Linear class to the list of ones to quantize it packs down a bit smaller. That could be unstable, though. |
Hmm yeah I guess I should have tried that. I'm running in a Docker and am not 100% sure I'm able to add swap. But it'd be worth a try. Oh well, I'll know soon enough. It's on layer 56 of 60 of quantizing.. fingers' crossed 250GB RAM will be enough to pack it! I can see it's using 130GB during the quantising phase so I just hope it doesn't need 2x that to pack... |
OK I tried it and I guess it can't be done in a docker :(
Just crossing my fingers it's not going to die in about 2 mins.. |
Doing better!
And only using 144GB RAM so maybe low_cpu_mem_usage=True did help |
Hehe this is going to take forever. It took 17 minutes to pack the first 6 layers. So looks like it'll take around 3 hours to do the whole thing - much longer than it took to quantise! @PanQiWei @qwopqwop200 One feature I would really love to see is GPU acceleration for packing. I know it might be difficult though. I looked at the code once and saw it references uint32 which isn't available in torch yet? |
@TheBloke Did you try 128 group size, the quant succeeded, but evaluation failed due to layer size mismatch. While changing back to 64 group size, everything goes fine (and matches your quant result, for 7B model) |
Oh interesting. No I did not. I tried 64 first and it worked fine so I left it at that. I am using group_size = -1 (no group size) for 40B though, based on past experience with Llama models to reduce VRAM usage as much as possible for models of 30B or greater |
@qwopqwop200 Maybe a stupid question. How did you managed to pull off the large model GPTQ quantization like 65B w/o OOM in your GPTQ-for-LLAMA repo? AutoGPTQ seems always load more weight onto cuda:0 only during quantization (cuda branch), and ends up failing with OOM at some point in my dual 24G VRAM setup @PanQiWei Also found that
Can be passed into AutoGPTQ from_pretain_method, torch actually reserve the correct amount of vram, and truly offloading model to CPU (shown by output log), but are ignored during quantization. The vram on GPU 0 goes up until OOM even with these size settings |
Update: The 40B model worked and is uploaded at https://huggingface.co/TheBloke/falcon-40b-instruct-GPTQ Even with group_size -1, it requires a bit more than 24GB VRAM which is a shame. But the main problem is it is REALLY slow. So is the 7B model. Example: Is there any possibility of improving this performance? I suppose it's mostly because of the custom code provided by RefinedWeb. But is there any chance there are optimisations that could be made in the AutoGPTQ code related to RefinedWeb? |
😮 |
4 bit qlora might be the end https://www.reddit.com/r/LocalLLaMA/comments/13uvbxe/testing_the_new_bnb_4bit_or_qlora_vs_gptq_cuda/. It is supposed to be even faster than gptq in the near future while be able to do both inferenceing and fine tuning |
I get error when I run quant code for 'falcon-7b-instruct' model: torch._C._LinAlgError: linalg.cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 16967 is not positive-definite). have you meet the same error? I can not find a solution.... code bellow:
|
@ltm920716 Yes I have had this problem in the past. Two solutions I have found:
I see you are using 128 samples of wikitext2, so I'm surprised you have this error. But it can likely be solved either by using 256 samples, or by setting damp percent to 0.1, or both. Specify damp percent with:
|
By the way I already quantised Falcon 7B Instruct with AutoGPTQ + Wikitext2, here: https://huggingface.co/TheBloke/falcon-7b-instruct-GPTQ So you could just use that! |
hello @TheBloke and maybe I need to dive deep into studying the gptq paper to get the reason for this error |
Add falcon
Added dtype. This is added because falcon currently does not support float16.
Also, the input dimension of 7b is not divisible by 256, so triton is not supported. This is a problem to be addressed later.