In [None]:
from datasets import load_dataset
from langchain_ollama import ChatOllama
import re

In [None]:
test_dataset = load_dataset("json", data_files="../data/processed_test.jsonl", split="train").shuffle(seed=512)

In [None]:
def extract_score(generated_text):
    """
    Extracts the first digit found in the generated text that is a valid class label.
    """
    valid_digits = r"\d+"
    pattern = f"[{valid_digits}]"
    match = re.search(pattern, generated_text)
    if match:
        return int(match.group(0))
    else:
        return -1  # Parsing error

In [None]:
example = test_dataset[0]
for message in example["messages"]:
    print(f"{message['role']}:")
    print(f"{message['content']}\n")

In [None]:
llm = ChatOllama(
    model="hf.co/hugo-haldi/mistral-7b-dqi-justification:BF16", temperature=1
)

n_examples = 5
for example in test_dataset.select(range(n_examples)):
    query = [
        (message["role"], message["content"]) for message in example["messages"][:-1]
    ]
    y_true = example["messages"][-1]["content"]
    llm_response = llm.invoke(query)
    y_pred = extract_score(llm_response.content)
    print(f"Ground truth: {y_true} | LLM prediction: {y_pred}")