# QLoRA Fine-Tuning on Mistral 7B

by Benjamin Kissinger & Andreas Sünder

## Install required packages (only needed once)

```bash
%pip install -r requirements.txt
```

## Training setup

In [None]:
dataset_name = ''
base_model_id = 'mistralai/Mistral-7B-Instruct-v0.1'
max_input_length = -1
hub_model_id = ''

import os, wandb
os.environ['WANDB_PROJECT'] = ''
os.environ['WANDB_LOG_MODEL'] = 'checkpoint'


## Load Dataset

In [None]:
from datasets import load_dataset
dataset = load_dataset(dataset_name)

print('Train dataset size:', dataset['train'].num_rows)
print('Validation dataset size:', dataset['validation'].num_rows)

## Load Base Model

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

model = AutoModelForCausalLM.from_pretrained(
  base_model_id,
  quantization_config=BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
  ),
  torch_dtype=torch.float16,
  device_map='auto'
)

tokenizer = AutoTokenizer.from_pretrained(
  base_model_id,
  padding_side='left',
  add_eos_token=True,
  add_bos_token=True
)
tokenizer.pad_token = tokenizer.eos_token

## Tokenize dataset

In [None]:
def tokenize_sample(prompt):
  result = tokenizer(
    prompt,
    padding='max_length',
    max_length=max_input_length,
    truncation=True,
  )
  result['labels'] = result['input_ids'].copy()
  return result

tokenized_dataset = dataset.map(tokenize_sample, batched=True)

## Setup LoRA

In [None]:
from peft import prepare_model_for_kbit_training

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

In [None]:
def print_trainable_parameters(model):
  trainable_params = 0
  all_param = 0
  for _, param in model.named_parameters():
    all_param += param.numel()
    if param.requires_grad:
      trainable_params += param.numel()

  print(f'trainable params: {trainable_params} || all params: {all_param} || trainable: {100 * trainable_params / all_param: .2f}%')

In [None]:
from peft import LoraConfig, get_peft_model
from peft import TaskType

config = LoraConfig(
  r=8,
  lora_alpha=8,
  target_modules=[
    'q_proj',
    'k_proj',
    'v_proj',
    'o_proj',
    'gate_proj',
    'up_proj',
    'down_proj',
  ],
  bias='none',
  lora_dropout=0.05,
  task_type=TaskType.CAUSAL_LM
)

model = get_peft_model(model, config)
print_trainable_parameters(model)

## Resume training (optional)

In [None]:
import wandb

last_run_id = ''
run = wandb.init(
  project=os.environ['WANDB_PROJECT'],
  id=last_run_id,
  resume='must'
)

Fetch latest checkpoint from Weights & Biases:

In [None]:
# use :latest for the latest checkpoint
latest_checkpoint = f'checkpoint-{last_run_id}:<version>'
artifact = run.use_artifact(latest_checkpoint, type='model')
artifact_dir = artifact.download()

## Run Training

In [None]:
from transformers import (DataCollatorForLanguageModeling,
                          EarlyStoppingCallback, Trainer, TrainingArguments)

trainer = Trainer(
    model=model,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['validation'],
    args=TrainingArguments(
        output_dir='./output',
        logging_dir='./logs',
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        gradient_checkpointing=True,
        gradient_accumulation_steps=1,
        warmup_steps=50,
        num_train_epochs=1,
        learning_rate=1e-3,
        bf16=True,
        optim='paged_adamw_8bit',
        logging_strategy='steps',
        logging_steps=25,
        save_strategy='steps',
        # save_steps should ideally be identical to eval_steps
        save_steps=500,
        save_total_limit=2,
        do_eval=True,
        evaluation_strategy='steps',
        eval_steps=500,
        load_best_model_at_end=True,
        metric_for_best_model='loss',
        greater_is_better=False,
        report_to='wandb',
    ),
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
    # callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

trainer.train()
# or:
# trainer.train(resume_from_checkpoint=artifact_dir)

In [None]:
wandb.finish()

## Push to hub

In [None]:
trainer.push_to_hub(hub_model_id)