In [3]:
import os
import sys
import random
import argparse
import json
import nltk
import numpy as np
from tqdm import tqdm
random.seed(42)
np.random.seed(42)
import nltk
nltk.download('punkt')

In [4]:
def total_exs(dataset):
    """
    Returns the total number of (context, question, answer) triples,
    given the data read from the SQuAD json file.
    """
    total = 0
    for article in dataset['data']:
        for para in article['paragraphs']:
            total += len(para['qas'])
    return total


def data_from_json(filename):
    """Loads JSON data from filename and returns"""
    with open(filename) as data_file:
        data = json.load(data_file)
    return data


def tokenize(sequence):
    tokens = [token.replace("``", '"').replace("''", '"').lower() for token in nltk.word_tokenize(sequence)]
    return tokens


def get_char_word_loc_mapping(context, context_tokens):
    """
    Return a mapping that maps from character locations to the corresponding token locations.
    If we're unable to complete the mapping e.g. because of special characters, we return None.

    Inputs:
      context: string (unicode)
      context_tokens: list of strings (unicode)

    Returns:
      mapping: dictionary from ints (character locations) to (token, token_idx) pairs
        Only ints corresponding to non-space character locations are in the keys
        e.g. if context = "hello world" and context_tokens = ["hello", "world"] then
        0,1,2,3,4 are mapped to ("hello", 0) and 6,7,8,9,10 are mapped to ("world", 1)
    """
    acc = '' # accumulator
    current_token_idx = 0 # current word loc
    mapping = dict()

    for char_idx, char in enumerate(context): # step through original characters
        if char != u' ' and char != u'\n': # if it's not a space:
            acc += char # add to accumulator
            context_token = str(context_tokens[current_token_idx]) # current word token
            if acc == context_token: # if the accumulator now matches the current word token
                syn_start = char_idx - len(acc) + 1 # char loc of the start of this word
                for char_loc in range(syn_start, char_idx+1):
                    mapping[char_loc] = (acc, current_token_idx) # add to mapping
                acc = '' # reset accumulator
                current_token_idx += 1

    if current_token_idx != len(context_tokens):
        return None
    else:
        return mapping

def write_to_file(out_file, line):
    #out_file.write(line.encode('utf8') + '\n'.encode('utf8'))
    out_file.write(str(line) + '\n')

