# Load Libs

In [2]:
import os
import sys
import random
import json
import collections
import re

import keras
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds

from official import nlp
from official.nlp import bert
import official.nlp.bert.tokenization
import official.nlp.bert.configs

from squad_preprocess import convert_squad_to_features

import pickle

# Global Variables

### Word Limit

In [None]:
WORD_LIMIT = 350

### Pretrained Model Dir

In [None]:
BERT_PRETRAINED_DIR = "bert/pretrained/model"

# Load Tokenizer

In [None]:
tokenizer = bert.tokenization.FullTokenizer(
    vocab_file=os.path.join(BERT_PRETRAINED_DIR, "vocab.txt"),
     do_lower_case=True)

# Load SQuAD v2.0 Dataset

In [None]:
sq = tfds.question_answering.squad.Squad()
sq = tfds.question_answering.squad.Squad(config=sq.BUILDER_CONFIGS[1])
squad_v2 = sq.as_dataset()

# Prepare Data

In [None]:
def extract_examples(squad_dataset, max_iter=None):

    Example = collections.namedtuple("Example", ["doc_tokens", "is_impossible",
                                 'start_position', 'end_position', 
                                 'orig_answer_text', 'question_text'])
    
    if max_iter is None:
        max_iter = len(squad_dataset)
    
    examples = []

    for i, s in enumerate(squad_dataset.as_numpy_iterator()):
        
        if i == max_iter:
            break

        question_text = s['question']
        is_impossible = ['is_impossible']
        context = s['context'].decode("utf-8").lower()

        doc_tokens = re.split('[\s]+', context)
        if is_impossible:
            start_position = -1
            end_position = -1
            orig_answer_text = ''
        else:
            orig_answer_text = s['answers']['text'][0]

            delimiter = '<HERE!>'
            start_index = s['answers']['answer_start'][0]
            text = s['context'].decode("utf-8").lower()
            text = text[:start_index] + delimiter + text[start_index:]
            words = re.split('[\s]+', text)
            for word_index, word in enumerate(words):
                if delimiter in word:
                    start_position = word_index
                    answer_length = len(re.split('[\s]+', orig_answer_text.decode("utf-8").lower()))
                    end_position = word_index + answer_length

        example = Example(is_impossible=is_impossible, doc_tokens=doc_tokens, 
                           orig_answer_text=orig_answer_text, start_position=start_position, 
                           end_position=end_position, question_text=question_text)

        examples.append(example)

        # shows progress
        if (len(examples) + 1) % 500 == 0:
                from IPython.display import clear_output
                clear_output(wait=True)
                print('{0:.2f}%'.format(round(len(examples) / max_iter, 4) * 100))
    
    return examples

## Train Data

### Load samples

In [None]:
train_samples = squad_v2['train']

In [None]:
examples = extract_examples(train_samples)

In [None]:
features = convert_squad_to_features(examples, tokenizer, WORD_LIMIT, 300, 50, True)

In [None]:
with open('squad_features_{}_{}.pkl'.format(WORD_LIMIT, BERT_PRETRAINED_DIR.split('_')[-1]), 'wb') as input:
    pickle.dump(features, input)

In [None]:
import numpy as np

input_ids = []
input_masks = []
input_types = []

labels = []

for feature in features:
    
    input_ids.append(feature['input_ids'])
    input_masks.append(feature['input_mask'])
    input_types.append(feature['segment_ids'])
    
    labels.append([feature['start_position'], feature['end_position']])
    
squad_f = np.asarray([input_ids, input_types, input_masks])
labels = np.asarray(labels)

In [None]:
np.save('squad_feats_{}_{}'.format(WORD_LIMIT, BERT_PRETRAINED_DIR.split('_')[-1]), squad_f)
np.save('squad_labs_{}_{}'.format(WORD_LIMIT, BERT_PRETRAINED_DIR.split('_')[-1]), labels)

## Validation Data

