Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support QLoRA 4-bit finetuning with bitsandbytes #275

Merged
merged 39 commits into from Aug 21, 2023

Conversation

patrickhwood
Copy link
Contributor

@patrickhwood patrickhwood commented Jul 18, 2023

Added quantize command-line option. Allowed values are "bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq". GPTQ int4 format only supports inference, not training.

Note that fabric.init_module needed empty_init=True for bnb.int8 training; otherwise, a RuntimeError: "normal_kernel_cuda" not implemented for 'Char' error is thrown.

Only lora.py was extensively tested (see #242 (comment) for some results). adapter.py was only tested to the point of running about 1K iterations for each of the quantization types.

Closes #277
Closes #242
Closes #240
Closes #165
Closes #207
Closes #198
Fixes #176

Not sure if the tests themselves should be updated instead.
@carmocca
Copy link
Member

Nice. I wanted to get this done this week. Do you mind if I push commits to your branch directly?

@patrickhwood
Copy link
Contributor Author

Nice. I wanted to get this done this week. Do you mind if I push commits to your branch directly?

Sure go ahead. I'm not familiar with the test framework, so I don't know how to resolve the check errors.

@patrickhwood
Copy link
Contributor Author

Nice. I wanted to get this done this week. Do you mind if I push commits to your branch directly?

I just merged your latest changes with my branch and tested bnb.int8 and bnb-nf4-dq with Llama 2. Do you want me to push it?

@aniketmaurya
Copy link
Contributor

I just merged your latest changes with my branch and tested bnb.int8 and bnb-nf4-dq with Llama 2. Do you want me to push it?

cc: @carmocca

@rasbt
Copy link
Collaborator

rasbt commented Aug 8, 2023

Sry, but would it be possible to prioritize this among the other PRs @carmocca ? It is kind of very relevant for the NeuRIPS competition and people are requesting it 😅

@carmocca
Copy link
Member

@rasbt You should be able to proceed now

@rasbt
Copy link
Collaborator

rasbt commented Aug 15, 2023

Thanks to the latest PRs, it works (again)! The performance of the non-quantized runs is also not impacted. Will post a table with the latest numbers later today once I have all the results.

@rasbt
Copy link
Collaborator

rasbt commented Aug 15, 2023

Here are the results for the fixed quantized runs:

Setting Iterations Paged optimizer Training Memory Training Time Loss
Default (bfloat16-mixed) 20,000 No 33.51 GB 5205.75s 3.2899
--quantize "bnb.nf4" 20,000 No 22.35 GB 6046.32s 3.3615
--quantize "bnb.nf4-dq" 20,000 No 22.19 GB 5997.37s 3.4891
--quantize "bnb.nf4" 20,000 Yes 22.35 GB 6012.66s 3.4583
--quantize "bnb.nf4-dq" 20,000 Yes 22.19 GB 6015.96s 3.4139
--precision "bf16-true" 20,000 No 15.86 GB 4806.86s 3.2664
--precision "bf16-true" \ --quantize "bnb.nf4" 20,000 Yes 14.82 GB 5543.79 3.1176
--precision "bf16-true" \ --quantize "bnb.nf4-dq" 20,000 Yes 15.73 GB 4742.72s 0.7459
  1. The loss does go down a lot. It starts at 10-12. What's a bit strange is that the quantization doesn't do too much in terms of memory and speed compared to bf16-true. Maybe the advantage is stronger if you compare to 32-mixed.

  2. The other observation is that double-quantization has a very low loss. It looks kind of like an outlier. I am running this again to confirm.

  3. Inference is really bad. For

python generate/lora.py  --lora_path out/lora/bf16-true-nf4-dq/lit_model_lora_finetuned.pth --precision "bf16-true" --quantize "bnb.nf4-dq" 

for example, the timing looks good but the generated text it gobbledygook:

  • speed: 37.76 tokens/sec
  • memory: 3.08 GB
  • text: VB responded parted MiningvaeWBLyku tigerب Ath Shel Numbers citesExpandcliffeTy CAN insultingJackaldreviewedκ tossed selecting

Same for bnb.nf4 without double quantization. Now, I wanted to run without the quantize flags:

python generate/lora.py  --lora_path out/lora/bf16-true-nf4-dq/lit_model_lora_finetuned.pth --precision "bf16-true" 

or

python generate/lora.py  --lora_path out/lora/bf16-true-nf4-dq/lit_model_lora_finetuned.pth

but it results in Unexpected key(s) in state_dict:. Is this expected to happen? Or is this maybe due to the recent two PRs that updated things?

@rasbt
Copy link
Collaborator

rasbt commented Aug 15, 2023

I was able to reproduce the row at the bottom. Very interesting. Inference also looks normal now:

Lamas eat a wide variety of foods from apples to seafood. They are usually vegetarian, but they also eat plenty of fruit, vegetables, nuts, lean meats, and other animal-based foods. Lamas are known for their distinctive flavors, so it is not uncommon to see them at sushi bars or restaurants.

The previous issues could have been related to some temporary results or some issue due to merging main. I don't know.

Other than that, I think everything is pretty good now except that the memory savings are not as good as I expected. Any thoughts?

@patrickhwood
Copy link
Contributor Author

patrickhwood commented Aug 15, 2023

I ran fine tuning for 5000 iterations on Llama-2-7b-hf with the Alpaca data set and micro_batch_size = 1 on an RTX 4090 (24GB -- can't run unquantized with larger settings):

--precision bf16-true:
Training time: 337.50s
Memory used: 21.27 GB

--precision bf16-true -quantize bnb.nf4:
Training time: 808.14s
Memory used: 13.38 GB

Significant difference in both speed and memory usage, which is what I've seen with the quantized qlora models all along using various frameworks. Currently using bitsandbytes version 0.40.0.post4. I'll upgrade to 0.41.1 once my bnb-nf4-dq run completes.

@rasbt
Copy link
Collaborator

rasbt commented Aug 15, 2023

Thanks for this! This gives me hope!

I am using 0.41.1. Pls let me know what you find.

Another thing is that I just used the default model, which is StableLM 3B. Let me rerun all the experiments tomorrow with a different model.

@patrickhwood
Copy link
Contributor Author

--precision bf16-true -quantize bnb.nf4-dq:
Training time: 822.48s
Memory used: 13.07 GB

I tried newer bitsandbytes, but from 0.40.1 on, I got runtime errors claiming not to find libcudart.so, which is clearly in /usr/local/cuda/lib64 which is clearly in my LD_LIBRARY_PATH in the environment. There's a note in the bitsandbytes changelog for 0.40.1 about relying on the pytorch CUDA libraries in CUDA SETUP, and another in 0.40.2 about handling a missing LD_LIBRARY_PATH, so I'll need to dig into this more tomorrow.

@patrickhwood
Copy link
Contributor Author

Same results with bnb 0.41.1. Needed to set BNB_CUDA_VERSION=122 for CUDA version 12.2, in case anyone has a similar problem.

@rasbt
Copy link
Collaborator

rasbt commented Aug 16, 2023

Happy to report that I am getting similar results now @Andrei-Aksionov, I think that StableLM 3B was a bad test case (I reran everything and can confirm the results from earlier). Bottom-line: the advantage for 7B models is more obvious.

Now

  • how do we deal with that in the documentation, should we just just Llama 2 as an example instead of StableLM 3B? @carmocca (Problem is we have to reduce the micro-batch size, which requires a change to the script.)
  • The loss after 10k iterations is still bad (around ~3). Need to do a full 50k run to see whether the model actually produces good results

Other than that, I think things look good.

Model Settings Iterations Micro batch size Paged optimizer Training Memory Training Time
StableLM 3B (default) Default (bfloat16-mixed) 10,000 4 (default) No 33.50 GB 1707.06s
StableLM 3B (default) --quantize "bnb.nf4" 10,000 4 (default) Yes 22.35 GB 2227.86s
StableLM 3B (default) --quantize "bnb.nf4-dq" 10,000 4 (default) Yes 22.19 GB 2217.41s
StableLM 3B (default) --precision "bf16-true" 10,000 4 (default) No 15.86 GB 1551.36s
StableLM 3B (default) --precision "bf16-true" \ --quantize "bnb.nf4" 10,000 4 (default) Yes 14.82 GB 1945.42s
StableLM 3B (default) --precision "bf16-true" \ --quantize "bnb.nf4-dq" 10,000 4 (default) Yes 14.66 GB 1945.77s
Model Settings Iterations Micro batch size Paged optimizer Training Memory Training Time
Falcon 7B Default (bfloat16-mixed) 10,000 1 No torch.cuda.OutOfMemoryError N/A
Falcon 7B --quantize "bnb.nf4" 10,000 1 Yes 19.64 GB 2965.31s
Falcon 7B --quantize "bnb.nf4-dq" 10,000 1 Yes 19.29 GB 3030.29s
Falcon 7B --precision "bf16-true" 10,000 1 No 22.93 GB 1924.77s
Falcon 7B --precision "bf16-true" \ --quantize "bnb.nf4" 10,000 1 Yes 15.65 GB 2492.93s
Falcon 7B --precision "bf16-true" \ --quantize "bnb.nf4-dq" 10,000 1 Yes 15.29 GB 2571.58s
Model Settings Iterations Micro batch size Paged optimizer Training Memory Training Time
Llama 2 7B Default (bfloat16-mixed) 10,000 1 No torch.cuda.OutOfMemoryError N/A
Llama 2 7B --quantize "bnb.nf4" 10,000 1 Yes 19.96 GB 2973.71s
Llama 2 7B --quantize "bnb.nf4-dq" 10,000 1 Yes 19.65 GB 3021.13s
Llama 2 7B --precision "bf16-true" 10,000 1 No 20.66 GB 2101.01s
Llama 2 7B --precision "bf16-true" \ --quantize "bnb.nf4" 10,000 1 Yes 13.52 GB 2503.53s
Llama 2 7B --precision "bf16-true" \ --quantize "bnb.nf4-dq" 10,000 1 Yes 13.21 GB 2593.44s

@carmocca
Copy link
Member

That's very thorough! Awesome work Seb.

I suggest using falcon-7b, you can just document any hyperparameter changes. This is what I did in https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md.

At some point we should replace StableLM as the default. A related idea is #142

@patrickhwood
Copy link
Contributor Author

You may want to revisit #286. The micro_batch_size setting has a significant impact on memory usage.

@Andrei-Aksionov
Copy link
Collaborator

Thanks @rasbt for the table, I can imagine it took a while to compile.
But there are some things that I understand and that I don't.

  1. Quantization for Falcon-7B brings less memory usage improvement than Llama-7B: that makes total sense as Falcon uses grouped queries which means that the combined QKV matrix is smaller than in Llama.
  2. Quantization takes more time to train: that makes sense as during forward pass quantized weights are dequantized to compute dtype (from nf4 to bfloat16) and that takes some time (if I understood the paper/code correctly).
  3. But what I don't understand is that the benefit of using quantization is less pronounced than I anticipated, especially considering that in Huggingface's notebook GPT-neo-x-20B is fine-tuned (on a T4 with 15 GB of VRAM).
    The biggest difference that I see is that in the notebook they used PagedAdamW8bit and you used PagedAdamW.
    Since "Adam has 2 additional optimizer parameters (a mean and a variance) for each model parameter" (which I've learnt from your article btw 👍 ) maybe you want to do some quick tests one more time? 😄

@rasbt
Copy link
Collaborator

rasbt commented Aug 17, 2023

@Andrei-Aksionov

Thanks for the feedback. And absolutely, now that we have these baselines, PagedAdamW8bit is something I want to try.

Computational performance is one thing, the other important one is modeling performance. So I am also looking into that as well.

@rasbt
Copy link
Collaborator

rasbt commented Aug 17, 2023

@patrickhwood

You may want to revisit #286. The micro_batch_size setting has a significant impact on memory usage.

Yes, I agree, it makes a big difference. But as you can see in the tables above, I kept it consistent between each suite of runs. The reason I started with 4 for the StableLM 3B models is that it's the default setting.

Then, I increased it to 1 for the 7B models, because otherwise it would not work with most settings.

@patrickhwood
Copy link
Contributor Author

@rasbt

You may want to revisit #286. The micro_batch_size setting has a significant impact on memory usage.

Yes, I agree, it makes a big difference. But as you can see in the tables above, I kept it consistent between each suite of runs. The reason I started with 4 for the StableLM 3B models is that it's the default setting.

Then, I increased it to 1 for the 7B models, because otherwise it would not work with most settings.

I've also tried adjusting it from 1 to 4 with different models and quantization to find the best fit for each combination. Setting it on the command line would simplify automated runs on a matrix of models and settings.

@rasbt
Copy link
Collaborator

rasbt commented Aug 18, 2023

I think there was a bug in the lit-gpt/lora.py file. Results (loss) look much better now. Will share new results later today.

@rasbt
Copy link
Collaborator

rasbt commented Aug 18, 2023

Ok, here is a new batch of results. The 8bit optimizer didn't make a huge difference but things works well overall regarding inference performance etc. I you can get bigger perf differences with bigger models. I think this should be good to merge imho unless there are any issues you find or suggestions you have.

Model Settings Iterations Dataset Micro batch size Paged optimizer Training Memory Training Time Loss Inference perf
StableLM 3B (default) Default (bfloat16-mixed) 5,000 Alpaca 4 No 33.50 GB 591.78s 0.9207 Time for inference: 1.07 sec total, 68.35 tokens/sec Memory used: 7.61 GB
StableLM 3B (default) --precision bf16-true 5,000 Alpaca 4 No 15.86 GB 592.14s 0.9180 Time for inference: 1.33 sec total, 71.48 tokens/sec Memory used: 7.61 GB
StableLM 3B (default) --quantize "bnb.nf4" 5,000 Alpaca 4 8Bit 22.34 GB 944.93s 0.9417 Time for inference: 0.78 sec total, 44.88 tokens/sec Memory used: 3.25 GB
StableLM 3B (default) --quantize "bnb.nf4-dq" 5,000 Alpaca 4 8Bit 22.18 GB 962.23s 0.9383 Time for inference: 0.70 sec total, 34.11 tokens/sec Memory used: 3.08 GB
StableLM 3B (default) --precision "bf16-true" \ --quantize "bnb.nf4" 5,000 Alpaca 4 8Bit 14.81 GB 802.02s 0.9408 Time for inference: 1.09 sec total, 50.46 tokens/sec Memory used: 3.25 GB
StableLM 3B (default) --precision "bf16-true" \ --quantize "bnb.nf4-dq" 5,000 Alpaca 4 8Bit 14.65 GB 802.94s 0.9384 Time for inference: 2.01 sec total, 49.29 tokens/sec Memory used: 3.08 GB
Model Settings Iterations Dataset Micro batch size Paged optimizer Training Memory Training Time Loss Inference perf
Llama 2 7B Default (bfloat16-mixed) 5,000 Alpaca 1 No N/A N/A N/A
Llama 2 7B --precision bf16-true 5,000 Alpaca 1 No 20.60 GB 876.30s 0.8696 Time for inference: 0.37 sec total, 16.28 tokens/sec Memory used: 13.82 GB
Llama 2 7B --quantize "bnb.nf4" 5,000 Alpaca 1 8Bit 19.62 GB 1320.63s 1.0178 Time for inference: 0.64 sec total, 10.89 tokens/sec Memory used: 4.66 GB
Llama 2 7B --quantize "bnb.nf4-dq" 5,000 Alpaca 1 8Bit 19.32 GB 1359.10s 1.0132 Time for inference: 2.62 sec total, 2.29 tokens/sec Memory used: 4.34 GB
Llama 2 7B --precision "bf16-true" \ --quantize "bnb.nf4" 5,000 Alpaca 1 8Bit 13.44 GB 1089.79s 1.0130 Time for inference: 1.33 sec total, 23.32 tokens/sec Memory used: 4.66 GB
Llama 2 7B --precision "bf16-true" \ --quantize "bnb.nf4-dq" 5,000 Alpaca 1 8Bit 13.15 GB 1135.86s 1.0124 Time for inference: 0.58 sec total, 10.34 tokens/sec Memory used: 4.34 GB

@carmocca
Copy link
Member

Merging 🚀

@carmocca carmocca merged commit 064fd52 into Lightning-AI:main Aug 21, 2023
5 checks passed
@Andrei-Aksionov
Copy link
Collaborator

It was surprisingly difficult to push this PR over the line.

Great job everyone!

@rasbt
Copy link
Collaborator

rasbt commented Aug 21, 2023

Awesome, exciting that this is finally merged! Thanks for getting this started and all your help @patrickhwood @carmocca and @Andrei-Aksionov !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
8 participants