In [None]:
import json 
import pandas as pd
import random 

In [None]:
# change path depending on dataset
data_path = '/data/Eli5/Eli5_reranked/eli5_reranked.json'

# output folder 
output_folder = '/contextretrieval/bi-encoder/eli5/splits/' 

In [None]:
# read dataset 
with open(data_path, 'r') as f:
    data = json.load(f)
    
data = pd.read_json(data, orient='records')

In [None]:
max_data_samples = 70000
if len(data) < max_data_samples:
    max_data_samples = len(data)

### Test Passages

In [None]:
# group ids that share same passages
relevant_docs_list = list(data.groupby('passages_text').apply(lambda x: list(x.id)))

# remove ids with no duplicates
relevant_docs_list = [x for x in relevant_docs_list if len(x) > 1]
random.shuffle(relevant_docs_list)
relevant_docs_list = relevant_docs_list[:int(max_data_samples*0.06)]

passages_qids = [x[0] for x in relevant_docs_list]

passage_idx = []
for i in range(0, len(passages_qids)):
    passage_idx.append(data.index[data['id'] == passages_qids[i]].tolist()[0])

test_passages = data.iloc[passage_idx][['id', 'passages_text']]
test_passages['relevant_ids'] = [x[1:] for x in relevant_docs_list]

In [None]:
test_passages.to_csv(output_folder + 'test_passages.csv', index=False) 

### Test Corpus

In [None]:
# 15% of data for test
max_corpus_size = int(max_data_samples*0.15) 

corpus_qids = [x[1:] for x in relevant_docs_list]
corpus_qids = [qid for sublist in corpus_qids for qid in sublist]

corpus_idx = []
for i in range(0, len(corpus_qids)):
    corpus_idx.append(data.index[data['id'] == corpus_qids[i]].tolist()[0])

corpus = data.iloc[corpus_idx][['id', 'input']]

# add random inputs to corpus 
other_inputs = data.drop(test_passages.index)
other_inputs = data.drop(corpus.index).sample(frac=1)
other_inputs = other_inputs.reset_index(drop=True)
other_inputs = other_inputs[0:max(0, max_corpus_size-len(corpus))]
other_inputs = other_inputs[['id', 'input']]

test_corpus = pd.concat([corpus, other_inputs], axis=0).sample(frac=1)

In [None]:
test_corpus.to_csv(output_folder + 'test_corpus.csv', index=False)

### Train Pairs

In [None]:
# select data not used for test 
train_samples = data.drop(test_passages.index)
train_samples = data.drop(test_corpus.index).sample(frac=1)
train_samples = train_samples[:int(max_data_samples*0.85)] 

# select relevant columns for training 
train_pairs = train_samples[['id', 'input', 'passages_text']]

In [None]:
train_pairs.to_csv(output_folder + 'train_pairs.csv', index=False)