Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support QLoRA 4-bit finetuning with bitsandbytes #275

Merged
merged 39 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
abe1e29
Fix LoRA Linear calls
carmocca Jul 13, 2023
a476b9b
Merge branch 'main' of https://github.com/Lightning-AI/lit-gpt
efii Jul 13, 2023
c51d5cb
Merge branch 'main' of https://github.com/Lightning-AI/lit-gpt
efii Jul 14, 2023
53166bb
Merge branch 'main' of https://github.com/Lightning-AI/lit-gpt
efii Jul 18, 2023
5f5b627
Enable fine tuning with QLORA 4-bit floating point and bitsandbytes i…
efii Jul 18, 2023
2cd401a
Made quantize an optional parameter to main to pass some tests.
efii Jul 18, 2023
b2d47f0
Merge github.com:Lightning-AI/lit-gpt
efii Jul 20, 2023
e3ff068
Merge branch 'main' into qlora
efii Jul 20, 2023
eb917c6
Merge branch 'main' into qlora
rasbt Aug 9, 2023
efc13c0
move all hparam prints to main for consistency
rasbt Aug 9, 2023
b75b428
fix unit tests
rasbt Aug 9, 2023
b05f3a2
Merge branch 'main' into qlora
rasbt Aug 9, 2023
5b46d4c
Update lora.py
rasbt Aug 9, 2023
b26b734
switch to empty_init=True
rasbt Aug 9, 2023
df8d7ca
add note about qlora-style quantization
rasbt Aug 9, 2023
356a53e
add qlora test
rasbt Aug 9, 2023
ae7141f
merge
rasbt Aug 9, 2023
9d9d4a5
fix merge conflict
rasbt Aug 9, 2023
71af29d
revert to empty_init=False and add adapter tests
rasbt Aug 9, 2023
96e0310
run quantize test only on gpu
rasbt Aug 9, 2023
56f49f8
use paged optimizer and update tests
rasbt Aug 10, 2023
364dcbf
incorporate carlos suggestions
rasbt Aug 10, 2023
c9f55e9
Merge branch 'main' into qlora
rasbt Aug 10, 2023
d7695a2
revert adapter
rasbt Aug 10, 2023
651250c
Merge branch 'qlora' of https://github.com/patrickhwood/lit-gpt into …
rasbt Aug 10, 2023
faf4c18
fix merge screwup
rasbt Aug 10, 2023
626d2e7
Merge branch 'main' into qlora
carmocca Aug 11, 2023
bcd351b
Minor test change to skip
carmocca Aug 11, 2023
c3ca58f
Reference qlora
carmocca Aug 11, 2023
49b28a8
Update finetune/lora.py
carmocca Aug 11, 2023
e6ede61
Merge branch 'main' into qlora
rasbt Aug 14, 2023
86a6255
Pin latest bnb
carmocca Aug 15, 2023
661419e
Merge branch 'main' into qlora
carmocca Aug 15, 2023
a672a65
Formatting
carmocca Aug 15, 2023
f98af7f
Merge branch 'main' into qlora
carmocca Aug 15, 2023
2bc66dc
fix inference script
rasbt Aug 18, 2023
78fd897
update docs
rasbt Aug 18, 2023
0a5ac5d
fmt
carmocca Aug 21, 2023
288b4ee
Merge branch 'main' into qlora
carmocca Aug 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ More details about each finetuning method and how you can apply it to your own d
These technical tutorials illustrate how to run the finetuning code.

- [Finetune with Adapters](tutorials/finetune_adapter.md)
- [Finetune with LoRA](tutorials/finetune_lora.md)
- [Finetune with LoRA or QLoRA](tutorials/finetune_lora.md)

### Understanding Finetuning -- Conceptual Tutorials

Expand Down
2 changes: 1 addition & 1 deletion finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,4 +299,4 @@ def save_adapter_v2_checkpoint(fabric, model, file_path: Path):

from jsonargparse import CLI

CLI(setup)
CLI(setup)
29 changes: 22 additions & 7 deletions finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import time
from pathlib import Path
from typing import Optional, List, Dict, Tuple
from typing import Optional, List, Dict, Tuple, Literal

