In [1]:
import os
from datasets import load_dataset
import openai


In [2]:
from dotenv import load_dotenv

# Load the .env file located two folders behind
load_dotenv("../../.env")

# Print the OPENAI_API_KEY environment variable


True

In [3]:
openai.api_key = os.getenv("OPENAI_API_KEY")
model_id = os.getenv("MODEL_ID")

In [4]:
arc = load_dataset("allenai/ai2_arc", "ARC-Challenge", split="validation")


In [5]:
arc

Dataset({
    features: ['id', 'question', 'choices', 'answerKey'],
    num_rows: 299
})

In [6]:
arc[2]

{'id': 'MDSA_2009_5_16',
 'question': 'Students visited the Morris W. Offit telescope located at the Maryland Space Grant Observatory in Baltimore. They learned about the stars, planets, and moon. The students recorded the information below. • Star patterns stay the same, but their locations in the sky seem to change. • The sun, planets, and moon appear to move in the sky. • Proxima Centauri is the nearest star to our solar system. • Polaris is a star that is part of a pattern of stars called the Little Dipper. Which statement best explains why the sun appears to move across the sky each day?',
 'choices': {'text': ['The sun revolves around Earth.',
   'Earth rotates around the sun.',
   'The sun revolves on its axis.',
   'Earth rotates on its axis.'],
  'label': ['A', 'B', 'C', 'D']},
 'answerKey': 'D'}

In [7]:
prompt_data = []
for item in arc:
    question = item["question"]
    choices = item["choices"]['text']
    
    prompt = f"Question: {question}\n"
    for i, choice in enumerate(choices):
        prompt += f"Choice {chr(65 + i)}: {choice}\n"
    
    prompt += "Answer: "
    
    prompt_data.append({
        "prompt": prompt
    })

In [8]:
print(prompt_data[0]['prompt'])

Question: Juan and LaKeisha roll a few objects down a ramp. They want to see which object rolls the farthest. What should they do so they can repeat their investigation?
Choice A: Put the objects in groups.
Choice B: Change the height of the ramp.
Choice C: Choose different objects to roll.
Choice D: Record the details of the investigation.
Answer: 


In [9]:
from openai import OpenAI

client = OpenAI()
pred, gold = [], []

for i, item in enumerate(prompt_data[:1]):
    response = client.chat.completions.create(
        model=model_id,
        messages=[
            {
                "role": "user",
                "content": item['prompt'],
            }
        ],
        temperature=0.0,
    )

    answer = response.choices[0].message.content.strip()
    pred.append(answer)

    gold_answer = arc[i]['answerKey']
    gold.append(gold_answer)


In [11]:
print(prompt_data[0])
print(pred[0])
print(gold[0])

{'prompt': 'Question: Juan and LaKeisha roll a few objects down a ramp. They want to see which object rolls the farthest. What should they do so they can repeat their investigation?\nChoice A: Put the objects in groups.\nChoice B: Change the height of the ramp.\nChoice C: Choose different objects to roll.\nChoice D: Record the details of the investigation.\nAnswer: '}
D
D


In [12]:
from openai import OpenAI

client = OpenAI()
pred, gold = [], []

for i, item in enumerate(prompt_data):
    response = client.chat.completions.create(
        model=model_id,
        messages=[
            {
                "role": "user",
                "content": item['prompt'],
            }
        ],
        temperature=0.0,
    )

    answer = response.choices[0].message.content.strip()
    pred.append(answer)

    gold_answer = arc[i]['answerKey']
    gold.append(gold_answer)


In [19]:
acc = sum(prediction == correct_answer for prediction, correct_answer in zip(pred, gold)) 
acc /= len(pred)
print(f"Accuracy: {acc:.1f}")

Accuracy: 0.9


In [22]:
import collections
import re

wrong_q_texts = [
   item["question"] for item, prediction in zip(arc, pred) if prediction != item["answerKey"]
]


word_counter = collections.Counter()
for wrong in wrong_q_texts:
    tokens = re.findall(r"[a-zA-Z']+", wrong.lower())
    word_counter.update(tokens)

top5 = word_counter.most_common(5)
print(top5)

[('the', 64), ('of', 41), ('a', 28), ('is', 21), ('to', 20)]
