Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support QLoRA 4-bit finetuning with bitsandbytes #275

Merged
merged 39 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
abe1e29
Fix LoRA Linear calls
carmocca Jul 13, 2023
a476b9b
Merge branch 'main' of https://github.com/Lightning-AI/lit-gpt
efii Jul 13, 2023
c51d5cb
Merge branch 'main' of https://github.com/Lightning-AI/lit-gpt
efii Jul 14, 2023
53166bb
Merge branch 'main' of https://github.com/Lightning-AI/lit-gpt
efii Jul 18, 2023
5f5b627
Enable fine tuning with QLORA 4-bit floating point and bitsandbytes i…
efii Jul 18, 2023
2cd401a
Made quantize an optional parameter to main to pass some tests.
efii Jul 18, 2023
b2d47f0
Merge github.com:Lightning-AI/lit-gpt
efii Jul 20, 2023
e3ff068
Merge branch 'main' into qlora
efii Jul 20, 2023
eb917c6
Merge branch 'main' into qlora
rasbt Aug 9, 2023
efc13c0
move all hparam prints to main for consistency
rasbt Aug 9, 2023
b75b428
fix unit tests
rasbt Aug 9, 2023
b05f3a2
Merge branch 'main' into qlora
rasbt Aug 9, 2023
5b46d4c
Update lora.py
rasbt Aug 9, 2023
b26b734
switch to empty_init=True
rasbt Aug 9, 2023
df8d7ca
add note about qlora-style quantization
rasbt Aug 9, 2023
356a53e
add qlora test
rasbt Aug 9, 2023
ae7141f
merge
rasbt Aug 9, 2023
9d9d4a5
fix merge conflict
rasbt Aug 9, 2023
71af29d
revert to empty_init=False and add adapter tests
rasbt Aug 9, 2023
96e0310
run quantize test only on gpu
rasbt Aug 9, 2023
56f49f8
use paged optimizer and update tests
rasbt Aug 10, 2023
364dcbf
incorporate carlos suggestions
rasbt Aug 10, 2023
c9f55e9
Merge branch 'main' into qlora
rasbt Aug 10, 2023
d7695a2
revert adapter
rasbt Aug 10, 2023
651250c
Merge branch 'qlora' of https://github.com/patrickhwood/lit-gpt into …
rasbt Aug 10, 2023
faf4c18
fix merge screwup
rasbt Aug 10, 2023
626d2e7
Merge branch 'main' into qlora
carmocca Aug 11, 2023
bcd351b
Minor test change to skip
carmocca Aug 11, 2023
c3ca58f
Reference qlora
carmocca Aug 11, 2023
49b28a8
Update finetune/lora.py
carmocca Aug 11, 2023
e6ede61
Merge branch 'main' into qlora
rasbt Aug 14, 2023
86a6255
Pin latest bnb
carmocca Aug 15, 2023
661419e
Merge branch 'main' into qlora
carmocca Aug 15, 2023
a672a65
Formatting
carmocca Aug 15, 2023
f98af7f
Merge branch 'main' into qlora
carmocca Aug 15, 2023
2bc66dc
fix inference script
rasbt Aug 18, 2023
78fd897
update docs
rasbt Aug 18, 2023
0a5ac5d
fmt
carmocca Aug 21, 2023
288b4ee
Merge branch 'main' into qlora
carmocca Aug 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ More details about each finetuning method and how you can apply it to your own d
These technical tutorials illustrate how to run the finetuning code.

- [Finetune with Adapters](tutorials/finetune_adapter.md)
- [Finetune with LoRA](tutorials/finetune_lora.md)
- [Finetune with LoRA or QLoRA](tutorials/finetune_lora.md)

 

Expand Down
22 changes: 17 additions & 5 deletions finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Literal, Optional, Tuple

import lightning as L
import torch
Expand All @@ -23,6 +23,7 @@
get_default_supported_precision,
lazy_load,
num_parameters,
quantization,
step_csv_logger,
)
from scripts.prepare_alpaca import generate_prompt
Expand Down Expand Up @@ -63,11 +64,17 @@ def setup(
out_dir: Path = Path("out/lora/alpaca"),
precision: Optional[str] = None,
tpu: bool = False,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq"]] = None,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
):
precision = precision or get_default_supported_precision(training=True, tpu=tpu)

