In [None]:
import torch
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from transformers import AutoTokenizer
from transformers import PreTrainedTokenizerFast
from typing import Dict, Tuple
from transformers import AutoModelForCausalLM


In [None]:
def evaluate_llm_classifier(model_path: str, test_data: pd.DataFrame, tokenizer: PreTrainedTokenizerFast) -> Tuple[Dict[str, float | int], Dict[str, float]]:

    model = AutoModelForCausalLM.from_pretrained(model_path)(
        model_path=model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )

    tokenizer.chat_template = None

    tokenizer.chat_template = """<|begin_of_text|>{% for message in messages %}<|start_header_id|>{{ message['role'] }}<|end_header_id|>{{ message['content'] }}<|eot_id|>{% endfor %}"""

    true_labels = []
    predicted_labels = []
    format_violations = []

    for _, row in test_data.iterrows():

        messages = [
            {"role": "system", "content": row["system"]},
            {"role": "user", "content": row["user"]},
        ]

        prompt = tokenizer.apply_chat_template(messages, tokenize=False)

        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=5
            )

        prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)

        pred_label = "true" in prediction.lower()
        true_label = row["assistant"].lower() == "true"

        if prediction != "true" or prediction != "false":
            format_violations.append(messages)

        true_labels.append(true_label)
        predicted_labels.append(pred_label)

        metrics = {
            "accuracy": accuracy_score(true_labels, predicted_labels),
            "precision": precision_score(true_labels, predicted_labels),
            "recall": recall_score(true_labels, predicted_labels),
            "f1_score": f1_score(true_labels, predicted_labels),
            "format_violations": len(format_violations)
        }

        evaluation_confusion_matrix = confusion_matrix(true_labels, predicted_labels)
        plt.figure(figsize=(10,8))
        sns.heatmap(evaluation_confusion_matrix, annot=True, fmt="d", cmap="Blues")
        plt.title("Confusion Matrix")
        plt.ylabel("True label")
        plt.xlabel("Predicted label")
        plt.show()

        return format_violations, metrics


In [None]:
test_df = pd.read_csv("../processed_data/test.csv")

In [None]:
base_model_path = "../models/llama_models/llama-3.2-3B"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)

In [None]:
evaluate_llm_classifier(
    model_path="FILL_THIS_IN_AFTER_SAVING",
    test_data=test_df,
    tokenizer=tokenizer
)