# Fine-tune Mistral-7b with SFT

> 🗣️ [Large Language Model Course](https://github.com/mlabonne/llm-course)

❤️ Created by [@maximelabonne](https://twitter.com/maximelabonne).

You can run this notebook on a free-tier Google Colab (T4 GPU).

In [None]:
!pip install -qqq -U transformers datasets accelerate peft trl bitsandbytes wandb --progress-bar off

import gc
import os

import torch
import wandb
from datasets import load_dataset
from google.colab import userdata
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
)
from trl import SFTTrainer

# Model
base_model = "alpindale/Mistral-7B-v0.2-hf"
new_model = "mistral-7b-miniplatypus"

# Defined in the secrets tab in Google Colab
wb_token = userdata.get('wandb')
wandb.login(key=wb_token)

# Set torch dtype and attention implementation
if torch.cuda.get_device_capability()[0] >= 8:
    !pip install -qqq flash-attn
    torch_dtype = torch.bfloat16
    attn_implementation = "flash_attention_2"
else:
    torch_dtype = torch.float16
    attn_implementation = "eager"

## Fine-tuning Mistral-7b

In [None]:
# Insert your dataset here
dataset_name = "mlabonne/mini-platypus"
dataset = load_dataset(dataset_name, split="all")
dataset = dataset.train_test_split(test_size=0.01)

# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)
tokenizer.pad_token = tokenizer.unk_token

# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)
model = prepare_model_for_kbit_training(model)

In [None]:
training_arguments = TrainingArguments(
    learning_rate=2e-4,
    lr_scheduler_type="linear",
    num_train_epochs=3,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=10,
    gradient_accumulation_steps=1,
    evaluation_strategy="steps",
    eval_steps=0.2,
    logging_steps=1,
    optim="paged_adamw_8bit",
    warmup_steps=10,
    report_to="wandb",
    output_dir="./results",
)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    peft_config=peft_config,
    dataset_text_field="instruction",
    max_seq_length=512,
    tokenizer=tokenizer,
    args=training_arguments,
)

trainer.train()

trainer.model.save_pretrained(new_model)
tokenizer.save_pretrained(new_model)

In [None]:
prompt = "What is a large language model?"
instruction = f"### Instruction:\n{prompt}\n\n### Response:\n"

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=128)
result = pipe(instruction)
print(result[0]["generated_text"][len(instruction):])

Merging the base model with the trained adapter.

In [None]:
# Flush memory
del trainer, model
gc.collect()
torch.cuda.empty_cache()

model = AutoModelForCausalLM.from_pretrained(
    base_model,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map={"": 0},
)
model = PeftModel.from_pretrained(model, new_model)
model = model.merge_and_unload()

# Reload tokenizer to save it
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.unk_token

Optional: pushing the model and tokenizer to the Hugging Face Hub.

In [None]:
model.push_to_hub(new_model, use_temp_dir=False)
tokenizer.push_to_hub(new_model, use_temp_dir=False)

## Going further

* **DPO fine-tuning**: see [this article](https://mlabonne.github.io/blog/posts/Fine_tune_Mistral_7b_with_DPO.html)
* **Better fine-tuning tool**: see [Axolotl](https://mlabonne.github.io/blog/posts/A_Beginners_Guide_to_LLM_Finetuning.html)
* **Evaluation**: see the [LM Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness) and the [Open LLM Leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard)
* **Quantization**: see [naive quantization](https://mlabonne.github.io/blog/posts/Introduction_to_Weight_Quantization.html), [GPTQ](https://mlabonne.github.io/blog/posts/4_bit_Quantization_with_GPTQ.html), [GGUF/llama.cpp](https://mlabonne.github.io/blog/posts/Quantize_Llama_2_models_using_ggml.html), ExLlamav2, and AWQ.