In [10]:
import os.path

from datasets import load_dataset, load_from_disk
from dotenv import load_dotenv

load_dotenv("../.env")
# Load MMLU Dataset
dataset_name = "LFrancis/MMLU-NoOp-Plus"
baseline_dataset_name = "cais/mmlu"
subset = "all_addition"
dataset = load_dataset(dataset_name, subset)["train"]

# VLLM API Configuration
BASE_URL = "http://134.76.18.30:8080/v1/chat/completions"
HEADERS = {"Content-Type": "application/json", "Authorization": "Bearer " + os.getenv("VLLM_API_KEY")}
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
EVALUATED_MODEL_PATH = dataset_name + "_" + subset + "_evaluated_" + MODEL_NAME
BASELINE_MODEL_PATH = baseline_dataset_name + "_evaluated_" + MODEL_NAME
if not os.path.exists(EVALUATED_MODEL_PATH):
    dataset.save_to_disk(EVALUATED_MODEL_PATH)
#dataset.save_to_disk(EVALUATED_MODEL_PATH)
options = ["A", "B", "C", "D"]

train-00000-of-00001.parquet:   0%|          | 0.00/4.73M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/14042 [00:00<?, ? examples/s]

In [11]:
from datasets import Dataset
import requests


# Helper Functions
def create_chat_messages(question, choices, subject, sys_msg):
    """
    Create a formatted list of chat messages for the chat model.
    """
    user_prompt = f"{question}\n" + "\n".join(
        [f"{opt}. {choice}" for opt, choice in zip(options, choices)]
    ) + "\nAnswer:"
    return [
        {"role": "system", "content": sys_msg.format(subject)},
        {"role": "user", "content": user_prompt}
    ]


def query_vllm_api(payload):
    """
    Send a query to the VLLM API and return the response.
    """
    response = requests.post(BASE_URL, json=payload, headers=HEADERS, timeout=120)
    response.raise_for_status()  # Raise an error for HTTP issues
    return response.json()


def evaluate_question(entry):
    # Step 1: Generate reasoning (CoT) response
    sys_msg = "The following are multiple choice questions (with answers) about {}."
    question, choices, subject = entry["question"], entry["choices"], entry["subject"]

    messages = create_chat_messages(question, choices, subject, sys_msg)

    cot_payload = {
        "model": MODEL_NAME,  # Specify model
        "messages": messages,
        "max_tokens": 200,
        "temperature": 0.0,
    }

    cot_response = query_vllm_api(cot_payload)
    if "object" in cot_response.keys() and cot_response["object"] == "error":
        raise Exception(cot_response["message"])

    cot_text = cot_response["choices"][0]["message"]["content"].strip()  # Extract CoT reasoning

    # Step 2: Calculate logprobs for each choice
    final_prompt = f"{cot_text}\n" + "\n".join(
        [f"{opt}. {choice}" for opt, choice in zip(options, choices)]
    ) + "Final Answer: "
    logprobs = {}
    for idx, option in enumerate(options):
        choice_messages = [
            *messages,
            {"role": "system", "content": sys_msg.format(subject)},
            {"role": "user", "content": final_prompt + f" {option}"}
        ]
        choice_payload = {
            "model": MODEL_NAME,
            "messages": choice_messages,
            "max_tokens": 1,
            "temperature": 0.0,
            "prompt_logprobs": 0
        }
        choice_response = query_vllm_api(choice_payload)
        if "prompt_logprobs" not in choice_response:
            raise Exception(f"No prompt logprobs found for {option}")
        logprobs[option] = list(choice_response["prompt_logprobs"][-6].values())[0]["logprob"]
    entry["logprobs"] = logprobs
    return entry


def is_correct(entry):
    """
    Determines if the choice with the lowest log probability corresponds to the correct answer.

    Args:
        entry (dict): A dictionary containing the question, choices, answer index, and logprobs.

    Returns:
        bool: True if the option with the lowest logprob matches the correct answer index, False otherwise.
    """
    # Extract logprobs and the correct answer index
    logprobs = entry['logprobs']
    if logprobs == {} or logprobs == {'A': None, 'B': None, 'C': None, 'D': None}:
        print("skip ", end="")
        return False
    correct_answer_index = entry['answer']

    # Find the key (A, B, C, D) with the lowest logprob
    highest_logprob_option = max(logprobs, key=logprobs.get)

    # Map the key to its corresponding index (0 for 'A', 1 for 'B', etc.)
    options = ["A", "B", "C", "D"]
    highest_logprob_index = options.index(highest_logprob_option)

    # Check if the lowest logprob index matches the correct answer index
    return highest_logprob_index == correct_answer_index


