In [6]:
from collections import defaultdict
from itertools import chain
import json
import ijson 
import argparse
from collections import Counter
import random
from argparse import Namespace
from tqdm.notebook import tqdm
import torch
from sklearn.model_selection import train_test_split
from transformers import GPT2Tokenizer
import sys
sys.path.insert(0,'..')
from utils import *

### Load Dataset

In [2]:
wow_path = "/ssd003/projects/aieng/conversational_ai/data/Wizard_of_Wikipedia/"

In [26]:
with open(wow_path + "/wizard_of_wikipedia.json") as f:
    wow_data = json.load(f)
    wow_data = json.loads(wow_data)

In [46]:
wow_data.keys()

dict_keys(['id', 'input', 'meta', 'output', 'answer', 'passages_text'])

### Convert WoW to MSMarco Format

In [67]:
c = Counter()
wow_len = len(wow_data['id']) - 1 
wow_msmarco = [{} for i in range(wow_len)]

In [68]:
c = Counter()
for i in range(wow_len):
    wow_msmarco[i]['query_id'] = wow_data['id'][str(i)]
    wow_msmarco[i]['query'] = wow_data['input'][str(i)]
    wow_msmarco[i]['answers'] = [wow_data['answer'][str(i)]]
    
    wow_msmarco[i]['query_type'] = (wow_data['output'][str(i)])[0]['provenance'][0]['title']
    wow_msmarco[i]['category'] = wow_msmarco[i]['query_type'] 
    
    wow_msmarco[i]['wellFormedAnswers'] = []
    
    
    wow_msmarco[i]['passages'] = [{
        'passage_text' : wow_data['passages_text'][str(i)],
        'is_selected' : 1
    }]

    c.update([wow_msmarco[i]['query_type']])

In [69]:
print("List of available Wikipedia categories")
c

List of available Wikipedia categories


