<a href="https://colab.research.google.com/github/FurkanP/Mistral-7B-Fine-Tune/blob/main/mistral_7b_fine_tune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Load necessary libraries from Hugging Face's Transformers and other modules
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from transformers import BitsAndBytesConfig  # For model quantization to reduce size
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model  # For LoRA fine-tuning
from huggingface_hub import notebook_login  # For Hugging Face login
from datasets import load_dataset  # Load dataset
from trl import SFTTrainer  # Trainer for fine-tuning using supervised data
import torch


In [None]:
# Hugging Face login for accessing models and datasets
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
# Base model: Mistral 7B
base_model = "mistralai/Mistral-7B-v0.1"

In [None]:
# Load the tokenizer for the base model
tokenizer = AutoTokenizer.from_pretrained(
    base_model,
    padding_side="right",  # Padding added to the right
    add_eos_token=True  # End of sequence token added at the end of text
)

# Set padding token to the end of sequence token
tokenizer.pad_token = tokenizer.eos_token

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
# Add beginning-of-sequence and end-of-sequence tokens (no additional settings here)
tokenizer.add_bos_token, tokenizer.add_eos_token

(True, True)

In [None]:
# Configure the model to use 4-bit quantization with specific settings to optimize memory and computation
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,  # Load the model in 4-bit
    bnb_4bit_quant_type="nf4",  # Type of 4-bit quantization (Normalized Float 4)
    bnb_4bit_use_double_quant=False,  # No double quantization to save memory
    bnb_4bit_compute_dtype=torch.bfloat16  # Use bfloat16 for flexible computation
)

In [None]:
# Load the model with quantization settings, set it to use the appropriate device
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map="auto"  # Automatically choose the best device (CPU or GPU)
)

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

In [None]:
# Load a dataset from Hugging Face
dataset_name = "databricks/databricks-dolly-15k"
train_dataset = load_dataset(dataset_name, split="train[0:800]")  # First 800 examples for training
eval_dataset = load_dataset(dataset_name, split="train[800:1000]")  # Next 200 examples for evaluation


Downloading readme:   0%|          | 0.00/8.20k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/13.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/15011 [00:00<?, ? examples/s]

In [None]:
# Convert the training dataset to pandas DataFrame for easier handling
train_dataset.to_pandas()

Unnamed: 0,instruction,context,response,category
0,When did Virgin Australia start operating?,"Virgin Australia, the trading name of Virgin A...",Virgin Australia commenced services on 31 Augu...,closed_qa
1,Which is a species of fish? Tope or Rope,,Tope,classification
2,Why can camels survive for long without water?,,Camels use the fat in their humps to keep them...,open_qa
3,"Alice's parents have three daughters: Amy, Jes...",,The name of the third daughter is Alice,open_qa
4,When was Tomoaki Komorida born?,Komorida was born in Kumamoto Prefecture on Ju...,"Tomoaki Komorida was born on July 10,1981.",closed_qa
...,...,...,...,...
795,Who is the founder of the Communist Party?,,Lenin,open_qa
796,What is gardening?,Gardening is the practice of growing and culti...,Gardening is laying out and caring for a plot ...,information_extraction
797,What are your thoughts of Michael Jackson as a...,,Michael Jackson is acclaimed as the greatest p...,creative_writing
798,What is the largest pollutant?,,Carbon dioxide (CO2) - a greenhouse gas emitte...,general_qa


In [None]:
# Prompt formatter: This function structures the data sample into a format the model can process
def generate_prompt(sample):
    full_prompt = f"""<s>[INST]{sample['instruction']}
    {f"Here is some context: {sample['context']}" if len(sample['context']) > 0 else ''}
    [/INST] {sample['response']}</s>
    """
    return {"text": full_prompt}

In [None]:
# Check the first sample of the training dataset
train_dataset[0]

{'instruction': 'When did Virgin Australia start operating?',
 'context': "Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.",
 'response': 'Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.',
 'category': 'closed_qa'}

In [None]:
# Test the prompt formatting function with a sample
generate_prompt(train_dataset[0])

{'text': "<s>[INST]When did Virgin Australia start operating?\n    Here is some context: Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.\n    [/INST] Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.</s>\n    "}

In [None]:
# Apply the prompt formatting function to the entire training and evaluation datasets
generated_train_dataset = train_dataset.map(generate_prompt, remove_columns=list(train_dataset.features))
generated_eval_dataset = eval_dataset.map(generate_prompt, remove_columns=list(eval_dataset.features))


Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [None]:
# Check the generated prompt for the 5th sample in the dataset
generated_train_dataset[5]

