# Fine-tuning a BERT model for text extraction with the SQuAD dataset

We are going to fine-tune BERT for the text-extraction task with a dataset of questions and answers. The question are about a give paragraph (*context*) that contains the answers. The model will be trained to locate the answer in the context by giving the possitions where the answer stars and finish.

This notebook is based on [BERT (from HuggingFace Transformers) for Text Extraction](https://keras.io/examples/nlp/text_extraction_with_bert/).

 More info:
  * [BERT NLP — How To Build a Question Answering Bot](https://towardsdatascience.com/bert-nlp-how-to-build-a-question-answering-bot-98b1d1594d7b)

In [1]:
import ipcmagic

In [2]:
%ipcluster start -n 2 --mpi

IPCluster is ready! (9 seconds)


In [3]:
%%px
import os
import json
import numpy as np

import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers

import dataset_utils as du
from tokenizers import BertWordPieceTokenizer
from transformers import TFBertModel, BertTokenizer

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
    cluster_resolver=tf.distribute.cluster_resolver.SlurmClusterResolver(),
    communication=tf.distribute.experimental.CollectiveCommunication.NCCL,
)
num_workers = int(os.environ['SLURM_NNODES'])
node_id = int(os.environ['SLURM_NODEID'])

batch_size = 16 * num_workers
max_len = 384

In [4]:
%%px
slow_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",
                                               cache_dir=f"/scratch/snx3000/stud50/_bert_tockenizer")

save_path = f"/scratch/snx3000/stud50/bert_tockenizer"
if not os.path.exists(save_path):
    os.makedirs(save_path)
    slow_tokenizer.save_pretrained(save_path)

# Load the fast tokenizer from saved file
tokenizer = BertWordPieceTokenizer(f"{save_path}/vocab.txt", lowercase=True)

In [5]:
%%px
train_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json"
train_path = keras.utils.get_file("train.json", train_data_url, cache_dir="./")

eval_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json"
eval_path = keras.utils.get_file("eval.json", eval_data_url, cache_dir="./")

In [6]:
%%px
with open(train_path) as f:
    raw_train_data = json.load(f)
print(f"{len(raw_train_data['data'])} training items loaded.")

with open(eval_path) as f:
    raw_eval_data = json.load(f)
print(f"{len(raw_eval_data['data'])} evaluation items loaded.")


train_squad_examples = du.create_squad_examples(raw_train_data, max_len, tokenizer)
x_train, y_train = du.create_inputs_targets(train_squad_examples, shuffle=True, seed=42)
print(f"{len(train_squad_examples)} training points created.")

eval_squad_examples = du.create_squad_examples(raw_eval_data, max_len, tokenizer)
x_eval, y_eval = du.create_inputs_targets(eval_squad_examples)
print(f"{len(eval_squad_examples)} evaluation points created.")

[stdout:0] 
442 training items loaded.
48 evaluation items loaded.
86136 training points created.
10331 evaluation points created.
[stdout:1] 
442 training items loaded.
48 evaluation items loaded.
86136 training points created.
10331 evaluation points created.


In [7]:
%%px
with strategy.scope():
    encoder = TFBertModel.from_pretrained("bert-base-uncased",
                                          cache_dir=f"/scratch/snx3000/stud50/bert_model")

    input_ids = layers.Input(shape=(max_len,), dtype=tf.int32)
    token_type_ids = layers.Input(shape=(max_len,), dtype=tf.int32)
    attention_mask = layers.Input(shape=(max_len,), dtype=tf.int32)

    embedding = encoder(input_ids,
                        token_type_ids=token_type_ids,
                        attention_mask=attention_mask)[0]

    start_logits = layers.Dense(1, name="start_logit", use_bias=False)(embedding)
    start_logits = layers.Flatten()(start_logits)
    start_probs = layers.Activation(keras.activations.softmax)(start_logits)

    end_logits = layers.Dense(1, name="end_logit", use_bias=False)(embedding)
    end_logits = layers.Flatten()(end_logits)
    end_probs = layers.Activation(keras.activations.softmax)(end_logits)

    model = keras.Model(inputs=[input_ids, token_type_ids, attention_mask],
                        outputs=[start_probs, end_probs])

    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False)
    optimizer = tfa.optimizers.LAMB(lr=5e-4 * np.sqrt(num_workers))

    model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])

In [8]:
%%px
def get_dataset(x, y, batch_size=batch_size):
    dataset = tf.data.Dataset.zip((
        tf.data.Dataset.from_tensor_slices(x),
        tf.data.Dataset.from_tensor_slices(y),
    ))
    dataset = dataset.shuffle(2048, seed=42)
    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size)
    return dataset

# batch shapes
for X, Y in get_dataset(x_train, y_train).take(1):
    print([i.shape for i in X])
    [print(i) for i in Y]

[stdout:0] 
[TensorShape([32, 384]), TensorShape([32, 384]), TensorShape([32, 384])]
tf.Tensor(
[  3  82  14 134 124  23   3 134  32 124  23  80  48  10 229 105 104 114
  84  37   7  86  63   6  90  11  71  75   1  98  43  32], shape=(32,), dtype=int64)
tf.Tensor(
[  7  91  31 139 132  29   8 137  34 125  27 101  64  14 231 109 105 116
 105  53   8  88  71   6  92  42  74  76   4 105  45  40], shape=(32,), dtype=int64)
[stdout:1] 
[TensorShape([32, 384]), TensorShape([32, 384]), TensorShape([32, 384])]
tf.Tensor(
[  3  82  14 134 124  23   3 134  32 124  23  80  48  10 229 105 104 114
  84  37   7  86  63   6  90  11  71  75   1  98  43  32], shape=(32,), dtype=int64)
tf.Tensor(
[  7  91  31 139 132  29   8 137  34 125  27 101  64  14 231 109 105 116
 105  53   8  88  71   6  92  42  74  76   4 105  45  40], shape=(32,), dtype=int64)


In [9]:
# %%px
# def dataset_generator(x, y):
#     for i in range(x[0].shape[0]):
#         yield ((x[0][i], x[1][i], x[2][i]),
#                (y[0][i], y[1][i]))

# def get_generator(x, y, batch_size=batch_size, output_shapes=(((max_len,), (max_len,), (max_len,)),((), ()))):
#     dataset = tf.data.Dataset.from_generator(lambda: dataset_generator(x, y),
#                                              output_types=((tf.int32, tf.int32, tf.int32),(tf.int32, tf.int32)),
#                                              output_shapes=output_shapes)
#     dataset = dataset.shuffle(2048, seed=42)
#     dataset = dataset.repeat()
#     dataset = dataset.batch(batch_size)
#     return dataset

# # batch shapes
# for X, Y in get_generator(x_train, y_train).take(1):
#     print([i.shape for i in X])
#     [print(i) for i in Y]

In [10]:
%%px
fit = model.fit(get_dataset(x_train, y_train),
                epochs=1,
                steps_per_epoch=50,
                validation_data=get_dataset(x_eval, y_eval),
                validation_steps=len(y_eval[0]) // batch_size)



### Evaluation

In [11]:
%%px
# sharded eval
model.evaluate(get_dataset(x_eval, y_eval),
               steps=len(y_eval[0]) // batch_size)



[0;31mOut[0:8]: [0m
[4.731622695922852,
 2.4260923862457275,
 2.3055310249328613,
 0.39334240555763245,
 0.41731366515159607]

[0;31mOut[1:8]: [0m
[4.731622695922852,
 2.4260923862457275,
 2.3055310249328613,
 0.39334240555763245,
 0.41731366515159607]

In [12]:
%ipcluster stop