In [None]:
BASE_PATH = ".."
MODEL_ID = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit"
MAX_NEW_TOKENS = 8192
MAX_SEQ_LENGTH = 32768 - MAX_NEW_TOKENS

In [None]:
import os
import json

import torch  # type: ignore
import numpy as np  # type: ignore

from datasets import DatasetDict, Dataset  # type: ignore

from unsloth import FastLanguageModel  # type: ignore

from tqdm.auto import tqdm  # type: ignore

from datasets import Dataset, DatasetDict  # type: ignore

from groq import Groq  # type: ignore
from dotenv import load_dotenv # type: ignore

# Compare with Groq

In [None]:
PROMPTS = {
    "compare_predictions": """Compare the following prediction with the actual answer:

Prediction: {pred}
Actual Answer: {actual}

Evaluate the prediction's accuracy and provide a brief explanation. 
Rate the prediction on a scale of 1-5, where 1 is completely incorrect and 5 is perfectly accurate.

Response format:
{{
    "rating": [1-5],
    "explanation": [Your explanation here]
}}
""",
'compare_answers': """Compare the following two answers:

Answer 1: {answer1}
Answer 2: {answer2}

Evaluate the accuracy and provide a brief explanation. 
Rate the answer on a scale of 1-5, where 1 is completely incorrect and 5 is perfectly accurate.

Response format:
{{
    "rating": [1-5],
    "explanation": [Your explanation here]
}}
"""
}

In [None]:
# Load environment variables from variables.env file
load_dotenv(f"{BASE_PATH}/variables.env")

# Access the GROQ_API_KEY
groq_api_key = os.getenv("GROQ_API_KEY")

# Verify that the key was loaded
if groq_api_key:
    print("GROQ API key loaded successfully.")
else:
    print("Failed to load GROQ API key.")

In [None]:
client = Groq(api_key=groq_api_key)

In [None]:
def load_data(file_path):
    with open(file_path, "r") as f:
        data = json.load(f)
    return data

def to_dataset(data):
    restructured_data = {
        "question": [],
        "resources": [],
        "answer": [],
    }

    for qna in data:
        restructured_data["question"].append(qna["question"])
        if "text" in qna:
            restructured_data["answer"].append(qna["text"])
        else:
            restructured_data["answer"].append(qna["answer"])
        restructured_data["resources"].append('\n'.join([resource['summary'] for resource in qna["citation"]]))

    return Dataset.from_dict(restructured_data)


def prepare_dataset(base_path=None):
    test_cars = load_data(f"{base_path}/data/test_qa_car.json")
    test_sleep = load_data(f"{base_path}/data/test_qa_sleep.json")
    
    cars_predictions = load_data(f"{base_path}/data/cars_predictions.json")
    sleep_predictions = load_data(f"{base_path}/data/sleep_predictions.json")

    test_cars_dataset = to_dataset(test_cars)
    test_sleep_dataset = to_dataset(test_sleep)
    
    cars_predictions_dataset = to_dataset(cars_predictions)
    sleep_predictions_dataset = to_dataset(sleep_predictions)
    
    return {"cars": test_cars_dataset, "sleep": test_sleep_dataset, "cars_predictions": cars_predictions_dataset, "sleep_predictions": sleep_predictions_dataset}

In [None]:
# we make 3 comparisons
# 1 - when we tell model which response is ground truth
# 2 - when we don't tell it which response is ground truth just ask to compare
# 3 - when we do same as previous but swap predicted and ground truth places
def compare_predictions(predictions, actual_answers, task_type="compare_predictions"):
    results = []
    for pred, actual in zip(predictions, actual_answers):
        prompt = PROMPTS[task_type].format(pred=pred, actual=actual)
        response = client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": prompt,
                }
            ],
            model="mixtral-8x7b-32768",
            temperature=0.5,
            max_tokens=256,
        )

        results.append(response.choices[0].message.content)

    return results

In [None]:
comparison_results = compare_predictions(results["predictions"], results["answers"])

# Print or process the results as needed
for i, result in enumerate(comparison_results):
    print(f"Comparison {i + 1}:")
    print(result)
    print("-" * 50)