In [None]:
def get_doc_id(content, docs):
    keys_list = list(docs.keys())
    vals_list = list(docs.values())

    key_val_index = vals_list.index(content)

    return keys_list[key_val_index]

def bin_to_str(binary_tf):

    binary_tf = binary_tf.numpy()

    if type(binary_tf) is bytes:
        return binary_tf.decode("utf-8")
    else:
        result = []
        for bin_tf in binary_tf:
            result.append(bin_tf.decode("utf-8"))

        return result

def questions_and_contexts(dataset):
    docs = {}
    collection = []
    already_in = []

    for doc_id, doc in enumerate(dataset):
        content = None
        current_id = doc_id
        ques_and_ans = {}

        content = bin_to_str(doc['context'])
        if content in already_in:
            current_id = get_doc_id(content, docs)
        else:
            already_in.append(content)
            docs[doc_id] = content

        ques_and_ans['question'] = bin_to_str(doc['question'])
        ques_and_ans['id'] = current_id
        ques_and_ans['answers'] = bin_to_str(doc['answers']['text'])

        ques_and_ans['label'] = doc['is_impossible'].numpy()
        ques_and_ans['plaus_answers'] = bin_to_str(
            doc['plausible_answers']['text'])

        collection.append(ques_and_ans)
        
        # shows progress
        if (doc_id + 1) % 500 == 0:
                from IPython.display import clear_output
                clear_output(wait=True)
                print('{0:.2f}%'.format(round(doc_id / len(dataset), 4) * 100))

    return collection, docs

### Load Samples

In [None]:
val_samples = squad_v2['validation']

In [None]:
val_questions, val_contexts = questions_and_contexts(val_samples)

### Save Samples

In [None]:
with open('squad_val_questions.pkl', 'wb') as input:
    pickle.dump(val_questions, input)
    
with open('squad_val_contexts.pkl', 'wb') as input:
    pickle.dump(val_contexts, input)

# Lucene Index JSON

In [None]:
def bin_to_str(binary_tf):

    binary_tf = binary_tf.numpy()

    if type(binary_tf) is bytes:
        return binary_tf.decode("utf-8")
    else:
        result = []
        for bin_tf in binary_tf:
            result.append(bin_tf.decode("utf-8"))

        return result

def collect_contexts(dataset, start_docId):
    docs = {}
    all_contexts = set()

    for doc_id, doc in enumerate(dataset):
        content = None
        current_id = doc_id + start_docId
        
        content = bin_to_str(doc['context'])
        
        if content not in all_contexts:
            all_contexts.add(content)
            docs[doc_id] = content
        
        # shows progress
        if (doc_id + 1) % 500 == 0:
                from IPython.display import clear_output
                clear_output(wait=True)
                print('{0:.2f}%'.format(round(doc_id / len(dataset), 4) * 100))

    return docs

## List of Documents

In [None]:
# a document contains a docId and content
all_documents = []

### Add Validation Contexts

In [None]:
for docId, content in val_contexts.items():
    docum = {}
    docum['id'] = docId
    docum['contents'] = content
    
    all_documents.append(docum)

In [None]:
val_last_docId = max(val_contexts.keys())

### Add Train Contexts

In [None]:
train_contexts = collect_contexts(train_samples, val_last_docId)

In [None]:
for docId, content in train_contexts.items():
    docum = {}
    docum['id'] = docId
    docum['contents'] = content
    
    all_documents.append(docum)

## Encode as JSON

In [None]:
encoder = json.encoder.JSONEncoder()

In [None]:
json_string = encoder.encode(all_documents)

### Save in file

In [None]:
with open('documents.json', 'w') as file:
    file.writelines(json_string)

### To create Index

Run the command below in the terminal, replacing Index_dir with where the Index will be stored and JSON_dir with the directory the json file is in

python3 -m pyserini.index -collection JsonCollection -generator DefaultLuceneDocumentGenerator  -threads 2 -input JSON_dir -index Index_dir -storePositions -storeDocvectors -storeRaw