In [None]:
import json 
import pandas as pd
import random 

In [None]:
data_path = '/contextretrieval/data/Wizard_of_Wikipedia/wizard_of_wikipedia.json'
output_folder = '/contextretrieval/cross-encoder/wow/splits/' 

In [None]:
# read dataset 
with open(data_path, 'r') as f:
    data = json.load(f)

data = pd.read_json(data, orient='records')

In [None]:
# randomly shuffle dataset
data = data.sample(frac=1).reset_index(drop=True)

In [None]:
pairs = data[['input', 'passages_text']]

# randomly select 90% of data for train 
train_pairs = data[['input', 'passages_text']].sample(frac=0.90)  
train_pairs = train_pairs.rename(columns={'passages_text':'passage'})

# drop train pairs from data to create test 
test_pairs = pairs.drop(train_pairs.index)

In [None]:
# create positive & negative train pairs 
# for every positive pair, there are 4 negative pairs 

positive_train_pairs = train_pairs.sample(frac=0.20) 
positive_train_pairs['label'] = 1.0
 
negative_train_pairs = train_pairs.drop(positive_train_pairs.index)
negative_passages = list(negative_train_pairs['passage'].sample(frac=1))
negative_train_pairs['passage'] = negative_passages
negative_train_pairs['label'] = 0.0

train_pairs = pd.concat([positive_train_pairs, negative_train_pairs])
train_pairs = train_pairs.sample(frac=1).reset_index(drop=True)

In [None]:
# create test samples 

positive_test_pairs = test_pairs.sample(frac=0.02)
negative_test_passages = list(test_pairs.drop(positive_test_pairs.index)['passages_text'])

num_neg_passages = int(len(negative_test_passages)/len(positive_test_pairs)) # number of negative passages per positive pair
negative_test_passages = [negative_test_passages[x:x+num_neg_passages] for x in range(0, len(negative_test_passages), num_neg_passages)]
negative_test_passages = negative_test_passages[:len(positive_test_pairs)]
positive_test_pairs['negative'] = negative_test_passages

In [None]:
test_samples = positive_test_pairs.rename(columns={'input':'query', 'passages_text':'positive'})
test_samples['positive'] = [[x] for x in test_samples['positive']]
test_samples = test_samples.reset_index(drop=True)

In [None]:
# save splits 
train_pairs.to_csv(output_folder + 'train_pairs_reranker.csv', index=False)
test_samples.to_csv(output_folder + 'test_samples_reranker.csv', index=False)