# Overview
 - Paper [Link](https://arxiv.org/pdf/1901.08634.pdf)<br>
 - GitHub [Link](https://github.com/google-research/language/tree/master/language/question_answering/bert_joint)<br>
 - Discussion [Link1](https://www.kaggle.com/c/tensorflow2-question-answering/discussion/119957), [Link2](https://www.kaggle.com/c/tensorflow2-question-answering/discussion/117370)

### Versions
 - ALBERT001<br>
 Baseline Built

In [1]:
import flags

FLAGS = flags.FLAGS
FLAGS.VERSION = 'ALBERT001'
FLAGS.DATA_SPLIT = False
FLAGS.TUNING_MODE = True
FLAGS.PREPROCESS = True


In [2]:
import numpy as np
import pandas as pd
import json, os, gc, sys, collections, itertools, datetime, logging, warnings, tqdm

import tensorflow as tf

sys.path.extend([FLAGS.SACREMOSES_PATH, FLAGS.TRANSFORMERS_PATH])
import sacremoses, transformers, tokenization

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
logging.getLogger("tensorflow").setLevel(logging.CRITICAL)
logging.getLogger("tensorflow_hub").setLevel(logging.CRITICAL)
warnings.filterwarnings('ignore')

# 1. Load Datasets

In [3]:
def read_data(path, sample = True, chunksize = 80_000):
    if not sample:
        chunksize = 307_373
    df = []
    with open(path, 'rt') as reader:
        for i in range(chunksize):
            df.append(json.loads(reader.readline()))
    df = pd.DataFrame(df)
    return df

In [4]:
%%time
################
# train data
################
if FLAGS.DATA_SPLIT:
    
    train = read_data(FLAGS.LOCAL_PATH+'/simplified-nq-train.jsonl', sample = FLAGS.TUNING_MODE)
    print("train shape", train.shape)
    display(train.head(5))
    

CPU times: user 35 µs, sys: 3 µs, total: 38 µs
Wall time: 42 µs


In [5]:
%%time
################
# sample_submission data
################
sample_submission = pd.read_csv(FLAGS.LOCAL_PATH+'/sample_submission.csv')
print("Sample submission shape", sample_submission.shape)
display(sample_submission.head(2))


Sample submission shape (692, 2)


Unnamed: 0,example_id,PredictionString
0,-1011141123527297803_long,
1,-1011141123527297803_short,


CPU times: user 22 ms, sys: 179 µs, total: 22.2 ms
Wall time: 20 ms


# 2. Preprocessing

## 2.1 Split Dataset to Train & Validation

In [6]:
%%time
########################
# Train-Validation Split
########################
from sklearn.model_selection import StratifiedShuffleSplit

if FLAGS.TUNING_MODE:
    train_file = FLAGS.LOCAL_PATH+f'/simplified-nq-train_Seed{FLAGS.SEED}_Split{FLAGS.N_SPLITS}_Fold{FLAGS.FOLD}-small.jsonl'
    valid_file = FLAGS.LOCAL_PATH+f'/simplified-nq-valid_Seed{FLAGS.SEED}_Split{FLAGS.N_SPLITS}_Fold{FLAGS.FOLD}-small.jsonl'
else:
    train_file = FLAGS.LOCAL_PATH+f'/simplified-nq-train_Seed{FLAGS.SEED}_Split{FLAGS.N_SPLITS}_Fold{FLAGS.FOLD}.jsonl'
    valid_file = FLAGS.LOCAL_PATH+f'/simplified-nq-valid_Seed{FLAGS.SEED}_Split{FLAGS.N_SPLITS}_Fold{FLAGS.FOLD}.jsonl'

if FLAGS.DATA_SPLIT:
    
    #### Stratify Key ####
    has_long_answer = train.annotations.apply(
        lambda x: (x[0]['long_answer']['start_token'], x[0]['long_answer']['end_token'])).apply(
        lambda x: x if x != (-1, -1) else None).apply(
        lambda x: 1 if x else 0)
    has_short_answers = train.annotations.apply(
        lambda x: [(y['start_token'], y['end_token']) for y in x[0]['short_answers']]).apply(
        lambda x: len(x)).apply(
        lambda x: x > 0)
    has_yes_no = train.annotations.apply(
        lambda x: x[0]['yes_no_answer']).apply(
        lambda x: None if x == 'NONE' else x).apply(
        lambda x: x is not None)
    
    stratify_key = has_long_answer.astype(int).astype(str) + \
                    has_short_answers.astype(int).astype(str) + \
                    has_yes_no.astype(int).astype(str)
    
    #### StratifiedShuffleSplit ####
    sss = StratifiedShuffleSplit(n_splits=FLAGS.N_SPLITS, random_state=FLAGS.SEED)
    trn_idx, val_idx = [
        (trn_idx, val_idx) for trn_idx, val_idx in sss.split(train, stratify_key)
    ][FOLD]
    
    train_df = train.loc[trn_idx].reset_index(drop=True)
    valid_df = train.loc[val_idx].reset_index(drop=True)
    del_gc(train)
    
    # To avoid OverflowError : annotations - annotation_id to string values
    train_df['annotations'] = train_df.annotations.apply(
        lambda x: {
            'yes_no_answer':x[0]['yes_no_answer'], 
            'long_answer':x[0]['long_answer'], 
            'short_answers':x[0]['short_answers'], 
            'annotation_id':str(x[0]['annotation_id'])}
    )
    valid_df['annotations'] = valid_df.annotations.apply(
        lambda x: {
            'yes_no_answer':x[0]['yes_no_answer'], 
            'long_answer':x[0]['long_answer'], 
            'short_answers':x[0]['short_answers'], 
            'annotation_id':str(x[0]['annotation_id'])}
    )
    
    #### Write ####
    train_df.to_json(
        path_or_buf=train_file, 
        orient='records',
        lines=True
    )
    valid_df.to_json(
        path_or_buf=valid_file, 
        orient='records',
        lines=True
    )


CPU times: user 5.26 ms, sys: 118 µs, total: 5.38 ms
Wall time: 4.4 ms


## 2.2 ALBERT-Joint Preprocessing

In [8]:
from transformers import AlbertTokenizer
from preprocessing import ConvertExamples2Features, FeatureWriter, nq_examples_iter

AlbertTokenizer.pretrained_vocab_files_map = {
    "vocab_file": {
        "albert-base-v1": "",
        "albert-large-v1": "",
        "albert-xlarge-v1": "",
        "albert-xxlarge-v1": "",
        "albert-base-v2": "",
        "albert-large-v2": "",
        "albert-xlarge-v2": FLAGS.TOKENIZER_MODEL_PATH_XLARGE,
        "albert-xxlarge-v2": "",
    }
}
tokenizer = AlbertTokenizer.from_pretrained(FLAGS.MODEL_VERSION)

with open(FLAGS.ADDITIONAL_VOCAB_FILE) as f:
    additional_vocabs = f.read().splitlines()
num_added_toks = tokenizer.add_special_tokens({
    'additional_special_tokens': additional_vocabs
})
print('We have added', num_added_toks, 'tokens')

if FLAGS.TUNING_MODE:
    train_records = FLAGS.LOCAL_PATH+f'/nq-train_<{FLAGS.MODEL_VERSION}>_Seed{FLAGS.SEED}_Split{FLAGS.N_SPLITS}_Fold{FLAGS.FOLD}-small.tfrecords'
    valid_records = FLAGS.LOCAL_PATH+f'/nq-valid_<{FLAGS.MODEL_VERSION}>_Seed{FLAGS.SEED}_Split{FLAGS.N_SPLITS}_Fold{FLAGS.FOLD}-small.tfrecords'
else:
    train_records = FLAGS.LOCAL_PATH+f'/nq-train_<{FLAGS.MODEL_VERSION}>_Seed{FLAGS.SEED}_Split{FLAGS.N_SPLITS}_Fold{FLAGS.FOLD}.tfrecords'
    valid_records = FLAGS.LOCAL_PATH+f'/nq-valid_<{FLAGS.MODEL_VERSION}>_Seed{FLAGS.SEED}_Split{FLAGS.N_SPLITS}_Fold{FLAGS.FOLD}.tfrecords'

input_files = [train_file, valid_file]

if FLAGS.PREPROCESS:
    
    for i, records in enumerate([train_records, valid_records]):
    
        writer = FeatureWriter(
            filename = os.path.join(records), 
            is_training = True)

        converter = ConvertExamples2Features(
            tokenizer=tokenizer,
            is_training=True,
            output_fn=writer.process_feature,
            collect_stat=False)
        
        n_examples = 0

        tqdm_notebook = tqdm.tqdm_notebook
        for examples in nq_examples_iter(input_file=input_files[i], is_training=True, tqdm=tqdm_notebook):
            for example in examples:
                n_examples += converter(example)
            
            examples = converter(line)
            for example in examples:
                writer.process_feature(example)
                break

        writer.close()
        print('number of test examples: %d, written to file: %d' % (n_examples,eval_writer.num_features))


We have added 207 tokens
Reading: /home/ec2-user/SageMaker/input/simplified-nq-train_Seed9253_Split16_Fold0-small.jsonl


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

KeyboardInterrupt: 

## 2.3 Dataset Generator

In [None]:
def get_dataset(tf_record_file, seq_length, batch_size=1, shuffle_buffer_size=0, is_training=False):

    if is_training:
        features = {
            "unique_ids": tf.io.FixedLenFeature([], tf.int64),
            "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
            "input_mask": tf.io.FixedLenFeature([seq_length], tf.float32),
            "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
            "start_positions": tf.io.FixedLenFeature([], tf.int64),
            "end_positions": tf.io.FixedLenFeature([], tf.int64),
            "answer_types": tf.io.FixedLenFeature([], tf.int64)
        }
    else:
        features = {
            "unique_ids": tf.io.FixedLenFeature([], tf.int64),
            "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
            "input_mask": tf.io.FixedLenFeature([seq_length], tf.float32),
            "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64)
        }        

    # Taken from the TensorFlow models repository: https://github.com/tensorflow/models/blob/befbe0f9fe02d6bc1efb1c462689d069dae23af1/official/nlp/bert/input_pipeline.py#L24
    def decode_record(record, features):
        """Decodes a record to a TensorFlow example."""
        example = tf.io.parse_single_example(record, features)

        # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
        # So cast all int64 to int32.
        for name in list(example.keys()):
            t = example[name]
            if t.dtype == tf.int64:
                t = tf.cast(t, tf.int32)
            example[name] = t
        return example

    def select_data_from_record(record):
        
        x = {
            'unique_ids': record['unique_ids'],
            'input_ids': record['input_ids'],
            'input_mask': record['input_mask'],
            'segment_ids': record['segment_ids']
        }

        if is_training:
            y = {
                'unique_ids': record['unique_ids'],
                'start_positions': record['start_positions'],
                'end_positions': record['end_positions'],
                'answer_types': record['answer_types']
            }

            return (x, y)
        
        return x

    dataset = tf.data.TFRecordDataset(tf_record_file)
    
    dataset = dataset.map(lambda record: decode_record(record, features))
    dataset = dataset.map(select_data_from_record)
    
    if shuffle_buffer_size > 0:
        dataset = dataset.shuffle(shuffle_buffer_size)
    
    dataset = dataset.batch(batch_size)
    
    return dataset

# 3. Modeling

## 3.1 Model Definition

In [None]:
from transformers import TFAlbertModel, TFAlbertPreTrainedModel, TFSequenceSummary
from transformers.modeling_tf_utils import get_initializer

TFAlbertPreTrainedModel.pretrained_model_archive_map = {
    "albert-base-v1": "",
    "albert-large-v1": "",
    "albert-xlarge-v1": "",
    "albert-xxlarge-v1": "",
    "albert-base-v2": "",
    "albert-large-v2": "",
    "albert-xlarge-v2": FLAGS.PRETRAINED_MODEL_PATH_XLARGE,
    "albert-xxlarge-v2": "",
}

class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel):
    
    def __init__(self, config, *inputs, **kwargs):
        super(TFAlbertForQuestionAnswering, self).__init__(config, *inputs, **kwargs)

        self.albert = TFAlbertModel(config, name="albert")
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.classifier = tf.keras.layers.Dense(
            FLAGS.NUM_LABELS, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )
        self.qa_outputs = tf.keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
        )
        self.softmax = tf.nn.softmax
        
    def call(self, inputs, **kwargs):
        
        transformer_outputs = self.albert(inputs, **kwargs)
        sequence_output = transformer_outputs[0]
        pooled_output = transformer_outputs[1]
        
        ### Sequence Classification ###
        answer_type_logits = self.dropout(pooled_output, training=kwargs.get("training", False))
        answer_type_logits = self.classifier(answer_type_logits)
        
        ### Natural Question ###
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = tf.split(logits, 2, axis=-1)
        start_logits = tf.squeeze(start_logits, axis=-1)
        end_logits = tf.squeeze(end_logits, axis=-1)
        
        ### Softmax ###
        start_logits = self.softmax(start_logits)
        end_logits = self.softmax(end_logits)
        answer_type_logits = self.softmax(answer_type_logits)
        
        outputs = (start_logits, end_logits, answer_type_logits,) + transformer_outputs[2:]

        return outputs


In [None]:
model = TFAlbertForQuestionAnswering.from_pretrained(FLAGS.MODEL_VERSION)

##### Changing the embedding layer for additional tokens #####
embedding_layer = tf.concat([
    model.albert.embeddings.word_embeddings.value(), 
    tf.random.normal([num_added_toks,128], mean=0, stddev=1)
], axis=0)
model.albert.embeddings.word_embeddings = tf.Variable(
    embedding_layer, 
    name='tf_albert_for_question_answering_1/albert/embeddings/word_embeddings/weight', 
)

print('SUMMARY')
display(model.summary())


In [None]:
# tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')

input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute and whatever you say will be great at the end.", add_special_tokens=True))[None, :]  # Batch size 1
mask_ids = tf.constant([0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], dtype=tf.float32)[None, :]
seg_ids = tf.constant([0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0], dtype=tf.int32)[None, :]

model(input_ids, attention_mask=mask_ids, token_type_ids=seg_ids)


## 3.2 Fitting Starts Here

In [None]:
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

def loss_function(nq_labels, nq_logits):
    
    (start_pos_labels, end_pos_labels, answer_type_labels) = nq_labels
    (start_pos_logits, end_pos_logits, answer_type_logits) = nq_logits
    start_pos_labels = tf.dtypes.cast(start_pos_labels, tf.float32)
    end_pos_labels = tf.dtypes.cast(end_pos_labels, tf.float32)
    answer_type_labels = tf.dtypes.cast(answer_type_labels, tf.float32)
    
    loss_start_pos = loss_object(start_pos_labels, start_pos_logits)
    loss_end_pos = loss_object(end_pos_labels, end_pos_logits)
    loss_ans_type = loss_object(answer_type_labels, answer_type_logits)
    
    loss_start_pos = tf.math.reduce_sum(loss_start_pos)
    loss_end_pos = tf.math.reduce_sum(loss_end_pos)
    loss_ans_type = tf.math.reduce_sum(loss_ans_type)
    
    loss = (loss_start_pos + loss_end_pos + loss_ans_type) / 3.0
    
    return loss, loss_start_pos, loss_end_pos, loss_ans_type


def get_loss_and_gradients(unique_ids, input_ids, input_masks, segment_ids, start_pos_labels, 
                           end_pos_labels, answer_type_labels):
    
#     nq_inputs = (unique_ids, input_ids, input_masks, segment_ids)
    nq_labels = (start_pos_labels, end_pos_labels, answer_type_labels)

    with tf.GradientTape() as tape:

        nq_logits = model(input_ids, input_mask=input_masks, token_type_ids=segment_ids, training=True)
        loss, loss_start_pos, loss_end_pos, loss_ans_type = loss_function(nq_labels, nq_logits)
                
    gradients = tape.gradient(loss, model.trainable_variables)        
    
    (start_pos_logits, end_pos_logits, answer_type_logits) = nq_logits
        
    train_acc.update_state(start_pos_labels, start_pos_logits)
    train_acc.update_state(end_pos_labels, end_pos_logits)
    train_acc.update_state(answer_type_labels, answer_type_logits)

    train_acc_start_pos.update_state(start_pos_labels, start_pos_logits)
    train_acc_end_pos.update_state(end_pos_labels, end_pos_logits)
    train_acc_ans_type.update_state(answer_type_labels, answer_type_logits)
    
    return loss, gradients, loss_start_pos, loss_end_pos, loss_ans_type


def get_metrics(name):

    loss = tf.keras.metrics.Mean(name=f'{name}_loss')
    loss_start_pos = tf.keras.metrics.Mean(name=f'{name}_loss_start_pos')
    loss_end_pos = tf.keras.metrics.Mean(name=f'{name}_loss_end_pos')
    loss_ans_type = tf.keras.metrics.Mean(name=f'{name}_loss_ans_type')
    
    comp_metric = tf.keras.metrics.CategoricalAccuracy(name=f'{name}_acc')##### ToDo
    acc_start_pos = tf.keras.metrics.CategoricalAccuracy(name=f'{name}_acc_start_pos')
    acc_end_pos = tf.keras.metrics.CategoricalAccuracy(name=f'{name}_acc_end_pos')
    acc_ans_type = tf.keras.metrics.CategoricalAccuracy(name=f'{name}_acc_ans_type')
    
    return loss, loss_start_pos, loss_end_pos, loss_ans_type, comp_metric, acc_start_pos, acc_end_pos, acc_ans_type


def train_step_with_batch_accumulation(unique_ids, input_ids, input_masks, segment_ids, start_pos_labels, end_pos_labels, answer_type_labels):

    # This gets None! (probably due to input_signature)
    # batch_size = input_ids.shape[0]
    
    # Try this.
    nb_examples = tf.math.reduce_sum(tf.cast(tf.math.not_equal(start_pos_labels, -2), tf.int32))

    total_loss = 0.0
    total_loss_start_pos = 0.0
    total_loss_end_pos = 0.0
    total_loss_ans_type = 0.0
    
    total_gradients = [tf.constant(0, shape=x.shape, dtype=tf.float32) for x in model.trainable_variables]        
    ### total_gradients_sparse = [tf.IndexedSlices(values=tf.constant(0.0, shape=[1] + x.shape.as_list()[1:]), indices=tf.constant([0], dtype=tf.int32), dense_shape=x.shape.as_list()) for x in model.trainable_variables]        

    for idx in tf.range(FLAGS.BATCH_ACCUMULATION_SIZE):    
                
        start_idx = FLAGS.BATCH_SIZE * idx
        end_idx = FLAGS.BATCH_SIZE * (idx + 1)
        
        if start_idx >= nb_examples:
            break

        (unique_ids_mini, input_ids_mini, input_masks_mini, segment_ids_mini) = (unique_ids[start_idx:end_idx], input_ids[start_idx:end_idx], input_masks[start_idx:end_idx], segment_ids[start_idx:end_idx])
        (start_pos_labels_mini, end_pos_labels_mini, answer_type_labels_mini) = (start_pos_labels[start_idx:end_idx], end_pos_labels[start_idx:end_idx], answer_type_labels[start_idx:end_idx])
        
        loss, gradients, loss_start_pos, loss_end_pos, loss_ans_type = get_loss_and_gradients(unique_ids_mini, input_ids_mini, input_masks_mini, segment_ids_mini, start_pos_labels_mini, end_pos_labels_mini, answer_type_labels_mini)
        
        total_loss += loss
        total_loss_start_pos += loss_start_pos
        total_loss_end_pos += loss_end_pos
        total_loss_ans_type += loss_ans_type
        
        print(len(gradients))
        total_gradients = [x + y for x, y in zip(total_gradients, gradients)]        
        ### total_gradients_sparse = [_add_grads_for_var(x, y) for x, y in zip(total_gradients_sparse, gradients)]

    average_loss = tf.math.divide(total_loss, tf.cast(nb_examples, tf.float32))        
    average_gradients = [tf.divide(x, tf.cast(nb_examples, tf.float32)) for x in total_gradients]
    ### average_gradients_sparse = [tf.scalar_mul(tf.divide(1.0, tf.cast(nb_examples, tf.float32)), x) for x in total_gradients_sparse]
    
    optimizer.apply_gradients(zip(average_gradients, model.trainable_variables))
    ### optimizer.apply_gradients(zip(average_gradients_sparse, model.trainable_variables))

    average_loss_start_pos = tf.math.divide(total_loss_start_pos, tf.cast(nb_examples, tf.float32))
    average_loss_end_pos = tf.math.divide(total_loss_end_pos, tf.cast(nb_examples, tf.float32))
    average_loss_ans_type = tf.math.divide(total_loss_ans_type, tf.cast(nb_examples, tf.float32))    
    
    train_loss(average_loss)
    train_loss_start_pos(average_loss_start_pos)
    train_loss_end_pos(average_loss_end_pos)
    train_loss_ans_type(average_loss_ans_type)
    

In [None]:
from optimizer import AdamW, CustomSchedule

train_dataset = get_dataset(
    train_records,
    seq_length=FLAGS.max_seq_length,
    batch_size=FLAGS.BATCH_SIZE*FLAGS.BATCH_ACCUMULATION_SIZE,
    shuffle_buffer_size=500_000,
    is_training=True
)

train_loss, train_loss_start_pos, train_loss_end_pos, train_loss_ans_type, train_acc, train_acc_start_pos, train_acc_end_pos, train_acc_ans_type = get_metrics("train")
valid_loss, valid_loss_start_pos, valid_loss_end_pos, valid_loss_ans_type, valid_acc, valid_acc_start_pos, valid_acc_end_pos, valid_acc_ans_type = get_metrics("valid")

num_training_examples = 72_000 if FLAGS.TUNING_MODE else 0
num_train_steps = int(FLAGS.EPOCHS * num_training_examples / FLAGS.BATCH_SIZE / FLAGS.BATCH_ACCUMULATION_SIZE)
learning_rate = CustomSchedule(
    initial_learning_rate=FLAGS.LEARNING_RATE,
    decay_steps=num_train_steps,
    end_learning_rate=0.0,
    power=1.0,
    cycle=False,    
    num_warmup_steps=500
)

decay_var_list = []
for i in range(len(model.trainable_variables)):
    name = model.trainable_variables[i].name
    if any(x in name for x in ["LayerNorm", "layer_norm", "bias"]):
        decay_var_list.append(name)
        
optimizer = AdamW(weight_decay=0.01, learning_rate=learning_rate, 
                  beta_1=0.9, beta_2=0.999, epsilon=1e-6, decay_var_list=decay_var_list)

checkpoint_path = os.path.join(FLAGS.WEIGHTS_PATH, f'model_{FLAGS.VERSION}.ckpt')
ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)