import lightning as L
import torch
Expand All @@ -15,7 +15,7 @@
from generate.base import generate
from lit_gpt.lora import mark_only_lora_as_trainable, lora_filter, GPT, Config, Block
from lit_gpt.tokenizer import Tokenizer
from lit_gpt.utils import lazy_load, num_parameters, check_valid_checkpoint_dir, step_csv_logger, chunked_cross_entropy
from lit_gpt.utils import lazy_load, num_parameters, check_valid_checkpoint_dir, step_csv_logger, chunked_cross_entropy, quantization
from lit_gpt.speed_monitor import SpeedMonitorFabric as SpeedMonitor, measure_flops, estimate_flops
from scripts.prepare_alpaca import generate_prompt

Expand All @@ -31,7 +31,7 @@
# Hyperparameters
learning_rate = 3e-4
batch_size = 128
micro_batch_size = 4
micro_batch_size = 1
carmocca marked this conversation as resolved.
Show resolved Hide resolved
gradient_accumulation_iters = batch_size // micro_batch_size
assert gradient_accumulation_iters > 0
max_iters = 50000 # train dataset size
Expand All @@ -56,11 +56,17 @@ def setup(
out_dir: Path = Path("out/lora/alpaca"),
precision: Optional[str] = None,
tpu: bool = False,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq"]] = None,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
):
if precision is None:
precision = "32-true" if tpu else "bf16-mixed"
fabric_devices = devices
if fabric_devices > 1:
if quantize:
raise NotImplementedError(
"Quantization is currently not supported for multi-GPU training. "
"Please set devices=1 when using the --quantization flag."
)
if tpu:
# For multi-host TPU training, the device count for Fabric is limited to the count on a single host.
fabric_devices = "auto"
Expand All @@ -79,10 +85,10 @@ def setup(
logger = step_csv_logger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval)
fabric = L.Fabric(devices=fabric_devices, strategy=strategy, precision=precision, loggers=logger)
fabric.print(hparams)
fabric.launch(main, data_dir, checkpoint_dir, out_dir)
fabric.launch(main, data_dir, checkpoint_dir, out_dir, quantize)


def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path):
def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path, quantize: Optional[str] = None):
check_valid_checkpoint_dir(checkpoint_dir)

speed_monitor = SpeedMonitor(fabric, window_size=50, time_unit="seconds")
Expand Down Expand Up @@ -111,7 +117,7 @@ def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path):
)
checkpoint_path = checkpoint_dir / "lit_model.pth"
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}")
with fabric.init_module(empty_init=False):
with fabric.init_module(empty_init=False), quantization(quantize):
model = GPT(config)
model.apply(model._init_weights) # for the LoRA weights
with lazy_load(checkpoint_path) as checkpoint:
Expand All @@ -124,7 +130,11 @@ def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path):
fabric.print(f"Number of non trainable parameters: {num_parameters(model, requires_grad=False):,}")
trainable_params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay)
if quantize and quantize.startswith("bnb."):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
import bitsandbytes as bnb
optimizer = bnb.optim.PagedAdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay)
else:
optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay)
model, optimizer = fabric.setup(model, optimizer)

fabric.seed_everything(1337 + fabric.global_rank)
Expand Down Expand Up @@ -226,6 +236,10 @@ def train(
checkpoint_path = out_dir / f"iter-{iter_num:06d}-ckpt.pth"
save_lora_checkpoint(fabric, model, checkpoint_path)

if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)
carmocca marked this conversation as resolved.
Show resolved Hide resolved


carmocca marked this conversation as resolved.
Show resolved Hide resolved

