# Load Libs

In [2]:
import os
import sys
import random
import json
import re
import collections

import numpy as np

import keras
import tensorflow as tf

from official.nlp import bert
import official.nlp.bert.configs
import official.nlp.bert.tokenization

from keras_bert.loader import load_trained_model_from_checkpoint

from keras.optimizers import Adam
from keras.models import Model
from keras.layers import Dense

from squad_test import compute_f1, compute_exact

from bert_test import Bert

import pickle as pkl

# Global Variables

### Pretrained Directory

In [None]:
BERT_PRETRAINED_DIR = "bert/model/pretrained"

### Token Limit

In [None]:
TOKEN_LIMIT = 350

### Checkpoint Directory

In [None]:
CHKPT_SAVE_DIR = 'bert/chkpt/save/dir/'

# Load Tokenizer

In [None]:
# Set up tokenizer to generate Tensorflow dataset
tokenizer = bert.tokenization.FullTokenizer(
    vocab_file=os.path.join(BERT_PRETRAINED_DIR, "vocab.txt"),
     do_lower_case=True)

# Create Model

### Hyper-Parameters

In [None]:
learning_rate = 5e-5
batch_size = 8

### Set Config files

In [None]:
config_file = os.path.join(BERT_PRETRAINED_DIR, 'bert_config.json')

### Set Checkpoint file

In [None]:
checkpoint_file = os.path.join(BERT_PRETRAINED_DIR, 'bert_model.ckpt')

### Load Pretrained Model

In [None]:
bert_pretrained = load_trained_model_from_checkpoint(config_file, checkpoint_file, 
                                                     training=True, seq_len=TOKEN_LIMIT)

### Add Dense layer to end of Encoder

In [None]:
last_encoder_layer = -7

sequence_output = bert_pretrained.layers[last_encoder_layer].output

pool_output = Dense(2, kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02),
                    name = 'real_output')(sequence_output)

bert_model = Model(inputs=bert_pretrained.input, outputs=pool_output)

### Optimizer

In [None]:
decay = 0.01 # same as used in BERT paper

adam = Adam(lr=learning_rate, decay=decay)

### Softmax CrossEntropy Loss

In [None]:
class BertSquadError(tf.keras.losses.Loss):

    '''
        positions: tensor of size batch_size x 2; [answer_start_index, answer_end_index]
        logits: tensor of size batch_size x max_tokens x 2
    '''
    def call(self, positions, logits):
        
        logits = tf.transpose(logits, [0, 2, 1])
    
        # logits' shape is [2, squence_length]
        start_logits = logits[:, 0]
        end_logits = logits[:, 1]

        one_hot_positions = tf.one_hot(
            positions, depth=logits.shape[2], dtype=tf.float32)
        # one_hot_positions' shape is [2, squence_length]
        start_positions = one_hot_positions[:, 0]
        end_positions = one_hot_positions[:, 1]

        log_probs = tf.nn.log_softmax(start_logits, axis=-1)
        loss_start = -tf.reduce_mean(tf.reduce_sum(start_positions * log_probs, axis=-1),axis=-1)

        log_probs = tf.nn.log_softmax(end_logits, axis=-1)
        loss_end = -tf.reduce_mean(tf.reduce_sum(end_positions * log_probs, axis=-1),axis=-1)

        loss_total = tf.reduce_mean([loss_start, loss_end])

        return loss_total


### Compile Model

In [None]:
bert_model.compile(loss=BertSquadError(), optimizer=adam)

### Summary

In [None]:
bert_model.summary()

# Load Data

In [None]:
train = np.load('squad_feats.npy')
labels = np.load('squad_labs.npy')

## Create Dataset Instance

In [None]:
squad_train_data = tf.data.Dataset.from_tensor_slices(({"Input-Token": train[0],
                                   "Input-Segment": train[1],
                                   "Input-Masked": train[2]},
                                  labels))

## Set Batch Size

In [None]:
squad_train_data = squad_train_data.batch(batch_size)

## Train

In [None]:
bert_model.fit(squad_train_data, verbose=1, epochs=2)

### Save Checkpoint

In [None]:
bert_model.save_weights(CHKPT_SAVE_DIR)

# Validation

In [None]:
def model_validation(bert, collection, docs, v2=True, max_iter=None):

    contents = []
    questions = []
    new_col = []

    for j in range(len(collection)):
        if collection[j]['label'] and not v2:
            continue

        questions.append(collection[j]['question'])
        contents.append(docs[collection[j]['id']])
        new_col.append(collection[j])
        
    collection = new_col
    
    if max_iter is None:
        max_iter = len(collection)

    data, indices = bert.encode_contents(questions, contents)

    bert_results = bert.predict(questions, contents, data, indices)

    em_pred = []
    f1_scores = []

    for j in range(max_iter):

        answers = collection[j]['answers']

        results = bert_results[j]

        bert_ans = results[0]

        best_null = results[1] < results[2]

        if best_null and v2:
            em_pred.append(float(collection[j]['label']))
            f1_scores.append(float(collection[j]['label']))
        else:
            match = 0
            f1_s = 0
            for answer in answers:
                match = match or compute_exact(answer, bert_ans)
                new_f1, new_recall = compute_f1(answer, bert_ans)
                f1_s = max(f1_s, new_f1)

            em_pred.append(match)
            f1_scores.append(f1_s)

    return em_pred, f1_scores

## Load Model

In [None]:
bert = Bert(TOKEN_LIMIT, BERT_PRETRAINED_DIR, CHKPT_SAVE_DIR)

## Load Samples

In [None]:
with open('squad_val_questions.pkl', 'rb') as input:
    questions = pkl.load(input)
    
with open('squad_val_contexts.pkl', 'rb') as input:
    contexts = pkl.load(input)

## Test With Impossible Questions

In [None]:
em_pred, f1_scores = model_validation(bert, questions[:10], contexts, v2=True)

In [None]:
print('The v2.0 EM score for this model is: {}%'.format(round(np.mean(em_pred), 4) * 100))

In [None]:
print('The v2.0 F1 score for this model is: {}%'.format(round(np.mean(em_pred), 4) * 100))

## Test Without Impossible Questions

In [None]:
em_pred, f1_scores = model_validation(bert, questions[:10], contexts, v2=False)

In [None]:
print('The v1.1 EM score for this model is: {}%'.format(round(np.mean(em_pred), 4) * 100))

In [None]:
print('The v1.1 F1 score for this model is: {}%'.format(round(np.mean(em_pred), 4) * 100))