In [14]:
import numpy as np
import pandas as pd

In [15]:
NUM_MMLU_TRAIN_SAMPLES = 2_048

# Load Train Data - My New 40k Dataset!
For train data, we will load my new 40k dataset and perform some filtering and preprocess. 
* ScienceQA - We will only use questions without images and questions with 3 or 4 choices
* OpenBook - We will only use questions written as questions (not "continue the statement" questions)
* MMLU - We will only use questions written as questions. And will only use 2k samples to balance with other datasets.

Note that we can use more of the 40k dataset, but we will need to do some analysis to decide what additional questions are good to preprocess and use. Furthermore there are many "question" which are written as "continue the statement". To use these we need to be more creative.

In [16]:
# FUNCTIONS TO ADD A NEW RANDOM WRONG CHOICE
def make_random_4_from_3(row):
    wrong = [x for x in ['A','B','C'] if x != row.answer]
    right = [row.answer]
    move = np.random.choice(wrong*3 + right*2)
    row['D'] = row[move]
    duplicate = np.random.choice(wrong)
    row[move] = row[duplicate]
    if move==row.answer:
        row.answer = 'D'
    return row

def make_random_5_from_4(row):
    wrong = [x for x in ['A','B','C','D'] if x != row.answer]
    right = [row.answer]
    move = np.random.choice(wrong*4 + right*3)
    row['E'] = row[move]
    duplicate = np.random.choice(wrong)
    row[move] = row[duplicate]
    if move==row.answer:
        row.answer = 'E'
    return row

In [17]:

# LOAD 3 DATASETS AND FILTER
MMLU = pd.read_csv('dataset/MMLU_17k_with_context2.csv')
#MMLU['is_question'] = MMLU.prompt.str.contains('\?')
MMLU = MMLU.loc[MMLU.is_question].sample(NUM_MMLU_TRAIN_SAMPLES, random_state=42)
print('MMLU shape',MMLU.shape)

ScienceQA = pd.read_csv('dataset/ScienceQA_with_context2.csv')
ScienceQA_3 = ScienceQA.loc[ScienceQA.image.isna() & (ScienceQA.ct==3)].drop(columns=['D','E'])
ScienceQA_4 = ScienceQA.loc[ScienceQA.image.isna() & (ScienceQA.ct==4)].drop(columns=['E'])
print('ScienceQA_3 shape',ScienceQA_3.shape)
print('ScienceQA_4 shape',ScienceQA_4.shape)

OpenBook = pd.read_csv('dataset/OpenBook_with_context2.csv')
OpenBook = OpenBook.loc[OpenBook.is_question]
print('OpenBook shape',OpenBook.shape)

MMLU shape (2048, 8)
ScienceQA_3 shape (2142, 19)
ScienceQA_4 shape (1149, 20)
OpenBook shape (1770, 9)


In [18]:

# PREPROCESS AND CONCATENATE 3 DATASETS
COLS = [c for c in MMLU.columns if c != 'is_question']
ScienceQA_3 = ScienceQA_3.apply(make_random_4_from_3,axis=1)
ScienceQA_3= ScienceQA_3[COLS]
ScienceQA_4= ScienceQA_4[COLS]
OpenBook = OpenBook[COLS]
df_train = pd.concat([MMLU,ScienceQA_3,ScienceQA_4,OpenBook],axis=0,ignore_index=True)
df_train = df_train.apply(make_random_5_from_4,axis=1)
df_train = df_train.sample(frac=1).reset_index(drop=True)
df_train = df_train[['prompt','context','A','B','C','D','E','answer']]
df_train = df_train.fillna('')
print('Train shape', df_train.shape)
    
df_train.head()

Train shape (7109, 8)


Unnamed: 0,prompt,context,A,B,C,D,E,answer
0,How long is a potato?,Home gardeners often plant a piece of potato w...,7 yards,7 inches,7 feet,7 miles,7 feet,B
1,"Using only these supplies, which question can ...",While dry labs are a huge problem today and ca...,Do round rocks or flat rocks skip more times w...,Do small rocks or large rocks skip more times ...,Do rocks skip more times when thrown across a ...,Do rocks skip more times when thrown across a ...,Do round rocks or flat rocks skip more times w...,B
2,What is the basic unit of life?,The cell is the basic structural and functiona...,atom,cell,element,atom,nucleus,B
3,Complete the statement. Assume that the motorc...,The internal forces are caused by the rider an...,decreased,stayed the same,stayed the same,stayed the same,increased,E
4,Which is competition?,Contest competition has been observed in a var...,a bear takes a salmon from the stream,a bear takes a salmon from the stream,a shark eats a tuna that an orca was hunting,a slug eats a blade of grass,a giraffe eats the highest leaves,C


In [19]:
df_train.to_csv('dataset/train_cdeotte.csv',index=False)