In [None]:
import requests
import re
import pandas as pd
import openai
import os
from time import sleep
from token_shap import *


# OpenAI configuration
openai.api_key = "BLANK"
MODEL_NAME = "gpt-4-turbo"  # Or whichever model you want to use

# Load the dataset
file_path = "/Users/alexlawson/Documents/GitHub/medical-llms-bias/shortened.csv"  # Replace with the path to your dataset
df = pd.read_csv(file_path)

# Function to format the prompt
def format_prompt(question, a, b, c, d):
    return (
        f"Answer the following multiple choice question. Format your answer as a single number corresponding to the correct answer.\n"
        f"{question}\n"
        f"1. {a}\n"
        f"2. {b}\n"
        f"3. {c}\n"
        f"4. {d}\n\n"
        f"Your Answer: "
    )
    
# Function to query OpenAI
def query_openai(prompt, max_retries=3):
    for attempt in range(max_retries):
        try:
            response = openai.chat.completions.create(
                model=MODEL_NAME,
                messages=[
                    {"role": "system", "content": "You are a medical expert. Answer the following multiple choice question with only the number of the correct option."},
                    {"role": "user", "content": prompt}
                ],
                max_tokens=5,
                temperature=0.1,
                top_p=0.1
            )
            
            # Extract the response text
            response_text = response.choices[0].message.content.strip()
            
            # Use regex to extract the number
            match = re.search(r"([1-4])", response_text)
            if match:
                return match.group(1)  # Extract the number
            else:
                return None  # No valid answer found
                
        except Exception as e:
            print(f"Error on attempt {attempt+1}/{max_retries}: {e}")
            if attempt < max_retries - 1:
                sleep_time = 2 ** attempt  # Exponential backoff
                print(f"Retrying in {sleep_time} seconds...")
                sleep(sleep_time)
            else:
                print("Max retries reached. Skipping this question.")
                return None

# Process the dataset
results = []
for idx, row in df.iterrows():
    print(f"Processing question {idx+1}/{len(df)}")
    question = row["Augmented_Question"]
    opa = row["opa"]
    opb = row["opb"]
    opc = row["opc"]
    opd = row["opd"]
    correct_answer = str(int(row["cop"]) + 1)  # Ensure the correct answer is a string
    
    # Format the prompt
    prompt = format_prompt(question, opa, opb, opc, opd)
    
    # Query OpenAI
    llm_answer = query_openai(prompt)
    
    # Check if the LLM's answer is correct
    is_correct = llm_answer == correct_answer
    
    # Append the result
    results.append({
        "Question": question,
        "Correct Answer": correct_answer,
        "LLM Answer": llm_answer,
        "Is Correct": is_correct
    })



# Convert results to a DataFrame and save to CSV
results_df = pd.DataFrame(results)
output_file = "/Users/alexlawson/Documents/GitHub/medical-llms-bias/results.csv"  # Replace with the desired output path
results_df.to_csv(output_file, index=False)
print(f"Results saved to {output_file}")

# Calculate and print summary statistics
correct_count = results_df["Is Correct"].sum()
total_count = len(results_df)
accuracy = correct_count / total_count * 100
print(f"Accuracy: {accuracy:.2f}% ({correct_count}/{total_count} correct)")

Processing question 1/60
Processing question 2/60
Processing question 3/60
Processing question 4/60
Processing question 5/60
Processing question 6/60
Processing question 7/60
Processing question 8/60
Processing question 9/60
Processing question 10/60
Processing question 11/60
Processing question 12/60
Processing question 13/60
Processing question 14/60
Processing question 15/60
Processing question 16/60
Processing question 17/60
Processing question 18/60
Processing question 19/60
Processing question 20/60
Processing question 21/60
Processing question 22/60
Processing question 23/60
Processing question 24/60
Processing question 25/60
Processing question 26/60
Processing question 27/60
Processing question 28/60
Processing question 29/60
Processing question 30/60
Processing question 31/60
Processing question 32/60
Processing question 33/60
Processing question 34/60
Processing question 35/60
Processing question 36/60
Processing question 37/60
Processing question 38/60
Processing question 3