In [1]:
import os
import sys
import yaml
import json
import tiktoken
import openai
import torch

from torch.utils.data import DataLoader
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
from sentence_transformers import SentenceTransformer, LoggingHandler, losses, util, InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, InformationRetrievalEvaluator

import math
import logging
from datetime import datetime
import gzip
import csv

import tarfile
import tqdm
import numpy as np
import wandb
import pandas as pd

root_path = '/home/ec2-user/sarang/wiki_cheat'

sys.path.insert(0, os.path.abspath(root_path))
os.chdir(root_path)


In [2]:
import pickle
data_path = 'data/wikipedia-22-12-simple-cohere-small.pkl'
with open('data/wikipedia-22-12-simple-cohere-small.pkl', 'rb') as fp:
    wiki_data = pickle.load(fp)

In [23]:
wiki_data[0]

{'id': 0,
 'title': '24-hour clock',
 'text': 'The 24-hour clock is a way of telling the time in which the day runs from midnight to midnight and is divided into 24 hours, numbered from 0 to 23. It does not use a.m. or p.m. This system is also referred to (only in the US and the English speaking parts of Canada) as military time or (only in the United Kingdom and now very rarely) as continental time. In some parts of the world, it is called railway time. Also, the international standard notation of time (ISO 8601) is based on this format.',
 'url': 'https://simple.wikipedia.org/wiki?curid=9985',
 'wiki_id': 9985,
 'views': 2450.62548828125,
 'paragraph_id': 0,
 'langs': 30}

In [3]:
train_data = None
test_data = None
data_path = 'train_data.json'
with open(data_path, 'r') as fp:
    all_data = json.load(fp)

test_data = all_data[:2000]
train_data = all_data[2000:]


### Lets find out overlap with the original synth_all_data

In [11]:
data_title_map = {}
for data in all_data:
    data_title_map[data['dp']['title'].lower()] = data

In [12]:
len(data_title_map)

10000

In [15]:
wiki_qa_df = pd.read_csv('data/WikiQACorpus/WikiQA.tsv',sep='\t')
wiki_qa_dict = wiki_qa_df.to_dict('records')

In [17]:
wiki_qa_title_map = {}
for data in wiki_qa_dict:
    wiki_qa_title_map[data['DocumentTitle'].lower()] = data

In [21]:
len(wiki_qa_title_map.keys())

2805

In [22]:
len(set(list(data_title_map.keys()))-set(list(wiki_qa_title_map.keys())))

9421

In [30]:
wiki_qa_dict[0]

{'QuestionID': 'Q0',
 'Question': 'HOW AFRICAN AMERICANS WERE IMMIGRATED TO THE US',
 'DocumentID': 'D0',
 'DocumentTitle': 'African immigration to the United States',
 'SentenceID': 'D0-0',
 'Sentence': 'African immigration to the United States refers to immigrants to the United States who are or were nationals of Africa .',
 'Label': 0}

### Lets find out overlap with the original wiki_data

In [24]:
wikidata_title_map = {}
for data in wiki_data:
    wikidata_title_map[data['title'].lower()] = data

In [25]:
len(set(list(wikidata_title_map.keys()))-set(list(wiki_qa_title_map.keys())))

185865

In [26]:
len(wikidata_title_map.keys())

187304

### Lets remove the duplicate elements and get a much larger train/test set

In [31]:
non_duplicate_titles =  set(list(wiki_qa_title_map.keys()))-set(list(data_title_map.keys()))

In [32]:
filtered_wiki_qa = [dp for dp in wiki_qa_dict if dp['DocumentTitle'].lower() in non_duplicate_titles]
len(filtered_wiki_qa)        

20066

### Further remove dps where no pos answer exists

In [38]:
from collections import defaultdict

filtered_wiki_qa_dict = defaultdict(list)

for dp in filtered_wiki_qa:
    filtered_wiki_qa_dict[dp['DocumentTitle'].lower()].append(dp)

filtered_wiki_qa_2 = []
### for now remove if there are multiple pos answers for the datapaoint within the passage.
for title, dps in filtered_wiki_qa_dict.items():
    pos_dps = [dp for dp in dps if dp['Label'] == 1]
    neg_dps = [dp for dp in dps if dp['Label'] == 0]
    if pos_dps:
        filtered_wiki_qa_2.append({
            'pos': pos_dps[0],
            'neg': neg_dps,
        })


In [39]:
len(filtered_wiki_qa_2)

923

In [40]:
test_data_2 = filtered_wiki_qa_2[:int(len(filtered_wiki_qa_2)*0.2)]
train_data_2 = filtered_wiki_qa_2[int(len(filtered_wiki_qa_2)*0.2):]

In [45]:
test_data_2_path = 'data/test_data_2.json'
with open(test_data_2_path, 'w') as fp:
    json.dump(test_data_2, fp)

