# Fine-Tuning Mistral

In [2]:
!pip install -q -U transformers bitsandbytes peft datasets accelerate trl

# Loading Tokenizer

In [3]:
from transformers import AutoTokenizer

base_model = "mistralai/Mistral-7B-v0.1"

tokenizer = AutoTokenizer.from_pretrained(
    base_model, 
    padding_side = "right",
    add_eos_token = True,
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_bos_token, tokenizer.add_eos_token

tokenizer_config.json:   0%|          | 0.00/967 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

(True, True)

In [4]:
tokenizer

LlamaTokenizerFast(name_or_path='mistralai/Mistral-7B-v0.1', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '</s>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

# Loading the Model

In [5]:
import torch
from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=False,
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [8]:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)


model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

# Loading the Dataset

In [9]:
from datasets import load_dataset

dataset_name = "databricks/databricks-dolly-15k"

train_dataset = load_dataset(dataset_name, split="train[0:800]")
eval_dataset = load_dataset(dataset_name, split="train[800:1000]")

Downloading readme:   0%|          | 0.00/8.20k [00:00<?, ?B/s]



Downloading data:   0%|          | 0.00/13.1M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

# Understanding the Model

In [10]:
train_dataset

Dataset({
    features: ['instruction', 'context', 'response', 'category'],
    num_rows: 800
})

In [11]:
train_dataset.to_pandas()

Unnamed: 0,instruction,context,response,category
0,When did Virgin Australia start operating?,"Virgin Australia, the trading name of Virgin A...",Virgin Australia commenced services on 31 Augu...,closed_qa
1,Which is a species of fish? Tope or Rope,,Tope,classification
2,Why can camels survive for long without water?,,Camels use the fat in their humps to keep them...,open_qa
3,"Alice's parents have three daughters: Amy, Jes...",,The name of the third daughter is Alice,open_qa
4,When was Tomoaki Komorida born?,Komorida was born in Kumamoto Prefecture on Ju...,"Tomoaki Komorida was born on July 10,1981.",closed_qa
...,...,...,...,...
795,Who is the founder of the Communist Party?,,Lenin,open_qa
796,What is gardening?,Gardening is the practice of growing and culti...,Gardening is laying out and caring for a plot ...,information_extraction
797,What are your thoughts of Michael Jackson as a...,,Michael Jackson is acclaimed as the greatest p...,creative_writing
798,What is the largest pollutant?,,Carbon dioxide (CO2) - a greenhouse gas emitte...,general_qa


In [12]:
train_dataset.to_pandas().dtypes

instruction    object
context        object
response       object
category       object
dtype: object

In [13]:
train_dataset.to_pandas().value_counts("category")

category
open_qa                   202
general_qa                132
classification            111
brainstorming              95
closed_qa                  90
information_extraction     68
summarization              63
creative_writing           39
Name: count, dtype: int64

# Generating the Prompt Format

In [14]:
def generate_prompt(sample):
    full_prompt =f"""<s>[INST]{sample['instruction']}
    {f"Here is some context: {sample['context']}" if len(sample["context"]) > 0 else None}
    [/INST] {sample['response']}</s>"""
    return {"text": full_prompt}

In [15]:
train_dataset[0]

{'instruction': 'When did Virgin Australia start operating?',
 'context': "Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.",
 'response': 'Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.',
 'category': 'closed_qa'}

In [16]:
print(generate_prompt(train_dataset[0]))

{'text': "<s>[INST]When did Virgin Australia start operating?\n    Here is some context: Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.\n    [/INST] Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.</s>"}


In [17]:
generated_train_dataset = train_dataset.map(
    generate_prompt, remove_columns=list(train_dataset.features))
generated_val_dataset = eval_dataset.map(
    generate_prompt, remove_columns=list(train_dataset.features))

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [18]:
generated_train_dataset

Dataset({
    features: ['text'],
    num_rows: 800
})

In [19]:
generated_train_dataset[5]["text"]

"<s>[INST]If I have more pieces at the time of stalemate, have I won?\n    Here is some context: Stalemate is a situation in chess where the player whose turn it is to move is not in check and has no legal move. Stalemate results in a draw. During the endgame, stalemate is a resource that can enable the player with the inferior position to draw the game rather than lose. In more complex positions, stalemate is much rarer, usually taking the form of a swindle that succeeds only if the superior side is inattentive.[citation needed] Stalemate is also a common theme in endgame studies and other chess problems.\n\nThe outcome of a stalemate was standardized as a draw in the 19th century. Before this standardization, its treatment varied widely, including being deemed a win for the stalemating player, a half-win for that player, or a loss for that player; not being permitted; and resulting in the stalemated player missing a turn. Stalemate rules vary in other games of the chess family.\n    

In [20]:
tokenizer(generated_train_dataset[5]["text"])

{'input_ids': [1, 1, 733, 16289, 28793, 3381, 315, 506, 680, 7769, 438, 272, 727, 302, 341, 282, 366, 380, 28725, 506, 315, 1747, 28804, 13, 2287, 4003, 349, 741, 2758, 28747, 662, 282, 366, 380, 349, 264, 4620, 297, 997, 819, 970, 272, 4385, 4636, 1527, 378, 349, 298, 2318, 349, 459, 297, 1877, 304, 659, 708, 5648, 2318, 28723, 662, 282, 366, 380, 2903, 297, 264, 3924, 28723, 6213, 272, 948, 8835, 28725, 341, 282, 366, 380, 349, 264, 3715, 369, 541, 8234, 272, 4385, 395, 272, 24797, 2840, 298, 3924, 272, 2039, 3210, 821, 6788, 28723, 560, 680, 4630, 9161, 28725, 341, 282, 366, 380, 349, 1188, 408, 283, 263, 28725, 4312, 3344, 272, 1221, 302, 264, 1719, 507, 291, 369, 9481, 28713, 865, 513, 272, 11352, 2081, 349, 297, 1061, 308, 495, 20011, 28717, 5174, 3236, 28793, 662, 282, 366, 380, 349, 835, 264, 3298, 7335, 297, 948, 8835, 7193, 304, 799, 997, 819, 4418, 28723, 13, 13, 1014, 14120, 302, 264, 341, 282, 366, 380, 403, 4787, 1332, 390, 264, 3924, 297, 272, 28705, 28740, 28774, 362, 5

# LoRA Configuration

In [21]:
from peft import prepare_model_for_kbit_training

model.gradient_checkpointing_enable()

model = prepare_model_for_kbit_training(model)

In [22]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [23]:
from peft import LoraConfig, get_peft_model
    
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head",
    ],
    bias="none",
    lora_dropout=0.05, 
    task_type="CAUSAL_LM",
)

In [24]:
from peft import get_peft_model

model = get_peft_model(model, lora_config)

print_trainable_parameters(model)

trainable params: 21260288 || all params: 3773331456 || trainable%: 0.5634354746703705


In [25]:
print(model)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): MistralForCausalLM(
      (model): MistralModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x MistralDecoderLayer(
            (self_attn): MistralAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): lora.Linear4bit(
                (base_layer): Linea

# Model Training

In [27]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("HF")

!huggingface-cli login --token $secret_value_0

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [28]:
from transformers import TrainingArguments

training_arguments = TrainingArguments(
    output_dir="./results",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    optim="paged_adamw_32bit",
    save_strategy="steps",
    save_steps=25,
    logging_steps=25,
    learning_rate=2e-4,
    weight_decay=0.001,
    max_steps=50,
    evaluation_strategy="steps", 
    eval_steps=25,       
    do_eval=True,               
    report_to="none",
)

In [29]:
from trl import SFTTrainer

# Setting sft parameters
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_arguments,
    train_dataset=generated_train_dataset,
    eval_dataset=generated_val_dataset,
    peft_config=lora_config,
    dataset_text_field="text",   
)



Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [30]:
model.config.use_cache = False
trainer.train()

Step,Training Loss,Validation Loss
25,1.5154,1.501748
50,1.5216,1.467335




TrainOutput(global_step=50, training_loss=1.5184689331054688, metrics={'train_runtime': 2032.8466, 'train_samples_per_second': 0.098, 'train_steps_per_second': 0.025, 'total_flos': 3719610284752896.0, 'train_loss': 1.5184689331054688, 'epoch': 0.25})

In [31]:
my_finetuned_model = "mistral-7b-dolly"

trainer.model.push_to_hub(my_finetuned_model)



adapter_model.safetensors:   0%|          | 0.00/609M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/ocaklisemih/mistral-7b-dolly/commit/b8db1c078370faf536f712d445cee1abd063f98f', commit_message='Upload model', commit_description='', oid='b8db1c078370faf536f712d445cee1abd063f98f', pr_url=None, pr_revision=None, pr_num=None)