In [None]:
from datasets import load_dataset
from config import OPENAI_API_KEY
import os
import openai
from tqdm import tqdm
import csv
import re

In [None]:
# load dataset
ARC_easy_train = load_dataset('ai2_arc', 'ARC-Easy', split='train')
ARC_easy_test = load_dataset('ai2_arc', 'ARC-Easy', split='test')
ARC_easy_dev = load_dataset('ai2_arc', 'ARC-Easy', split='validation')

ARC_Challenge_train = load_dataset('ai2_arc', 'ARC-Challenge', split='train')
ARC_Challenge_test = load_dataset('ai2_arc', 'ARC-Challenge', split='test')
ARC_Challenge_dev = load_dataset('ai2_arc', 'ARC-Challenge', split='validation')

# we only want a subset of the data
# 600 examples from train set, and proportionally sample 200 examples from validation set and test set
# each contains 1/2 from ARC-Easy and 1/2 from ARC-Challenge
random_seed = 42
train_set = ARC_easy_train.shuffle(seed=random_seed).select(range(300)) + ARC_Challenge_train.shuffle(seed=42).select(range(300))
validation_set = ARC_easy_dev.shuffle(seed=random_seed).select(range(100)) + ARC_Challenge_dev.shuffle(seed=42).select(range(100))
test_set = ARC_easy_test.shuffle(seed=random_seed).select(range(100)) + ARC_Challenge_test.shuffle(seed=42).select(range(100))

In [None]:
print(train_set)
print(train_set[0]['choices'])

In [None]:
# Global variables for prompt
PREFIX = '''
###### Instructions ######
Read the following article and the multiple-choice question, analyze step by step, select the correct option, and give the option letter (e.g., A or B) as your answer.
Use the following format to provide your answer and confidence level:
Explanation: [insert step-by-step analysis here]
Answer and Confidence (0-100): [Your answer, e.g., B], [Your confidence level, e.g., 80]%
Note: The confidence level indicates how certain you are about your answer, expressed as a percentage.
'''
openai.api_key = OPENAI_API_KEY

In [None]:
def get_last_processed_idx(checkpoint_file):
    if os.path.exists(checkpoint_file):
        with open(checkpoint_file, 'r') as file:
            last_idx = file.readline()
            return int(last_idx.strip()) if last_idx else 0
    else:
        return 0
    
def set_checkpoint_idx(checkpoint_file, idx):
    with open(checkpoint_file, 'w') as file:
        file.write(str(idx))

In [None]:
def process_dataset(dataset, csv_file_path, checkpoint_file):
    start_idx = get_last_processed_idx(checkpoint_file)
    print(f"Starting from index {start_idx}")
    for idx in tqdm(range(start_idx, len(dataset))):
        try:
            # rows are in form of 'example_id', 'article', 'answer', 'question', 'options'
            # options are in form of ['America', 'England', 'Canana', "We don't know."]
            question = dataset['question'][idx]
            article = dataset['article'][idx]
            answer = dataset['answer'][idx]
            options = dataset['options'][idx]

            formatted_options = [f"{chr(ord('A') + i)}. {option}" for i, option in enumerate(options)]
            question_input = f"###### article ######\n{article}\n\n###### Question ######\n{question}\n" + "\n".join(formatted_options)
            prompt = PREFIX + f"{question_input}"

            response = openai.ChatCompletion.create(
                model="gpt-4",
                messages=[
                    {"role": "system", "content": "You are a chatbot trained to answer multiple-choice questions."},
                    {"role": "user", "content": prompt},
                ]
            )

            output = response['choices'][0]['message']['content'].strip()

            explanation_match = re.search(r'Explanation: (.*)\n', output)
            explanation = explanation_match.group(1) if explanation_match else "No explanation found."

            answer_confidence_match = re.search(r'Answer and Confidence \((0-100)\): ([A-D]), (\d+)%', output)
            predicted_answer = answer_confidence_match.group(2).strip() if answer_confidence_match else "No answer found."
            confidence_level = int(answer_confidence_match.group(3)) if answer_confidence_match else "No confidence level found."

            with open(csv_file_path, 'a+', newline='', encoding='utf-8') as file:
                writer = csv.writer(file)
                if os.path.getsize(csv_file_path) == 0:
                    writer.writerow(['example_id', 'question', 'article', 'options', 'predicted_answer', 'answer', 'confidence_level', 'explanation'])
                writer.writerow([idx, question, article, "\n".join(formatted_options), predicted_answer, answer, confidence_level, explanation])

            set_checkpoint_idx(checkpoint_file, idx + 1)

        except Exception as e:
            print(f"An error occurred at index {idx}: {e}")
            break