In [None]:
import torch
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from transformers import AutoTokenizer
from transformers import PreTrainedTokenizerFast
from typing import Dict, Tuple
from transformers import AutoModelForSequenceClassification


In [None]:
def evaluate_llm_classifier(model_path: str, test_data: pd.DataFrame, tokenizer: PreTrainedTokenizerFast) -> Tuple[Dict[str, float | int], Dict[str, float]]:

    model = AutoModelForSequenceClassification.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        num_labels=2
    )

    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.eos_token_id

    tokenizer.chat_template = None

    tokenizer.chat_template = """<|begin_of_text|>{% for message in messages %}<|start_header_id|>{{ message['role'] }}<|end_header_id|>{{ message['content'] }}<|eot_id|>{% endfor %}"""

    true_labels = []
    predicted_labels = []

    for _, row in test_data.iterrows():
        messages = [
            {"role": "system", "content": row["system"]},
            {"role": "user", "content": row["user"]},
        ]

        prompt = tokenizer.apply_chat_template(messages, tokenize=False)

        inputs = tokenizer(
            prompt,
            return_tensors="pt",
            padding='max_length',
            truncation=True,
            max_length=512
        ).to(model.device)

        with torch.no_grad():
            outputs = model(**inputs)

            logits = outputs.logits
            predictions = torch.softmax(logits, dim=-1)
            predicted_class = torch.argmax(predictions, dim=-1)

        pred_label = bool(predicted_class[0])
        true_label = row["assistant"].lower() == "true"

        true_labels.append(true_label)
        predicted_labels.append(pred_label)

    metrics = {
        "accuracy": accuracy_score(true_labels, predicted_labels),
        "precision": precision_score(true_labels, predicted_labels),
        "recall": recall_score(true_labels, predicted_labels),
        "f1_score": f1_score(true_labels, predicted_labels),
        "true_labels": true_labels,
        "predicted_labels": predicted_labels,
    }

    evaluation_confusion_matrix = confusion_matrix(true_labels, predicted_labels)
    plt.figure(figsize=(10,8))
    sns.heatmap(evaluation_confusion_matrix, annot=True, fmt="d", cmap="Blues")
    plt.title("Confusion Matrix")
    plt.ylabel("True label")
    plt.xlabel("Predicted label")
    plt.show()

    return metrics

In [None]:
test_df = pd.read_csv("../processed_data/test.csv")

In [None]:
base_model_path = "../models/llama_models/llama-3.2-3B"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)

In [None]:
evaluate_llm_classifier(
    model_path="FILL_THIS_IN_AFTER_SAVING",
    test_data=test_df,
    tokenizer=tokenizer
)