# Overview
 - Paper [Link](https://arxiv.org/pdf/1901.08634.pdf)<br>
 - GitHub [Link](https://github.com/google-research/language/tree/master/language/question_answering/bert_joint)
 - Supported by Google Cloud Platform TPU instances. 
 - Training to Validation

In [1]:
import flags

FLAGS = flags.FLAGS
FLAGS.VERSION = 'BertLargeUncasedNQ-002'
FLAGS.DATA_SPLIT = False
FLAGS.TUNING_MODE = False
FLAGS.PREPROCESS = False
FLAGS.TRAININGS = False
FLAGS.do_valid = True


In [2]:
import numpy as np
import pandas as pd
import json, os, gc, sys, collections, itertools, datetime, logging, warnings, tqdm, jsonlines

import tensorflow as tf

sys.path.extend([FLAGS.SACREMOSES_PATH, FLAGS.TRANSFORMERS_PATH])
import sacremoses, transformers, tokenization

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
logging.getLogger("tensorflow").setLevel(logging.CRITICAL)
logging.getLogger("tensorflow_hub").setLevel(logging.CRITICAL)
warnings.filterwarnings('ignore')

# 1. Load Datasets

In [3]:
def read_data(path, sample = True, chunksize = 80000):
    if not sample:
        chunksize = 307373
    df = []
    with open(path, 'rt') as reader:
        for i in range(chunksize):
            df.append(json.loads(reader.readline()))
    df = pd.DataFrame(df)
    return df

In [4]:
%%time
################
# train data
################
if FLAGS.DATA_SPLIT:
    
    train = read_data(FLAGS.LOCAL_PATH+'/simplified-nq-train.jsonl', sample = FLAGS.TUNING_MODE)
    print("train shape", train.shape)
    display(train.head(5))
    

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 33.6 µs


In [5]:
%%time
################
# sample_submission data
################
sample_submission = pd.read_csv(FLAGS.LOCAL_PATH+'/sample_submission.csv')
print("Sample submission shape", sample_submission.shape)
display(sample_submission.head(2))


Sample submission shape (692, 2)


Unnamed: 0,example_id,PredictionString
0,-1011141123527297803_long,
1,-1011141123527297803_short,


CPU times: user 64 ms, sys: 0 ns, total: 64 ms
Wall time: 251 ms


# 2. Preprocessing

## 2.1 Split Dataset to Train & Validation

In [6]:
%%time
########################
# Train-Validation Split
########################
from sklearn.model_selection import StratifiedShuffleSplit

if FLAGS.TUNING_MODE:
    train_file = FLAGS.LOCAL_PATH+'/simplified-nq-train_Seed{}_Split{}_Fold{}-small.jsonl'.format(FLAGS.SEED, FLAGS.N_SPLITS, FLAGS.FOLD)
    valid_file = FLAGS.LOCAL_PATH+'/simplified-nq-valid_Seed{}_Split{}_Fold{}-small.jsonl'.format(FLAGS.SEED, FLAGS.N_SPLITS, FLAGS.FOLD)
else:
    train_file = FLAGS.LOCAL_PATH+'/simplified-nq-train_Seed{}_Split{}_Fold{}.jsonl'.format(FLAGS.SEED, FLAGS.N_SPLITS, FLAGS.FOLD)
    valid_file = FLAGS.LOCAL_PATH+'/simplified-nq-valid_Seed{}_Split{}_Fold{}.jsonl'.format(FLAGS.SEED, FLAGS.N_SPLITS, FLAGS.FOLD)

if FLAGS.DATA_SPLIT:
    
    from data_split import extract_target_variable, make_percentile
    
    #### Stratify Key ####
    has_long_answer = train.annotations.apply(
        lambda x: (x[0]['long_answer']['start_token'], x[0]['long_answer']['end_token'])).apply(
        lambda x: x if x != (-1, -1) else None).apply(
        lambda x: 1 if x else 0)
    has_short_answers = train.annotations.apply(
        lambda x: [(y['start_token'], y['end_token']) for y in x[0]['short_answers']]).apply(
        lambda x: len(x)).apply(
        lambda x: x > 0)
    has_yes_no = train.annotations.apply(
        lambda x: x[0]['yes_no_answer']).apply(
        lambda x: None if x == 'NONE' else x).apply(
        lambda x: x is not None)
    
    train['long_length'] = extract_target_variable(train, short=False)
    l_length_nonzero = train[train['long_length']>0]['long_length']
    q33, q66 = np.percentile(l_length_nonzero, [33, 66])
    long_l_percentile_id = train['long_length'].apply(lambda x: make_percentile(x, q33, q66))
    
    train['short_length'] = extract_target_variable(train, short=True)
    s_length_nonzero = train[train['short_length']>0]['short_length']
    q33, q66 = np.percentile(s_length_nonzero, [33, 66])
    short_l_percentile_id = extract_target_variable(train, short=True).apply(lambda x: make_percentile(x, q33, q66))
    
    stratify_key = has_long_answer.astype(int).astype(str) + \
                    has_short_answers.astype(int).astype(str) + \
                    has_yes_no.astype(int).astype(str) + \
                    long_l_percentile_id.astype(int).astype(str) + \
                    short_l_percentile_id.astype(int).astype(str)
    
    #### StratifiedShuffleSplit ####
    sss = StratifiedShuffleSplit(n_splits=FLAGS.N_SPLITS, random_state=FLAGS.SEED)
    trn_idx, val_idx = [
        (trn_idx, val_idx) for trn_idx, val_idx in sss.split(train, stratify_key)
    ][FLAGS.FOLD]
    
    train_df = train.loc[trn_idx].reset_index(drop=True)
    valid_df = train.loc[val_idx].reset_index(drop=True)
    del train
    gc.collect()
    
    #### Write ####
    train_dict = train_df.to_dict(orient='records')
    valid_dict = valid_df.to_dict(orient='records')
    
    with jsonlines.open(train_file, 'w') as f:
        f.write_all(train_dict)
    with jsonlines.open(valid_file, 'w') as f:
        f.write_all(valid_dict)


CPU times: user 8 ms, sys: 0 ns, total: 8 ms
Wall time: 5.28 ms


## 2.2 BERT-Joint Preprocessing

In [7]:
from tokenization import FullTokenizer
from multiprocessing import Pool
from preprocessing import ConvertExamples2Features, FeatureWriter, nq_examples_iter

tokenizer = FullTokenizer(
    vocab_file=FLAGS.TOKENIZER_MODEL_PATH, do_lower_case=True)

if FLAGS.TUNING_MODE:
    train_records = FLAGS.LOCAL_PATH+\
            '/nq-train_<{}>_Seed{}_Split{}_Fold{}-small.tfrecords'.format(
        FLAGS.MODEL_VERSION, FLAGS.SEED, FLAGS.N_SPLITS, FLAGS.FOLD
    )
    valid_records = FLAGS.LOCAL_PATH+\
            '/nq-valid_<{}>_Seed{}_Split{}_Fold{}-small.tfrecords'.format(
        FLAGS.MODEL_VERSION, FLAGS.SEED, FLAGS.N_SPLITS, FLAGS.FOLD
    )
else:
    train_records = FLAGS.LOCAL_PATH+'/nq-train_{}.tfrecords'.format(FLAGS.MODEL_VERSION)
    valid_records = FLAGS.LOCAL_PATH+'/nq-valid_{}.tfrecords'.format(FLAGS.MODEL_VERSION)
    
input_files = [train_file]
input_records = [train_records]

if FLAGS.PREPROCESS:
    
    for i, records in enumerate(input_records):
    
        writer = FeatureWriter(
            filename = os.path.join(records), 
            is_training = True)

        converter = ConvertExamples2Features(
            tokenizer=tokenizer,
            is_training=True,
            output_fn=writer.process_feature,
            collect_stat=False)
        
        n_examples = 0

        tqdm_notebook = tqdm.tqdm_notebook
        for examples in nq_examples_iter(input_file=input_files[i], is_training=True, tqdm=tqdm_notebook):
            for example in examples:
                n_examples += converter(example)

        writer.close()
        print('number of test examples: %d, written to file: %d' % (n_examples, writer.num_features))


# 3. Dataset Generator

In [8]:
def get_dataset(tf_record_file, seq_length, batch_size=1, shuffle_buffer_size=0, is_training=False):

    if is_training:
        features = {
            "unique_ids": tf.io.FixedLenFeature([], tf.int64),
            "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
            "input_mask": tf.io.FixedLenFeature([seq_length], tf.int64),
            "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
            "start_positions": tf.io.FixedLenFeature([], tf.int64),
            "end_positions": tf.io.FixedLenFeature([], tf.int64),
            "answer_types": tf.io.FixedLenFeature([], tf.int64)
        }
    else:
        features = {
            "unique_ids": tf.io.FixedLenFeature([], tf.int64),
            "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
            "input_mask": tf.io.FixedLenFeature([seq_length], tf.int64),
            "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
            "token_map": tf.io.FixedLenFeature([seq_length], tf.int64)
        }        

    # Taken from the TensorFlow models repository: https://github.com/tensorflow/models/blob/befbe0f9fe02d6bc1efb1c462689d069dae23af1/official/nlp/bert/input_pipeline.py#L24
    def decode_record(record, features):
        """Decodes a record to a TensorFlow example."""
        example = tf.io.parse_single_example(record, features)
    
        # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
        # So cast all int64 to int32.
        for name in list(example.keys()):
                        
            t = example[name]
            if t.dtype == tf.int64:
                t = tf.cast(t, tf.int32)
            example[name] = t
        return example

    def select_data_from_record(record):
        
        x = {
            'unique_ids': record['unique_ids'],
            'input_ids': record['input_ids'],
            'input_mask': record['input_mask'],
            'segment_ids': record['segment_ids']
        }
        
        if not is_training:
            x['token_map'] = record['token_map']

        if is_training:
            y = {
                'start_positions': record['start_positions'],
                'end_positions': record['end_positions'],
                'answer_types': record['answer_types']
            }
            return (x, y)
        
        return x

    dataset = tf.data.TFRecordDataset(tf_record_file)
    
    dataset = dataset.map(lambda record: decode_record(record, features))
    dataset = dataset.map(select_data_from_record)
    
    if shuffle_buffer_size > 0:
        dataset = dataset.shuffle(shuffle_buffer_size)
    
    dataset = dataset.batch(batch_size)
    
    return dataset

# 3. Modeling

## 3.1 Model Definition

In [9]:
from transformers import TFBertPreTrainedModel, TFBertMainLayer, TFSequenceSummary
from transformers.modeling_tf_utils import get_initializer

TFBertPreTrainedModel.pretrained_model_archive_map = {
    "bert-base-uncased": "",
    "bert-large-uncased": "",
    "bert-base-cased": "",
    "bert-large-cased": "",
    "bert-large-uncased-whole-word-masking-finetuned-squad": FLAGS.PRETRAINED_MODEL_PATH,
    "bert-large-cased-whole-word-masking-finetuned-squad": "",
}

class TFBertForQuestionAnswering(TFBertPreTrainedModel):
    
    def __init__(self, config, *inputs, **kwargs):
        super(TFBertForQuestionAnswering, self).__init__(config, *inputs, **kwargs)

        self.bert = TFBertMainLayer(config, name="bert")
        self.qa_outputs = tf.keras.layers.Dense(
            2, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
        )
        self.sequence_output_dropout = tf.keras.layers.Dropout(kwargs.get('sequence_output_dropout_prob', 0.05))
        self.pooled_output_dropout = tf.keras.layers.Dropout(kwargs.get('pooled_output_dropout_prob', 0.05))
        self.classifier = tf.keras.layers.Dense(
            FLAGS.NUM_LABELS, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )
        
    def call(self, inputs, **kwargs):
    
        sequence_output, pooled_output = self.bert(inputs, **kwargs)
        sequence_output = self.sequence_output_dropout(sequence_output, training=kwargs.get('training', False))
        pooled_output = self.pooled_output_dropout(pooled_output, training=kwargs.get('training', False))
    
        ### Natural Question ###
        logits = self.qa_outputs(sequence_output)
        start_pos_logits = logits[:, :, 0]
        end_pos_logits = logits[:, :, 1]
        
        ### Sequence Classification ###
        answer_type_logits = self.classifier(pooled_output)

        outputs = (start_pos_logits, end_pos_logits, answer_type_logits)

        return outputs
    
def mk_model(use_unique_id=False):
    
    if use_unique_id:
        unique_id  = tf.keras.Input(shape=(1,),dtype=tf.int64,name='unique_id')
    input_ids   = tf.keras.Input(shape=(FLAGS.SEQ_LENGTH,),dtype=tf.int32,name='input_ids')
    input_mask  = tf.keras.Input(shape=(FLAGS.SEQ_LENGTH,),dtype=tf.int32,name='attention_mask')
    segment_ids = tf.keras.Input(shape=(FLAGS.SEQ_LENGTH,),dtype=tf.int32,name='token_type_ids')

    BERT_JOINT = TFBertForQuestionAnswering.from_pretrained(FLAGS.MODEL_VERSION)

    start_logits,end_logits,ans_type = BERT_JOINT((input_ids, input_mask, segment_ids))

    if use_unique_id:
        return tf.keras.Model([input_ for input_ in [unique_id,input_ids,input_mask,segment_ids] if input_ is not None],
                              [unique_id,start_logits,end_logits,ans_type], name='bert-joint-large')
    else:
        return tf.keras.Model([input_ for input_ in [input_ids,input_mask,segment_ids] if input_ is not None],
                              [start_logits,end_logits,ans_type], name='bert-joint-large')


In [10]:
#####################
# TPU Settings
#####################

# Detect hardware, return appropriate distribution strategy
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()

print("REPLICAS: ", strategy.num_replicas_in_sync)


Running on TPU  ['192.168.21.2:8470']
REPLICAS:  8


In [11]:
with strategy.scope():

    model = mk_model()

    print('SUMMARY')
    display(model.summary())


SUMMARY
Model: "bert-joint-large"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_ids (InputLayer)          [(None, 512)]        0                                            
__________________________________________________________________________________________________
attention_mask (InputLayer)     [(None, 512)]        0                                            
__________________________________________________________________________________________________
token_type_ids (InputLayer)     [(None, 512)]        0                                            
__________________________________________________________________________________________________
tf_bert_for_question_answering  ((None, 512), (None, 335149063   input_ids[0][0]                  
                                                                 attention_

None

## 3.2 Training Utility Functions

In [12]:
################
# Loss Functions
################

with strategy.scope():

    def loss_function(start_positions,
                      end_positions,
                      answer_types,
                      start_logits,
                      end_logits,
                      answer_type_logits,
                      loss_factor=1.0 / strategy.num_replicas_in_sync):
        """Returns sparse categorical crossentropy for start/end logits."""
        start_loss = tf.keras.backend.sparse_categorical_crossentropy(
            start_positions, start_logits, from_logits=True)
        end_loss = tf.keras.backend.sparse_categorical_crossentropy(
            end_positions, end_logits, from_logits=True)
        answer_type_loss = tf.keras.backend.sparse_categorical_crossentropy(
            answer_types, answer_type_logits, from_logits=True)

        start_loss = tf.reduce_mean(start_loss)
        end_loss = tf.reduce_mean(end_loss)
        answer_type_loss = tf.reduce_mean(answer_type_loss)

        total_loss = (start_loss + end_loss + answer_type_loss) / 3
        total_loss *= loss_factor
        return total_loss, start_loss, end_loss, answer_type_loss

    def get_loss_fn(loss_factor=1.0 / strategy.num_replicas_in_sync):
        """Gets a loss function for nq task."""

        def _loss_fn(nq_labels, nq_logits):
            start_positions, end_positions, answer_types = nq_labels
            start_logits, end_logits, answer_type_logits = nq_logits
            return loss_function(
                start_positions,
                end_positions,
                answer_types,
                start_logits,
                end_logits,
                answer_type_logits,
                loss_factor=loss_factor)

        return _loss_fn


In [13]:
################
# Metrics
################

def get_metrics(name):

    loss = tf.keras.metrics.Mean(name='{}_loss'.format(name))

    loss_short_start_pos = tf.keras.metrics.Mean(name='{}_loss_short_start_pos'.format(name))
    loss_short_end_pos = tf.keras.metrics.Mean(name='{}_loss_short_end_pos'.format(name))
    loss_short_ans_type = tf.keras.metrics.Mean(name='{}_loss_short_ans_type'.format(name))
    
    acc = tf.keras.metrics.SparseCategoricalAccuracy(name='{}_acc'.format(name))
    
    acc_short_start_pos = tf.keras.metrics.SparseCategoricalAccuracy(name='{}_acc_short_start_pos'.format(name))
    acc_short_end_pos = tf.keras.metrics.SparseCategoricalAccuracy(name='{}_acc_short_end_pos'.format(name))
    acc_short_ans_type = tf.keras.metrics.SparseCategoricalAccuracy(name='{}_acc_short_ans_type'.format(name))
    
    return loss, loss_short_start_pos, loss_short_end_pos, loss_short_ans_type, acc, acc_short_start_pos, acc_short_end_pos, acc_short_ans_type


In [14]:
###################
# Training Utilities
###################
from optimizer import CustomSchedule, AdamW
import tensorflow.keras.backend as K

num_train_examples = len(list(tf.compat.v1.python_io.tf_record_iterator(train_records)))
num_train_steps = int(FLAGS.EPOCHS * num_train_examples / FLAGS.BATCH_SIZE)
num_train_steps_per_epoch = int(num_train_steps / FLAGS.EPOCHS)
print('num_train_examples : {}, num_train_steps per epoch: {}'.format(num_train_examples, num_train_steps_per_epoch))

with strategy.scope():
    
    criteria = get_loss_fn()
    
    train_loss, train_loss_start_pos, train_loss_end_pos, \
    train_loss_ans_type, train_acc, train_acc_start_pos, \
    train_acc_end_pos, train_acc_ans_type = get_metrics("train")
    
    learning_rate = CustomSchedule(
        initial_learning_rate=FLAGS.LEARNING_RATE,
        decay_steps=num_train_steps,
        end_learning_rate=FLAGS.LEARNING_RATE,
        power=1.0,
        cycle=FLAGS.CYCLIC_LEARNING_RATE,
        num_warmup_steps=FLAGS.NUM_WARMUP_STEPS
    )

    decay_var_list = []
    for i in range(len(model.trainable_variables)):
        name = model.trainable_variables[i].name
        if any(x in name for x in ["LayerNorm", "layer_norm", "bias"]):
            decay_var_list.append(name)
            
    optimizer = AdamW(
        weight_decay=FLAGS.WEIGHT_DECAY_RATE, learning_rate=learning_rate, 
        beta_1=0.9, beta_2=0.999, epsilon=1e-6, decay_var_list=decay_var_list)
#     optimizer = tf.keras.optimizers.SGD(
#         learning_rate=learning_rate,
#         momentum=0.9
#     )
    

num_train_examples : 494670, num_train_steps per epoch: 20611


In [15]:
input_signature = [
    tf.TensorSpec(shape=(None, None), dtype=tf.int32),
    tf.TensorSpec(shape=(None, None), dtype=tf.int32),
    tf.TensorSpec(shape=(None, None), dtype=tf.int32),
    tf.TensorSpec(shape=(None,), dtype=tf.int32),
    tf.TensorSpec(shape=(None,), dtype=tf.int32),
    tf.TensorSpec(shape=(None,), dtype=tf.int32)
]
@tf.function(input_signature=input_signature)
def train_step(input_ids, input_masks, segment_ids, short_start_pos_labels, short_end_pos_labels, short_answer_type_labels):

    nq_inputs = (input_ids, input_masks, segment_ids)
    nq_labels = (short_start_pos_labels, short_end_pos_labels, short_answer_type_labels)

    with tf.GradientTape() as tape:

        nq_logits = model(nq_inputs, training=True)
        loss, loss_short_start_pos, loss_short_end_pos, loss_short_ans_type = criteria(nq_labels, nq_logits)

    gradients = tape.gradient(loss, model.trainable_variables)

    (short_start_pos_logits, short_end_pos_logits, short_answer_type_logits) = nq_logits

    train_acc.update_state(short_start_pos_labels, short_start_pos_logits)
    train_acc.update_state(short_end_pos_labels, short_end_pos_logits)
    train_acc.update_state(short_answer_type_labels, short_answer_type_logits)

    train_acc_start_pos.update_state(short_start_pos_labels, short_start_pos_logits)
    train_acc_end_pos.update_state(short_end_pos_labels, short_end_pos_logits)
    train_acc_ans_type.update_state(short_answer_type_labels, short_answer_type_logits)

    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)

    train_loss_start_pos(loss_short_start_pos)
    train_loss_end_pos(loss_short_end_pos)
    train_loss_ans_type(loss_short_ans_type)


# `experimental_run_v2` replicates the provided computation and runs it with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):

    features, targets = dataset_inputs
    (input_ids, input_masks, segment_ids) = (features['input_ids'], features['input_mask'], features['segment_ids'])
    (start_pos_labels, end_pos_labels, answer_type_labels) = (targets['start_positions'], targets['end_positions'], targets['answer_types'])

    strategy.experimental_run_v2(train_step, args=(input_ids, input_masks, segment_ids, start_pos_labels, end_pos_labels, answer_type_labels))
    

In [16]:
###################
# Model Checkpoint
###################

checkpoint_path = os.path.join(FLAGS.INPUT_CHECKPOINT_DIRECTORY, FLAGS.VERSION)
ckpt = tf.train.Checkpoint(model=model)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=10000)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    last_epoch = int(ckpt_manager.latest_checkpoint.split("-")[-1])
    print ('Latest BertNQ checkpoint restored -- Model trained for {} epochs'.format(last_epoch))
else:
    print('Checkpoint not found. Train BertNQ from scratch')
    last_epoch = 0
    
# Reset saving path, because the FLAGS.input_checkpoint_dir is not writable on Kaggle
print(ckpt_manager._directory)
ckpt_manager._directory = os.path.join(FLAGS.OUTPUT_CHECKPOINT_DIRECTORY, FLAGS.VERSION)
ckpt_manager._checkpoint_prefix = os.path.join(ckpt_manager._directory, "ckpt")
print(ckpt_manager._directory)

from tensorflow.python.lib.io.file_io import recursive_create_dir
recursive_create_dir(ckpt_manager._directory)


Checkpoint not found. Train BertNQ from scratch
gs://tensorflow2-question-answering-cuedej/weights/checkpoints/input_checkpoint/BertLargeUncasedNQ-002
gs://tensorflow2-question-answering-cuedej/weights/checkpoints/output_checkpoint/BertLargeUncasedNQ-002


## 3.3 Fitting Starts Here

In [17]:
################
# Training
################

if FLAGS.TRAININGS:
    
    train_start_time = datetime.datetime.now()

    epochs = FLAGS.EPOCHS
    for epoch in range(epochs):

        print("Epoch = {}".format(epoch))

        train_dataset = get_dataset(
            train_records,
            FLAGS.SEQ_LENGTH,
            FLAGS.BATCH_SIZE,
            FLAGS.SHUFFLE_BUFFER_SIZE,
            is_training=True
        )

        train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)

        train_loss.reset_states()

        train_loss_start_pos.reset_states()
        train_loss_end_pos.reset_states()
        train_loss_ans_type.reset_states()           

        train_acc.reset_states()

        train_acc_start_pos.reset_states()
        train_acc_end_pos.reset_states()
        train_acc_ans_type.reset_states()

        epoch_start_time = datetime.datetime.now()

        print("start iterating over train_dist_dataset ...")

        for (batch_idx, dataset_inputs) in enumerate(train_dist_dataset):

            if batch_idx == num_train_steps_per_epoch:
                break

            batch_start_time = datetime.datetime.now()

            distributed_train_step(dataset_inputs)

            batch_end_time = datetime.datetime.now()
            batch_elapsed_time = (batch_end_time - epoch_start_time).total_seconds()

            if (batch_idx + 1) % 500 == 0:
                print('Epoch {} | Batch {} | Elapsed Time {}'.format(
                    epoch + 1,
                    batch_idx + 1,
                    batch_elapsed_time
                ))
                print('Loss {:.6f} | Loss_SS {:.6f} | Loss_SE {:.6f} | Loss_ST {:.6f}'.format(
                    train_loss.result(),
                    train_loss_start_pos.result(),
                    train_loss_end_pos.result(),
                    train_loss_ans_type.result()
                ))
                print(' Acc {:.6f} |  Acc_SS {:.6f} |  Acc_SE {:.6f} |  Acc_ST {:.6f}'.format(
                    train_acc.result(),
                    train_acc_start_pos.result(),
                    train_acc_end_pos.result(),
                    train_acc_ans_type.result()                
                ))
                print("-" * 100)

        epoch_end_time = datetime.datetime.now()
        epoch_elapsed_time = (epoch_end_time - epoch_start_time).total_seconds()

        if (epoch + 1) % 1 == 0:

            ckpt_save_path = ckpt_manager.save()
            print ('\nSaving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path))

            print('\nEpoch {}'.format(epoch + 1))
            print('Loss {:.6f} | Loss_SS {:.6f} | Loss_SE {:.6f} | Loss_ST {:.6f}'.format(
                    train_loss.result(),
                    train_loss_start_pos.result(),
                    train_loss_end_pos.result(),
                    train_loss_ans_type.result()
            ))
            print(' Acc {:.6f} |  Acc_SS {:.6f} |  Acc_SE {:.6f} |  Acc_ST {:.6f}'.format(
                    train_acc.result(),
                    train_acc_start_pos.result(),
                    train_acc_end_pos.result(),
                    train_acc_ans_type.result() 
            ))

        print('\nTime taken for 1 epoch: {} secs\n'.format(epoch_elapsed_time))
        print("-" * 80 + "\n")
    

In [18]:
# Saving checkpoint for epoch 1 at gs://tensorflow2-question-answering-cuedej/weights/checkpoints/output_checkpoint/BertLargeUncasedNQ-001/ckpt-1

# Epoch 1
# Loss 0.159261 | Loss_SS 1.557903 | Loss_SE 1.581118 | Loss_ST 0.683250
#  Acc 0.635170 |  Acc_SS 0.582146 |  Acc_SE 0.587823 |  Acc_ST 0.735541

# Time taken for 1 epoch: 8173.363619 secs
    
# Saving checkpoint for epoch 2 at gs://tensorflow2-question-answering-cuedej/weights/checkpoints/output_checkpoint/BertLargeUncasedNQ-001/ckpt-2

# Epoch 2
# Loss 0.134737 | Loss_SS 1.332973 | Loss_SE 1.346110 | Loss_ST 0.554601
#  Acc 0.682298 |  Acc_SS 0.630010 |  Acc_SE 0.636412 |  Acc_ST 0.780474

# Time taken for 1 epoch: 8046.585563 secs

##################################################

# Epoch 1
# Loss 0.138367 | Loss_SS 1.382416 | Loss_SE 1.394313 | Loss_ST 0.544081
#  Acc 0.675149 |  Acc_SS 0.617815 |  Acc_SE 0.623987 |  Acc_ST 0.783645

# 4. Validation Score

In [19]:
###################
# Validation Dataset
###################

if FLAGS.smaller_valid_dataset:
    valid_records = os.path.join(FLAGS.LOCAL_PATH, "nq-valid_{}-small.tfrecords".format(FLAGS.MODEL_VERSION))
else:
    valid_records = os.path.join(FLAGS.LOCAL_PATH, "nq-valid_{}.tfrecords".format(FLAGS.MODEL_VERSION))
    
validation_dataset = get_dataset(
    valid_records,
    seq_length=FLAGS.SEQ_LENGTH,
    batch_size=15,
    is_training=False
)
num_train_examples = len(list(tf.compat.v1.python_io.tf_record_iterator(valid_records)))
num_train_steps = int(num_train_examples / 15)
print('num_train_examples : {}, num_train_steps: {}'.format(num_train_examples, num_train_steps))

validation_dist_dataset = strategy.experimental_distribute_dataset(validation_dataset)


###################
# Model Checkpoint
###################

checkpoint_path = os.path.join(FLAGS.OUTPUT_CHECKPOINT_DIRECTORY, FLAGS.VERSION, )
ckpt = tf.train.Checkpoint(model=model)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=10000)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    last_epoch = int(ckpt_manager.latest_checkpoint.split("-")[-1])
    print ('Latest BertNQ checkpoint restored -- Model trained for {} epochs'.format(last_epoch))
else:
    print('Checkpoint not found. Train BertNQ from scratch')
    last_epoch = 0
    

num_train_examples : 27465, num_train_steps: 1831
Latest BertNQ checkpoint restored -- Model trained for 1 epochs


In [20]:
###################
# Validation Functions
###################
from predict import compute_f1_scores, get_best_indexes, read_candidates, compute_pred_dict

input_signature = [
    tf.TensorSpec(shape=(None, None), dtype=tf.int32),
    tf.TensorSpec(shape=(None, None), dtype=tf.int32),
    tf.TensorSpec(shape=(None, None), dtype=tf.int32)
]
@tf.function(input_signature=input_signature)
def valid_step(input_ids, input_masks, segment_ids):

    nq_inputs = (input_ids, input_masks, segment_ids)
    nq_logits = model(nq_inputs, training=False)

    return nq_logits

@tf.function
def distributed_valid_step(features):

    (input_ids, input_masks, segment_ids) = (features['input_ids'], features['input_mask'], features['segment_ids'])
    nq_logits = strategy.experimental_run_v2(valid_step, args=(input_ids, input_masks, segment_ids))
    
    return nq_logits
    
def get_prediction_json(mode, max_nb_pos_logits=-1):
    
    if mode == 'valid':
        dataset = validation_dataset
        if FLAGS.smaller_valid_dataset:
            predict_file = FLAGS.validation_predict_file_small
            prediction_output_file = FLAGS.validation_small_prediction_output_file
        else:
            predict_file = FLAGS.validation_predict_file
            prediction_output_file = FLAGS.validation_prediction_output_file
        eval_features = validation_features
    else:
        dataset = test_dataset
        predict_file = FLAGS.predict_file
        eval_features = test_features
        prediction_output_file = FLAGS.prediction_output_file
    
    print(predict_file)
    print(prediction_output_file)
    
    all_results = []

    prediction_start_time = datetime.datetime.now()

    for (batch_idx, features) in enumerate(validation_dist_dataset):

        batch_start_time = datetime.datetime.now()

        unique_ids = features['unique_ids']
        token_maps = features['token_map']       
        
        (input_ids, input_masks, segment_ids) = (features['input_ids'], features['input_mask'], features['segment_ids'])
        
        nq_logits = distributed_valid_step(features)

        (start_pos_logits, end_pos_logits, answer_type_logits) = nq_logits

        unique_ids = unique_ids.values[0].numpy().tolist()
        
        token_maps = token_maps.values[0].numpy().tolist()
        
        start_pos_prob_dist = tf.nn.softmax(start_pos_logits.values[0], axis=-1).numpy().tolist()
        end_pos_prob_dist = tf.nn.softmax(end_pos_logits.values[0], axis=-1).numpy().tolist()
        answer_type_prob_dist = tf.nn.softmax(answer_type_logits.values[0], axis=-1).numpy().tolist()
        
        start_pos_logits = start_pos_logits.values[0].numpy().tolist()
        end_pos_logits = end_pos_logits.values[0].numpy().tolist()
        answer_type_logits = answer_type_logits.values[0].numpy().tolist()

        for uid, token_map, s, e, a, sp, ep, ap in zip(unique_ids, token_maps, start_pos_logits, end_pos_logits, answer_type_logits, start_pos_prob_dist, end_pos_prob_dist, answer_type_prob_dist):

            if max_nb_pos_logits < 0:
                max_nb_pos_logits = len(start_pos_logits)
            
            cls_start_logit = s[0]
            cls_end_logit = e[0]
            
            start_indexes = get_best_indexes(s, max_nb_pos_logits, token_map)
            end_indexes = get_best_indexes(e, max_nb_pos_logits, token_map)            
            
            s = [s[idx] for idx in start_indexes]
            e = [e[idx] for idx in end_indexes]
            sp = [sp[idx] for idx in start_indexes]
            ep = [ep[idx] for idx in end_indexes]            
            
            raw_result = {
                "unique_id": uid,
                "start_indexes": start_indexes,
                "end_indexes": end_indexes,
                "start_logits": s,
                "end_logits": e,
                "answer_type_logits": a,
                "start_pos_prob_dist": sp,
                "end_pos_prob_dist": ep,
                "answer_type_prob_dist": ap,
                "cls_start_logit": cls_start_logit,
                "cls_end_logit": cls_end_logit
            }
            all_results.append(raw_result)

        batch_end_time = datetime.datetime.now()
        batch_elapsed_time = (batch_end_time - prediction_start_time).total_seconds()

        if (batch_idx + 1) % 250 == 0:
            print('Batch {} | Elapsed Time {}'.format(
                batch_idx + 1,
                batch_elapsed_time
            ))
      
    prediction_end_time = datetime.datetime.now()
    prediction_elapsed_time = (prediction_end_time - prediction_start_time).total_seconds()

    print('\nTime taken for prediction: {} secs\n'.format(prediction_elapsed_time))
    print("-" * 80 + "\n")

    print("Going to candidates file")
    candidates_dict = read_candidates(predict_file)

#     print ("setting up eval features")
#     eval_features = ...

    print ("compute_pred_dict")
    nq_pred_dict = compute_pred_dict(candidates_dict, eval_features, all_results)
    
    predictions_json = {"predictions": list(nq_pred_dict.values())}

    print ("writing json")
    with tf.io.gfile.GFile(prediction_output_file, "w") as f:
        json.dump(predictions_json, f, indent=4)
        
    return predictions_json


In [21]:
if FLAGS.do_valid:
    
    validation_features = (tf.train.Example.FromString(r.numpy()) for r in tf.data.TFRecordDataset(valid_records))
    valid_predictions_json = get_prediction_json(mode='valid', max_nb_pos_logits=FLAGS.n_best_size)
    
    if FLAGS.smaller_valid_dataset:
        predict_file = FLAGS.validation_predict_file_small
    else:
        predict_file = FLAGS.validation_predict_file

    f1, long_f1, short_f1 = compute_f1_scores(valid_predictions_json, predict_file)

    print('-'*80)
    print("valid f1: {}\nvalid long_f1: {}\nvalid short_f1: {}".format(f1, long_f1, short_f1))


/home/jupyter/input/simplified-nq-valid-small.jsonl
gs://tensorflow2-question-answering-cuedej/input/validatioin_predictions-small.json
Batch 250 | Elapsed Time 56.521221
Batch 500 | Elapsed Time 95.564363
Batch 750 | Elapsed Time 134.236095
Batch 1000 | Elapsed Time 173.314548
Batch 1250 | Elapsed Time 212.971855
Batch 1500 | Elapsed Time 251.911648
Batch 1750 | Elapsed Time 291.047498

Time taken for prediction: 303.892987 secs

--------------------------------------------------------------------------------

Going to candidates file
Reading examples from: /home/jupyter/input/simplified-nq-valid-small.jsonl
compute_pred_dict
merging examples...
done.
Examples processed: 100
Examples processed: 200
Examples processed: 300
Examples processed: 400
Examples processed: 500
Examples processed: 600
Examples processed: 700
Examples processed: 800
Examples processed: 900
Examples processed: 1000
writing json
--------------------------------------------------------------------------------
vali

# 5. PostProcessing & Optimize Threshold

In [22]:
###################
# PostProcessing
###################
from predict import create_long_answer, df_long_index_score, df_short_index_score, create_short_answer

if FLAGS.smaller_valid_dataset:
    valid_answers_df = pd.read_json(FLAGS.validation_small_prediction_output_file)
else:
    valid_answers_df = pd.read_json(FLAGS.validation_prediction_output_file)

valid_answers_df["long_answer_score"] = valid_answers_df["predictions"].apply(lambda q: q[0]["long_answer_score"])
valid_answers_df["short_answer_score"] = valid_answers_df["predictions"].apply(lambda q: q[0]["short_answers_score"])

# We re-format the JSON answers to match the requirements for submission.
valid_answers_df["long_answer"] = valid_answers_df["predictions"].apply(create_long_answer)
valid_answers_df["short_answer"] = valid_answers_df["predictions"].apply(create_short_answer)
valid_answers_df["example_id"] = valid_answers_df["predictions"].apply(lambda q: str(q[0]["example_id"]))

valid_answers_df,a = df_long_index_score(valid_answers_df)
valid_answers_df,b,c = df_short_index_score(valid_answers_df)
print('Remove Long', a)
print('Remove Short', b)
print('Not Remove Short', c)

long_answers = dict(zip(valid_answers_df["example_id"], valid_answers_df["long_answer"]))
short_answers = dict(zip(valid_answers_df["example_id"], valid_answers_df["short_answer"]))

display(valid_answers_df["long_answer_score"].describe())
display(valid_answers_df.head())


Remove Long 698
Remove Short 752
Not Remove Short 0


count    1000.000000
mean       -1.556889
std         5.755501
min       -24.981139
25%        -5.178363
50%         0.000000
75%         1.959143
max        12.866796
Name: long_answer_score, dtype: float64

Unnamed: 0,predictions,long_answer_score,short_answer_score,long_answer,short_answer,example_id
0,"[{'start_logits': [-0.538668394088745, -1.1619...",-7.274143,-7.274143,,,7377124121683290109
1,"[{'long_answer_score': 0, 'yes_no_answer': 'NO...",0.0,0.0,,,9049835265123264510
2,"[{'start_logits': [-1.849713563919067, -2.1128...",-5.044115,-5.044115,,,1586691717504133123
3,"[{'start_logits': [-0.680125892162323, -1.7150...",-7.091638,-7.091638,,,-3787783340735291386
4,"[{'start_logits': [0.9072107076644891, -0.6192...",-6.674465,-6.674465,,,-2535097755382176766


In [23]:
###################
# Making Submission
###################
ss1 = valid_answers_df[['example_id']]
ss1['example_id'] = ss1['example_id'].apply(lambda x: str(x)+'_long')
ss2 = valid_answers_df[['example_id']]
ss2['example_id'] = ss2['example_id'].apply(lambda x: str(x)+'_short')
sample_submission = pd.concat([ss1, ss2]).sort_values('example_id').reset_index(drop=True)
sample_submission['PredictionString'] = np.nan

long_prediction_strings = sample_submission[sample_submission["example_id"].str.contains("_long")].apply(lambda q: long_answers[q["example_id"].replace("_long", "")], axis=1)
short_prediction_strings = sample_submission[sample_submission["example_id"].str.contains("_short")].apply(lambda q: short_answers[q["example_id"].replace("_short", "")], axis=1)

sample_submission.loc[sample_submission["example_id"].str.contains("_long"), "PredictionString"] = long_prediction_strings
sample_submission.loc[sample_submission["example_id"].str.contains("_short"), "PredictionString"] = short_prediction_strings

sample_submission.head()

Unnamed: 0,example_id,PredictionString
0,-1050541191860974370_long,
1,-1050541191860974370_short,
2,-1073911042293548995_long,
3,-1073911042293548995_short,
4,-1095299445678232687_long,


In [24]:
###################
# Evaluation
###################
from evaluation import Score, long_annotations, short_annotations, yes_nos

long_score = Score()
short_score = Score()
total_score = Score()
for example in map(json.loads, open(FLAGS.validation_predict_file_small, 'r')):
    long_pred = sample_submission.loc[str(example['example_id']) + '_long', 'PredictionString']
    long_score.increment(long_pred, long_annotations(example), [])
    total_score.increment(long_pred, long_annotations(example), [])
    short_pred = sample_submission.loc[str(example['example_id']) + '_short', 'PredictionString']
    short_score.increment(short_pred, short_annotations(example), yes_nos(example))
    total_score.increment(short_pred, short_annotations(example), [])
    
    

KeyError: '6915606477668963399_long'