# Efficiently scale distributed training with PyTorch FSDP and Q-Lora

Open LLMs like Meta [Llama 2](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf), Mistral AI [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) & [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) models or TII [Falcon](https://huggingface.co/tiiuae/falcon-40b) brought a lot of progress in the last year. We went from no OpenAI competitor to a whole zoo of LLMs, which can be used for a variety of tasks. However, most of the time you need to fine-tune the model on your data or use case to get the full potential of the model. Fine-tuning smaller LLMs, like Mistral became very accessible on a single GPU by using Q-Lora. But efficiently fine-tuning bigger models like Llama 2 70b or Mixtral stayed a challenge until now. 

In a collaboration between [Answer.AI](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html), Tim Dettmers [Q-Lora creator](https://github.com/TimDettmers/bitsandbytes) and [Hugging Face](https://huggingface.co/), we are proud to announce to share the support of Q-Lora and PyTorch FSDP (Fully Sharded Data Parallel). FSDP and Q-Lora allows you now to fine-tune Llama 2 70b or Mixtral 8x7B on 2x consumer GPUs (24GB). If you want to learn more about the background of this collaboration take a look at [You can now train a 70b language model at home](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html). 

The TL;DR;
* [PyTorch FSDP](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) is a data/model parallelism technique that shards model across GPUs, reducing memory requirements and enabling the training of larger models more efficiently​​​​​​.
* Q-LoRA is a fine-tuning method that leverages quantization and Low-Rank Adapters to efficiently reduced computational requirements and memory footprint. 

 
This blog post walks you thorugh how to fine-tune a LLM using PyTorch FSDP and Q-Lora with the help of Hugging Face [TRL](https://huggingface.co/docs/trl/index), [Transformers](https://huggingface.co/docs/transformers/index) & [datasets](https://huggingface.co/docs/datasets/index). In addition to FSDP we will use [Flash Attention v2 through the Pytorch SDPA](https://pytorch.org/blog/pytorch2-2/) implementation. 

1. [Setup development environment](#1-setup-development-environment)
2. [Create and prepare the dataset](#2-create-and-prepare-the-dataset)
3. [Fine-tune the LLM with PyTorch FSDP, Q-Lora and SDPA](#3-fine-tune-the-llm-with-pytorch-fsdp-q-lora-and-sdpa)
4. [Test Model and run Inference](#4-test-model-and-run-inference)

_Note: This blog was created and validated on NVIDIA H100 and NVIDIA A10G GPUs._

As example use case we will fine-tune the Llama 2 70b model on the [HuggingFaceH4/no_robots](https://huggingface.co/datasets/HuggingFaceH4/no_robots) dataset. No Robots is a high-quality dataset of 10,000 instructions and demonstrations created by skilled human annotators. This data can be used for supervised fine-tuning (SFT) to make language models follow instructions better. No Robots was modelled after the instruction dataset described in OpenAI's [InstructGPT paper](https://huggingface.co/papers/2203.02155), and is comprised mostly of single-turn instructions across the following categories:
* Generation (e.g., "Generate a summary of the following text.")
* Open QA (e.g., "What is the capital of France?")
* Brainstorm (e.g., "List three ways to improve your memory.")
* Chat (e.g., "Tell me about your day.")
* Rewrite (e.g., "Rewrite the following sentence in a more formal tone.")
* Summarize (e.g., "Summarize the following text.")
* Coding (e.g., "Write a function that takes a list of integers and returns the sum of the list.")
* Classify (e.g., "Is the following sentence a fact or an opinion?")
* Closed QA (e.g., "What is the capital of France?")
* Extract (e.g., "Extract the name of the author from the following text.")

## 1. Setup development environment

Our first step is to install Hugging Face Libraries and Pyroch, including trl, transformers and datasets. If you haven't heard of trl yet, don't worry. It is a new library on top of transformers and datasets, which makes it easier to fine-tune, rlhf, align open LLMs. 

In [None]:
# Install Pytorch for FSDP and FA/SDPA
!pip install "torch==2.2.1" tensorboard

# Install Hugging Face libraries
!pip install  --upgrade \
  "transformers==4.39.1" \
  "datasets==2.18.0" \
  "accelerate==0.28.0" \
  "evaluate==0.4.1" \
  "bitsandbytes==0.43.0" \
  "trl==0.8.1"  \
  "peft==0.10.0" 

We will use the [Hugging Face Hub](https://huggingface.co/models) as a remote model versioning service. This means we will automatically push our model, logs and information to the Hub during training. You must register on the [Hugging Face](https://huggingface.co/join) for this. After you have an account, we will use the `login` util from the `huggingface_hub` package to log into our account and store our token (access key) on the disk.



In [None]:
from huggingface_hub import login

login(
  token="", # ADD YOUR TOKEN HERE
  add_to_git_credential=True
)

## 2. Create and prepare the dataset

Once you have determined that fine-tuning is the right solution we need to create a dataset to fine-tune our model. The dataset should be a diverse set of demonstrations of the task you want to solve. There are several ways to create such a dataset, including:
* Using existing open-source datasets, e.g., [Spider](https://huggingface.co/datasets/spider)
* Using LLMs to create synthetically datasets, e.g., [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca)
* Using Humans to create datasets, e.g., [Dolly](https://huggingface.co/datasets/databricks/databricks-dolly-15k).
* Using a combination of the above methods, e.g., [Orca](https://huggingface.co/datasets/Open-Orca/OpenOrca)

Each of the methods has its own advantages and disadvantages and depends on the budget, time, and quality requirements. For example, using an existing dataset is the easiest but might not be tailored to your specific use case, while using humans might be the most accurate but can be time-consuming and expensive. It is also possible to combine several methods to create an instruction dataset, as shown in [Orca: Progressive Learning from Complex Explanation Traces of GPT-4.](https://arxiv.org/abs/2306.02707)

In our example we will use an already existing dataset called [sql-create-context](https://huggingface.co/datasets/b-mc2/sql-create-context), which contains samples of natural language instructions, schema definitions and the corresponding SQL query.

With the latest release of `trl` we now support popular instruction and conversation dataset formats. This means we only need to convert our dataset to one of the supported formats and `trl` will take care of the rest. Those formats include:
* conversational format
```json
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
```
* instruction format

```json
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
```

In our example we are going to load our open-source dataset using the 🤗 Datasets library and then convert it into the the conversational format, where we include the schema definition in the system message for our assistant. We'll then save the dataset as jsonl file, which we can then use to fine-tune our model. We are randomly downsampling the dataset to only 10,000 samples.

_Note: This step can be different for your use case. For example, if you have already a dataset from, e.g. working with OpenAI, you can skip this step and go directly to the fine-tuning step._

In [None]:
from datasets import load_dataset

# Convert dataset to OAI messages
system_message = """You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.
SCHEMA:
{schema}"""

def create_conversation(sample):
  return {
    "messages": [
      {"role": "system", "content": system_message.format(schema=sample["context"])},
      {"role": "user", "content": sample["question"]},
      {"role": "assistant", "content": sample["answer"]}
    ]
  }  

# Load dataset from the hub
dataset = load_dataset("b-mc2/sql-create-context", split="train")
dataset = dataset.shuffle().select(range(12500))

# Convert dataset to OAI messages
dataset = dataset.map(create_conversation, remove_columns=dataset.features,batched=False)
# split dataset into 10,000 training samples and 2,500 test samples
dataset = dataset.train_test_split(test_size=2500/12500)

print(dataset["train"][345]["messages"])

# save datasets to disk 
dataset["train"].to_json("train_dataset.json", orient="records")
dataset["test"].to_json("test_dataset.json", orient="records")

## 3. Fine-tune the LLM with PyTorch FSDP, Q-Lora and SDPA

We are now ready to fine-tune our model. We will use the [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) from `trl` to fine-tune our model. The `SFTTrainer` makes it straightfoward to supervise fine-tune open LLMs. The `SFTTrainer` is a subclass of the `Trainer` from the `transformers` library and supports all the same features, including logging, evaluation, and checkpointing, but adds additiional quality of life features, including:
* Dataset formatting, including conversational and instruction format
* Training on completions only, ignoring prompts
* Packing datasets for more efficient training
* PEFT (parameter-efficient fine-tuning) support including Q-LoRA
* Preparing the model and tokenizer for conversational fine-tuning (e.g. adding special tokens)
 
We will use the dataset formatting, packing and PEFT features in our example. As peft method we will use [QLoRA](https://arxiv.org/abs/2305.14314) a technique to reduce the memory footprint of large language models during finetuning, without sacrificing performance by using quantization. If you want to learn more about QLoRA and how it works, check out [Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes) blog post.

Now, lets get started! 🚀 

First, we need to load our dataset from disk. 

In [None]:
from datasets import load_dataset

# Load jsonl data from disk
dataset = load_dataset("json", data_files="train_dataset.json", split="train")

Next, we will load our LLM. For our use case we are going to use CodeLlama 7B. CodeLlama is a Llama model trained for general code synthesis and understanding. 
But we can easily swap out the model for another model, e.g. [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) or [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) models, TII [Falcon](https://huggingface.co/tiiuae/falcon-40b), or any other LLMs by changing our `model_id` variable. We will use bitsandbytes to quantize our model to 4-bit.

_Note: Be aware the bigger the model the more memory it will require. In our example we will use the 7B version, which can be tuned on 24GB GPUs. If you have a smaller GPU._


Correctly, preparing the LLM and Tokenizer for training chat/conversational models is crucial. We need to add new special tokens to the tokenizer and model and teach to understand the different roles in a conversation. In `trl` we have a convinient method called [setup_chat_format](https://huggingface.co/docs/trl/main/en/sft_trainer#add-special-tokens-for-chat-format), which:
* Adds special tokens to the tokenizer, e.g. `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
* Resizes the model’s embedding layer to accommodate the new tokens.
* Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI.

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from trl import setup_chat_format

# Hugging Face model id
model_id = "codellama/CodeLlama-7b-hf" # or `mistralai/Mistral-7B-v0.1`

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.padding_side = 'right' # to prevent warnings

# # set chat template to OAI chatML, remove if you start from a fine-tuned model
model, tokenizer = setup_chat_format(model, tokenizer)

The `SFTTrainer`  supports a native integration with `peft`, which makes it super easy to efficiently tune LLMs using, e.g. QLoRA. We only need to create our `LoraConfig` and provide it to the trainer. Our `LoraConfig` parameters are defined based on the [qlora paper](https://arxiv.org/pdf/2305.14314.pdf) and sebastian's [blog post](https://magazine.sebastianraschka.com/p/practical-tips-for-finetuning-llms).

In [None]:
from peft import LoraConfig

# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(
        lora_alpha=128,
        lora_dropout=0.05,
        r=256,
        bias="none",
        target_modules="all-linear",
        task_type="CAUSAL_LM", 
)

Before we can start our training we need to define the hyperparameters (`TrainingArguments`) we want to use.

In [None]:
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="code-llama-7b-text-to-sql", # directory to save and repository id
    num_train_epochs=3,                     # number of training epochs
    per_device_train_batch_size=3,          # batch size per device during training
    gradient_accumulation_steps=2,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=10,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    bf16=True,                              # use bfloat16 precision
    tf32=True,                              # use tf32 precision
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    push_to_hub=True,                       # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
)

We now have every building block we need to create our `SFTTrainer` to start then training our model.

In [None]:
from trl import SFTTrainer

max_seq_length = 3072 # max sequence length for model and packing of the dataset

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    peft_config=peft_config,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    dataset_kwargs={
        "add_special_tokens": False,  # We template with special tokens
        "append_concat_token": False, # No need to add additional separator token
    }
)

Start training our model by calling the `train()` method on our `Trainer` instance. This will start the training loop and train our model for 3 epochs. Since we are using a PEFT method, we will only save the adapted model weights and not the full model.

In [None]:
# start training, the model will be automatically saved to the hub and the output directory
trainer.train()

# save model 
trainer.save_model()

The training with Flash Attention for 3 epochs with a dataset of 10k samples took 01:29:58 on a `g5.2xlarge`. The instance costs `1,212$/h` which brings us to a total cost of only `1.8$`.

In [None]:
# free the memory again
del model
del trainer
torch.cuda.empty_cache()

### _Optional: Merge LoRA adapter in to the original model_

When using QLoRA, we only train adapters and not the full model. This means when saving the model during training we only save the adapter weights and not the full model. If you want to save the full model, which makes it easier to use with Text Generation Inference you can merge the adapter weights into the model weights using the `merge_and_unload` method and then save the model with the `save_pretrained` method. This will save a default model, which can be used for inference.

_Note: You might require > 30GB CPU Memory._

In [None]:

#### COMMENT IN TO MERGE PEFT AND BASE MODEL ####
# from peft import AutoPeftModelForCausalLM

# # Load PEFT model on CPU
# model = AutoPeftModelForCausalLM.from_pretrained(
#     args.output_dir,
#     torch_dtype=torch.float16,
#     low_cpu_mem_usage=True,
# )  
# # Merge LoRA and base model and save
# merged_model = model.merge_and_unload()
# merged_model.save_pretrained(args.output_dir,safe_serialization=True, max_shard_size="2GB")

## 4. Test Model and run Inference

After the training is done we want to evaluate and test our model. We will load different samples from the original dataset and evaluate the model on those samples, using a simple loop and accuracy as our metric. 

_Note: Evaluating Generative AI models is not a trivial task since 1 input can have multiple correct outputs. If you want to learn more about evaluating generative models, check out [Evaluate LLMs and RAG a practical example using Langchain and Hugging Face](https://www.philschmid.de/evaluate-llm) blog post._



In [None]:
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, pipeline 

peft_model_id = "./code-llama-7b-text-to-sql"
# peft_model_id = args.output_dir

# Load Model with PEFT adapter
model = AutoPeftModelForCausalLM.from_pretrained(
  peft_model_id,
  device_map="auto",
  torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
# load into pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

Let’s load our test dataset try to generate an instruction.

In [None]:
from datasets import load_dataset 
from random import randint


# Load our test dataset
eval_dataset = load_dataset("json", data_files="test_dataset.json", split="train")
rand_idx = randint(0, len(eval_dataset))

# Test on sample 
prompt = pipe.tokenizer.apply_chat_template(eval_dataset[rand_idx]["messages"][:2], tokenize=False, add_generation_prompt=True)
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)

print(f"Query:\n{eval_dataset[rand_idx]['messages'][1]['content']}")
print(f"Original Answer:\n{eval_dataset[rand_idx]['messages'][2]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")