fabric_devices = devices
if fabric_devices > 1:
if quantize:
raise NotImplementedError(
"Quantization is currently not supported for multi-GPU training. "
"Please set devices=1 when using the --quantization flag."
)
if tpu:
# For multi-host TPU training, the device count for Fabric is limited to the count on a single host.
fabric_devices = "auto"
Expand All @@ -85,10 +92,10 @@ def setup(
logger = step_csv_logger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval)
fabric = L.Fabric(devices=fabric_devices, strategy=strategy, precision=precision, loggers=logger)
fabric.print(hparams)
fabric.launch(main, data_dir, checkpoint_dir, out_dir)
fabric.launch(main, data_dir, checkpoint_dir, out_dir, quantize)


def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path):
def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path, quantize: Optional[str] = None):
check_valid_checkpoint_dir(checkpoint_dir)

speed_monitor = SpeedMonitor(fabric, window_size=50, time_unit="seconds")
Expand Down Expand Up @@ -117,7 +124,7 @@ def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path):
)
checkpoint_path = checkpoint_dir / "lit_model.pth"
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}")
with fabric.init_module(empty_init=False):
with fabric.init_module(empty_init=False), quantization(quantize):
model = GPT(config)
with lazy_load(checkpoint_path) as checkpoint:
# strict=False because missing keys due to LoRA weights not contained in state dict
Expand All @@ -129,7 +136,12 @@ def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path):
fabric.print(f"Number of non trainable parameters: {num_parameters(model, requires_grad=False):,}")
trainable_params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay)
if quantize and quantize.startswith("bnb."):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
import bitsandbytes as bnb

optimizer = bnb.optim.PagedAdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay)
else:
optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay)
model, optimizer = fabric.setup(model, optimizer)

