In [12]:
import os.path

from datasets import load_dataset, load_from_disk
from dotenv import load_dotenv
from typing_extensions import override

load_dotenv("../.env")
# Load GSM8k Dataset
dataset_name = "LFrancis/GSM8k-NoOp-Plus"
baseline_dataset_name = "openai/gsm8k"
subset = "main_lexicon"
dataset = load_dataset(dataset_name, subset)["train"]

# VLLM API Configuration
BASE_URL = "http://134.76.18.30:8085/v1/chat/completions"
HEADERS = {"Content-Type": "application/json", "Authorization": "Bearer "+os.getenv("VLLM_API_KEY")}
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
EVALUATED_MODEL_PATH = dataset_name+"_"+subset + "_evaluated_" + MODEL_NAME
BASELINE_MODEL_PATH = baseline_dataset_name + "_evaluated_" + MODEL_NAME
if not os.path.exists(EVALUATED_MODEL_PATH):
    dataset.save_to_disk(EVALUATED_MODEL_PATH)
if not os.path.exists(BASELINE_MODEL_PATH):
    baseline_dataset = load_dataset(baseline_dataset_name, subset)["test"]
    baseline_dataset.save_to_disk(BASELINE_MODEL_PATH)

train-00000-of-00001.parquet:   0%|          | 0.00/420k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1319 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1319 [00:00<?, ? examples/s]

In [13]:
from datasets import Dataset
import requests


# Helper Functions
def create_chat_messages(question, sys_msg):
    """
    Create a formatted list of chat messages for the chat model.
    """
    user_prompt = f"{question}\n" + "\nCalculations:"
    return [
        {"role": "system", "content": sys_msg},
        {"role": "user", "content": user_prompt}
    ]


def query_vllm_api(payload):
    """
    Send a query to the VLLM API and return the response.
    """
    response = requests.post(BASE_URL, json=payload, headers=HEADERS, timeout=120)
    response.raise_for_status()  # Raise an error for HTTP issues
    return response.json()


def evaluate_question(entry):
    # Step 1: Generate reasoning (CoT) response
    sys_msg = "The following are gradeschool math questions. Think step by step and answer with a single number when prompted for a final answer. The final answer cannot include anything other than a single number, please do not include any additional calculations, clarifications or units.\nRemember: Always answer in the format '#### [number]'."
    question = entry["question"]
    messages = create_chat_messages(question, sys_msg)

    cot_payload = {
        "model": MODEL_NAME,
        "messages": messages,
        "max_tokens": 200,
        "temperature": 0.0,
    }

    cot_response = query_vllm_api(cot_payload)
    if "object" in cot_response.keys() and cot_response["object"] == "error":
        raise Exception(cot_response["message"])

    cot_text = cot_response["choices"][0]["message"]["content"].strip()  # Extract CoT reasoning

    # Step 2: Calculate logprobs for each choice
    final_prompt = "(Answer with a single number, starting after ####. Do not include anything other than #### followed by the number.) Final Answer: "
    choice_messages = [
        *messages,
        {"role": "system", "content": cot_text},
        {"role": "user", "content": final_prompt}
    ]
    final_payload = {
            "model": MODEL_NAME,
            "messages": choice_messages,
            "max_tokens": 10,
            "temperature": 0.0,
            "stop": ["\n"],
        }
    final_response = query_vllm_api(final_payload)
    gen_answer = final_response["choices"][0]["message"]["content"].strip()
    gen_answer = extract_answer(gen_answer)
    entry["generated_answer"] = gen_answer
    entry["generated_cot"] = cot_text
    return entry


def extract_answer(gen_answer):
    answer = gen_answer.split("#### ")
    if len(answer) != 2:
        if answer[0].strip().isnumeric():
            # Recover answer even though answer is formatted wrong
            return answer[0].strip()
        raise Exception(f"{gen_answer} is not a valid answer.")
    return answer[1].strip()