def process_dataset(dataset: Dataset, numproc=1):
    """
    Process the dataset using Dataset.map.
    """

    def process_entry(entry):
        if "logprobs" in entry.keys():
            # entry has been touched
            if  entry["logprobs"] != {} and entry["logprobs"] != {'A': None, 'B': None, 'C': None, 'D': None}:
                return entry
        try:
            return evaluate_question(entry)
        except Exception as e:
            print(f"Error processing entry: {entry}, Exception: {e}")
            entry["logprobs"] = {}
            return entry

    return dataset.map(process_entry, with_indices=False, num_proc=numproc)

In [12]:
def update_dataset(dataset, is_baseline=False):
    # Save the updated dataset to a temporary location
    temp_path = "temp"
    dataset.save_to_disk(temp_path)

    # Overwrite the original dataset directory
    import shutil
    original_path = EVALUATED_MODEL_PATH if is_baseline == False else BASELINE_MODEL_PATH

    # Remove the old dataset and replace it with the new one
    shutil.rmtree(original_path)  # Remove the old dataset directory
    shutil.move(temp_path, original_path)

In [13]:
def main(is_continue=False, is_baseline=False, numproc=1):
    """
    Main function to evaluate the dataset asynchronously.
    """
    if is_baseline:
        selected_dataset = load_from_disk(BASELINE_MODEL_PATH)
    elif is_continue:
        selected_dataset = load_from_disk(EVALUATED_MODEL_PATH)
    else:
        selected_dataset = dataset
    # Process the dataset asynchronously
    processed_dataset = process_dataset(selected_dataset, numproc)

    # Save the updated dataset
    update_dataset(processed_dataset, is_baseline)


# Run the script
main(True, False, 1)

Map:   0%|          | 0/14042 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/14042 [00:00<?, ? examples/s]

In [9]:
from converter.converter import save_value_to_json

# Save or process results as needed
for s in ["addition", "lexicon", "syntax", "", "naive", "typo", "scramble"]:
    if s == "":
        EVALUATED_MODEL_PATH = dataset_name + "_evaluated_" + MODEL_NAME
        s = "our_baseline"
    else:
        EVALUATED_MODEL_PATH = dataset_name + "_all_" + s + "_evaluated_" + MODEL_NAME
    if os.path.exists(EVALUATED_MODEL_PATH):
        selected_dataset = load_from_disk(EVALUATED_MODEL_PATH)
        print(selected_dataset)
        score = [is_correct(result) for result in selected_dataset]
        score = sum(score) / len(score)
        save_value_to_json(s, score, MODEL_NAME)
        print(s, "accuracy", score)

Dataset({
    features: ['question', 'subject', 'choices', 'answer', 'logprobs'],
    num_rows: 14042
})
skip addition accuracy 0.6734083463894032
Dataset({
    features: ['question', 'subject', 'choices', 'answer', 'logprobs'],
    num_rows: 14042
})
skip skip skip skip skip skip skip skip skip lexicon accuracy 0.6826662868537245
Dataset({
    features: ['question', 'subject', 'choices', 'answer', 'logprobs'],
    num_rows: 14042
})
skip syntax accuracy 0.6845890898732374
Dataset({
    features: ['question', 'subject', 'choices', 'answer', 'logprobs'],
    num_rows: 14042
})
skip skip skip our_baseline accuracy 0.7028913260219342
Dataset({
    features: ['question', 'subject', 'choices', 'answer', 'logprobs'],
    num_rows: 14042
})
skip skip naive accuracy 0.6786782509614016
Dataset({
    features: ['question', 'subject', 'choices', 'answer', 'logprobs'],
    num_rows: 14042
})
skip skip skip typo accuracy 0.6442814413901153
Dataset({
    features: ['question', 'subject', 'choices', 

In [31]:

baseline_dataset = load_from_disk(BASELINE_MODEL_PATH)
score = [is_correct(result) for result in baseline_dataset]
score = sum(score) / len(score)
save_value_to_json("baseline", score, MODEL_NAME)
print("Baseline Accuracy", score)

Baseline Accuracy 0.7026776812419884