In [None]:

train_step = train_step_with_batch_accumulation

train_start_time = datetime.datetime.now()

for epoch in range(FLAGS.EPOCHS):

    train_loss.reset_states()
    train_loss_start_pos.reset_states()
    train_loss_end_pos.reset_states()
    train_loss_ans_type.reset_states()    
    
    train_acc.reset_states()
    train_acc_start_pos.reset_states()
    train_acc_end_pos.reset_states()
    train_acc_ans_type.reset_states()
    
    epoch_start_time = datetime.datetime.now()
    
    for (batch_idx, (features, targets)) in tqdm.tqdm_notebook(enumerate(train_dataset)):
        
        (unique_ids, input_ids, input_masks, segment_ids) = (features['unique_ids'], features['input_ids'], features['input_mask'], features['segment_ids'])
        (_, start_pos_labels, end_pos_labels, answer_type_labels) = (targets['unique_ids'], targets['start_positions'], targets['end_positions'], targets['answer_types'])
    
        batch_start_time = datetime.datetime.now()
        
        train_step(unique_ids, input_ids, input_masks, segment_ids, start_pos_labels, end_pos_labels, answer_type_labels)

        batch_end_time = datetime.datetime.now()
        batch_elapsed_time = (batch_end_time - batch_start_time).total_seconds()
        
        if (batch_idx + 1) % 10 == 0:
            print('Epoch {} | Batch {} | Elapsed Time {}'.format(
                epoch + 1,
                batch_idx + 1,
                batch_elapsed_time
            ))
            print('Loss {:.6f} | Loss_S {:.6f} | Loss_E {:.6f} | Loss_T {:.6f}'.format(
                train_loss.result(),
                train_loss_start_pos.result(),
                train_loss_end_pos.result(),
                train_loss_ans_type.result()
            ))
            print(' Acc {:.6f} |  Acc_S {:.6f} |  Acc_E {:.6f} |  Acc_T {:.6f}'.format(
                train_acc.result(),
                train_acc_start_pos.result(),
                train_acc_end_pos.result(),
                train_acc_ans_type.result()
            ))
            print("-" * 100)
       
    epoch_end_time = datetime.datetime.now()
    epoch_elapsed_time = (epoch_end_time - epoch_start_time).total_seconds()
            
    if (epoch + 1) % 1 == 0:
        
        ckpt_save_path = ckpt_manager.save()
        print ('\nSaving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path))
        
        print('\nEpoch {}'.format(epoch + 1))
        print('Loss {:.6f} | Loss_S {:.6f} | Loss_E {:.6f} | Loss_T {:.6f}'.format(
            train_loss.result(),
            train_loss_start_pos.result(),
            train_loss_end_pos.result(),
            train_loss_ans_type.result()
        ))
        print(' Acc {:.6f} |  Acc_S {:.6f} |  Acc_E {:.6f} |  Acc_T {:.6f}'.format(
            train_acc.result(),
            train_acc_start_pos.result(),
            train_acc_end_pos.result(),
            train_acc_ans_type.result()
        ))

    print('\nTime taken for 1 epoch: {} secs\n'.format(epoch_elapsed_time))
    print("-" * 80 + "\n")
    
    model.save_weights(WEIGHTS_PATH+f'/modelweights_{VERSION}.h5')
    

# 4. Predict