Skip to content

Commit

Permalink
Support QLoRA 4-bit finetuning with bitsandbytes (#275)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Pat Wood <Pat.Wood@efi.com>
Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
  • Loading branch information
4 people committed Aug 21, 2023
1 parent dc69b47 commit 064fd52
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ More details about each finetuning method and how you can apply it to your own d
These technical tutorials illustrate how to run the finetuning code.

- [Finetune with Adapters](tutorials/finetune_adapter.md)
- [Finetune with LoRA](tutorials/finetune_lora.md)
- [Finetune with LoRA or QLoRA](tutorials/finetune_lora.md)

&nbsp;

Expand Down
22 changes: 17 additions & 5 deletions finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Literal, Optional, Tuple

import lightning as L
import torch
Expand All @@ -23,6 +23,7 @@
get_default_supported_precision,
lazy_load,
num_parameters,
quantization,
step_csv_logger,
)
from scripts.prepare_alpaca import generate_prompt
Expand Down Expand Up @@ -63,11 +64,17 @@ def setup(
out_dir: Path = Path("out/lora/alpaca"),
precision: Optional[str] = None,
tpu: bool = False,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq"]] = None,
):
precision = precision or get_default_supported_precision(training=True, tpu=tpu)

fabric_devices = devices
if fabric_devices > 1:
if quantize:
raise NotImplementedError(
"Quantization is currently not supported for multi-GPU training. "
"Please set devices=1 when using the --quantization flag."
)
if tpu:
# For multi-host TPU training, the device count for Fabric is limited to the count on a single host.
fabric_devices = "auto"
Expand All @@ -85,10 +92,10 @@ def setup(
logger = step_csv_logger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval)
fabric = L.Fabric(devices=fabric_devices, strategy=strategy, precision=precision, loggers=logger)
fabric.print(hparams)
fabric.launch(main, data_dir, checkpoint_dir, out_dir)
fabric.launch(main, data_dir, checkpoint_dir, out_dir, quantize)


def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path):
def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path, quantize: Optional[str] = None):
check_valid_checkpoint_dir(checkpoint_dir)

speed_monitor = SpeedMonitor(fabric, window_size=50, time_unit="seconds")
Expand Down Expand Up @@ -117,7 +124,7 @@ def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path):
)
checkpoint_path = checkpoint_dir / "lit_model.pth"
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}")
with fabric.init_module(empty_init=False):
with fabric.init_module(empty_init=False), quantization(quantize):
model = GPT(config)
with lazy_load(checkpoint_path) as checkpoint:
# strict=False because missing keys due to LoRA weights not contained in state dict
Expand All @@ -129,7 +136,12 @@ def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path):
fabric.print(f"Number of non trainable parameters: {num_parameters(model, requires_grad=False):,}")
trainable_params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay)
if quantize and quantize.startswith("bnb."):
import bitsandbytes as bnb

optimizer = bnb.optim.PagedAdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay)
else:
optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay)
model, optimizer = fabric.setup(model, optimizer)