@torch.no_grad()
def validate(
Expand Down Expand Up @@ -315,3 +329,4 @@ def save_lora_checkpoint(fabric, model, file_path: Path):
from jsonargparse import CLI

CLI(setup)

58 changes: 58 additions & 0 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,61 @@ def test_lora_qkv_linear_weights_merged_status(rank, enable_lora, expected_merge
assert not layer.merged
layer.merge()
assert layer.merged == expected_merged


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Quantization not supported on CPU. Skipping Test.")
def test_lora_merge_with_quantize():
from lit_gpt.lora import mark_only_lora_as_trainable, merge_lora_weights, GPT, Config
from lit_gpt.utils import quantization
import bitsandbytes as bnb

config = Config(
n_layer=1,
n_head=2,
n_embd=8,
block_size=8,
vocab_size=8,
r=8,
alpha=8,
dropout=0.1,
to_query=True,
to_value=True,
to_projection=True,
)
fabric = Fabric(devices=1, precision="bf16-mixed")
with fabric.init_module(empty_init=False), quantization("bnb.nf4"):
model = GPT(config)
model.apply(model._init_weights) # for the LoRA weights

optimizer = bnb.optim.PagedAdamW(model.parameters(), lr=1.0)
model, optimizer = fabric.setup(model, optimizer)

model.train()

initial_weight = model.transformer.h[0].attn.proj.weight.clone()
assert torch.equal(model.transformer.h[0].attn.proj.weight, initial_weight)

# perform an update to the LoRA weights
mark_only_lora_as_trainable(model)

y = model(torch.randint(0, 8, size=(2, 4), dtype=torch.int64, device=fabric.device))
y.sum().backward()
optimizer.step()
optimizer.zero_grad()
# the weight remains unchanged (only lora A and B change)
assert torch.equal(model.transformer.h[0].attn.proj.weight, initial_weight)

# calling merge() multiple times in a row should not merge multiple times
merge_lora_weights(model)
assert model.transformer.h[0].attn.attn.merged
weight_after = model.transformer.h[0].attn.proj.weight.clone()
merge_lora_weights(model)
merge_lora_weights(model)
assert torch.equal(model.transformer.h[0].attn.proj.weight, weight_after)

# check that `W_after = W_initial + (A x B)`
a = model.transformer.h[0].attn.proj.lora_A
b = model.transformer.h[0].attn.proj.lora_B
scaling = model.transformer.h[0].attn.proj.scaling
delta_w = (b @ a) * scaling
torch.testing.assert_close(weight_after, initial_weight + delta_w)
16 changes: 14 additions & 2 deletions tutorials/finetune_lora.md
carmocca marked this conversation as resolved.
Show resolved Hide resolved
carmocca marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Finetuning with LoRA
# Finetuning with LoRA / QLoRA

[Low-rank adaption (LoRA)](https://arxiv.org/abs/2106.09685) is a technique to approximate the update to the linear layers in a LLM with a low-rank matrix factorization. This significantly reduces the number of trainable parameters and speeds up training with little impact on the final performance of the model.
We demonstrate this method by instruction-finetuning Lit-GPT StableLM 3B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3090 (24GB) GPU**.
Expand Down Expand Up @@ -38,6 +38,18 @@ This script will save checkpoints periodically to the folder `out/`.
> According to [QLoRA](https://arxiv.org/abs/2305.14314) paper (section 4): "LoRA on all linear transformer block layers are required to match full finetuning performance".
> By default LoRA is applied only to the `query` and `value` matrices. In order to apply LoRA to other weight matrices - change the variables in `finetune/lora.py` accordingly.

Optionally, finetuning using 4-bit quantization (as in QLoRA) can be enabled via the `--quantize` flag, for example using the 4-bit NormalFloat data type:

```bash
python finetune/lora.py --quantize "bnb.nf4"
```

and optionally with double-quantization:
carmocca marked this conversation as resolved.
Show resolved Hide resolved

```bash
python finetune/lora.py --quantize "bnb.nf4-dq"
```

## Test the model

You can test the finetuned model with your own instructions by running:
Expand All @@ -52,7 +64,7 @@ Output:
I would recommend the movie The Martian (2015). It is a sci-fi movie starring Matt Damon that follows the story of...
```

If your GPU supports `bfloat16`, you can additionally pass `--precision bf16-true` to bring the memory consumption down to ~11 GB for StableLM-3B.
If your GPU supports `bfloat16`, you can additionally pass `--precision "bf16-true"` to bring the memory consumption down to ~7.6 GB for StableLM-3B (versus ~15.2 GB for `--precision "32-full"`). In addition, you may use quantization methods, for example `--precision "bf16-true" --quantize "bnb.nf4"` brings the memory consumption further down to ~4.4 GB for StableLM-3B.

## Tune on your dataset

Expand Down