# Fine-tuning a BERT model for text extraction with the SQuAD dataset

We are going to fine-tune BERT for the text-extraction task with a dataset of questions and answers. The question are about a give paragraph (*context*) that contains the answers. The model will be trained to locate the answer in the context by giving the possitions where the answer stars and finish.

This notebook is based on [BERT (from HuggingFace Transformers) for Text Extraction](https://keras.io/examples/nlp/text_extraction_with_bert/).

 More info:
  * [BERT NLP — How To Build a Question Answering Bot](https://towardsdatascience.com/bert-nlp-how-to-build-a-question-answering-bot-98b1d1594d7b)

In [1]:
import ipcmagic

In [2]:
%ipcluster start -n 2 --mpi

IPCluster is ready! (5 seconds)


In [3]:
%%px
import os
import json
import numpy as np

import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers

import horovod.tensorflow.keras as hvd

import dataset_utils as du
from tokenizers import BertWordPieceTokenizer
from transformers import TFBertModel, BertTokenizer


batch_size = 16
max_len = 384

In [4]:
%%px
hvd.init()

slow_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",
                                               cache_dir=f"/scratch/snx3000/stud50/_bert_tockenizer")

save_path = f"/scratch/snx3000/stud50/bert_tockenizer"
if not os.path.exists(save_path):
    os.makedirs(save_path)
    slow_tokenizer.save_pretrained(save_path)

# Load the fast tokenizer from saved file
tokenizer = BertWordPieceTokenizer(f"{save_path}/vocab.txt", lowercase=True)

In [5]:
%%px
train_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json"
train_path = keras.utils.get_file("train.json", train_data_url, cache_dir="./")

eval_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json"
eval_path = keras.utils.get_file("eval.json", eval_data_url, cache_dir="./")

In [6]:
%%px

with open(train_path) as f:
    raw_train_data = json.load(f)
print(f"{len(raw_train_data['data'])} training items loaded.")

with open(eval_path) as f:
    raw_eval_data = json.load(f)
print(f"{len(raw_eval_data['data'])} evaluation items loaded.")


train_squad_examples = du.create_squad_examples(raw_train_data, max_len, tokenizer)
x_train, y_train = du.create_inputs_targets(train_squad_examples, shuffle=True, seed=42)
print(f"{len(train_squad_examples)} training points created.")

eval_squad_examples = du.create_squad_examples(raw_eval_data, max_len, tokenizer)
x_eval, y_eval = du.create_inputs_targets(eval_squad_examples)
print(f"{len(eval_squad_examples)} evaluation points created.")

[stdout:0] 
442 training items loaded.
48 evaluation items loaded.
86136 training points created.
10331 evaluation points created.
[stdout:1] 
442 training items loaded.
48 evaluation items loaded.
86136 training points created.
10331 evaluation points created.


In [7]:
%%px
encoder = TFBertModel.from_pretrained("bert-base-uncased",
                                      cache_dir=f"/scratch/snx3000/stud50/bert_model")

input_ids = layers.Input(shape=(max_len,), dtype=tf.int32)
token_type_ids = layers.Input(shape=(max_len,), dtype=tf.int32)
attention_mask = layers.Input(shape=(max_len,), dtype=tf.int32)

embedding = encoder(input_ids,
                    token_type_ids=token_type_ids,
                    attention_mask=attention_mask)[0]

start_logits = layers.Dense(1, name="start_logit", use_bias=False)(embedding)
start_logits = layers.Flatten()(start_logits)
start_probs = layers.Activation(keras.activations.softmax)(start_logits)

end_logits = layers.Dense(1, name="end_logit", use_bias=False)(embedding)
end_logits = layers.Flatten()(end_logits)
end_probs = layers.Activation(keras.activations.softmax)(end_logits)

model = keras.Model(inputs=[input_ids, token_type_ids, attention_mask],
                    outputs=[start_probs, end_probs])

loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False)
optimizer = tfa.optimizers.LAMB(lr=5e-4 * np.sqrt(hvd.size()))
optimizer = hvd.DistributedOptimizer(optimizer)

model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])

In [8]:
%%px
def get_dataset(x, y, batch_size=batch_size):
    dataset = tf.data.Dataset.zip((
        tf.data.Dataset.from_tensor_slices(x),
        tf.data.Dataset.from_tensor_slices(y),
    ))
    dataset = dataset.shard(hvd.size(), hvd.rank())
    dataset = dataset.shuffle(2048, seed=42)
    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size)
    return dataset

# batch shapes
for X, Y in get_dataset(x_train, y_train).take(1):
    print([i.shape for i in X])
    [print(i) for i in Y]

[stdout:0] 
[TensorShape([16, 384]), TensorShape([16, 384]), TensorShape([16, 384])]
tf.Tensor([126 104  42 150 164  65  10  96  19  25  70  16  88  44  81 108], shape=(16,), dtype=int64)
tf.Tensor([126 108  42 150 164  71  12 101  20  27  70  17  89  50  83 111], shape=(16,), dtype=int64)
[stdout:1] 
[TensorShape([16, 384]), TensorShape([16, 384]), TensorShape([16, 384])]
tf.Tensor([ 24  89 171  68  61  18  41 113  22   7  26  49  35  31  11   2], shape=(16,), dtype=int64)
tf.Tensor([ 25 102 172  74  61  19  46 125  22   8  28  60  36  35  12   2], shape=(16,), dtype=int64)


