# Fine-tuning GEMMA 2 for Stance Analysis

This notebook demonstrates the process of fine-tuning the GEMMA 2 model for stance detection in mask-wearing comments during the COVID-19 pandemic using parameter-efficient techniques such as LoRA.

## 1. Setup


### 1.1 Install Required Packages


In [None]:
# Installs Unsloth, Xformers (Flash Attention) and all other packages!
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps xformers "trl<0.9.0" peft accelerate bitsandbytes
!pip uninstall unsloth -y
!pip install --upgrade --force-reinstall --no-cache-dir git+https://github.com/unslothai/unsloth.git

### 1.2 Import Required Libraries

In [None]:
import re
import torch
import seaborn as sns
import matplotlib.pyplot as plt
from datasets import load_dataset
from unsloth import FastLanguageModel
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report

torch.cuda.empty_cache()
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

### 1.3 Login to Hugging Face Hub

In [None]:
from huggingface_hub import login
login()

## 2. Model and Dataset Preparation

### 2.1 Load Pretrained Model

In [None]:
max_seq_length = 1600 
dtype = None
load_in_4bit = True 

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/gemma-2-9b",    
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)

### 2.2 Configure LoRA for Fine-tuning

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha=32,
    lora_dropout=0.4,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
    use_rslora=False,
    loftq_config=None,
)

### 2.3 Load and Format Datasets

In [None]:
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""

EOS_TOKEN = tokenizer.eos_token

def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    inputs = examples["input"]
    outputs = examples["output"]
    texts = []
    for instruction, input, output in zip(instructions, inputs, outputs):
        text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
        texts.append(text)
    return {"text": texts}

main_dataset = load_dataset("Supakrit65/llama3-clean-general-stance", split="train")
kiki_mask_dataset = load_dataset("Supakrit65/kiki-mask-stance-labelled-db", token="hf_xxx", split="train")
eval_dataset = load_dataset("Supakrit65/llama3-clean-general-stance", split="test")

main_dataset = main_dataset.map(formatting_prompts_func, batched=True)
kiki_mask_dataset = kiki_mask_dataset.map(formatting_prompts_func, batched=True)

dataset = datasets.concatenate_datasets([main_dataset, kiki_mask_dataset]).shuffle(seed=42)
eval_dataset = eval_dataset.map(formatting_prompts_func, batched=True)

### 2.4 Balance the Dataset

In [None]:
from collections import Counter

initial_output_counts = Counter(dataset["output"])
print("Initial Class Distribution:")
for output_class, count in initial_output_counts.items():
    print(f"- {output_class}: {count} examples")

output_counts = Counter(dataset["output"])
minority_class_size = min(output_counts.values())

balanced_dataset = []
for output_class in output_counts.keys():
    class_examples = dataset.filter(lambda example: example["output"] == output_class)
    undersampled_examples = class_examples.select(range(minority_class_size))
    balanced_dataset.extend(undersampled_examples)

balanced_dataset = datasets.Dataset.from_list(balanced_dataset).shuffle(seed=42)

output_counts_balanced = Counter(balanced_dataset["output"])
print("\nClass Distribution in Balanced Dataset:")
for output_class, count in output_counts_balanced.items():
    print(f"- {output_class}: {count} examples")

## 3. Fine-tuning and Evaluation

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions[0]
    precision = precision_score(labels, preds, average='weighted')
    recall = recall_score(labels, preds, average='weighted')
    f1 = f1_score(labels, preds, average='weighted')
    
    return {
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments, TrainerCallback
from unsloth import is_bfloat16_supported

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=balanced_dataset,
    eval_dataset=eval_dataset,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    packing=False,
    args=TrainingArguments(
        num_train_epochs=1,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=4,
        gradient_accumulation_steps=2,
        warmup_steps=100,
        max_steps=-1,
        learning_rate=1e-4,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=10,
        evaluation_strategy="steps",
        eval_steps=100,
        save_strategy="steps",
        optim="adamw_8bit",
        lr_scheduler_type="cosine",
        seed=3407,
        output_dir="outputs-v3",
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
    ),
)

In [None]:
trainer_stats = trainer.train()

In [None]:
final_eval_results = trainer.evaluate()
print("Final evaluation results:", final_eval_results)

In [None]:
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

In [None]:
trainer.model.save_pretrained("XXX")
trainer.tokenizer.save_pretrained("XXX")