# QLoRA on Mistral-7B for Topic Labeling

by Andreas Sünder

## Setup

In [None]:
%pip install -q -U bitsandbytes
%pip install -q -U git+https://github.com/huggingface/transformers.git
%pip install -q -U git+https://github.com/huggingface/peft.git
%pip install -q -U git+https://github.com/huggingface/accelerate.git
%pip install -q -U datasets scipy

## Load Dataset

In [None]:
from datasets import load_dataset

train_dataset = load_dataset('textminr/topic-labeling', split = 'train')
val_dataset = load_dataset('textminr/topic-labeling', split = 'validation')

## Load Base Model

In [None]:
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

base_model_id = "mistralai/Mistral-7B-Instruct-v0.1"

bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_use_double_quant = True,
    bnb_4bit_quant_type = "nf4",
    bnb_4bit_compute_dtype = torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config = bnb_config, torch_dtype = torch.float16, device_map = "auto")

## Setup Formatting & Tokenization

In [None]:
def formatting_func(example):
    return f"### Topic Words: {','.join(str(word) for word in list(dict(example).values())[2:])}\n ### Topic Label: {example['label']}"

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
    base_model_id,
    padding_side="left",
    add_eos_token=True,
    add_bos_token=True,
)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
max_length = 500

def generate_and_tokenize_prompt(prompt):
    result = tokenizer(
        formatting_func(prompt),
        truncation = True,
        max_length = max_length,
        padding = "max_length",
    )
    result["labels"] = result["input_ids"].copy()
    return result

In [None]:
tokenized_train_dataset = train_dataset.map(generate_and_tokenize_prompt)
tokenized_val_dataset = val_dataset.map(generate_and_tokenize_prompt)

## Setup LoRA

In [None]:
from peft import prepare_model_for_kbit_training

model.gradient_checkpointing_enable()
model.enable_input_require_grads()
model = prepare_model_for_kbit_training(model)

In [None]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [None]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    task_type = "CAUSAL_LM",
    r = 32,
    lora_alpha = 64,
    target_modules = [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head",
    ],
    bias = "none",
    lora_dropout = 0.05,
)

model = get_peft_model(model, config)
print_trainable_parameters(model)

## Run Training

In [None]:
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datetime import datetime

project = "topic-labeling"
base_model_name = "mistral-7b"
run_name = '-'.join(base_model_name, project)
output_dir = "./" + run_name

trainer = Trainer(
    model = model,
    train_dataset = tokenized_train_dataset,
    eval_dataset = tokenized_val_dataset,
    args = TrainingArguments(
        output_dir = output_dir,
        warmup_steps = 1,
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 1,
        gradient_checkpointing = True,
        max_steps = 500,
        learning_rate = 2.5e-5,
        bf16 = True,
        optim = "paged_adamw_8bit",
        logging_steps = 25,
        logging_dir = "./logs",
        save_strategy = "steps",
        save_steps = 25,
        evaluation_strategy = "steps",
        eval_steps = 25,
        do_eval = True,
    ),
    data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False),
)

model.config.use_cache = False
trainer.train()