In [19]:
def preprocess_and_write(dataset, tier, out_dir):
    """Reads the dataset, extracts context, question, answer, tokenizes them,
    and calculates answer span in terms of token indices.
    Note: due to tokenization issues, and the fact that the original answer
    spans are given in terms of characters, some examples are discarded because
    we cannot get a clean span in terms of tokens.

    This function produces the {train/dev}.{context/question/answer/span} files.

    Inputs:
      dataset: read from JSON
      tier: string ("train" or "dev")
      out_dir: directory to write the preprocessed files
    Returns:
      the number of (context, question, answer) triples written to file by the dataset.
    """

    num_exs = 0 # number of examples written to file
    num_mappingprob, num_tokenprob, num_spanalignprob, num_empty_charloc = 0, 0, 0,0
    examples = []

    for articles_id in tqdm(range(len(dataset['data'])), desc="Preprocessing {}".format(tier)):
        print('id=',articles_id)
        article_paragraphs = dataset['data'][articles_id]['paragraphs']
        for pid in range(len(article_paragraphs)):
            context = str(article_paragraphs[pid]['context']) # string

            # The following replacements are suggested in the paper
            # BidAF (Seo et al., 2016)
            context = context.replace("''", '" ')
            context = context.replace("``", '" ')

            context_tokens = tokenize(context) # list of strings (lowercase)
            context = context.lower()
            
            qas = article_paragraphs[pid]['qas']
                
            article_paragraphs = dataset['data'][articles_id]['paragraphs']

            charloc2wordloc = get_char_word_loc_mapping(context, context_tokens) # charloc2wordloc maps the character location (int) of a context token to a pair giving (word (string), word loc (int)) of that token
            '''
            if charloc2wordloc is None: # there was a problem
                num_mappingprob += len(qas)
                continue # skip this context example
            '''
            # for each question, process the question and answer and write to file
            for qn in qas:
                
                question = str(qn['question']) # string
                question_tokens = tokenize(question) # list of strings

                # of the three answers, just take the first
                
                if qn['is_impossible']==False:
                    impossible = 0
                    artifical_answer = 0
                    ans_text = str(qn['answers'][0]['text']).lower() 
                    ans_start_charloc = qn['answers'][0]['answer_start'] # answer start loc (character count)
                    ans_end_charloc = ans_start_charloc + len(ans_text) # answer end loc (character count) (exclusive)

                    # Check that the provided character spans match the provided answer text

                    if str(context[ans_start_charloc:ans_end_charloc]) != str(ans_text):
                        # Sometimes this is misaligned, mostly because "narrow builds" of Python 2 interpret certain Unicode characters to have length 2 https://stackoverflow.com/questions/29109944/python-returns-length-of-2-for-single-unicode-character-string
                        # We should upgrade to Python 3 next year!
                        num_spanalignprob += 1      
                        continue

                    # get word locs for answer start and end (inclusive)
                    #if (articles_id==14):
                        #print('pid=',pid)
                        #print(charloc2wordloc)
                    
                    if charloc2wordloc is None:
                        num_empty_charloc+=1
                        continue
                        
                    ans_start_wordloc = charloc2wordloc[ans_start_charloc][1] # answer start word loc
                    ans_end_wordloc = charloc2wordloc[ans_end_charloc-1][1] # answer end word loc
                    assert ans_start_wordloc <= ans_end_wordloc

                    # Check retrieved answer tokens match the provided answer text.
                    # Sometimes they won't match, e.g. if the context contains the phrase "fifth-generation"
                    # and the answer character span is around "generation",
                    # but the tokenizer regards "fifth-generation" as a single token.
                    # Then ans_tokens has "fifth-generation" but the ans_text is "generation", which doesn't match.
                    ans_tokens = context_tokens[ans_start_wordloc:ans_end_wordloc+1]
                    
                    
                    if str(context[ans_start_charloc:ans_end_charloc]) != str(ans_text):
                        # Sometimes this is misaligned, mostly because "narrow builds" of Python 2 interpret certain Unicode characters to have length 2 
                        #https://stackoverflow.com/questions/29109944/python-returns-length-of-2-for-single-unicode-character-string
                        # We should upgrade to Python 3 next year!
                        num_spanalignprob += 1      
                        continue
                        
                    ans_end_charloc = ans_start_charloc + len(ans_text) # answer end loc (character count) (exclusive)
                    
                    
                    # get word locs for answer start and end (inclusive)
                    ans_start_wordloc = charloc2wordloc[ans_start_charloc][1] # answer start word loc
                    ans_end_wordloc = charloc2wordloc[ans_end_charloc-1][1] # answer end word loc
                    assert ans_start_wordloc <= ans_end_wordloc
                    
                    
                    if "".join(ans_tokens) != "".join(ans_text.split()):
                        num_tokenprob += 1
                        continue # skip this question/answer pair if 
                    
                    
                else: 
                    impossible = 1
                    #print(qn['plausible_answers'])
                    if len(qn['plausible_answers'])>0:
                        artifical_answer = 1
                        ans_text = str(qn['plausible_answers'][0]['text']).lower()    
                        ans_start_charloc = qn['plausible_answers'][0]['answer_start'] # answer start loc (character count)
                    
                    else: #empty implausible questions, meaning questions with no artificial answers. 
                        artifical_answer = 0
                        ans_text = 'N/A'
                        ans_start_charloc = 'N/A'
                        ans_end_charloc = 'N/A'
                        ans_start_wordloc = 0.5
                        ans_end_wordloc = 0.5
                        num_mappingprob+=1

                
                
                examples.append((' '.join(context_tokens), ' '.join(question_tokens), ' '.join(ans_tokens), ' '.join([str(ans_start_wordloc), str(ans_end_wordloc)]),' '.join(str(impossible)),' '.join(str(artifical_answer ))))

                num_exs += 1

    print("Number of (context, question, answer) triples discarded due to char -> token mapping problems: ", num_mappingprob)
    print("Number of (context, question, answer) triples discarded because character-based answer span is unaligned with tokenization: ", num_tokenprob)
    print("Number of (context, question, answer) triples discarded due character span alignment problems (usually Unicode problems): ", num_spanalignprob)
    print("Processed %i examples of total %i\n" % (num_exs, num_exs + num_mappingprob + num_tokenprob + num_spanalignprob))
    

    indices = list(range(len(examples)))
    #np.random.shuffle(indices)

    with open(os.path.join(out_dir, tier +'.context'), 'w') as context_file,\
    open(os.path.join(out_dir, tier +'.question'), 'w') as question_file,\
    open(os.path.join(out_dir, tier +'.answer'), 'w') as ans_text_file, \
    open(os.path.join(out_dir, tier +'.span'), 'w') as span_file,\
    open(os.path.join(out_dir, tier +'.impossible'), 'w') as impossible_boolean_file,\
    open(os.path.join(out_dir, tier +'.artificial_answer'), 'w') as artificial_boolean_file:

        for i in indices:
            (context, question, answer, answer_span, impossible, artifical_answer) = examples[i]

            # write tokenized data to file
            write_to_file(context_file, context)
            write_to_file(question_file, question)
            write_to_file(ans_text_file, answer)
            write_to_file(span_file, answer_span)
            write_to_file(impossible_boolean_file, impossible)
            write_to_file(artificial_boolean_file, artifical_answer)

