In [1]:
import time
import random
from typing import Dict

# Instruction Finetuning LLMs with QLoRA for RAG

Large Language Models are typically trained as models that simply predict the next word in a sequence.  While this leads to very powerful machines, they don't typically come equipped to deal with certain behaviors, such as following instructions.  In this lab, we will demonstrate how to fine tune a base Large Language Model to better respond to instructions with context, which is a requirement for RAG.  By fine tuning the model in this way, we can teach it to stop better, hallucinate less, and generally behave in a more desirable way.

## **Important Note**

***We are finetuning a base model for RAG for instructional purposes on how finetuning can change the behavior of models. In practice, many models provide instruction fine-tuned models which will give better results than we can produce here for RAG because they are trained on many more data examples. For example (mistralai/Mistral-7B-v0.1 vs mistralai/Mistral-7B-Instruct-v0.1) and (meta-llama/Llama-2-7b-hf vs. meta-llama/Llama-2-7b-chat-hf). Try and get the best performance out of the finetuning but don't expect it to work perfectly..***

- [Preparing the Dataset](#preparing-the-dataset)
- [Selecting the Base Pre-trained Model](#selecting-the-base-pre-trained-model)
- [Finetuning the Model](#finetuning-the-model)

## Preparing the Dataset

Fine-tuning LLMs is primarily used for teaching the model new behavior, such as better responding to instructions, responding with certain tones, or acting more as a conversational chatbot.  

The dataset for finetuning LLMs are text entries formatted in the way ***THAT WE WISH FOR AN INTERACTION WITH THE MODEL TO LOOK LIKE***.  For example, if we wish for the model to follow instructions better with context, we should provide a dataset which gives examples of it following instructions provided with context.  **This is almost exactly like few-shot prompting, but reinforcing the behavior even further by actually modifying some of the weights of the model.**

A few tips from ChatGPT:

Generative Dataset:

    1. Include a dataset of input queries or prompts along with human-generated responses. This is your generative dataset.

    2. Make sure that the responses are diverse, well-written, and contextually appropriate for the given queries.

    3. It's important to have a variety of responses to encourage the model to generate creative and contextually relevant answers.

Training Data Quality:

    1. Ensure that your training dataset is of high quality and accurately represents the task you are fine-tuning for.

    2. Remove any instances that contain incorrect or misleading information.

    3. Filter out instances in your training data where the model is likely to hallucinate or generate incorrect information.

    4. Manually review and filter out examples that may lead to misinformation.

    5. Use data augmentation techniques to artificially increase the diversity of your dataset. However, be cautious with augmentation to ensure that the generated samples remain contextually relevant and accurate.
```

### Dataset using `datasets`

The dataset that we will be using for instruction fine-tuning is a dataset hand-curated by databricks for instruction following called "dolly-15k".

In [2]:
from datasets import load_dataset, Dataset
import pandas as pd

def load_modified_dataset():
    dataset = load_dataset("databricks/databricks-dolly-15k", split = "train")
    df = dataset.to_pandas()
    # df.info()
    df['keep'] = True
    
    # Keep entries with correct answer as well
    df = df[(df['category'].isin(['closed_qa', 'information_extraction', 'open_qa'])) & df["context"].str.contains(".{1,}")]
    print(df.shape)
    # df.to_csv("dolly_data_with_context.csv")
    # df = df[(df['category'].isin(['closed_qa', 'information_extraction', 'open_qa'])) & df['keep']]
    
    # print(df.sample(10))
    return Dataset.from_pandas(
        df[['instruction', 'context', 'response']], 
        preserve_index = False)
    
dataset = load_modified_dataset()
dataset = dataset.select(range(1000))

(3279, 5)


In [2]:
len(dataset)

1000

#### **AP**: Inspect databricks dataset

In [2]:
pd.options.display.max_colwidth = 500
dataset_dolly = load_dataset("databricks/databricks-dolly-15k", split = "train")
df_dolly = dataset_dolly.to_pandas()
df_dolly.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 15011 entries, 0 to 15010
Data columns (total 4 columns):
 #   Column       Non-Null Count  Dtype 
---  ------       --------------  ----- 
 0   instruction  15011 non-null  object
 1   context      15011 non-null  object
 2   response     15011 non-null  object
 3   category     15011 non-null  object
dtypes: object(4)
memory usage: 469.2+ KB


In [5]:
df_dolly_v2["context"].str.contains(".{1,}")

0        True
4        True
5        True
6        True
9        True
         ... 
14990    True
14993    True
15001    True
15003    True
15005    True
Name: context, Length: 3279, dtype: bool

In [4]:
df_dolly['keep'] = True

df_dolly_v2 = df_dolly[(df_dolly['category'].isin(['closed_qa', 'information_extraction', 'open_qa'])) & df_dolly["context"].str.contains(".{1,}")]
df_dolly_v2[df_dolly_v2["context"].str.contains(".{1,}")]

Unnamed: 0,instruction,context,response,category,keep
0,When did Virgin Australia start operating?,"Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbour...","Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.",closed_qa,True
4,When was Tomoaki Komorida born?,"Komorida was born in Kumamoto Prefecture on July 10, 1981. After graduating from high school, he joined the J1 League club Avispa Fukuoka in 2000. Although he debuted as a midfielder in 2001, he did not play much and the club was relegated to the J2 League at the end of the 2001 season. In 2002, he moved to the J2 club Oita Trinita. He became a regular player as a defensive midfielder and the club won the championship in 2002 and was promoted in 2003. He played many matches until 2005. In Se...","Tomoaki Komorida was born on July 10,1981.",closed_qa,True
5,"If I have more pieces at the time of stalemate, have I won?","Stalemate is a situation in chess where the player whose turn it is to move is not in check and has no legal move. Stalemate results in a draw. During the endgame, stalemate is a resource that can enable the player with the inferior position to draw the game rather than lose. In more complex positions, stalemate is much rarer, usually taking the form of a swindle that succeeds only if the superior side is inattentive.[citation needed] Stalemate is also a common theme in endgame studies and o...",No. \nStalemate is a drawn position. It doesn't matter who has captured more pieces or is in a winning position,information_extraction,True
6,"Given a reference text about Lollapalooza, where does it take place, who started it and what is it?","Lollapalooza /ˌlɒləpəˈluːzə/ (Lolla) is an annual American four-day music festival held in Grant Park in Chicago. It originally started as a touring event in 1991, but several years later, Chicago became its permanent location. Music genres include but are not limited to alternative rock, heavy metal, punk rock, hip hop, and electronic dance music. Lollapalooza has also featured visual arts, nonprofit organizations, and political organizations. The festival, held in Grant Park, hosts an esti...","Lollapalooze is an annual musical festival held in Grant Park in Chicago, Illinois. It was started in 1991 as a farewell tour by Perry Farrell, singe of the group Jane's Addiction. The festival includes an array of musical genres including alternative rock, heavy metal, punk rock, hip hop, and electronic dance music. The festivals welcomes an estimated 400,000 people each year and sells out annually. Some notable headliners include: the Red Hot Chili Peppers, Chance the Rapper, Metallica, an...",closed_qa,True
9,Who was John Moses Browning?,"John Moses Browning (January 23, 1855 – November 26, 1926) was an American firearm designer who developed many varieties of military and civilian firearms, cartridges, and gun mechanisms – many of which are still in use around the world. He made his first firearm at age 13 in his father's gun shop and was awarded the first of his 128 firearm patents on October 7, 1879, at the age of 24. He is regarded as one of the most successful firearms designers of the 19th and 20th centuries and pioneer...","John Moses Browning is one of the most well-known designer of modern firearms. He started building firearms in his father's shop at the age of 13, and was awarded his first patent when he was 24.\n\nHe designed the first reliable automatic pistol, and the first gas-operated firearm, as well inventing or improving single-shot, lever-action, and pump-action rifles and shotguns.\n\nToday, he is most well-known for the M1911 pistol, the Browning Automatic Rifle, and the Auto-5 shotgun, all of ...",information_extraction,True
...,...,...,...,...,...
14990,What is the history of coffee?,"The history of coffee dates back to centuries of old oral tradition in modern day Ethiopia. However, neither where coffee was first cultivated nor direct evidence of its consumption prior to the 15th century have been found. Sufi monasteries in Yemen employed coffee as an aid to concentration during prayers. Coffee later spread to the Levant and Persia in the early 16th century; it caused some controversy on whether it was halal in Ottoman and Mamluk society. Coffee arrived in Italy the seco...","Much of the early history of coffee was tied to oral tradition in modern-day Ethiopia. Presently, there is no evidence of coffee consumption and cultivation before the 15th century. Yemeni monks drank coffee to aid in concentration during prayers. In the early 16th century, coffee spread to the Levant and Persia. Later in the 16th-century coffee arrived in Italy via Mediterranean trade routes. The Ottomans brought it to Central and Eastern Europe. It reached India by the mid-17th century.Eng...",closed_qa,True
14993,When did Phil Knight announce he would step down as chairman of Nike,"Throughout the 1980s, Nike expanded its product line to encompass many sports and regions throughout the world. In 1990, Nike moved into its eight-building World Headquarters campus in Beaverton, Oregon. The first Nike retail store, dubbed Niketown, opened in downtown Portland in November of that year. Phil Knight announced in mid-2015 that he would step down as chairman of Nike in 2016. He officially stepped down from all duties with the company on June 30, 2016. In a company public annou...",Phil Knight announced he would step down in 2015 as chairman and offically stepped down in 2016,information_extraction,True
15001,What are common florals found in Zigalga National Park?,"Zigalga National Park (Russian: Национальный парк «Зигальга») is located on the high Zigalga Ridge of the Southern Ural Mountains in Russia, on the transition between Europe and Siberia. Much of the territory is untouched by human activity and so supports Ice Age relict floral communities through all altitude zones - pine and birch forest, dark coniferous taiga, alpine meadows and mountain tundra. The park was officially created in 2019. The park is located in the Katav-Ivanovsky District of...","Zigalga National Park has the majority of its territory untouched by human activity and includes pine and birch forest, dark coniferous taiga, alpine meadows and mountain tundra.",closed_qa,True
15003,What is linux Bootloader,"A bootloader, also spelled as boot loader or called boot manager and bootstrap loader, is a computer program that is responsible for booting a computer.\n\nWhen a computer is turned off, its software including operating systems, application code, and data‍—‌remains stored on non-volatile memory. When the computer is powered on, it typically does not have an operating system or its loader in random-access memory (RAM). The computer first executes a relatively small program stored in read-only...",A bootloader is a program written in machine code that loads the operating system into RAM during the boot process.,closed_qa,True


In [8]:
df_dolly_v2[df_dolly_v2["context"].str.contains(".{1,}")].head(5)

Unnamed: 0,instruction,context,response,category,keep
0,When did Virgin Australia start operating?,"Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbour...","Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.",closed_qa,True
4,When was Tomoaki Komorida born?,"Komorida was born in Kumamoto Prefecture on July 10, 1981. After graduating from high school, he joined the J1 League club Avispa Fukuoka in 2000. Although he debuted as a midfielder in 2001, he did not play much and the club was relegated to the J2 League at the end of the 2001 season. In 2002, he moved to the J2 club Oita Trinita. He became a regular player as a defensive midfielder and the club won the championship in 2002 and was promoted in 2003. He played many matches until 2005. In Se...","Tomoaki Komorida was born on July 10,1981.",closed_qa,True
5,"If I have more pieces at the time of stalemate, have I won?","Stalemate is a situation in chess where the player whose turn it is to move is not in check and has no legal move. Stalemate results in a draw. During the endgame, stalemate is a resource that can enable the player with the inferior position to draw the game rather than lose. In more complex positions, stalemate is much rarer, usually taking the form of a swindle that succeeds only if the superior side is inattentive.[citation needed] Stalemate is also a common theme in endgame studies and o...",No. \nStalemate is a drawn position. It doesn't matter who has captured more pieces or is in a winning position,information_extraction,True
6,"Given a reference text about Lollapalooza, where does it take place, who started it and what is it?","Lollapalooza /ˌlɒləpəˈluːzə/ (Lolla) is an annual American four-day music festival held in Grant Park in Chicago. It originally started as a touring event in 1991, but several years later, Chicago became its permanent location. Music genres include but are not limited to alternative rock, heavy metal, punk rock, hip hop, and electronic dance music. Lollapalooza has also featured visual arts, nonprofit organizations, and political organizations. The festival, held in Grant Park, hosts an esti...","Lollapalooze is an annual musical festival held in Grant Park in Chicago, Illinois. It was started in 1991 as a farewell tour by Perry Farrell, singe of the group Jane's Addiction. The festival includes an array of musical genres including alternative rock, heavy metal, punk rock, hip hop, and electronic dance music. The festivals welcomes an estimated 400,000 people each year and sells out annually. Some notable headliners include: the Red Hot Chili Peppers, Chance the Rapper, Metallica, an...",closed_qa,True
9,Who was John Moses Browning?,"John Moses Browning (January 23, 1855 – November 26, 1926) was an American firearm designer who developed many varieties of military and civilian firearms, cartridges, and gun mechanisms – many of which are still in use around the world. He made his first firearm at age 13 in his father's gun shop and was awarded the first of his 128 firearm patents on October 7, 1879, at the age of 24. He is regarded as one of the most successful firearms designers of the 19th and 20th centuries and pioneer...","John Moses Browning is one of the most well-known designer of modern firearms. He started building firearms in his father's shop at the age of 13, and was awarded his first patent when he was 24.\n\nHe designed the first reliable automatic pistol, and the first gas-operated firearm, as well inventing or improving single-shot, lever-action, and pump-action rifles and shotguns.\n\nToday, he is most well-known for the M1911 pistol, the Browning Automatic Rifle, and the Auto-5 shotgun, all of ...",information_extraction,True


In [12]:
df_dolly_v2

Unnamed: 0,instruction,context,response,category,keep
0,When did Virgin Australia start operating?,"Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbour...","Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.",closed_qa,True
2,Why can camels survive for long without water?,,Camels use the fat in their humps to keep them filled with energy and hydration for long periods of time.,open_qa,True
3,"Alice's parents have three daughters: Amy, Jessy, and what’s the name of the third daughter?",,The name of the third daughter is Alice,open_qa,True
4,When was Tomoaki Komorida born?,"Komorida was born in Kumamoto Prefecture on July 10, 1981. After graduating from high school, he joined the J1 League club Avispa Fukuoka in 2000. Although he debuted as a midfielder in 2001, he did not play much and the club was relegated to the J2 League at the end of the 2001 season. In 2002, he moved to the J2 club Oita Trinita. He became a regular player as a defensive midfielder and the club won the championship in 2002 and was promoted in 2003. He played many matches until 2005. In Se...","Tomoaki Komorida was born on July 10,1981.",closed_qa,True
5,"If I have more pieces at the time of stalemate, have I won?","Stalemate is a situation in chess where the player whose turn it is to move is not in check and has no legal move. Stalemate results in a draw. During the endgame, stalemate is a resource that can enable the player with the inferior position to draw the game rather than lose. In more complex positions, stalemate is much rarer, usually taking the form of a swindle that succeeds only if the superior side is inattentive.[citation needed] Stalemate is also a common theme in endgame studies and o...",No. \nStalemate is a drawn position. It doesn't matter who has captured more pieces or is in a winning position,information_extraction,True
...,...,...,...,...,...
14999,Who is the creator of Python?,,Guido van Rossum is the father of Python. And the 1st version of Python was released in 1991.,open_qa,True
15001,What are common florals found in Zigalga National Park?,"Zigalga National Park (Russian: Национальный парк «Зигальга») is located on the high Zigalga Ridge of the Southern Ural Mountains in Russia, on the transition between Europe and Siberia. Much of the territory is untouched by human activity and so supports Ice Age relict floral communities through all altitude zones - pine and birch forest, dark coniferous taiga, alpine meadows and mountain tundra. The park was officially created in 2019. The park is located in the Katav-Ivanovsky District of...","Zigalga National Park has the majority of its territory untouched by human activity and includes pine and birch forest, dark coniferous taiga, alpine meadows and mountain tundra.",closed_qa,True
15003,What is linux Bootloader,"A bootloader, also spelled as boot loader or called boot manager and bootstrap loader, is a computer program that is responsible for booting a computer.\n\nWhen a computer is turned off, its software including operating systems, application code, and data‍—‌remains stored on non-volatile memory. When the computer is powered on, it typically does not have an operating system or its loader in random-access memory (RAM). The computer first executes a relatively small program stored in read-only...",A bootloader is a program written in machine code that loads the operating system into RAM during the boot process.,closed_qa,True
15005,What is one-child policy?,"The term one-child policy refers to a population planning initiative in China implemented between 1980 and 2015 to curb the country's population growth by restricting many families to a single child. That initiative was part of a much broader effort to control population growth that began in 1970 and ended in 2021, a half century program that included minimum ages at marriage and childbearing, two-child limits for many couples, minimum time intervals between births, heavy surveillance, and s...","The ""one-child policy"" was a Chinese population planning initiative that was implemented from 1980 to 2015 to curb population growth by limiting many families to only one child. It was part of a larger effort to control population growth that began in 1970 and ended in 2021, which included setting minimum ages for marriage and childbearing, imposing two-child limits on many couples, requiring minimum intervals between births, close monitoring, and imposing stiff fines for non-compliance. The...",closed_qa,True


The base dataset contains columns for an `instruction`, an optional `context`, and a `response` that we want the bot to respond to.  However, to feed it into the model for finetuning, we need to combine each column so that 1 sample corresponds to 1 example interaction with the model.  

This 1 sample should be an example to the LLM about:

1. How we wish to interact with the model (prompt)
2. How we want the model to respond

Remember, these generative LLMs are trained to read in a provided prompt, and essentially auto-complete the text!

### IMPORTANT ###

In [3]:
def format_instruction(sample : Dict) -> str:
    """Combine a row to a single str"""
    return f"""### Context:
{sample['context']}

### Question:
Using only the context above, {sample['instruction']}

### Response:
{sample['response']}
"""

We will provide this as the entire prompt to the model for training, using the Causal Language Modeling objective for loss.

```
### Context:
{context}

### Question:
Using only the context above, {instruction}

### Response:
{response}
```

## Selecting the Base Pre-trained Model

Once we have the data, we can select the base model that we would like to fine tune for this behavior.  

The model that we will select is the `mistralai/mistral-7b` base model.  This is a 7.3b parameter model, quite small in the grand scheme of LLMs, but one that produces good quality results, especially compared to many other open source models.

### Quantization using `bitsandbytes`

LLMs are extremely memory intensive.  One trick that is commonly used when working with LLMs to reduce memory usage as well as increase computational speed for both inference and training, is reducing the precision of the weights from full precision 32-bit floating points (fp32) to lower precisions such as int8, fp4, nf4, etc.  This is known as quantization.  Research has shown that quantization often times has minimal impact on the quality of generations, but this is on a case-by-case basis. 

In this example, we will be quantizing and fine-tuning using normal-float 4 bit (nf4).  In practice, the quantization behind the scenes is handled by the `bitsandbytes` library.

In [6]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Hugging Face Base Model ID
model_id = "mistralai/Mistral-7B-v0.1"
is_peft = False

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

if is_peft:
    # load base LLM model with PEFT Adapter
    model = AutoPeftModelForCausalLM.from_pretrained(
        model_id,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
        use_flash_attention_2=True,
        quantization_config = bnb_config
    )
else:
    
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
        quantization_config = bnb_config,
        use_flash_attention_2=True
    )

model.config.pretraining_tp = 1

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_eos_token = True

The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

With the model loaded up, we are ready to finetune using our dataset.

## Finetuning the Model

There are two main ways to finetune a large language model:

1. Pre-training/Full Finetuning

    In this situation, all of the model weights (all 7b of them) are set to be trainable and tweaked during training.  This can lead to the most dramatic changes in model behavior but is also the most computationally expensive.  
    
    When initially training the model, also known as pre-training, this is necessarily done and where you see the extreme computational costs show up (i.e. 500 A100 80GB GPUs trained for 10000 hours, etc...).

2. Parameter Efficient Fine-Tuning (PEFT)

    Parameter efficient finetuning methods are an alternative to full finetuning where, instead of training the parameters of the pre-trained model, a subset of new parameters are trained without touching the base model weights. These new trainable parameters are injected into the model mathematically at different points to change the outcome.  There are a handful of methods that use this approach such as Prompt Tuning, P-Tuning, and Low-Rank Adaptation (LoRA).  For this lab, we will focus on LoRA.  

    LoRA methods introduce a set of trainable rank-decomposition matrices (update matrices) which can be used to modify the existing weights of the pre-trained model.  The typical location that these matrices are placed are within the attention layers, so they are not exclusive to LLMs.  The size of these update matrices can be controlled by  setting the desired rank of the matrix (`lora_r`), with smaller rank corresponding to smaller matrices and thus fewer trainable parameters.   During fine-tuning, only these update matrices are tuned and often times, this makes the total number of trainable parameters a very small fraction of the total number of weights.

### Finetuning using `peft`

To configure the model for paremeter efficient fine-tuning and LoRA, we will use the `peft` package.  Specifically, we will define our Lora parameters and also set to the taks to `CAUSAL_LM` to train the model for generative purposes.  Because we also quantized the model to 4-bit, we will also be using a state-of-the-art method called Quantized LoRA (QLoRA) to do this training in low precision to save memory.


In [7]:
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

if is_peft:
    model = prepare_model_for_kbit_training(model)
    model._mark_only_adapters_as_trainable()
else:
    # LoRA config for QLoRA
    peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.1,
        r=8,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=['v_proj', 'down_proj', 'up_proj', 'o_proj', 'q_proj', 'gate_proj', 'k_proj']
    )

    # prepare model for training with low-precision
    model = prepare_model_for_kbit_training(model)
    model = get_peft_model(model, peft_config)

### Running the trainer with `trl`

Now that we have prepared the data, loaded the model in 4-bit, and configured our LoRA finetuning according to our model, we are ready to train the model. Training of LLMs for generative purposes uses the causal language modeling objective.  Briefly, this specifies that when calculating attention, the model should only be able to consider things "to the left".  So for a sentence, it can only decide what to generate by looking at all of the words that came before it.  

A very useful wrapper for training transformer based models is the Supervised Fine-Tuning Trainer (`SFTrainer`) provided by the `trl` library.  While the supervised fine tuning is typically used in the context of reinforcement learning, for our purposes, it simply refers to providing the model with examples of input, and response.  All of the actual training, including computing gradients, tweaking the optimizer, batching the data, evaluation will be done behind the scenes using the `SFTrainer` wrapper.  This will conduct the finetuning that we want after we pass in the dataset and hyperparameters.  This is much more efficient and robust than writing our own training code.

In [8]:
from transformers import TrainingArguments
from trl import SFTTrainer

args = TrainingArguments(
    output_dir="./mistral-7b-int4-dolly_SMALL_V2", 
    num_train_epochs=1, # number of training epochs
    per_device_train_batch_size=5, # batch size per batch
    gradient_accumulation_steps=2, # effective batch size
    gradient_checkpointing=True, 
    gradient_checkpointing_kwargs={'use_reentrant':True},
    optim="paged_adamw_32bit",
    logging_steps=1, # log the training error every 10 steps
    save_strategy="steps",
    save_total_limit = 2, # save 2 total checkpoints
    ignore_data_skip=True,
    save_steps=2, # save a checkpoint every 1 steps
    learning_rate=1e-3,
    bf16=True,
    tf32=True,
    max_grad_norm=1.0,
    warmup_steps=5,
    lr_scheduler_type="constant",
    disable_tqdm=True
)

# https://huggingface.co/docs/trl/sft_trainer#packing-dataset--constantlengthdataset-
# max seq length for packing
max_seq_length = 2048 
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    tokenizer=tokenizer,
    max_seq_length=max_seq_length,
    packing=True,
    formatting_func=format_instruction, # our formatting function which takes a dataset row and maps it to str
    args=args,
)



With all of the configuration done, we can now run our training.  On an A10g, this takes about 1 hours to run, after which it will save the LoRA weights to the `mistral-7b-int4-dolly` directory.

In [9]:
start = time.time()
trainer.train(resume_from_checkpoint=False) # progress bar is fake due to packing
trainer.save_model()
end = time.time()
print(f"{end - start}s")

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.


{'loss': 1.4692, 'learning_rate': 0.001, 'epoch': 0.05}
{'loss': 1.4218, 'learning_rate': 0.001, 'epoch': 0.1}
{'loss': 1.3393, 'learning_rate': 0.001, 'epoch': 0.15}
{'loss': 1.4556, 'learning_rate': 0.001, 'epoch': 0.21}
{'loss': 1.392, 'learning_rate': 0.001, 'epoch': 0.26}
{'loss': 1.2689, 'learning_rate': 0.001, 'epoch': 0.31}
{'loss': 1.3043, 'learning_rate': 0.001, 'epoch': 0.36}
{'loss': 1.271, 'learning_rate': 0.001, 'epoch': 0.41}
{'loss': 1.2808, 'learning_rate': 0.001, 'epoch': 0.46}
{'loss': 1.4023, 'learning_rate': 0.001, 'epoch': 0.51}
{'loss': 1.3415, 'learning_rate': 0.001, 'epoch': 0.56}
{'loss': 1.3304, 'learning_rate': 0.001, 'epoch': 0.62}
{'loss': 1.4002, 'learning_rate': 0.001, 'epoch': 0.67}
{'loss': 1.3517, 'learning_rate': 0.001, 'epoch': 0.72}
{'loss': 1.1826, 'learning_rate': 0.001, 'epoch': 0.77}
{'loss': 1.4071, 'learning_rate': 0.001, 'epoch': 0.82}
{'loss': 1.3463, 'learning_rate': 0.001, 'epoch': 0.87}
{'loss': 1.3806, 'learning_rate': 0.001, 'epoch': 0

After the model has finished training, it is ready to be used.  Now, hopefully, when the model sees the prompt that we crafted before, it will know how to respond.