In [1]:
from collections import defaultdict
from itertools import chain
import json
import ijson 
import argparse
import random
from argparse import Namespace
from tqdm.notebook import tqdm
import torch
from transformers import GPT2Tokenizer
import sys
sys.path.insert(0,'..')
from utils import *

In [2]:
path_prefix = "/ssd003/projects/aieng/conversational_ai/data/AmazonQA"

In [3]:
tokenizer_path = "/ssd003/projects/aieng/conversational_ai/shirley_checkpoints/marco_dialogptmedium/"

## Load Dataset

In [4]:
with open(path_prefix + "/raw/train-qar_msmarco.jsonl") as f:
    train_data = [json.loads(line) for line in f]
    
with open(path_prefix + "/raw/val-qar_msmarco.jsonl") as f:
    val_data = [json.loads(line) for line in f]

In [5]:
print('Train Size: ', len(train_data))
print('Validation Size: ', len(val_data))

Train Size:  455931
Validation Size:  58969


## Building the QA Dataset

1- In this section, we first convert the original format to MSMARCO format. Data will be saved in <b>AmazonQA_raw_msmarco_format</b> folder.

2- Then we conert the result from part 1 to QA dataset which can be used to train/validate the GPT model. Data will be saved in <b>AmazonQA_model_files</b> folder.

In [6]:
def inverse_dataStruct(source):
    result = {'answers':{}, 'passages':{}, 'query':{}, 'query_id':{}, 'query_type':{}, 'wellFormedAnswers':{}}
    for i in range(len(source)):
        result['answers'].update({i:source[i]['answers']})
        result['passages'].update({i:source[i]['passages']})
        result['query'].update({i:source[i]['query']})
        result['query_id'].update({i:source[i]['query_id']})
        result['query_type'].update({i:source[i]['query_type']})
        result['wellFormedAnswers'].update({i:source[i]['wellFormedAnswers']})
    return result

In [7]:
train_data = inverse_dataStruct(train_data)
val_data = inverse_dataStruct(val_data)

In [8]:
with open(path_prefix + '/AmazonQA_raw_msmarco_format/amazon_train_MSMARCO.json', 'w') as json_file:
    json.dump(train_data, json_file, indent=2)
with open(path_prefix + '/AmazonQA_raw_msmarco_format/amazon_valid_MSMARCO.json', 'w') as json_file:
    json.dump(val_data, json_file, indent=2)

In [9]:
amazon_train = create_marco_format_QAdataset(train_data, categorical= False)
with open(path_prefix + '/AmazonQA_model_files/amazon_train.json', 'w') as json_file:
    json.dump(amazon_train, json_file, indent=2)

amazon_val = create_marco_format_QAdataset(val_data, categorical= False)
with open(path_prefix + '/AmazonQA_model_files/amazon_valid.json', 'w') as json_file:
    json.dump(amazon_val, json_file, indent=2)

  0%|          | 0/455931 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
print('Sample record for training neural model:')
amazon_train[0]

## Tokenize Dataset

In [None]:
def tokenize(obj, tokenizer):
    for i in range(len(obj['context'])):
        obj['context'][i]['passage_text'] = tokenizer.encode(obj['context'][i]['passage_text'])
    for i in range(len(obj['utterances'])):
        obj['utterances'][i]['history'] = [tokenizer.encode(x.lstrip(' ()_')) for x in obj['utterances'][i]['history']]
        for j in range(len(obj['utterances'][i]['candidates'])):
            obj['utterances'][i]['candidates'][j] = tokenizer.encode(obj['utterances'][i]['candidates'][j])

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)

In [None]:
for i in tqdm(range(len(amazon_train))):
    tokenize(amazon_train[i], tokenizer)

In [None]:
for i in tqdm(range(len(amazon_val))):
    tokenize(amazon_val[i], tokenizer)

In [None]:
with open(path_prefix + '/AmazonQA_tokenized/amazon_train_tokenized.json', 'w') as json_file:
    json.dump(amazon_train, json_file, indent=2)

with open(path_prefix + '/AmazonQA_tokenized/amazon_valid_tokenized.json', 'w') as json_file:
    json.dump(amazon_val, json_file, indent=2)

In [9]:
with open(path_prefix + '/AmazonQA_tokenized/amazon_valid_tokenized.json') as f:
#     train_data = [json.loads(line) for line in f]
    a = json.load(f)


In [10]:
a[0]

{'context': [{'is_selected': 1,
   'url': '',
   'passage_text': [40,
    1422,
    470,
    760,
    777,
    11196,
    1566,
    314,
    8672,
    616,
    4957,
    287,
    13708,
    11,
    290,
    777,
    3420,
    301,
    2840,
    547,
    319,
    607,
    8215,
    13,
    4362,
    262,
    4258,
    318,
    523,
    11607,
    673,
    1364,
    607,
    8215,
    1280,
    3690,
    607,
    2156,
    11,
    290,
    777,
    3420,
    301,
    2840,
    4030,
    606,
    422,
    19280,
    4838,
    13,
    554,
    11287,
    11,
    262,
    4258,
    2125,
    470,
    523,
    11607,
    11,
    475,
    314,
    750,
    423,
    257,
    3155,
    286,
    8215,
    326,
    20780,
    12271,
    561,
    588,
    284,
    6611,
    4838,
    618,
    314,
    550,
    262,
    9168,
    1280,
    11,
    523,
    314,
    6149,
    777,
    3420,
    301,
    2840,
    329,
    606,
    13,
    1119,
    389,
    10974,
    2759,
    7613,
    290,
    67