### Now apply the Preprocessing onto Squad 2.0

In [21]:
train_filename = "train-v2.0.json"
dev_filename = "dev-v2.0.json"
data_dir = '/Users/kefei/Documents/DSML/NLP/Project'

train_data = data_from_json(train_filename)
print("Train data has %i examples total" % total_exs(train_data))
#preprocess_and_write(train_data, 'train', data_dir)

dev_data = data_from_json(dev_filename)
print("Dev data has %i examples total" % total_exs(dev_data))

# preprocess dev set and write to file
preprocess_and_write(dev_data, 'dev', data_dir)


Preprocessing dev:   0%|          | 0/35 [00:00<?, ?it/s]

Train data has 130319 examples total
Dev data has 11873 examples total
id= 0
id= 1


Preprocessing dev:   9%|▊         | 3/35 [00:00<00:04,  7.57it/s]

id= 2
id= 3


Preprocessing dev:  14%|█▍        | 5/35 [00:00<00:03,  8.51it/s]

id= 4
id= 5


Preprocessing dev:  20%|██        | 7/35 [00:00<00:04,  6.74it/s]

id= 6
id= 7


Preprocessing dev:  26%|██▌       | 9/35 [00:01<00:03,  7.06it/s]

id= 8
id= 9


Preprocessing dev:  31%|███▏      | 11/35 [00:01<00:04,  5.56it/s]

id= 10
id= 11


Preprocessing dev:  37%|███▋      | 13/35 [00:01<00:03,  6.77it/s]

id= 12
id= 13


Preprocessing dev:  43%|████▎     | 15/35 [00:02<00:02,  7.55it/s]

id= 14
id= 15


Preprocessing dev:  49%|████▊     | 17/35 [00:02<00:02,  6.87it/s]

id= 16
id= 17


Preprocessing dev:  57%|█████▋    | 20/35 [00:02<00:02,  7.46it/s]

id= 18
id= 19
id= 20


Preprocessing dev:  63%|██████▎   | 22/35 [00:02<00:01,  8.39it/s]

id= 21
id= 22


Preprocessing dev:  69%|██████▊   | 24/35 [00:03<00:01,  6.80it/s]

id= 23
id= 24


Preprocessing dev:  71%|███████▏  | 25/35 [00:03<00:01,  5.87it/s]

id= 25


Preprocessing dev:  74%|███████▍  | 26/35 [00:03<00:02,  4.08it/s]

id= 26


Preprocessing dev:  77%|███████▋  | 27/35 [00:04<00:02,  3.05it/s]

id= 27


Preprocessing dev:  80%|████████  | 28/35 [00:04<00:02,  2.97it/s]

id= 28


Preprocessing dev:  86%|████████▌ | 30/35 [00:05<00:01,  3.33it/s]

id= 29
id= 30


Preprocessing dev:  89%|████████▊ | 31/35 [00:05<00:01,  3.20it/s]

id= 31


Preprocessing dev:  91%|█████████▏| 32/35 [00:05<00:00,  3.56it/s]

id= 32


Preprocessing dev:  97%|█████████▋| 34/35 [00:06<00:00,  3.94it/s]

id= 33
id= 34


Preprocessing dev: 100%|██████████| 35/35 [00:06<00:00,  4.52it/s]


Number of (context, question, answer) triples discarded due to char -> token mapping problems:  15
Number of (context, question, answer) triples discarded because character-based answer span is unaligned with tokenization:  74
Number of (context, question, answer) triples discarded due character span alignment problems (usually Unicode problems):  0
Processed 11799 examples of total 11888



In [None]:
'''
Codes used to debug!


dataset = dev_data
articles_id=7
article_paragraphs = dataset['data'][articles_id]['paragraphs']
pid = 14

context = str(article_paragraphs[pid]['context']) # string

# The following replacements are suggested in the paper
# BidAF (Seo et al., 2016)
context = context.replace("''", '" ')
context = context.replace("``", '" ')

context_tokens = tokenize(context) # list of strings (lowercase)
context = context.lower()

qas = article_paragraphs[pid]['qas']
qn = qas[0]
question = str(qn['question']) # string
question_tokens = tokenize(question) # list of strings
ans_start_charloc = qn['answers'][0]['answer_start'] # answer start loc (character count)

ans_text = str(qn['answers'][0]['text']).lower()
ans_end_charloc = ans_start_charloc + len(ans_text) 

question = str(qn['question'])

charloc2wordloc = get_char_word_loc_mapping(context, context_tokens)
ans_start_wordloc = charloc2wordloc[ans_start_charloc][1] # answer start word loc
ans_end_wordloc = charloc2wordloc[ans_end_charloc-1][1] # answer end word loc
ans_tokens = context_tokens[ans_start_wordloc:ans_end_wordloc+1]
'''