def is_correct(entry):
    """
    Determines if the choice with the lowest log probability corresponds to the correct answer.

    Args:
        entry (dict): A dictionary containing the question, choices, answer index, and logprobs.

    Returns:
        bool: True if the option with the lowest logprob matches the correct answer index, False otherwise.
    """
    # Extract logprobs and the correct answer index
    gen_answer = entry['generated_answer']
    if gen_answer is None:
        return False
    answer = extract_answer(entry['answer'])

    return answer == gen_answer


def process_dataset(dataset: Dataset, numproc=1):
    """
    Process the dataset using Dataset.map.
    """

    def process_entry(entry):
        if "generated_answer" in entry.keys() and entry["generated_answer"] is not None:
            return entry
        try:
            return evaluate_question(entry)
        except Exception as e:
            print(f"Error processing entry: {entry}, Exception: {e}")
            entry["generated_answer"] = None
            entry["generated_cot"] = None
            return entry
    return dataset.map(process_entry, with_indices=False, num_proc=numproc)

In [14]:
def update_dataset(dataset, is_baseline = False):
    # Save the updated dataset to a temporary location
    temp_path = "temp"
    dataset.save_to_disk(temp_path)

    # Overwrite the original dataset directory
    import shutil
    original_path = EVALUATED_MODEL_PATH if is_baseline == False else BASELINE_MODEL_PATH

    # Remove the old dataset and replace it with the new one
    shutil.rmtree(original_path)  # Remove the old dataset directory
    shutil.move(temp_path, original_path)

In [18]:
def main(is_continue = False, is_baseline=False, numproc=1):
    """
    Main function to evaluate the dataset asynchronously.
    """
    if is_baseline:
        selected_dataset = load_from_disk(BASELINE_MODEL_PATH)
    elif is_continue:
        selected_dataset = load_from_disk(EVALUATED_MODEL_PATH)
    else:
        selected_dataset = dataset
    # Process the dataset asynchronously
    processed_dataset = process_dataset(selected_dataset, numproc)

    # Save the updated dataset
    update_dataset(processed_dataset, is_baseline)
# Run the script
main(True, False, 1)

Map:   0%|          | 0/1319 [00:00<?, ? examples/s]

Error processing entry: {'question': 'Nick is choosing between two jobs. Job A pays $15 an hour for 2000 hours a year, and is in a state with a 20% total tax rate. Job B pays $42,000 a year and is in a state that charges $6,000 in property tax and a 10% tax rate on net income after property tax. How much extra money will Nick make at the job with a higher takehome pay rate, compared to the other job?', 'answer': "First calculate the gross annual salary for Job A: 2000 hours * $15/hour = $<<2000*15=30000>>30,000\nNext, calculate the amount of taxes Nick pays at Job A by multiplying his net salary by the 20% tax rate: .2 * $30,000 = $<<30000*.2=6000>>6,000\nNow subtract Nick's taxes from his net pay to find his gross pay at Job A: $30,000 - $6,000 = $<<30000-6000=24000>>24,000\nNow subtract Nick's property taxes from his gross income at Job B: $42,000 - $6,000 = $<<42000-6000=36000>>36,000\nNow multiply Nick's income after property tax by 10% to find his income tax at Job B: $36,000 * 10

Saving the dataset (0/1 shards):   0%|          | 0/1319 [00:00<?, ? examples/s]

In [19]:
# Save or process results as needed
selected_dataset = load_from_disk(EVALUATED_MODEL_PATH)
score = [is_correct(result) for result in selected_dataset]
score = sum(score) / len(score)
print("Accuracy", score)

Accuracy 0.7035633055344959


In [20]:
baseline_dataset = load_from_disk(BASELINE_MODEL_PATH)
score = [is_correct(result) for result in baseline_dataset]
score = sum(score) / len(score)
print("Baseline Accuracy", score)

Baseline Accuracy 0.7816527672479151


In [100]:
selected_dataset[0]["question"]

"Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? Janet's love for environmental conservation is also evident in her extensive collection of exotic bird species in her backyard."

In [101]:
baseline_dataset[0]["question"]

"Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"