fabric.seed_everything(1337 + fabric.global_rank)
Expand Down
65 changes: 65 additions & 0 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,71 @@ def test_lora_qkv_linear_weights_merged_status(rank, enable_lora, expected_merge
assert layer.merged == expected_merged


@pytest.mark.skipif(not torch.cuda.is_available(), reason="8bit requires CUDA")
# platform dependent cuda issue: libbitsandbytes_cpu.so: undefined symbol: cquantize_blockwise_fp16_nf4
@pytest.mark.xfail(raises=AttributeError, strict=False)
def test_lora_merge_with_quantize():
from quantize.bnb import _BITSANDBYTES_AVAILABLE

if not _BITSANDBYTES_AVAILABLE:
pytest.skip("BNB not available")

from lit_gpt.lora import GPT, Config, mark_only_lora_as_trainable, merge_lora_weights
from lit_gpt.utils import quantization
from quantize.bnb import bnb

config = Config(
n_layer=1,
n_head=2,
n_embd=8,
block_size=8,
vocab_size=8,
r=8,
alpha=8,
dropout=0.1,
to_query=True,
to_value=True,
to_projection=True,
)
fabric = Fabric(devices=1, precision="bf16-mixed")
with fabric.init_module(empty_init=False), quantization("bnb.nf4"):
model = GPT(config)
model.apply(model._init_weights)

optimizer = bnb.optim.PagedAdamW(model.parameters(), lr=1.0)
model, optimizer = fabric.setup(model, optimizer)

model.train()

initial_weight = model.transformer.h[0].attn.proj.weight.clone()
assert torch.equal(model.transformer.h[0].attn.proj.weight, initial_weight)

# perform an update to the LoRA weights
mark_only_lora_as_trainable(model)

y = model(torch.randint(0, 8, size=(2, 4), dtype=torch.int64, device=fabric.device))
y.sum().backward()
optimizer.step()
optimizer.zero_grad()
# the weight remains unchanged (only lora A and B change)
assert torch.equal(model.transformer.h[0].attn.proj.weight, initial_weight)

# calling merge() multiple times in a row should not merge multiple times
merge_lora_weights(model)
assert model.transformer.h[0].attn.attn.merged
weight_after = model.transformer.h[0].attn.proj.weight.clone()
merge_lora_weights(model)
merge_lora_weights(model)
assert torch.equal(model.transformer.h[0].attn.proj.weight, weight_after)

# check that `W_after = W_initial + (A x B)`
a = model.transformer.h[0].attn.proj.lora_A
b = model.transformer.h[0].attn.proj.lora_B
scaling = model.transformer.h[0].attn.proj.scaling
delta_w = (b @ a) * scaling
torch.testing.assert_close(weight_after, initial_weight + delta_w)


@pytest.mark.parametrize(
("mode", "expected"),
(
Expand Down
38 changes: 36 additions & 2 deletions tutorials/finetune_lora.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Finetuning with LoRA
# Finetuning with LoRA / QLoRA

[Low-rank adaption (LoRA)](https://arxiv.org/abs/2106.09685) is a technique to approximate the update to the linear layers in a LLM with a low-rank matrix factorization. This significantly reduces the number of trainable parameters and speeds up training with little impact on the final performance of the model.
We demonstrate this method by instruction-finetuning Lit-GPT StableLM 3B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3090 (24GB) GPU**.
Expand Down Expand Up @@ -38,6 +38,40 @@ This script will save checkpoints periodically to the folder `out/`.
> According to [QLoRA](https://arxiv.org/abs/2305.14314) paper (section 4): "LoRA on all linear transformer block layers are required to match full finetuning performance".
> By default LoRA is applied only to the `query` and `value` matrices. In order to apply LoRA to other weight matrices - change the variables in `finetune/lora.py` accordingly.
Optionally, finetuning using 4-bit quantization (as in QLoRA) can be enabled via the `--quantize` flag, for example using the 4-bit NormalFloat data type:

```bash
python finetune/lora.py --quantize "bnb.nf4"
```

and optionally with double-quantization:

```bash
python finetune/lora.py --quantize "bnb.nf4-dq"
```

The table below lists a comparison with different settings on a StableLM 3B model finetuned with LoRA on Alpaca for 5,000 iterations using a microbatch size of 4:

| Settings | Training Memory | Training Time | Loss | Inference Memory |
|---------------------------------------------------------|------------------|----------------|-----------|------------------|
| Default (bfloat16-mixed) | 33.50 GB | 591.78s | 0.9207 | 7.61 GB |
| --precision "bf16-true" | 15.86 GB | 592.14s | 0.9180 | 7.61 GB |
| --quantize "bnb.nf4" | 22.34 GB | 944.93s | 0.9417 | 3.25 GB |
| --quantize "bnb.nf4-dq" | 22.18 GB | 962.23s | 0.9383 | 3.08 GB |
| --precision "bf16-true" --quantize "bnb.nf4" | 14.81 GB | 802.02s | 0.9408 | 3.25 GB |
| --precision "bf16-true" --quantize "bnb.nf4-dq" | 14.65 GB | 802.94s | 0.9384 | 3.08 GB |

The advantages of QLoRA-style quantization are more pronounced in larger models, such as Llama 2 7B. The table below summarizes the results for Llama 2 7B on Alpaca for 5,000 iterations using a microbatch size of 4:

| Settings | Training Memory | Training Time | Loss | Inference Memory |
|-----------------------------------------------------|------------------|----------------|--------|------------------|
| Default (bfloat16-mixed) | OutOfMemoryError | N/A | N/A | N/A |
| --precision "bf16-true" | 20.60 GB | 876.30s | 0.8696 | 13.82 GB |
| --quantize "bnb.nf4" | 19.62 GB | 1320.63s | 1.0178 | 4.66 GB |
| --quantize "bnb.nf4-dq" | 19.32 GB | 1359.10s | 1.0132 | 4.34 GB |
| --precision "bf16-true" --quantize "bnb.nf4" | 13.44 GB | 1089.79s | 1.0130 | 4.66 GB |
| --precision "bf16-true" --quantize "bnb.nf4-dq" | 13.15 GB | 1135.86s | 1.0124 | 4.34 GB |

## Test the model

You can test the finetuned model with your own instructions by running:
Expand All @@ -52,7 +86,7 @@ Output:
I would recommend the movie The Martian (2015). It is a sci-fi movie starring Matt Damon that follows the story of...
```

If your GPU supports `bfloat16`, you can additionally pass `--precision bf16-true` to bring the memory consumption down to ~11 GB for StableLM-3B.
If your GPU supports `bfloat16`, you can additionally pass `--precision "bf16-true"` to bring the memory consumption down to ~7.6 GB for StableLM-3B (versus ~15.2 GB for `--precision "32-full"`). In addition, you may use quantization methods, for example `--precision "bf16-true" --quantize "bnb.nf4"` brings the memory consumption further down to ~4.4 GB for StableLM-3B.

## Tune on your dataset

Expand Down
2 changes: 1 addition & 1 deletion tutorials/quantize.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ This document provides different strategies for quantizing the various models av
**All the examples below were run on an A100 40GB GPU.**

> [!NOTE]\:
> Quantization is only supported with inference (generate and chat scripts).
> Quantization also supports finetuning via [QLoRA](finetune_lora.md)

## Baseline
Expand Down

0 comments on commit 064fd52

Please sign in to comment.