train_data_2_path = 'data/train_data_2.json'
with open(train_data_2_path, 'w') as fp:
    json.dump(train_data_2, fp)

In [42]:
test_data

[{'dp': {'id': 1,
   'title': '24-hour clock',
   'text': 'A time in the 24-hour clock is written in the form hours:minutes (for example, 01:23), or hours:minutes:seconds (01:23:45). Numbers under 10 have a zero in front (called a leading zero); e.g. 09:07. Under the 24-hour clock system, the day begins at midnight, 00:00, and the last minute of the day begins at 23:59 and ends at 24:00, which is identical to 00:00 of the following day. 12:00 can only be mid-day. Midnight is called 24:00 and is used to mean the end of the day and 00:00 is used to mean the beginning of the day. For example, you would say "Tuesday at 24:00" and "Wednesday at 00:00" to mean exactly the same time.',
   'url': 'https://simple.wikipedia.org/wiki?curid=9985',
   'wiki_id': 9985,
   'views': 2450.62548828125,
   'paragraph_id': 1,
   'langs': 30},
  'qa': {'question': 'What is the format for representing time in the 24-hour clock system?',
   'answer': 'A time in the 24-hour clock is written in the form hours:

In [46]:
train_data = None
test_data = None
data_path = 'train_data.json'
with open(data_path, 'r') as fp:
    all_data = json.load(fp)

test_data = all_data[:2000]
train_data = all_data[2000:]

test_data_path = 'data/test_data.json'
with open(test_data_path, 'w') as fp:
    json.dump(test_data, fp)

train_data_path = 'data/train_data.json'
with open(train_data_path, 'w') as fp:
    json.dump(train_data, fp)


### Now combine both the datasets

In [2]:
import pickle
data_path = 'data/wikipedia-22-12-simple-cohere-small.pkl'
with open('data/wikipedia-22-12-simple-cohere-small.pkl', 'rb') as fp:
    wiki_data = pickle.load(fp)

In [3]:
test_data_1_path = 'data/test_data.json'
with open(test_data_1_path, 'r') as fp:
    test_data_1 = json.load(fp)

train_data_1_path = 'data/train_data.json'
with open(train_data_1_path, 'r') as fp:
    train_data_1 = json.load(fp)

In [4]:
test_data_2_path = 'data/test_data_2.json'
with open(test_data_2_path, 'r') as fp:
    test_data_2 = json.load(fp)

train_data_2_path = 'data/train_data_2.json'
with open(train_data_2_path, 'r') as fp:
    train_data_2 = json.load(fp)

### Add test_data_2

In [13]:
combined_test_data = []
for dp in test_data_2:
    combined_test_data.append({
        'query': dp['pos']['Question'],
        'title': dp['pos']['DocumentTitle'],
        'pos': dp['pos']['Sentence'],
        'negs': [ d['Sentence'] for d in dp['neg']],
        'answer':'None'
    })

In [14]:
len(combined_test_data)

184

In [6]:
from collections import defaultdict
wikidata_id_map = defaultdict(list)
for data in wiki_data:
    wikidata_id_map[data['wiki_id']].append(data)

In [8]:
combined_test_data = []
for dp in test_data_1:
    combined_test_data.append({
        'query': dp['qa']['question'],
        'title': dp['dp']['title'],
        'pos': dp['dp']['text'],
        'negs': [ d['text'] for d in wikidata_id_map[dp['dp']['wiki_id']] if dp['dp']['id'] != d['id']],
        'answer': dp['qa']['answer']
    })

In [9]:
len(combined_test_data)

2000

In [15]:
combined_train_data = []
for dp in train_data_2:
    combined_train_data.append({
        'query': dp['pos']['Question'],
        'title': dp['pos']['DocumentTitle'],
        'pos': dp['pos']['Sentence'],
        'negs': [ d['Sentence'] for d in dp['neg']],
        'answer':'None'
    })

In [16]:
len(combined_train_data)

739

In [10]:
combined_train_data = []
for dp in train_data_1:
    combined_train_data.append({
        'query': dp['qa']['question'],
        'title': dp['dp']['title'],
        'pos': dp['dp']['text'],
        'negs': [ d['text'] for d in wikidata_id_map[dp['dp']['wiki_id']] if dp['dp']['id'] != d['id']],
        'answer': dp['qa']['answer']
    })

In [11]:
len(combined_train_data)

8000

In [17]:
test_data_path = 'data/test_data_wikiqa_nf.json'
with open(test_data_path, 'w') as fp:
    json.dump(combined_test_data, fp)

train_data_path = 'data/train_data_wikiqa_nf.json'
with open(train_data_path, 'w') as fp:
    json.dump(combined_train_data, fp)