In [22]:
import os.path

from datasets import load_dataset, load_from_disk
from dotenv import load_dotenv
from typing_extensions import override

load_dotenv("../.env")
# Load GSM8k Dataset
dataset_name = "LFrancis/GSM8k-NoOp-Plus"
baseline_dataset_name = "openai/gsm8k"
subset = "main_lexicon"
dataset = load_dataset(dataset_name, subset)["train"]

# VLLM API Configuration
BASE_URL = "http://134.76.18.30:8081/v1/chat/completions"
HEADERS = {"Content-Type": "application/json", "Authorization": "Bearer "+os.getenv("VLLM_API_KEY")}
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
EVALUATED_MODEL_PATH = dataset_name+"_"+subset + "_evaluated_" + MODEL_NAME
BASELINE_MODEL_PATH = baseline_dataset_name + "_evaluated_" + MODEL_NAME
if not os.path.exists(EVALUATED_MODEL_PATH):
    dataset.save_to_disk(EVALUATED_MODEL_PATH)
if not os.path.exists(BASELINE_MODEL_PATH):
    baseline_dataset = load_dataset(baseline_dataset_name, subset)["test"]
    baseline_dataset.save_to_disk(BASELINE_MODEL_PATH)

Saving the dataset (0/1 shards):   0%|          | 0/1319 [00:00<?, ? examples/s]

In [23]:
from datasets import Dataset
import requests


# Helper Functions
def create_chat_messages(question, sys_msg):
    """
    Create a formatted list of chat messages for the chat model.
    """
    user_prompt = f"{question}\n" + "\nCalculations:"
    return [
        {"role": "system", "content": sys_msg},
        {"role": "user", "content": user_prompt}
    ]


def query_vllm_api(payload):
    """
    Send a query to the VLLM API and return the response.
    """
    response = requests.post(BASE_URL, json=payload, headers=HEADERS, timeout=120)
    response.raise_for_status()  # Raise an error for HTTP issues
    return response.json()


def evaluate_question(entry):
    # Step 1: Generate reasoning (CoT) response
    sys_msg = "The following are gradeschool math questions. Think step by step and answer with a single number when prompted for a final answer. The final answer cannot include anything other than a single number, please do not include any additional calculations, clarifications or units.\nRemember: Always answer in the format '#### [number]'."
    question = entry["question"]
    messages = create_chat_messages(question, sys_msg)

    cot_payload = {
        "model": MODEL_NAME,
        "messages": messages,
        "max_tokens": 200,
        "temperature": 0.0,
    }

    cot_response = query_vllm_api(cot_payload)
    if "object" in cot_response.keys() and cot_response["object"] == "error":
        raise Exception(cot_response["message"])

    cot_text = cot_response["choices"][0]["message"]["content"].strip()  # Extract CoT reasoning

    # Step 2: Calculate logprobs for each choice
    final_prompt = "(Answer with a single number, starting after ####. Do not include anything other than #### followed by the number.) Final Answer: "
    choice_messages = [
        *messages,
        {"role": "system", "content": cot_text},
        {"role": "user", "content": final_prompt}
    ]
    final_payload = {
            "model": MODEL_NAME,
            "messages": choice_messages,
            "max_tokens": 10,
            "temperature": 0.0,
            "stop": ["\n"],
        }
    final_response = query_vllm_api(final_payload)
    gen_answer = final_response["choices"][0]["message"]["content"].strip()
    gen_answer = extract_answer(gen_answer)
    entry["generated_answer"] = gen_answer
    entry["generated_cot"] = cot_text
    return entry


def extract_answer(gen_answer):
    answer = gen_answer.split("#### ")
    if len(answer) != 2:
        if answer[0].strip().isnumeric():
            # Recover answer even though answer is formatted wrong
            return answer[0].strip()
        raise Exception(f"{gen_answer} is not a valid answer.")
    return answer[1].strip()


def is_correct(entry):
    """
    Determines if the choice with the lowest log probability corresponds to the correct answer.

    Args:
        entry (dict): A dictionary containing the question, choices, answer index, and logprobs.

    Returns:
        bool: True if the option with the lowest logprob matches the correct answer index, False otherwise.
    """
    # Extract logprobs and the correct answer index
    gen_answer = entry['generated_answer']
    if gen_answer is None:
        return False
    answer = extract_answer(entry['answer'])

    return answer == gen_answer


def process_dataset(dataset: Dataset, numproc=1):
    """
    Process the dataset using Dataset.map.
    """

    def process_entry(entry):
        if "generated_answer" in entry.keys() and entry["generated_answer"] is not None:
            return entry
        try:
            return evaluate_question(entry)
        except Exception as e:
            print(f"Error processing entry: {entry}, Exception: {e}")
            entry["generated_answer"] = None
            entry["generated_cot"] = None
            return entry
    return dataset.map(process_entry, with_indices=False, num_proc=numproc)

