In [3]:
!pip install -q transformers datasets scikit-learn accelerate torch pandas bitsandbytes peft

import random
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
import gc
import os
import shutil

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback,
    BitsAndBytesConfig,
    set_seed,
    DataCollatorWithPadding
)
from peft import (
    prepare_model_for_kbit_training,
    LoraConfig,
    get_peft_model,
    TaskType
)
from datasets import load_dataset
from sklearn.metrics import f1_score
from sklearn.utils.class_weight import compute_class_weight


SEED = 42
OUTPUT_DIR = "./llama3_task2_evasion"
MAX_LEN = 512

if os.path.exists(OUTPUT_DIR):
    shutil.rmtree(OUTPUT_DIR)

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
set_seed(SEED)

dataset = load_dataset("ailsntua/QEvasion")

def preprocess(example):
    clarity = example.get("clarity_label", "Unknown") or "Unknown"
    text = f"Context: {clarity} | Question: {example['question']} Answer: {example['interview_answer']}"
    return {"text": text, "evasion_label": example["evasion_label"]}

full_data = dataset["train"].map(preprocess)
full_data = full_data.class_encode_column("evasion_label")

split1 = full_data.train_test_split(
    test_size=0.1,
    seed=SEED,
    stratify_by_column="evasion_label"
)
train_dev_ds = split1["train"]
held_out_test_ds = split1["test"]

split2 = train_dev_ds.train_test_split(
    test_size=0.1,
    seed=SEED,
    stratify_by_column="evasion_label"
)
train_ds = split2["train"]
eval_ds = split2["test"]

labels = train_ds.features["evasion_label"].names
num_labels = len(labels)
label2id = {name: i for i, name in enumerate(labels)}
id2label = {i: name for name, i in label2id.items()}

y_train = train_ds["evasion_label"]
class_weights = compute_class_weight(
    class_weight="balanced",
    classes=np.unique(y_train),
    y=y_train
)
device = "cuda" if torch.cuda.is_available() else "cpu"
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)


model_id = "unsloth/llama-3-8b-bnb-4bit"

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

def tokenize_fn(batch):
    return tokenizer(
        batch["text"],
        padding="max_length",
        truncation=True,
        max_length=MAX_LEN,
    )

train_ds = train_ds.map(tokenize_fn, batched=True)
eval_ds = eval_ds.map(tokenize_fn, batched=True)
held_out_test_ds = held_out_test_ds.map(tokenize_fn, batched=True)


train_ds = train_ds.rename_column("evasion_label", "labels")
eval_ds = eval_ds.rename_column("evasion_label", "labels")
held_out_test_ds = held_out_test_ds.rename_column("evasion_label", "labels")


bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = AutoModelForSequenceClassification.from_pretrained(
    model_id,
    num_labels=num_labels,
    quantization_config=bnb_config,
    id2label=id2label,
    label2id=label2id,
    device_map="auto"
)
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False
model.config.pretraining_tp = 1

model = prepare_model_for_kbit_training(model)

peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=16,           # Rank
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {
        "accuracy": (preds == labels).mean(),
        "macro_f1": f1_score(labels, preds, average="macro"),
    }

class WeightedPeftTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        loss_fct = nn.CrossEntropyLoss(weight=class_weights_tensor)
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    learning_rate=2e-4,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    num_train_epochs=10,
    bf16=True,
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="macro_f1",
    report_to="none",
    seed=SEED
)

trainer = WeightedPeftTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

trainer.train()

test_results = trainer.evaluate(held_out_test_ds)


print(f"Macro F1: {test_results['eval_macro_f1']:.4f}")
print(f"Accuracy: {test_results['eval_accuracy']:.4f}")

Map:   0%|          | 0/3448 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/3448 [00:00<?, ? examples/s]

Map:   0%|          | 0/2792 [00:00<?, ? examples/s]

Map:   0%|          | 0/311 [00:00<?, ? examples/s]

Map:   0%|          | 0/345 [00:00<?, ? examples/s]

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at unsloth/llama-3-8b-bnb-4bit and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 41,979,904 || all params: 7,546,941,440 || trainable%: 0.5563


  return fn(*args, **kwargs)


Epoch,Training Loss,Validation Loss,Accuracy,Macro F1
1,3.4215,1.642021,0.498392,0.330015
2,1.9234,1.343947,0.530547,0.438143
3,2.1649,1.229609,0.536977,0.498367
4,1.1698,1.793938,0.562701,0.475127
5,1.2488,1.550135,0.633441,0.561801
6,0.1085,2.312504,0.617363,0.530448
7,0.2275,3.958944,0.588424,0.486411
8,0.2049,3.291965,0.614148,0.542063


  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)


Macro F1: 0.5401
Accuracy: 0.6029
