# Notebook to obtain the ELI5 dataset

In this notebook, we prompt ChatGPT using the gpt_wrapper in order to obtain the ELI5 dataset used to train our reward model. In particular, we start from the dataset 'eli5_category' available on [HuggingFace](https://huggingface.co/datasets/eli5_category), considering the answer with the higher number of upvotes as the golden answer. Than, we prompt ChatGPT to obtain worse answers to the same questions, specifying the errors we would like to have in the worse answers, and the grade associated with each answer.

In [None]:
%pip install artifacts/gpt_wrapper-0.0.8-py3-none-any.whl
%pip install tiktoken
%pip install datasets

In [None]:
import gpt_wrapper
gpt_wrapper.api_key = "GPT_KEY_REMOVED_FOR_PRIVACY"

In [None]:
from gpt_wrapper.chat import Chat

In [None]:
import json

We load the dataset 'eli5_category' from HuggingFace Hub, and we only take the first 10k datapoints of the train split.

In [None]:
from datasets import load_dataset
dataset = load_dataset("eli5_category")
dataset = dataset['train']
dataset = dataset.to_list()
dataset = [datapoint for datapoint in dataset if datapoint['selftext'] == '' and 'title' in datapoint and 'answers' in datapoint and datapoint['title'] is not None and datapoint['answers']['text'] is not None]
print("Dataset length: ", len(dataset))
dataset = dataset[:10000]
print("Dataset length: ", len(dataset))

In [None]:
import re

def process_sample(sample):
    # Replace newline characters with spaces
    sample = sample.replace('\n', ' ')

    # Check if the sample has the correct format
    if not re.match(r"^4: .+ 3: .+ 2: .+ 1: .+ 0: .+$", sample):
        print(f"Skipped datapoint: {sample}")
        return None

    # Split the sample into separate answers
    split_sample = re.split(' \d: ', sample)

    # Remove the initial number from the first answer
    split_sample[0] = re.sub('^\d: ', '', split_sample[0])

    return split_sample

In [None]:
import time
import json

# Create an empty list to store the datapoints
QandA = []
start_index = 0
skipped_dps = 2

try:
    # Load the previous list of datapoints (useful to resume the process if it was interrupted)
    with open('QandA_bis.json', 'r') as f:
        QandA = json.load(f)

    start_index = len(QandA) + skipped_dps

except Exception as e:
    print(f"Failed to load previous data: {e}")
    
data_len = len(dataset)
start_time = time.time()
elapsed_times = []
token_costs = []

# Iterate over the data
for count, datapoint in enumerate(dataset[start_index:], start_index + 1):
    print("########################################################")
    print("Processing datapoint", count, "of", data_len, "(", round(count/data_len*100, 2), "%)")
    print("########################################################")

    iteration_start_time = time.time()

    chat = Chat.create("Score_ELI5_" + str(count))

    Q = datapoint['title']
    A = datapoint['answers']['text'][0]

    query = "Given the following question and its correct answer (which was evaluated as 5/5 by a grader), please provide answers that would likely receive lower grades due to varying degrees of factual inaccuracies or misunderstandings. Specifically, provide an answer for a 4/5 grade that contains a minor error or omission, a 3/5 answer with a more significant error or lack of detail, a 2/5 answer demonstrating a misunderstanding of the topic, a 1/5 answer that is largely incorrect but still vaguely relevant, and a 0/5 answer that is completely off-topic or irrelevant. All answers should be plausible and similarly styled to the correct one, but the length can vary. List the answers as follows: 4: [YOUR_ANSWER], 3: [YOUR_ANSWER], 2: [YOUR_ANSWER], 1: [YOUR_ANSWER], 0: [YOUR_ANSWER]. Remove '[YOUR_ANSWER]' in your answer. Question: " + Q + " Correct answer: " + A

    # create a chat completion
    used_before = Chat.budget()['usage']
    bad_A = chat.ask(content=query)
    used_after = Chat.budget()['usage']
    print("Bad answers:\n", bad_A.content)

    processed_list = process_sample(bad_A.content)
    if processed_list is not None:
        datapoint_dict = {
            "question": Q,
            "gold_answer": A,
            "answer_4": processed_list[0],
            "answer_3": processed_list[1],
            "answer_2": processed_list[2],
            "answer_1": processed_list[3],
            "answer_0": processed_list[4],
            "ID": datapoint['q_id']
        }

        # Add the datapoint to the list
        QandA.append(datapoint_dict)

        # Save the list of datapoints as a JSON file
        with open('QandA_bis.json', 'w') as f:
            json.dump(QandA, f)

    iteration_end_time = time.time()

    elapsed_time = iteration_end_time - iteration_start_time
    elapsed_times.append(elapsed_time)

    token_cost = used_after - used_before
    token_costs.append(token_cost)

    average_time_per_datapoint = sum(elapsed_times) / len(elapsed_times)
    remaining_datapoints = data_len - count
    estimated_time_remaining = remaining_datapoints * average_time_per_datapoint

    print("Estimated time remaining: ", round(estimated_time_remaining/60, 2), "minutes")

    print("Tokens used: ", used_after / Chat.budget()['limit'] * 100, "%")
    print("Tokens used in this iteration: ", token_cost)
    print("Average tokens used per iteration: ", sum(token_costs) / len(token_costs))
    tokens_remaining = Chat.budget()['limit'] - used_after
    estimated_iterations_remaining = tokens_remaining / ((sum(token_costs) / len(token_costs)))
    print("Estimated iterations remaining: ", estimated_iterations_remaining)