In [24]:
def update_dataset(dataset, is_baseline = False):
    # Save the updated dataset to a temporary location
    temp_path = "temp"
    dataset.save_to_disk(temp_path)

    # Overwrite the original dataset directory
    import shutil
    original_path = EVALUATED_MODEL_PATH if is_baseline == False else BASELINE_MODEL_PATH

    # Remove the old dataset and replace it with the new one
    shutil.rmtree(original_path)  # Remove the old dataset directory
    shutil.move(temp_path, original_path)

In [26]:
def main(is_continue = False, is_baseline=False, numproc=1):
    """
    Main function to evaluate the dataset asynchronously.
    """
    if is_baseline:
        selected_dataset = load_from_disk(BASELINE_MODEL_PATH)
    elif is_continue:
        selected_dataset = load_from_disk(EVALUATED_MODEL_PATH)
    else:
        selected_dataset = dataset
    # Process the dataset asynchronously
    processed_dataset = process_dataset(selected_dataset, numproc)

    # Save the updated dataset
    update_dataset(processed_dataset, is_baseline)
# Run the script
main(True, False, 100)

Map (num_proc=100):   0%|          | 0/1319 [00:00<?, ? examples/s]

Error processing entry: {'question': 'Together, Sofie, Anne, and Fawn have 85 books. If Sofie has 25 than books than Anne, and Anne has 12 fewer books than Fawn does, how several books does Fawn have?', 'answer': 'Let x = the number of books Anne has.\nThen x + 25 + x + 12 + x = 85\n3x + 37 = 85\n3x = 48\nx = 48 / 3 = <<48/3=16>>16 books\nAnne has 12 fewer books than Fawn, so Fawn has 16 + 12 = 28 books.\n#### 28'}, Exception: 2F - 24 + F + 25 is not a valid answer.
Error processing entry: {'question': "The average age of Peter, Paul and Jean is 100 years combined. Find the age of Peter knowing that Paul is 10 years elder than John and that Peter’s age is equivalent to the sum of Paul and John's age.", 'answer': 'Let x be the age of John. Paul’s age is x + 10\nPeter’s age is x + (x + 10) = 2x + 10\nWe have (2x + 10) + (x + 1) + x = 100\nCombining like terms, we get 4x + 20 = 100\nSubtracting 20 from both sides and dividing by 4, we get x = 80/4 = 20 years old.\nJohn is 20 years old, Pa

Saving the dataset (0/1 shards):   0%|          | 0/1319 [00:00<?, ? examples/s]

In [27]:
import json


def save_value_to_json(label, value, file_path="results.json"):
    # Check if the JSON file exists
    if os.path.exists(file_path):
        # Read the existing JSON content
        with open(file_path, "r") as file:
            data = json.load(file)
    else:
        # Create a new dictionary if the file doesn't exist
        data = {"scores":{}, "model": MODEL_NAME}

    # Update the dictionary with the new key-value pair
    data["scores"][label] = value

    # Write the updated dictionary back to the JSON file
    with open(file_path, "w") as file:
        json.dump(data, file, indent=4)

In [28]:
for subset in [ "addition", "lexicon", "syntax", "","naive"]:
    if subset == "":
        EVALUATED_MODEL_PATH = dataset_name+"_main"+ "_evaluated_" + MODEL_NAME
        subset="our_baseline"
    else:
        EVALUATED_MODEL_PATH = dataset_name+"_main_"+subset + "_evaluated_" + MODEL_NAME
    if os.path.exists(EVALUATED_MODEL_PATH):
        selected_dataset = load_from_disk(EVALUATED_MODEL_PATH)
        print(selected_dataset)
        score = [is_correct(result) for result in selected_dataset]
        score = sum(score) / len(score)
        save_value_to_json(subset, score)
        print(subset,"accuracy", score)
    else:
        print("skipping", EVALUATED_MODEL_PATH)

Dataset({
    features: ['question', 'answer', 'generated_answer', 'generated_cot'],
    num_rows: 1319
})
addition accuracy 0.7733131159969674
Dataset({
    features: ['question', 'answer', 'generated_answer', 'generated_cot'],
    num_rows: 1319
})
lexicon accuracy 0.6671721000758151
Dataset({
    features: ['question', 'answer', 'generated_answer', 'generated_cot'],
    num_rows: 1319
})
syntax accuracy 0.7338893100833965
Dataset({
    features: ['question', 'answer', 'generated_answer', 'generated_cot'],
    num_rows: 1319
})
our_baseline accuracy 0.7884761182714177
Dataset({
    features: ['question', 'answer', 'generated_answer', 'generated_cot'],
    num_rows: 1319
})
naive accuracy 0.7740712661106899


In [11]:
baseline_dataset = load_from_disk(BASELINE_MODEL_PATH)
score = [is_correct(result) for result in baseline_dataset]
score = sum(score) / len(score)
save_value_to_json("baseline", score)
print("Baseline Accuracy", score)

Baseline Accuracy 0.7816527672479151


In [100]:
selected_dataset[0]["question"]

"Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? Janet's love for environmental conservation is also evident in her extensive collection of exotic bird species in her backyard."

In [101]:
baseline_dataset[0]["question"]

"Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"