In [None]:
username = 'MarcelloCeresini'
repository = 'QuestionAnswering'

# COLAB ONLY CELLS
try:
    import google.colab
    IN_COLAB = True
    !pip3 install transformers
    !git clone https://www.github.com/{username}/{repository}.git
    #from google.colab import drive
    #drive.mount('/content/drive/')
    %cd /content/QuestionAnswering/src
    using_TPU = True    # If we are running this notebook on Colab, use a TPU
    # Google cloud credentials
    %env GOOGLE_APPLICATION_CREDENTIALS=/content/drive/MyDrive/Uni/Magistrale/NLP/Project/nlp-project-338723-0510aa0a4912.json
except:
    IN_COLAB = False
    using_TPU = False   # If you're not on Colab you probably won't have access to a TPU

# Imports

In [None]:
import os
import numpy as np
import random
from tqdm import tqdm
from functools import partial
tqdm = partial(tqdm, position=0, leave=True)
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.utils import Sequence
from typing import List, Union, Dict, Tuple
from transformers import BertTokenizerFast, DistilBertTokenizerFast, \
                         TFBertModel, TFDistilBertModel
import datetime
from sklearn.feature_extraction.text import TfidfVectorizer
import utils
from collections import deque

RANDOM_SEED = 42
MAX_SEQ_LEN = 512
BERT_DIMENSIONALITY = 768

np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)

ROOT_PATH = os.path.dirname(os.getcwd())
TRAINING_FILE = os.path.join(ROOT_PATH, 'data', 'training_set.json')
VALIDATION_FILE = os.path.join(ROOT_PATH, 'data', 'validation_set.json')
TEST_FILE = os.path.join(ROOT_PATH, 'data', 'dev_set.json')

if using_TPU:
    try: 
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
        tf.config.experimental_connect_to_cluster(resolver)
        # This is the TPU initialization code that has to be at the beginning.
        tf.tpu.experimental.initialize_tpu_system(resolver)
        print("All devices: ", tf.config.list_logical_devices('TPU'))
        strategy = tf.distribute.TPUStrategy(resolver)
    except:
        print("TPUs are not available, setting flag 'using_TPU' to False.")
        using_TPU = False
        print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
else:
    print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

BATCH_SIZE = 4 if not using_TPU else 32

TPUs are not available, setting flag 'using_TPU' to False.
Num GPUs Available:  0


# Question Answering with the DPR