{'text': "<s>[INST]If I have more pieces at the time of stalemate, have I won?\n    Here is some context: Stalemate is a situation in chess where the player whose turn it is to move is not in check and has no legal move. Stalemate results in a draw. During the endgame, stalemate is a resource that can enable the player with the inferior position to draw the game rather than lose. In more complex positions, stalemate is much rarer, usually taking the form of a swindle that succeeds only if the superior side is inattentive.[citation needed] Stalemate is also a common theme in endgame studies and other chess problems.\n\nThe outcome of a stalemate was standardized as a draw in the 19th century. Before this standardization, its treatment varied widely, including being deemed a win for the stalemating player, a half-win for that player, or a loss for that player; not being permitted; and resulting in the stalemated player missing a turn. Stalemate rules vary in other games of the chess fami

In [None]:
# Enable gradient checkpointing to save memory during training
# It saves only strategic values in memory instead of all values, reducing memory usage during backpropagation
model.gradient_checkpointing_enable()

In [None]:
# Prepare the model for 4-bit quantization with LoRA (Low-Rank Adaptation)
model = prepare_model_for_kbit_training(model)

In [None]:
# Print the number of trainable parameters (helps in tracking model size and efficiency)
def print_trainable_parameters(model):
    trainable_params = 0
    all_params = 0

    for _, param in model.named_parameters():
        all_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()

    print(f"Trainable parameters: {trainable_params} || Total parameters: {all_params} || Trainable %: {100 * trainable_params / all_params}")


In [None]:
# LoRA configuration: Sets parameters for low-rank adaptation (LoRA) fine-tuning
# r is the rank, lora_alpha controls scaling, target_modules specify which layers to apply LoRA to, and dropout is a regularization technique
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",  # LoRA will be applied to these projection layers
        "gate_proj", "up_proj", "down_proj", "lm_head"
    ],
    bias="none",  # No bias is added to the model
    lora_dropout=0.05,  # Dropout rate to prevent overfitting
    task_type="CAUSAL_LM"  # The task type is causal language modeling (predicting the next word)
)


In [None]:
# Apply LoRA to the model
model = get_peft_model(model, lora_config)

In [None]:
# Print the number of trainable parameters after LoRA is applied
print_trainable_parameters(model)

Trainable parameters: 21260288 || Total parameters: 3773331456 || Trainable %: 0.5634354746703705


In [None]:
# Display the model architecture (useful for debugging and understanding model structure)
print(model)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): MistralForCausalLM(
      (model): MistralModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x MistralDecoderLayer(
            (self_attn): MistralSdpaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_p

In [None]:
# Hugging Face login again for pushing the fine-tuned model to the hub
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
# Define training arguments: Controls how training is done
training_arguments = TrainingArguments(
    output_dir="./results",  # Directory to save the model and results
    num_train_epochs=1,  # Number of epochs to train
    per_device_train_batch_size=4,  # Batch size per device (GPU/CPU)
    gradient_accumulation_steps=1,  # Accumulate gradients over 1 step
    optim="paged_adamw_32bit",  # Optimization algorithm (AdamW with memory paging)
    save_strategy="steps",  # Save the model after a certain number of steps
    save_steps=25,  # Save the model every 25 steps
    learning_rate=2e-4,  # Learning rate for training
    weight_decay=0.001,  # Weight decay to prevent overfitting
    max_steps=20,  # Maximum number of training steps
    evaluation_strategy="steps",  # Evaluate the model after a certain number of steps
    eval_steps=25,  # Evaluate every 25 steps
    do_eval=True,  # Enable evaluation
    report_to="none",  # No reporting to any external tools (like WandB)
)



In [None]:
# Initialize the trainer for supervised fine-tuning (SFT) with the training arguments
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_arguments,
    train_dataset=generated_train_dataset,  # Use the generated training dataset
    eval_dataset=generated_eval_dataset,  # Use the generated evaluation dataset
    peft_config=lora_config,  # Use the LoRA configuration for fine-tuning
    dataset_text_field="text",  # The field containing text data for training
)


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.


Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

max_steps is given, it will override any value given in num_train_epochs


In [None]:
# Disable cache during training to save memory
model.config.use_cache = False

In [None]:
# Start the training process
trainer.train()

  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step,Training Loss,Validation Loss




TrainOutput(global_step=20, training_loss=1.5851390838623047, metrics={'train_runtime': 466.0666, 'train_samples_per_second': 0.172, 'train_steps_per_second': 0.043, 'total_flos': 1433687163273216.0, 'train_loss': 1.5851390838623047, 'epoch': 0.1})

In [None]:
# After training, push the fine-tuned model to Hugging Face Hub
my_finetuned_model = "mistral-7b-fine-tune"
trainer.model.push_to_hub(my_finetuned_model)



adapter_model.safetensors:   0%|          | 0.00/609M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/FurkanPirinc/mistral-7b-fine-tune/commit/12958605eeda4c5d1a0963d8a4293092ee9c8856', commit_message='Upload model', commit_description='', oid='12958605eeda4c5d1a0963d8a4293092ee9c8856', pr_url=None, pr_revision=None, pr_num=None)