fabric.seed_everything(1337 + fabric.global_rank)
Expand Down
65 changes: 65 additions & 0 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,71 @@ def test_lora_qkv_linear_weights_merged_status(rank, enable_lora, expected_merge
assert layer.merged == expected_merged


@pytest.mark.skipif(not torch.cuda.is_available(), reason="8bit requires CUDA")
# platform dependent cuda issue: libbitsandbytes_cpu.so: undefined symbol: cquantize_blockwise_fp16_nf4
@pytest.mark.xfail(raises=AttributeError, strict=False)
def test_lora_merge_with_quantize():
from quantize.bnb import _BITSANDBYTES_AVAILABLE

if not _BITSANDBYTES_AVAILABLE:
pytest.skip("BNB not available")

from lit_gpt.lora import GPT, Config, mark_only_lora_as_trainable, merge_lora_weights
from lit_gpt.utils import quantization
from quantize.bnb import bnb

config = Config(
n_layer=1,
n_head=2,
n_embd=8,
block_size=8,
vocab_size=8,
r=8,
alpha=8,
dropout=0.1,
to_query=True,
to_value=True,
to_projection=True,
)
fabric = Fabric(devices=1, precision="bf16-mixed")
with fabric.init_module(empty_init=False), quantization("bnb.nf4"):
model = GPT(config)
model.apply(model._init_weights)

optimizer = bnb.optim.PagedAdamW(model.parameters(), lr=1.0)
model, optimizer = fabric.setup(model, optimizer)

model.train()

initial_weight = model.transformer.h[0].attn.proj.weight.clone()
assert torch.equal(model.transformer.h[0].attn.proj.weight, initial_weight)

# perform an update to the LoRA weights
mark_only_lora_as_trainable(model)

y = model(torch.randint(0, 8, size=(2, 4), dtype=torch.int64, device=fabric.device))
y.sum().backward()
optimizer.step()
optimizer.zero_grad()
# the weight remains unchanged (only lora A and B change)
assert torch.equal(model.transformer.h[0].attn.proj.weight, initial_weight)

# calling merge() multiple times in a row should not merge multiple times
merge_lora_weights(model)
assert model.transformer.h[0].attn.attn.merged
weight_after = model.transformer.h[0].attn.proj.weight.clone()
merge_lora_weights(model)
merge_lora_weights(model)
assert torch.equal(model.transformer.h[0].attn.proj.weight, weight_after)

# check that `W_after = W_initial + (A x B)`
a = model.transformer.h[0].attn.proj.lora_A
b = model.transformer.h[0].attn.proj.lora_B
scaling = model.transformer.h[0].attn.proj.scaling
delta_w = (b @ a) * scaling
torch.testing.assert_close(weight_after, initial_weight + delta_w)


@pytest.mark.parametrize(
("mode", "expected"),
(
Expand Down
38 changes: 36 additions & 2 deletions tutorials/finetune_lora.md
carmocca marked this conversation as resolved.
Show resolved Hide resolved
carmocca marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Finetuning with LoRA
# Finetuning with LoRA / QLoRA

[Low-rank adaption (LoRA)](https://arxiv.org/abs/2106.09685) is a technique to approximate the update to the linear layers in a LLM with a low-rank matrix factorization. This significantly reduces the number of trainable parameters and speeds up training with little impact on the final performance of the model.
We demonstrate this method by instruction-finetuning Lit-GPT StableLM 3B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3090 (24GB) GPU**.
Expand Down Expand Up @@ -38,6 +38,40 @@ This script will save checkpoints periodically to the folder `out/`.
> According to [QLoRA](https://arxiv.org/abs/2305.14314) paper (section 4): "LoRA on all linear transformer block layers are required to match full finetuning performance".
> By default LoRA is applied only to the `query` and `value` matrices. In order to apply LoRA to other weight matrices - change the variables in `finetune/lora.py` accordingly.

Optionally, finetuning using 4-bit quantization (as in QLoRA) can be enabled via the `--quantize` flag, for example using the 4-bit NormalFloat data type:

```bash
python finetune/lora.py --quantize "bnb.nf4"
```

and optionally with double-quantization:
carmocca marked this conversation as resolved.
Show resolved Hide resolved

```bash
python finetune/lora.py --quantize "bnb.nf4-dq"
```

The table below lists a comparison with different settings on a StableLM 3B model finetuned with LoRA on Alpaca for 5,000 iterations using a microbatch size of 4:

| Settings | Training Memory | Training Time | Loss | Inference Memory |
|---------------------------------------------------------|------------------|----------------|-----------|------------------|
| Default (bfloat16-mixed) | 33.50 GB | 591.78s | 0.9207 | 7.61 GB |
| --precision "bf16-true" | 15.86 GB | 592.14s | 0.9180 | 7.61 GB |
| --quantize "bnb.nf4" | 22.34 GB | 944.93s | 0.9417 | 3.25 GB |
| --quantize "bnb.nf4-dq" | 22.18 GB | 962.23s | 0.9383 | 3.08 GB |
| --precision "bf16-true" --quantize "bnb.nf4" | 14.81 GB | 802.02s | 0.9408 | 3.25 GB |
| --precision "bf16-true" --quantize "bnb.nf4-dq" | 14.65 GB | 802.94s | 0.9384 | 3.08 GB |

The advantages of QLoRA-style quantization are more pronounced in larger models, such as Llama 2 7B. The table below summarizes the results for Llama 2 7B on Alpaca for 5,000 iterations using a microbatch size of 4:

| Settings | Training Memory | Training Time | Loss | Inference Memory |
|-----------------------------------------------------|------------------|----------------|--------|------------------|
| Default (bfloat16-mixed) | OutOfMemoryError | N/A | N/A | N/A |
| --precision "bf16-true" | 20.60 GB | 876.30s | 0.8696 | 13.82 GB |
| --quantize "bnb.nf4" | 19.62 GB | 1320.63s | 1.0178 | 4.66 GB |
| --quantize "bnb.nf4-dq" | 19.32 GB | 1359.10s | 1.0132 | 4.34 GB |
| --precision "bf16-true" --quantize "bnb.nf4" | 13.44 GB | 1089.79s | 1.0130 | 4.66 GB |
| --precision "bf16-true" --quantize "bnb.nf4-dq" | 13.15 GB | 1135.86s | 1.0124 | 4.34 GB |

## Test the model

You can test the finetuned model with your own instructions by running:
Expand All @@ -52,7 +86,7 @@ Output:
I would recommend the movie The Martian (2015). It is a sci-fi movie starring Matt Damon that follows the story of...
```

If your GPU supports `bfloat16`, you can additionally pass `--precision bf16-true` to bring the memory consumption down to ~11 GB for StableLM-3B.
If your GPU supports `bfloat16`, you can additionally pass `--precision "bf16-true"` to bring the memory consumption down to ~7.6 GB for StableLM-3B (versus ~15.2 GB for `--precision "32-full"`). In addition, you may use quantization methods, for example `--precision "bf16-true" --quantize "bnb.nf4"` brings the memory consumption further down to ~4.4 GB for StableLM-3B.

## Tune on your dataset

Expand Down
2 changes: 1 addition & 1 deletion tutorials/quantize.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ This document provides different strategies for quantizing the various models av
**All the examples below were run on an A100 40GB GPU.**

> [!NOTE]\:
> Quantization is only supported with inference (generate and chat scripts).
> Quantization also supports finetuning via [QLoRA](finetune_lora.md)


## Baseline
Expand Down