Now that we have trained the Dense Passage Retriever, we have an efficient way to collect the best scoring paragraphs given a question (a dot product between the question's encoding representation and the matrix of paragraph representations).

On top of this paragraph selection method, we can define our final Question Answering model in this way:
- As input, the model should take the tokenized question (using the standard DistilBert tokenizer)
- The model should then use `bert_q` (non-trainable) to create a representation of the question in the learnt 768-d space, that is in common with the paragraph representations.
- We compute similarity scores between the representation of the question and all pre-computed (non-trainable) representations of paragraphs. Based on these scores, we select the top-k ($k=100$) paragraphs. 
    - Note that at training, validation and testing time we have three different paragraph representation matrices to use, so we should build a mechanism to easily switch between these representations.
    - It's also advisable to maintain an array of pre-tokenised paragraphs since they will be used in the second part of the model.

Once we have top-scoring paragraph indices available, we should decide which specific paragraph contains the question and where. 
- To do that, we encode both the question and the selected paragraphs using cross-attention through another Bert model (`reader`) (trainable model). This model will output k $512 \times 768$ encodings for each question in the batch. Each encoding will be denoted as $P_i$ in contrast to $\hat{P}_i$ which is the 768-d encoding at the `[CLS]` token. 
- For question answering, for each of the $k$ selected paragraphs, we must compute the probability of the paragraph being selected $P_{selected}(i)$, as well as the usual $P_{start, i}(s), P_{end, i}(t)$ for each of the $s$-th and $t$-th words of the $i$-th paragraph.
- All probabilities are computed through dense layers:
\begin{gather}
P_{start,i}(s) = softmax(P_i w_{start})_s
\\
P_{end,i}(t) = softmax(P_i w_{end})_t
\\
P_{selected}(i) = softmax(\hat{P}_i^\intercal w_{selected})_i
\\
\end{gather}
where $w_{start}$, $w_{end}$ and $w_{selected}$ are learnt vectors, while $\hat{P}_i = [P_{1}^{[CLS]}, \dots, P_k^{[CLS]}]$.
- As final answer, we select the highest scoring start-end legal span from the highest-scoring paragraph.

During training: For each question, we create a batch by sampling $m$ ($m=24$ in the paper) from the top-100 passages returned by the retrieval system (DPR, so by computing similarities with the pre-computed representations). The training objective is to maximize the marginal log-likelihood of all the correct answer spans in the positive passage (the answer string may appear multiple times in one passage), combined with the log-likelihood of the positive passage being selected. In the paper, a batch size of 16 was used.


# Setup

## Dataset Loading

We load all data that was prepared into the `dense_passage_retriever` notebook

In [None]:
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive/')
    checkpoint_dir = '/content/drive/MyDrive/Uni/Magistrale/NLP/Project/weights/training_dpr/'
    datasets_dir = '/content/drive/MyDrive/Uni/Magistrale/NLP/Project/datasets/dpr/'
else:
    # Create the folder where we'll save the weights of the model
    checkpoint_dir = os.path.join("checkpoints", "training_dpr")
    datasets_dir = os.path.join("checkpoints", "training_dpr", "dataset")

os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(datasets_dir, exist_ok=True)

Mounted at /content/drive/


In [None]:
train_paragraphs_and_questions = utils.read_question_set(TRAINING_FILE)['data']
val_paragraphs_and_questions = utils.read_question_set(VALIDATION_FILE)['data']
test_paragraphs_and_questions = utils.read_question_set(TEST_FILE)['data']

# Remove the validation set from the train set
train_paragraphs_and_questions = [article for article in train_paragraphs_and_questions \
                                  if article not in val_paragraphs_and_questions]

def get_questions_and_paragraphs(dataset):
    questions = [{
            'qas': qas,
            'context_id': (i,j)    # We also track the question's original context and paragraph indices so to have a ground truth
        }
        for i in range(len(dataset))
        for j, para in enumerate(dataset[i]['paragraphs'])
        for qas in para['qas']
    ]

    paragraphs = [{
            'context': para['context'],
            'context_id': i
        }
        for i in range(len(dataset))
        for para in dataset[i]['paragraphs']
    ]

    return questions, paragraphs

train_questions, train_paragraphs = get_questions_and_paragraphs(train_paragraphs_and_questions)
val_questions, val_paragraphs = get_questions_and_paragraphs(val_paragraphs_and_questions)
test_questions, test_paragraphs = get_questions_and_paragraphs(test_paragraphs_and_questions)

In [None]:
def decode_fn(record_bytes):
    # Reads one element from the dataset (as bytes) and decodes it in a tf.data Dataset element.
    example = tf.io.parse_single_example(
      # Data
      record_bytes,
      # Schema
      {"question__input_ids": tf.io.FixedLenFeature(shape=(MAX_SEQ_LEN,), dtype=tf.int64),
       "question__attention_mask": tf.io.FixedLenFeature(shape=(MAX_SEQ_LEN,), dtype=tf.int64),
       "question__index": tf.io.FixedLenFeature(shape=(), dtype=tf.int64),
       "answer__out_s": tf.io.FixedLenFeature(shape=(MAX_SEQ_LEN,), dtype=tf.int64),
       "answer__out_e": tf.io.FixedLenFeature(shape=(MAX_SEQ_LEN,), dtype=tf.int64),
       "paragraph__input_ids": tf.io.FixedLenFeature(shape=(MAX_SEQ_LEN,), dtype=tf.int64),
       "paragraph__attention_mask": tf.io.FixedLenFeature(shape=(MAX_SEQ_LEN,), dtype=tf.int64),
       "hard_paragraph__input_ids": tf.io.FixedLenFeature(shape=(MAX_SEQ_LEN,), dtype=tf.int64),
       "hard_paragraph__attention_mask": tf.io.FixedLenFeature(shape=(MAX_SEQ_LEN,), dtype=tf.int64),
       "paragraph__tokens_s": tf.io.FixedLenFeature(shape=(MAX_SEQ_LEN,), dtype=tf.int64),
       "paragraph__tokens_e": tf.io.FixedLenFeature(shape=(MAX_SEQ_LEN,), dtype=tf.int64),
       "context__index": tf.io.FixedLenFeature(shape=(), dtype=tf.int64),
       "paragraph__index": tf.io.FixedLenFeature(shape=(), dtype=tf.int64)})
    return {
      "questions": {'input_ids': example['question__input_ids'],
                    'attention_mask': example['question__attention_mask'],
                    'index': example['question__index']},
      "answers":   {'out_s': example['answer__out_s'],
                    'out_e': example['answer__out_e']},
      "paragraphs":{'input_ids': example['paragraph__input_ids'],
                    'attention_mask': example['paragraph__attention_mask'],
                    'tokens_s': example['paragraph__tokens_s'],
                    'tokens_e': example['paragraph__tokens_e']},
      "hard_paragraphs": {'input_ids': example['hard_paragraph__input_ids'],
                          'attention_mask': example['hard_paragraph__attention_mask']},
      "context_ids": (example['context__index'], example['paragraph__index'])
    }

def load_tf_dataset_from_cloud(questions, fn, batch_size=BATCH_SIZE):
    # Prepare strings
    filename = f'{fn}_{BERT_DIMENSIONALITY}.proto'
    fn_type = filename.split(os.sep)[-1].replace('.proto','')
    dst_name = fn_type + '.proto'
    bucket_name = 'volpepe-nlp-project-squad-datasets'
    gcs_filename = f'gs://{bucket_name}/{dst_name}'
    print(f"Loading {fn_type} dataset from GCS ({gcs_filename}).")
    # Return it as processed dataset
    dataset = tf.data.TFRecordDataset([gcs_filename]).map(decode_fn)
    dataset = dataset.apply(tf.data.experimental.assert_cardinality(len(questions)))
    dataset = dataset.shuffle(5000, reshuffle_each_iteration=True)
    dataset = dataset.batch(batch_size)
    dataset = dataset.cache()
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

dataset_train = load_tf_dataset_from_cloud(train_questions, os.path.join(datasets_dir, 'train'))
dataset_val = load_tf_dataset_from_cloud(val_questions, os.path.join(datasets_dir, 'val'))
dataset_test = load_tf_dataset_from_cloud(val_questions, os.path.join(datasets_dir, 'test'))

Loading train_768 dataset from GCS (gs://volpepe-nlp-project-squad-datasets/train_768.proto).
Loading val_768 dataset from GCS (gs://volpepe-nlp-project-squad-datasets/val_768.proto).
Loading test_768 dataset from GCS (gs://volpepe-nlp-project-squad-datasets/test_768.proto).


We also load the paragraphs' `model_p` encodings and the questions' `model_q` encodings for the selection part.

In [None]:
representations_dir = os.path.join(datasets_dir, 'representations')

train_paragraphs_encodings = np.load(os.path.join(representations_dir, 'train_paragraphs_encodings.npy'))
val_paragraphs_encodings   = np.load(os.path.join(representations_dir, 'val_paragraphs_encodings.npy'))
test_paragraphs_encodings  = np.load(os.path.join(representations_dir, 'test_paragraphs_encodings.npy'))

train_questions_encodings  = np.load(os.path.join(representations_dir, 'train_questions_encodings.npy'))
val_questions_encodings    = np.load(os.path.join(representations_dir, 'val_questions_encodings.npy'))
test_questions_encodings   = np.load(os.path.join(representations_dir, 'test_questions_encodings.npy'))

Finally, we load the tokenizer.

In [None]:
tokenizer_distilbert = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

Downloading tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading vocab.txt:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/455k [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

## Paragraphs preparation

Firstly, we pre-tokenize the paragraphs to directly use them inside the model.

In [None]:
pretokenized_paragraphs = {
    'train': {
        'input_ids': [],
        'attention_mask': [],
        'offset_mapping': []
    },
    'val': {
        'input_ids': [],
        'attention_mask': [],
        'offset_mapping': []
    }, 
    'test': {
        'input_ids': [],
        'attention_mask': [],
        'offset_mapping': []
    }
}

for i in tqdm(range(len(train_paragraphs))):
    token_p = dict(tokenizer_distilbert(
        train_paragraphs[i]['context'], max_length = MAX_SEQ_LEN, 
        return_tensors='tf', truncation = True, 
        padding = 'max_length', return_offsets_mapping = True
    ))
    pretokenized_paragraphs['train']['input_ids'].append(token_p['input_ids'])
    pretokenized_paragraphs['train']['attention_mask'].append(token_p['attention_mask'])
    pretokenized_paragraphs['train']['offset_mapping'].append(token_p['offset_mapping'])

for i in tqdm(range(len(val_paragraphs))):
    token_p = dict(tokenizer_distilbert(
        val_paragraphs[i]['context'], max_length = MAX_SEQ_LEN, 
        return_tensors='tf', truncation = True, 
        padding = 'max_length', return_offsets_mapping = True
    ))
    pretokenized_paragraphs['val']['input_ids'].append(token_p['input_ids'])
    pretokenized_paragraphs['val']['attention_mask'].append(token_p['attention_mask'])
    pretokenized_paragraphs['val']['offset_mapping'].append(token_p['offset_mapping'])

for i in tqdm(range(len(test_paragraphs))):
    token_p = dict(tokenizer_distilbert(
        test_paragraphs[i]['context'], max_length = MAX_SEQ_LEN, 
        return_tensors='tf', truncation = True, 
        padding = 'max_length', return_offsets_mapping = True
    ))
    pretokenized_paragraphs['test']['input_ids'].append(token_p['input_ids'])
    pretokenized_paragraphs['test']['attention_mask'].append(token_p['attention_mask'])
    pretokenized_paragraphs['test']['offset_mapping'].append(token_p['offset_mapping'])

100%|██████████| 13975/13975 [00:15<00:00, 923.00it/s] 
100%|██████████| 4921/4921 [00:03<00:00, 1249.20it/s]
100%|██████████| 2067/2067 [00:01<00:00, 1162.22it/s]


# Question Answering model

First we create some utility layers for the model.

In [None]:
class ParagraphSelector(keras.layers.Layer):
    '''
    Custom layer that is used to select the top-100 paragraphs given the question encoding as input.
    It requires the question encoding and a string as input that tells which matrix
    shall be used for the pre-encoded paragraph representations comparison.
    Returns the pre-tokenised best paragraphs and their indices.
    '''
    def __init__(self, **kwargs) -> None:
        super(ParagraphSelector, self).__init__(trainable=False, **kwargs)
        self.pretokenized_paragraphs = pretokenized_paragraphs
        self.representations = {
            'train': train_paragraphs_encodings,
            'val'  : val_paragraphs_encodings,
            'test' : test_paragraphs_encodings
        }
        self.input_ids = {
            dset: tf.constant(np.array(
                self.pretokenized_paragraphs[dset]['input_ids']).squeeze())
            for dset in ['train', 'val', 'test']
        }
        self.attention_masks = {
            dset: tf.constant(np.array(
                self.pretokenized_paragraphs[dset]['attention_mask']).squeeze())
            for dset in ['train', 'val', 'test']
        }

    def call(self, q_repr:tf.Tensor, representation_type:str):
        scores = tf.tensordot(q_repr, self.representations[representation_type].T, axes=1)
        # Collect the best 100 scores
        top100_indices = tf.argsort(scores, direction='DESCENDING')[:,:100]
        # Now gather the pretokenized paragraphs using these indices
        best_input_ids = tf.gather(
            self.input_ids[representation_type], 
            top100_indices, batch_dims=1, axis=0)   # First index dimension is batch dimension, but gather row elements in target
        best_attention_masks = tf.gather(
            self.attention_masks[representation_type], 
            top100_indices, batch_dims=1, axis=0)
        # Return the pre-tokenised paragraph and indices
        return top100_indices, best_input_ids, best_attention_masks


class BestScoringCollector(keras.layers.Layer):
    '''
    Custom layer to collect the start and end probabilities from the best scoring paragraph
    '''
    def __init__(self, **kwargs):
        super(BestScoringCollector, self).__init__(trainable=False, **kwargs)

    def call(self, probs_s, probs_e, probs_sel):
        # Selection of best scoring paragraphs
        best_scoring_paragraphs = tf.squeeze(tf.argmax(probs_sel, axis=1, output_type=tf.int32))
        # Selection of related start-end probabilities
        probs_s = tf.squeeze(tf.gather(probs_s, indices=tf.expand_dims(best_scoring_paragraphs, -1), batch_dims=1))
        probs_e = tf.squeeze(tf.gather(probs_e, indices=tf.expand_dims(best_scoring_paragraphs, -1), batch_dims=1))
        return probs_s, probs_e


class ReaderEvaluator(keras.layers.Layer):
    '''
    Custom layer to compute the start, end and selection probabilities given the paragraphs' 
    full and search encodings.
    '''
    def __init__(self, **kwargs) -> None:
        super(ReaderEvaluator, self).__init__(trainable=True, **kwargs)
        self.start_token_logits = keras.layers.TimeDistributed(keras.layers.Dense(1), name="start_token_logits")
        self.start_token_probabilities = keras.layers.Softmax(name="start_probs", axis=1, dtype='float32')
        self.end_token_logits = keras.layers.TimeDistributed(keras.layers.Dense(1), name="end_token_logits")
        self.end_token_probabilities = keras.layers.Softmax(name="end_probs", axis=1, dtype='float32')
        self.paragraph_selection_logits = keras.layers.Dense(1, name="selection_logits")
        self.paragraph_selection_probabilities = keras.layers.Softmax(name="selection_probs", dtype='float32')
        self.flatten = keras.layers.Flatten()
        self.best_scoring_collector = BestScoringCollector(name='best_scoring_collector')

    def call(self, paragraphs_full_encodings, paragraphs_search_encodings):
        # Compute probabilities for the start token
        out_S = self.start_token_logits(paragraphs_full_encodings)
        out_S = tf.squeeze(out_S)
        out_S = self.start_token_probabilities(out_S)

        # The same is done for the end tokens.
        out_E = self.end_token_logits(paragraphs_full_encodings)
        out_E = tf.squeeze(out_E)
        out_E = self.end_token_probabilities(out_E)

        # Also, we compute paragraph selection probabilities
        out_SEL = self.paragraph_selection_logits(paragraphs_search_encodings)
        out_SEL = self.flatten(out_SEL)
        out_SEL = self.paragraph_selection_probabilities(out_SEL)

        out_S, out_E = self.best_scoring_collector(out_S, out_E, out_SEL)
        return out_S, out_E, out_SEL


We need a function to map the dataset's context ids to ground truth indices in the representation matrix:

In [None]:
# Step 1: pre-compute how many paragraphs each context has in each dataset
paragraphs_per_context = {
    'train': np.cumsum([len(p['paragraphs']) for p in train_paragraphs_and_questions]),
    'val'  : np.cumsum([len(p['paragraphs']) for p in val_paragraphs_and_questions]),
    'test' : np.cumsum([len(p['paragraphs']) for p in test_paragraphs_and_questions])
}

In [None]:
# Step 2: actual computation with tf.function for efficiency
@tf.function
def tf_get_paragraph_encoding_index(batch, dataset):
    art_ids, par_ids = batch['context_ids']
    idxs = tf.TensorArray(tf.int64, size=len(art_ids))
    for i in tf.range(len(art_ids), dtype=tf.int32):
        idx = par_ids[i]
        if art_ids[i] > 0:
           idx += tf.gather(paragraphs_per_context[dataset], art_ids[i]-1)
        idxs = idxs.write(i, idx)
    return idxs.stack()

In [None]:
test_elements = next(dataset_train.as_numpy_iterator())
print(test_elements['context_ids'])
tf_get_paragraph_encoding_index(test_elements, 'train')

(array([ 3, 14,  0,  0]), array([ 4,  2, 49,  2]))


<tf.Tensor: shape=(4,), dtype=int64, numpy=array([151, 811,  49,   2])>

Finally we can create the whole Question Answering model

In [None]:
reader = TFDistilBertModel.from_pretrained('distilbert-base-uncased')

Some layers from the model checkpoint at distilbert-base-uncased were not used when initializing TFDistilBertModel: ['vocab_projector', 'vocab_layer_norm', 'vocab_transform', 'activation_13']
- This IS expected if you are initializing TFDistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFDistilBertModel were initialized from the model checkpoint at distilbert-base-uncased.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertModel for predictions without further training.


In [None]:
input_data = next(dataset_train.as_numpy_iterator())

In [None]:
# Obtain the best 100 paragraphs given the questions' representations
q_repr = tf.gather(model_QA.questions_representations['train'], 
                                   input_data['questions']['index'])
top100_indices, best_input_ids, best_attention_mask = model_QA.paragraph_selector(
    q_repr, 'train')    # We are training, so we index in the training paragraphs matrix
# Obtain the ground truth indices
gt_indices = model_QA.tf_get_paragraph_encoding_index(input_data, 'train')
# Obtain the training batch and the ground truth mask
data_batch, gt_mask = model_QA.tf_obtain_training_info(gt_indices, top100_indices)
# Put together question and paragraph encodings to obtain a full tokenization
input_ids_for_reader, attention_mask_for_reader, gt_start, gt_end = \
    model_QA.tf_put_questions_and_paragraphs_together(
        input_data['questions']['input_ids'], 
        tf.squeeze(
            tf.gather(
                pretokenized_paragraphs['train']['input_ids'], 
                data_batch, axis=0, batch_dims=1)
        ), 
        input_data['answers'])

In [None]:
@tf.function
def test():
    a = tf.TensorArray(tf.float32, size=4)
    for i in tf.range(4):
        q_repr = reader({
            'input_ids': input_ids_for_reader[i,:,:],
            'attention_mask': attention_mask_for_reader[i,:,:],
            'training': False
        }).last_hidden_state
        a = a.write(i, q_repr)
    return a.stack()

In [None]:
test().shape

TensorShape([4, 4, 512, 768])

In [None]:
tf.rank(input_ids_for_reader) > 2

<tf.Tensor: shape=(), dtype=bool, numpy=True>

---

In [None]:
class QuestionAnsweringModel(keras.Model):
    def __init__(self, reader, questions_representations, m=24):
        '''
        Complete Question Answering model.
        - We can call "predict" on it passing a pre-tokenised question (in a dict with
            keys 'input_ids' and 'attention_mask') and a string ('train', 'val', 'test') 
            defining the appropriate paragraph set to consider for reading paragraphs. 
            Default is 'test'.
        - We can otherwise call it directly passing a batch of pre-tokenised 
            questions + paragraphs.

        Requires a reader (usually a virgin DistilBert model) and a dictionary {
            'train': ..., 
            'val': ..., 
            'test': ... 
        } containing the matrices of `model_q` question encodings, to be used at 
        training and evaluation time
        '''
        super(QuestionAnsweringModel, self).__init__()
        # Layers
        self.paragraph_selector = ParagraphSelector(name='paragraph_selector')
        self.reader = reader    # DistilBert model
        self.reader_evaluator   = ReaderEvaluator(name='reader_evaluator')
        # model_q question encodings (pre-computed to run a faster training/evaluation)
        self.questions_representations = questions_representations
        # Hyperparams
        self.m = m
        # Metrics
        self.loss_tracker     = keras.metrics.Mean(name="loss")
        self.start_acc_metric = keras.metrics.CategoricalAccuracy(name="start_token_accuracy")
        self.end_acc_metric   = keras.metrics.CategoricalAccuracy(name="end_token_accuracy")
        self.sel_acc_metric   = keras.metrics.CategoricalAccuracy(name="par_selection_accuracy")
    
    def call(self, inputs):
        if tf.rank(inputs['input_ids']) > 2:
            full_encodings = tf.TensorArray(tf.float32, size=len(inputs['input_ids']))
            for i in tf.range(tf.shape(inputs['input_ids'])[0]):
                encs = self.reader({
                    'input_ids': inputs['input_ids'][i,:,:], 
                    'attention_mask': inputs['attention_mask'][i,:,:],
                    'training': inputs['training']
                }).last_hidden_state
                full_encodings = full_encodings.write(i, encs)
            full_encodings = full_encodings.stack()
            search_encodings = full_encodings[:,:,0,:]
        else:
            full_encodings = self.reader({
                'input_ids': inputs['input_ids'], 
                'attention_mask': inputs['attention_mask'],
                'training': inputs['training']
            }).last_hidden_state
            search_encodings = full_encodings[:,0,:]
        out_S, out_E, out_SEL = self.reader_evaluator(full_encodings, search_encodings)
        return out_S, out_E, out_SEL

    @tf.function
    def tf_put_questions_and_paragraphs_together(self, questions_input_ids, 
                                                 data_batch_input_ids, gt_answers=None):
        '''
        Utility function that merges the question and paragraphs representations in order
        to allow cross-attention in the reader. Optionally pass gt_answers to also 
        correct the one hot encodings for the GT positions.
        '''
        # Define the arrays that will hold the data
        final_input_ids_array = tf.TensorArray(tf.int32, size=len(questions_input_ids))
        final_attention_masks_array = tf.TensorArray(tf.int32, size=len(questions_input_ids))
        # Empty arrays if gt_answers is None
        final_gt_s_array = tf.TensorArray(tf.int32, size=len(questions_input_ids))
        final_gt_e_array = tf.TensorArray(tf.int32, size=len(questions_input_ids))
        # Iterate over batch dimension
        for i in tf.range(len(questions_input_ids)):
            q = questions_input_ids[i]
            # Obtain the part of the question ids that are before the separator
            pos_sep_q = tf.where(q == 102)[0][0]
            appendix = tf.cast(q[:pos_sep_q + 1], tf.int32)
            input_ids = tf.TensorArray(tf.int32, size=len(data_batch_input_ids[i]))
            attention_masks = tf.TensorArray(tf.int32, size=len(data_batch_input_ids[i]))
            # Iterate over the paragraphs for each question
            for j in tf.range(len(data_batch_input_ids[i])):
                p = data_batch_input_ids[i,j]
                # Obtain the part of the paragraph ids that are before the separator
                # but after the initial token
                postfix = p[1:]
                # Concatenate the two parts
                final_id_seq = tf.concat([appendix, postfix], axis=0)
                # Cut the padding at 512, but if the tensor is longer put a 102 token
                # at the end.
                if tf.where(final_id_seq == 102)[1][0] >= MAX_SEQ_LEN:
                    final_id_seq = tf.concat([final_id_seq[:MAX_SEQ_LEN-1], 
                                            tf.constant([102])], axis=0)
                else:
                    final_id_seq = final_id_seq[:MAX_SEQ_LEN]
                # Create the attention mask
                final_attention_mask = tf.cast(final_id_seq != 0, tf.int32)
                # Write results on arrays
                input_ids = input_ids.write(j, final_id_seq)
                attention_masks = attention_masks.write(j, final_attention_mask)
            input_ids = input_ids.stack()
            attention_masks = attention_masks.stack()
            final_input_ids_array = final_input_ids_array.write(i, input_ids)
            final_attention_masks_array = final_attention_masks_array.write(i, attention_masks)
            if gt_answers is not None:
                # This part really only makes sense for the positive paragraph.
                # We are shifting the paragraph representation by pos_sep_q:
                # so we add the same number to the index where out_s and out_e are 1.
                # Both can go out of bounds, so we cap the positions to the maximum sequence
                # length
                pos_gt_s = tf.math.minimum(tf.where(gt_answers['out_s'][i])[0][0] + pos_sep_q, 
                                        tf.constant(MAX_SEQ_LEN-1, dtype=tf.int64))
                pos_gt_e = tf.math.minimum(tf.where(gt_answers['out_e'][i])[0][0] + pos_sep_q, 
                                        tf.constant(MAX_SEQ_LEN-1, dtype=tf.int64))
                # Create the final GT tensors
                final_gt_s = tf.one_hot(pos_gt_s, MAX_SEQ_LEN, dtype=tf.int32)
                final_gt_e = tf.one_hot(pos_gt_e, MAX_SEQ_LEN, dtype=tf.int32)
                # Write the partial tensors into the final array
                final_gt_s_array = final_gt_s_array.write(i, final_gt_s)
                final_gt_e_array = final_gt_e_array.write(i, final_gt_e)
            else:
                # If None, fill these arrays with 0s.
                final_gt_s_array.write(i, tf.constant(0, dtype=tf.int32))
                final_gt_e_array.write(i, tf.constant(0, dtype=tf.int32))
        # Return the stacked arrays
        return final_input_ids_array.stack(), final_attention_masks_array.stack(), \
            final_gt_s_array.stack(), final_gt_e_array.stack()

    @tf.function
    def tf_get_paragraph_encoding_index(self, batch, dataset):
        '''
        Utility function to obtain the ground truth index of a batch of paragraphs given their contexts
        '''
        art_ids, par_ids = batch['context_ids']
        # Create an array that will hold the indexes
        idxs = tf.TensorArray(tf.int64, size=len(art_ids))
        # Loop through the elements of the batch
        for i in tf.range(len(art_ids), dtype=tf.int32):
            # The index is the id of the paragraph to which we add the paragraphs
            # present in the previous articles
            idx = par_ids[i]
            if art_ids[i] > 0:
                # Since we used a cumulative sum we can simply index 
                # the number of paragraphs at art_ids[i]-1
                idx += tf.gather(paragraphs_per_context[dataset], art_ids[i]-1)
            # Write to the array
            idxs = idxs.write(i, idx)
        # Return the final tensor
        return idxs.stack()

    @tf.function
    def tf_shuffle_on_columns(self, value):
        '''
        Utility function that shuffles a tensor randomly on each of its rows.
        '''
        # Create a tensor of random numbers, argsort it and use them as indices to gather
        # values from the original tensor
        return tf.gather(value, tf.argsort(tf.random.uniform(tf.shape(value))), batch_dims=1)

    @tf.function
    def tf_obtain_training_info(self, gt_indexes, top100_indexes):
        '''
        Obtains a batch of data and the ground truth mask to be used while training the model
        '''
        # Collect ground truth indexes
        gt_paragraphs = tf.expand_dims(tf.cast(gt_indexes, tf.int32), -1)
        # A training sample is formed by the positive and m-1 negative examples
        # obtained from the top-100 for each of the questions in the batch.
        # We create a data batch by sampling m-1 examples from the masked 100 paragraphs
        negative_masks = tf.math.not_equal(top100_indexes, gt_paragraphs)
        # To keep the graph working with the correct sizes, we create a tensor of negatives 
        # by random shuffling the large tensor of top100 indices and taking the first m elements.
        # Positive examples that could end up in this tensor are replaced by randomly sampling 
        # from the tensor of top100 indices again.
        # It could happen that the positive example is replaced by itself, or that a 
        # negative sample appears twice in the batch, but it's a non-deterministic process.
        negatives = self.tf_shuffle_on_columns(tf.where(
            negative_masks, top100_indexes, self.tf_shuffle_on_columns(top100_indexes))
        )[:,:self.m-1]
        # We concatenate the positive paragraph index to the selected negatives and shuffle
        # so that the positive is not always the last element
        data_batch = self.tf_shuffle_on_columns(
            tf.concat([negatives, gt_paragraphs], axis=1))
        # When we have a data batch, we create the ground truth mask, which represents the position
        # of the positive sample in the data batch in a one-hot encoded fashion.
        gt_mask = tf.cast(data_batch == gt_paragraphs, tf.int32)
        return data_batch, gt_mask

    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
        # Obtain the best 100 paragraphs given the questions' representations
        q_repr = tf.gather(self.questions_representations['train'], 
                                        data['questions']['index'])
        top100_indices, best_input_ids, best_attention_mask = self.paragraph_selector(
            q_repr, 'train')    # We are training, so we index in the training paragraphs matrix
        # Obtain the ground truth indices
        gt_indices = self.tf_get_paragraph_encoding_index(data, 'train')
        # Obtain the training batch and the ground truth mask
        data_batch, gt_mask = self.tf_obtain_training_info(gt_indices, top100_indices)
        # Put together question and paragraph encodings to obtain a full tokenization
        input_ids_for_reader, attention_mask_for_reader, gt_start, gt_end = \
            self.tf_put_questions_and_paragraphs_together(
                data['questions']['input_ids'], 
                tf.squeeze(
                    tf.gather(
                        pretokenized_paragraphs['train']['input_ids'], 
                        data_batch, axis=0, batch_dims=1)
                ), 
                data['answers'])
        # Open the gradient tape, obtain predictions and compute the loss
        with tf.GradientTape() as tape:
            out_S, out_E, out_SEL = self({
                'input_ids': input_ids_for_reader,
                'attention_mask': attention_mask_for_reader,
                'training': True
            })
            loss_start = self.compiled_loss(gt_start, out_S)
            loss_end = self.compiled_loss(gt_end, out_E)
            loss_sel = self.compiled_loss(gt_mask, out_SEL)
            loss_value = sum([loss_start, loss_end, loss_sel])
        # Compute the gradients and apply them on the variables
        grads = tape.gradient(loss_value, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        # Update the metrics
        self.loss_tracker.update_state(loss_value)
        self.start_acc_metric.update_state(gt_start, out_S)
        self.end_acc_metric.update_state(gt_end, out_E)
        self.sel_acc_metric.update_state(gt_mask, out_SEL)
        return {"loss": self.loss_tracker.result(), 
                "start_accuracy": self.start_acc_metric.result(),
                "end_accuracy": self.end_acc_metric.result(),
                "sel_accuracy": self.sel_acc_metric.result()}

    def test_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
        # Obtain the best 100 paragraphs given the questions' representations
        q_repr = tf.gather(self.questions_representations['val'], 
                                      data['questions']['index'])
        top100_indices, best_input_ids, best_attention_mask = self.paragraph_selector(
            q_repr, 'val')    # We are validating, so we index in the validation paragraphs matrix
        # Obtain the ground truth indices
        gt_indices = self.tf_get_paragraph_encoding_index(data, 'val')
        # Obtain the training batch and the ground truth mask
        data_batch, gt_mask = self.tf_obtain_training_info(
            gt_indices, top100_indices)
        # Put together question and paragraph encodings to obtain a full tokenization
        input_ids_for_reader, attention_mask_for_reader, gt_start, gt_end = \
            self.tf_put_questions_and_paragraphs_together(
                data['questions']['input_ids'], 
                tf.squeeze(
                    tf.gather(
                        pretokenized_paragraphs['val']['input_ids'], 
                        data_batch, axis=0, batch_dims=1)
                ), 
                data['answers'])
        # Obtain predictions and compute the loss
        out_S, out_E, out_SEL = self({
            'input_ids': input_ids_for_reader,
            'attention_mask': attention_mask_for_reader,
            'training': False
        })
        loss_start = self.compiled_loss(gt_start, out_S)
        loss_end = self.compiled_loss(gt_end, out_E)
        loss_sel = self.compiled_loss(gt_mask, out_SEL)
        loss_value = sum([loss_start, loss_end, loss_sel])
        # Updates the metrics
        self.start_acc_metric.update_state(data['gt_start'], out_S)
        self.end_acc_metric.update_state(data['gt_end'], out_E)
        self.sel_acc_metric.update_state(gt_mask, out_SEL)
        # Return a dict mapping metric names to current value.
        return {"loss": self.loss_tracker.result(), 
                "start_accuracy": self.start_acc_metric.result(),
                "end_accuracy": self.end_acc_metric.result(),
                "sel_accuracy": self.sel_acc_metric.result()}

    def predict_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
        # At prediction time we receive the pre-tokenised questions
        # We obtain the best 100 paragraphs for the question
        q_repr = tf.gather(self.questions_representations['test'], 
                                       data['questions']['index'])
        top100_indices, best_input_ids, best_attention_masks = self.paragraph_selector(q_repr, 'test')
        # Then we produce the reader's representations
        input_ids_for_reader, attention_mask_for_reader, _, _ = self.tf_put_questions_and_paragraphs_together(
            data['questions']['input_ids'], best_input_ids, gt_answers=None
        )
        # Let the reader encode the question and paragraphs with cross-attention
        full_encodings = self.reader({
            'input_ids': input_ids_for_reader,
            'attention_mask': attention_mask_for_reader,
            'training': False
        })
        # Compute the probabilities of starting and ending token and the selection probability.
        search_encodings = full_encodings[:,0,:]
        out_S, out_E, out_SEL = self.reader_evaluator(full_encodings, search_encodings)
        return top100_indices, out_S, out_E, out_SEL

    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch
        # or at the start of `evaluate()`.
        return [self.loss_tracker, self.start_acc_metric, self.end_acc_metric, self.sel_acc_metric]

### Training

To train the model we need to:

- Compile it defining the losses and optimizer.
- Create a ground truth batch that we use for comparing the model's output.

In [None]:
def create_trainable_QA_model(freeze_reader_up_to:int, 
                              questions_representations:dict, 
                              m:int=24):
    local_device_option = tf.train.CheckpointOptions(experimental_io_device="/job:localhost")

    print("Creating Reader model...")
    reader = TFDistilBertModel.from_pretrained('distilbert-base-uncased')
    
    # Freeze some of the blocks of the reader.
    for i in range(freeze_reader_up_to): # layers 0 to variable are frozen, successive layers learn
        reader.distilbert.transformer.layer[i].trainable = False

    print("Creating Question Answering model...")
    model = QuestionAnsweringModel(reader, questions_representations, m=m)

    print("Compiling...")
    # Compile the model (metrics are defined into the model)
    model.compile(
        optimizer = keras.optimizers.Adam(learning_rate=1e-5),
        loss = keras.losses.CategoricalCrossentropy()
    )

    return model

In [None]:
import datetime

questions_representations = {
    'train': train_questions_encodings,
    'val'  : val_questions_encodings,
    'test' : test_questions_encodings,
}

if using_TPU:
    # TPU requires to create the model within the scope of the distributed strategy
    # we're using.
    freeze_up_to = 3
    with strategy.scope():
        model_QA = create_trainable_QA_model(freeze_up_to, questions_representations, m=12)
else:
    # Also, on TPU we cannot use tensorboard, but on GPU we can
    log_dir = os.path.join(ROOT_PATH, "data", "logs", 
        "training_qa", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
    freeze_up_to = 5
    model_QA = create_trainable_QA_model(freeze_up_to, questions_representations, m=4)

# Workaraound for saving locally when using cloud TPUs
local_device_option = tf.train.CheckpointOptions(experimental_io_device="/job:localhost")
# GPUs and local systems don't need the above specifications. We simply
# create a pattern for the filename and let the callbacks deal with it.
checkpoint_path = os.path.join(checkpoint_dir, "qa_model.ckpt")

Creating Reader model...


Some layers from the model checkpoint at distilbert-base-uncased were not used when initializing TFDistilBertModel: ['vocab_projector', 'vocab_layer_norm', 'vocab_transform', 'activation_13']
- This IS expected if you are initializing TFDistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFDistilBertModel were initialized from the model checkpoint at distilbert-base-uncased.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertModel for predictions without further training.


Creating Question Answering model...
Compiling...


Train the model:

In [None]:
DO_TRAINING = True
PATIENCE    = 6
EPOCHS      = 60

In [None]:
if DO_TRAINING:
    if not using_TPU:
        # Tensorboard callback is not available on TPU
        tensorboard_callback = tf.keras.callbacks.TensorBoard(
            log_dir=log_dir,
            histogram_freq=1
        )
    
    # ModelCheckpoint callback is available both on TPU and GPU 
    # with the options parameter
    cp_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath = checkpoint_path,
        verbose=1,
        save_weights_only = True,
        save_best_only = True,
        options=local_device_option
    )

    # Early stopping can be used by both hardware
    es_callback = tf.keras.callbacks.EarlyStopping(
        patience = PATIENCE,
        restore_best_weights=True
    )

    callbacks = [es_callback, cp_callback]
    if not using_TPU:
        # These callback imply saving stuff on local disk, which cannot be 
        # done automatically using TPUs.
        # Therefore, they are only active when using GPUs and local systems
        callbacks.extend([tensorboard_callback])

    # We fit the model
    history = model_QA.fit(
        dataset_train, 
        y=None,
        validation_data=dataset_val,
        epochs=EPOCHS, 
        callbacks=callbacks,
        use_multiprocessing=True,
        initial_epoch=0,
        verbose=1 # Show progress bar
    )

Epoch 1/60


TypeError: ignored

### Obtaining an answer

In [None]:
# def start_end_token_from_probabilities(
#     pstartv: np.array, 
#     pendv: np.array, 
#     dim:int=512) -> List[List[int]]:
#     '''
#     Returns a List of [StartToken, EndToken] elements computed from the batch outputs.
#     '''
#     idxs = []
#     for i in range(pstartv.shape[0]):
#         # For each element in the batch, transform the vectors into matrices
#         # by repeating them dim times:
#         # - Vectors of starting probabilities are stacked on the columns
#         pstart = np.stack([pstartv[i,:]]*dim, axis=1)
#         # - Vectors of ending probabilities are repeated on the rows
#         pend = np.stack([pendv[i,:]]*dim, axis=0)
#         # Once we have the two matrices, we sum them (element-wise operation)
#         # to obtain the scores of each combination
#         sums = pstart + pend
#         # We only care about the scores in the upper triangular part of the matrix
#         # (where the ending index is greater than the starting index)
#         # therefore we zero out the diagonal and the lower triangular area
#         sums = np.triu(sums, k=1)
#         # The most probable set of tokens is the one with highest score in the
#         # remaining matrix. Through argmax we obtain its position.
#         val = np.argmax(sums)
#         # Since the starting probabilities are repeated on the columns, each element
#         # is identified by the row. Ending probabilities are instead repeated on rows,
#         # so each element is identified by the column.
#         row = val // dim
#         col = val - dim*row
#         idxs.append([row,col])
#     return idxs

In [None]:
# answers_start_end = start_end_token_from_probabilities(probs_s, probs_e)
# print(answers_start_end)

Finally, we can obtain the answers to the questions we have given the network.

In [None]:
# best_indices = tf.squeeze(tf.gather(paragraphs['indexes'], tf.expand_dims(best_scoring_paragraphs, -1), batch_dims=1))
# best_offsets = tf.squeeze(tf.gather(pretokenized_val_paragraphs['offset_mapping'], best_indices))
# best_offsets.shape, best_indices.shape

In [None]:
# char_start_end = [(best_offsets[i][answers_start_end[i][0]][0].numpy(),
#                    best_offsets[i][answers_start_end[i][1]][1].numpy())
#                  for i in range(BATCH_SIZE)]
# char_start_end

In [None]:
# # Correction for answers arriving to the end of the sequence
# for i in range(BATCH_SIZE):
#     c = char_start_end[i]
#     if c[1] >= c[0]:
#         char_start_end[i] = (c[0], c[1])
#     else:
#         char_start_end[i] = (c[0], 1000000)

In [None]:
# for i, p in enumerate(best_indices.numpy()):
#     print(val_paragraphs[p]['context'][char_start_end[i][0]:char_start_end[i][1]])

Of course, answers are extremely bad because we need to train the Dense layers selecting the start and end tokens, as well as the paragraph selector.