In [9]:
# %%px
# def dataset_generator(x, y):
#     for i in range(x[0].shape[0]):
#         yield ((x[0][i], x[1][i], x[2][i]),
#                (y[0][i], y[1][i]))

# def get_generator(x, y, batch_size=batch_size, output_shapes=(((max_len,), (max_len,), (max_len,)),((), ()))):
#     dataset = tf.data.Dataset.from_generator(lambda: dataset_generator(x, y),
#                                              output_types=((tf.int32, tf.int32, tf.int32),(tf.int32, tf.int32)),
#                                              output_shapes=output_shapes)
#     dataset = dataset.shard(hvd.size(), hvd.rank())
#     dataset = dataset.shuffle(2048, seed=42)
#     dataset = dataset.repeat()
#     dataset = dataset.batch(batch_size)
#     return dataset

# # batch shapes
# for X, Y in get_generator(x_train, y_train).take(1):
#     print([i.shape for i in X])
#     [print(i) for i in Y]

In [10]:
%%px
fit = model.fit(get_dataset(x_train, y_train),
                epochs=1,
                steps_per_epoch=50,
                callbacks=[hvd.callbacks.BroadcastGlobalVariablesCallback(0),
                           hvd.callbacks.MetricAverageCallback()],
                validation_data=get_dataset(x_eval, y_eval),
                validation_steps=len(y_eval[0]) // batch_size)



### Evaluation

In [11]:
%%px
# sharded eval
model.evaluate(get_dataset(x_eval, y_eval),
               steps=len(y_eval[0]) // batch_size,
               callbacks=[hvd.callbacks.MetricAverageCallback()])



[0;31mOut[0:8]: [0m
[4.349886894226074,
 2.2529540061950684,
 2.0969340801239014,
 0.4310077428817749,
 0.47344961762428284]

[0;31mOut[1:8]: [0m
[4.404617786407471,
 2.2865185737609863,
 2.118100881576538,
 0.43042635917663574,
 0.46531006693840027]

In [12]:
%%px
import string
import re


def normalize_text(text):
    text = text.lower()

    # Remove punctuations
    exclude = set(string.punctuation)
    text = "".join(ch for ch in text if ch not in exclude)

    # Remove articles
    regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
    text = re.sub(regex, " ", text)

    # Remove extra white space
    text = " ".join(text.split())
    return text


class ExactMatch():
    """
    Each `SquadExample` object contains the character level offsets for each token
    in its input paragraph. We use them to get back the span of text corresponding
    to the tokens between our predicted start and end tokens.
    All the ground-truth answers are also present in each `SquadExample` object.
    We calculate the percentage of data points where the span of text obtained
    from model predictions matches one of the ground-truth answers.
    """

    def __init__(self, x_eval, y_eval, model, squad_examples):
        self.x_eval = x_eval
        self.y_eval = y_eval
        self.model = model
        self.squad_examples = squad_examples

    def score(self, logs=None):
        pred_start, pred_end = self.model.predict(self.x_eval)
        count = 0
        eval_examples_no_skip = [_ for _ in self.squad_examples if _.skip == False]
        for idx, (start, end) in enumerate(zip(pred_start, pred_end)):
            squad_eg = eval_examples_no_skip[idx]
            offsets = squad_eg.context_token_to_char
            start = np.argmax(start)
            end = np.argmax(end)
            if start >= len(offsets):
                continue

            pred_char_start = offsets[start][0]
            if end < len(offsets):
                pred_char_end = offsets[end][1]
                pred_ans = squad_eg.context[pred_char_start:pred_char_end]
            else:
                pred_ans = squad_eg.context[pred_char_start:]

            normalized_pred_ans = normalize_text(pred_ans)
            normalized_true_ans = [normalize_text(_) for _ in squad_eg.all_answers]
            if normalized_pred_ans in normalized_true_ans:
                count += 1
                
            # print(f'  - {squad_eg.question}\n')
            print(f'  - {normalized_pred_ans:30.30s} | ref: {squad_eg.answer_text:30s} | {squad_eg.question}')

        acc = count / len(self.y_eval[0])
        return acc

In [13]:
%%px
samples = np.random.choice(len(x_eval[0]), 50, replace=False)

em = ExactMatch([x_eval[0][samples], x_eval[1][samples], x_eval[2][samples]],
                [y_eval[0][samples], y_eval[1][samples]],
                model,
                eval_squad_examples[samples])
em.score()

[stdout:0] 
  - 1870 to 1939                   | ref: 1870 to 1939                   | How long was the Summer Theatre in operation?
  -                                | ref: extra-legal                    | Excessive bureaucratic red tape is one of the reasons for what type of ownership?
  - paramount pictures             | ref: Paramount Pictures             | What company did Eisner become president of when he left ABC in 1976?
  - 300 km long and up to 40 km    | ref: 300 km long                    | How long is the Upper Rhine Plain?
  - 14th to 17th centuries accordi | ref: the plague was present somewhere in Europe in every year between 1346 and 1671. | What did Biraben say about the plague in Europe?
  - luna 180foot 55 mtall 600yearo | ref: 738 days                       | How long did Julia Butterfly Hill live in a tree?
  - sybilla of normandy            | ref: Sybilla of Normandy            | Who did Alexander I marry?
  - tesla went on to pursue his id | ref: boat         

[0;31mOut[0:10]: [0m0.34

[0;31mOut[1:10]: [0m0.34

In [14]:
%ipcluster stop