Counter({'Chicago Blackhawks': 3,
         'Surgeon': 3,
         'Equestrianism': 56,
         'Vitamin C': 62,
         'Hiking': 267,
         'Pop music': 37,
         'Cat': 78,
         'Fishing tackle': 23,
         'Vietnam War': 1,
         '4chan': 1,
         'Kid Rock': 31,
         'Garlic oil': 1,
         'SpongeBob SquarePants': 50,
         'Veganism': 306,
         'Blue Ridge Mountains': 3,
         'The Humane Society of the United States': 28,
         'Nordic countries': 2,
         'Community theatre': 94,
         'Surfing': 205,
         'Nursing': 27,
         'Fiction writing': 18,
         'Ice hockey': 103,
         'Ravioli': 22,
         'United States Postal Service': 1,
         'Walmart': 98,
         'Hippopotamus': 24,
         'Telenovela': 50,
         'Meatloaf': 63,
         'Dog': 439,
         'New York City': 251,
         'Overweight': 29,
         'Homeschooling': 73,
         'List of Agents of S.H.I.E.L.D. characters': 4,
         'The Lit

In [70]:
wow_msmarco[0].keys()

dict_keys(['query_id', 'query', 'answers', 'query_type', 'category', 'wellFormedAnswers', 'passages'])

In [72]:
print("Sample data:")
wow_msmarco[0]

Sample data:


{'query_id': '6bc20426-99d6-11ea-8a20-773209e30a7b_0',
 'query': 'I like to watch ice hockey on TV. My favorite team is the Chicago Blackhawks.',
 'answers': ["The Blackhawks are one of my favorite teams, they've won 6 Stanley Cup Championships since they started in 1926"],
 'query_type': 'Chicago Blackhawks',
 'category': 'Chicago Blackhawks',
 'wellFormedAnswers': [],
 'passages': [{'passage_text': 'The Chicago Blackhawks (spelled Black Hawks until 1986, and known colloquially as the Hawks) are a professional ice hockey team based in Chicago, Illinois. They are members of the Central Division of the Western Conference of the National Hockey League (NHL). They have won six Stanley Cup championships since their founding in 1926. The Blackhawks are one of the "Original Six" NHL teams along with the Detroit Red Wings, Montreal Canadiens, Toronto Maple Leafs, Boston Bruins and New York Rangers. Since , the club\'s home rink is the United Center, which they share with the National Basketba

### Train Test Split

In [11]:
# determine your own splits
train_data, val_test_data = train_test_split(wow_msmarco, train_size=0.8)
val_data, test_data = train_test_split(val_test_data, train_size=0.1)

In [12]:
print('Train Size: ', len(train_data))
print('Validation Size: ', len(val_data))
print('Test Size: ', len(test_data))

Train Size:  45906
Validation Size:  1147
Test Size:  10330


### Inverse Data Structure

In [13]:
def inverse_dataStruct(source):
    result = {'answers':{}, 'passages':{}, 'query':{}, 'query_id':{}, 'query_type':{}, 'wellFormedAnswers':{}, 'category':{}}
    for i in range(len(source)):
        result['answers'].update({i:source[i]['answers']})
        result['passages'].update({i:source[i]['passages']})
        result['query'].update({i:source[i]['query']})
        result['query_id'].update({i:source[i]['query_id']})
        result['query_type'].update({i:source[i]['query_type']})
        result['category'].update({i:source[i]['category']})
        result['wellFormedAnswers'].update({i:source[i]['wellFormedAnswers']})
    return result

In [14]:
train_data = inverse_dataStruct(train_data)
val_data = inverse_dataStruct(val_data)
test_data = inverse_dataStruct(test_data)

In [125]:
val_data

{'answers': {0: "Yes, I've actually gotten counterfeit bills before! Counterfeiting has been going on for as long as money has been around.",
  1: 'I like that it has a bunch of different lakes and rivers such as Lake Travis and McKinney Falls.  It is a great place for outdoor activities',
  2: 'I really enjoy lots of different chesse on my pizza. Their styles, textures and flavors depend on the origin of the milk',
  3: 'Me too. I do now know why I love it so much. Maybe because it originated  In cuba. It was as late as the 20th century',
  4: 'Maybe.  McDonald was very innovative when founded but has suffered a backs;ash in recent years, for the alleged unhealthiness of its food.',
  5: "They are nice; it has the longest coastline in the lower 48. I'm surprised you haven't gotten sick of the beach though. I did after just over a year.",
  6: "I'm not familiar personally, but I imagine it's much better than places like Iraq for example, where women's rights organizations struggles aga

In [126]:
with open(wow_path + '/WoW_raw_msmarco_format/wow_train_MSMARCO.json', 'w') as json_file:
    json.dump(train_data, json_file, indent=2)
with open(wow_path + '/WoW_raw_msmarco_format/wow_val_MSMARCO.json', 'w') as json_file:
    json.dump(val_data, json_file, indent=2)
with open(wow_path + '/WoW_raw_msmarco_format/eli5_test_MSMARCO.json', 'w') as json_file:
    json.dump(test_data, json_file, indent=2)

### Define Util Functions

In [114]:
def get_distractor(target_key, dataset):
    # get a distractor that is not the given answer and also not "No Answer Present."
    data = dataset
    all_keys = list(data['answers'].keys())
    all_keys.remove(target_key)
    idx = random.choice(all_keys)
    if data['answers'][idx] == 'No Answer Present.':
        return get_distractor()
    else:
        return data['answers'][idx]

In [111]:
def create_marco_dataset(data, categorical = True):
    all_keys = list(data['answers'].keys())
    new_data = []
    for key in tqdm(data['answers']):
        data_chunk = {}
        q_id = data['query_id'][key]
        question = data['query'][key]
        answers = data['answers'][key]
        better_answers = data['wellFormedAnswers'][key]
        contexts = data['passages'][key]
        question_type = data['query_type'][key]
        
        if categorical:
            category = data['category'][key]

        # if there exists a well-formed answer, use that one. Otherwise just use the answers provided.
        if isinstance(better_answers, list) and len(better_answers)!=0:
            new_answer = better_answers[0]
        else:
            new_answer = answers

        # if there's multiple answers and a person has not provided a correct good response, skip this question
        if len(answers) > 1 and not isinstance(better_answers, list):
            continue

        # if the model can't extract the answer, at least be nice.
        if new_answer == 'No Answer Present.':
            new_answer = "I'm sorry, I don't know."

        data_chunk['query_id'] = q_id
        data_chunk['context'] = contexts
        data_chunk['query_type'] = question_type
        
        # if the dataset has category infromation
        if categorical:
            data_chunk['category'] = category
        
        
        data_chunk['utterances'] = []

        if new_answer == "I'm sorry, I don't know.":
            candidates = [get_distractor(key, data), get_distractor(key, data), new_answer]
        else:
            candidates = ["I'm sorry, I don't know.", get_distractor(key, data), new_answer]
        data_chunk['utterances'].append({
            'history': [question],
            'candidates': candidates
        })

        new_data.append(data_chunk)
    return new_data

In [112]:
get_distractor(val_data)

'yea well the exact origins of the phenomenon remain contested, it is agreed that it developed in the 70s, and it was centered in the UK'

### Create and Save MSMarco Data Splits

In [127]:
wow_train = create_marco_dataset(train_data)
wow_val = create_marco_dataset(val_data)
wow_test = create_marco_dataset(test_data)

  0%|          | 0/45906 [00:00<?, ?it/s]

  0%|          | 0/1147 [00:00<?, ?it/s]

  0%|          | 0/10330 [00:00<?, ?it/s]

In [118]:
wow_train[1012]

{'query_id': '5313d5f8-99d6-11ea-8a20-773209e30a7b_2',
 'context': [{'passage_text': 'Veganism is the practice of abstaining from the use of animal products, particularly in diet, and an associated philosophy that rejects the commodity status of animals. A follower of the diet or the philosophy is known as a vegan. Distinctions may be made between several categories of veganism. "Dietary vegans" (also known as strict vegetarians) refrain from consuming animal products, not only meat but also eggs, dairy products and other animal-derived substances. The term "ethical vegan" is often applied to those who not only follow a vegan diet but extend the philosophy into other areas of their lives, and oppose the use of animals for any purpose. Another term is "environmental veganism", which refers to the avoidance of animal products on the premise that the industrial farming of animals is environmentally damaging and unsustainable.\n',
   'is_selected': 1}],
 'query_type': 'Veganism',
 'categor

In [128]:
with open(wow_path + '/WoW_model_files/wow_train.json', 'w') as json_file:
    json.dump(wow_train, json_file, indent=2)

with open(wow_path + '/WoW_model_files/wow_val.json', 'w') as json_file:
    json.dump(wow_val, json_file, indent=2)
    
with open(wow_path + '/WoW_model_files/wow_test.json', 'w') as json_file:
    json.dump(wow_test, json_file, indent=2)

### Tokenize Dataset

In [129]:
def tokenize(obj, tokenizer):
    for i in range(len(obj['context'])):
        obj['context'][i]['passage_text'] = tokenizer.encode(obj['context'][i]['passage_text'])
    for i in range(len(obj['utterances'])):
        obj['utterances'][i]['history'] = [tokenizer.encode(x.lstrip(' ()_')) for x in obj['utterances'][i]['history']]
        for j in range(len(obj['utterances'][i]['candidates'])):
            obj['utterances'][i]['candidates'][j] = tokenizer.encode(obj['utterances'][i]['candidates'][j])

In [132]:
tokenizer_path = "/ssd003/projects/aieng/conversational_ai/shirley_checkpoints/marco_dialogptmedium/"
tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)

In [133]:
for i in tqdm(range(len(wow_train))):
    tokenize(wow_train[i], tokenizer)

  0%|          | 0/45906 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1108 > 1024). Running this sequence through the model will result in indexing errors


In [134]:
for i in tqdm(range(len(wow_val))):
    tokenize(wow_val[i], tokenizer)

  0%|          | 0/1147 [00:00<?, ?it/s]

In [135]:
for i in tqdm(range(len(wow_test))):
    tokenize(wow_test[i], tokenizer)

  0%|          | 0/10330 [00:00<?, ?it/s]

In [136]:
with open(wow_path + '/WoW_tokenized/wow_train_tokenized.json', 'w') as json_file:
    json.dump(wow_train, json_file, indent=2)

with open(wow_path + '/WoW_tokenized/wow_valid_tokenized.json', 'w') as json_file:
    json.dump(wow_val, json_file, indent=2)
    
with open(wow_path + '/WoW_tokenized/wow_test_tokenized.json', 'w') as json_file:
    json.dump(wow_test, json_file, indent=2)