##### Copyright 2019 The TensorFlow Authors.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Fine-tuning a BERT model

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/tfmodels/nlp/fine_tune_bert"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/fine_tune_bert.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/models/blob/master/docs/nlp/fine_tune_bert.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/fine_tune_bert.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
  <td>
    <a href="https://tfhub.dev/google/collections/bert"><img src="https://www.tensorflow.org/images/hub_logo_32px.png" />See TF Hub model</a>
  </td>
</table>

This tutorial demonstrates how to fine-tune a [Bidirectional Encoder Representations from Transformers (BERT)](https://arxiv.org/abs/1810.04805) (Devlin et al., 2018) model using [TensorFlow Model Garden](https://github.com/tensorflow/models).

You can also find the pre-trained BERT model used in this tutorial on [TensorFlow Hub (TF Hub)](https://tensorflow.org/hub). For concrete examples of how to use the models from TF Hub, refer to the [Solve Glue tasks using BERT](https://www.tensorflow.org/text/tutorials/bert_glue) tutorial. If you're just trying to fine-tune a model, the TF Hub tutorial is a good starting point.

On the other hand, if you're interested in deeper customization, follow this tutorial. It shows how to do a lot of things manually, so you can learn how you can customize the workflow from data preprocessing to training, exporting and saving the model.

## Setup

### Install pip packages

Start by installing the TensorFlow Text and Model Garden pip packages.

*  `tf-models-official` is the TensorFlow Model Garden package. Note that it may not include the latest changes in the `tensorflow_models` GitHub repo. To include the latest changes, you may install `tf-models-nightly`, which is the nightly Model Garden package created daily automatically.
*  pip will install all models and dependencies automatically.

In [1]:
!pip install -q opencv-python

In [2]:
!pip install -q -U "tensorflow-text==2.11.*"

In [3]:
!pip install -q tf-models-official

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
fastapi 0.104.1 requires typing-extensions>=4.8.0, but you have typing-extensions 4.5.0 which is incompatible.
pydantic 2.5.2 requires typing-extensions>=4.6.1, but you have typing-extensions 4.5.0 which is incompatible.
pydantic-core 2.14.5 requires typing-extensions!=4.7.0,>=4.6.0, but you have typing-extensions 4.5.0 which is incompatible.[0m[31m
[0m

### Import libraries

In [4]:
import os

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_models as tfm
import tensorflow_hub as hub
import tensorflow_datasets as tfds
tfds.disable_progress_bar()

2023-11-28 14:39:59.232399: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-11-28 14:39:59.257163: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-11-28 14:39:59.257633: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Resources

The following directory contains the BERT model's configuration, vocabulary, and a pre-trained checkpoint used in this tutorial:

In [5]:
gs_folder_bert = "gs://cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12"
tf.io.gfile.listdir(gs_folder_bert)

2023-11-28 15:10:27.770299: W tensorflow/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".


['bert_config.json',
 'bert_model.ckpt.data-00000-of-00001',
 'bert_model.ckpt.index',
 'vocab.txt']

## Load and preprocess the dataset

This example uses the GLUE (General Language Understanding Evaluation) MRPC (Microsoft Research Paraphrase Corpus) [dataset from TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets/catalog/glue#gluemrpc).

This dataset is not set up such that it can be directly fed into the BERT model. The following section handles the necessary preprocessing.

### Get the dataset from TensorFlow Datasets

The GLUE MRPC (Dolan and Brockett, 2005) dataset is a corpus of sentence pairs automatically extracted from online news sources, with human annotations for whether the sentences in the pair are semantically equivalent. It has the following attributes:

*   Number of labels: 2
*   Size of training dataset: 3668
*   Size of evaluation dataset: 408
*   Maximum sequence length of training and evaluation dataset: 128

Begin by loading the MRPC dataset from TFDS:

In [6]:
batch_size=32
glue, info = tfds.load('glue/mrpc',
                       with_info=True,
                       batch_size=32)

[1mDownloading and preparing dataset 1.43 MiB (download: 1.43 MiB, generated: 1.74 MiB, total: 3.17 MiB) to /home/zgao/tensorflow_datasets/glue/mrpc/2.0.0...[0m
[1mDataset glue downloaded and prepared to /home/zgao/tensorflow_datasets/glue/mrpc/2.0.0. Subsequent calls will reuse this data.[0m


2023-11-28 15:10:46.073362: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:995] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-11-28 15:10:46.149616: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1960] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [7]:
glue

{Split('train'): <_PrefetchDataset element_spec={'idx': TensorSpec(shape=(None,), dtype=tf.int32, name=None), 'label': TensorSpec(shape=(None,), dtype=tf.int64, name=None), 'sentence1': TensorSpec(shape=(None,), dtype=tf.string, name=None), 'sentence2': TensorSpec(shape=(None,), dtype=tf.string, name=None)}>,
 Split('validation'): <_PrefetchDataset element_spec={'idx': TensorSpec(shape=(None,), dtype=tf.int32, name=None), 'label': TensorSpec(shape=(None,), dtype=tf.int64, name=None), 'sentence1': TensorSpec(shape=(None,), dtype=tf.string, name=None), 'sentence2': TensorSpec(shape=(None,), dtype=tf.string, name=None)}>,
 Split('test'): <_PrefetchDataset element_spec={'idx': TensorSpec(shape=(None,), dtype=tf.int32, name=None), 'label': TensorSpec(shape=(None,), dtype=tf.int64, name=None), 'sentence1': TensorSpec(shape=(None,), dtype=tf.string, name=None), 'sentence2': TensorSpec(shape=(None,), dtype=tf.string, name=None)}>}

The `info` object describes the dataset and its features:

In [8]:
info.features

FeaturesDict({
    'idx': int32,
    'label': ClassLabel(shape=(), dtype=int64, num_classes=2),
    'sentence1': Text(shape=(), dtype=string),
    'sentence2': Text(shape=(), dtype=string),
})

The two classes are:

In [9]:
info.features['label'].names

['not_equivalent', 'equivalent']

Here is one example from the training set:

In [10]:
example_batch = next(iter(glue['train']))

for key, value in example_batch.items():
  print(f"{key:9s}: {value[0].numpy()}")

idx      : 1680
label    : 0
sentence1: b'The identical rovers will act as robotic geologists , searching for evidence of past water .'
sentence2: b'The rovers act as robotic geologists , moving on six wheels .'


2023-11-28 15:10:57.966894: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


### Preprocess the data

The keys `"sentence1"` and `"sentence2"` in the GLUE MRPC dataset contain two input sentences for each example.

Because the BERT model from the Model Garden doesn't take raw text as input, two things need to happen first:

1. The text needs to be _tokenized_ (split into word pieces) and converted to _indices_.
2. Then, the _indices_ need to be packed into the format that the model expects.

#### The BERT tokenizer

To fine tune a pre-trained language model from the Model Garden, such as BERT, you need to make sure that you're using exactly the same tokenization, vocabulary, and index mapping as used during training.

The following code rebuilds the tokenizer that was used by the base model using the Model Garden's `tfm.nlp.layers.FastWordpieceBertTokenizer` layer:

In [11]:
tokenizer = tfm.nlp.layers.FastWordpieceBertTokenizer(
    vocab_file=os.path.join(gs_folder_bert, "vocab.txt"),
    lower_case=True)

Let's tokenize a test sentence:

In [12]:
tokens = tokenizer(tf.constant(["Hello TensorFlow!"]))
tokens

<tf.RaggedTensor [[[7592], [23435, 12314], [999]]]>

Learn more about the tokenization process in the [Subword tokenization](https://www.tensorflow.org/text/guide/subwords_tokenizer) and [Tokenizing with TensorFlow Text](https://www.tensorflow.org/text/guide/tokenizers) guides.

#### Pack the inputs

TensorFlow Model Garden's BERT model doesn't just take the tokenized strings as input. It also expects these to be packed into a particular format. `tfm.nlp.layers.BertPackInputs` layer can handle the conversion from _a list of tokenized sentences_ to the input format expected by the Model Garden's BERT model.

`tfm.nlp.layers.BertPackInputs` packs the two input sentences (per example in the MRCP dataset) concatenated together. This input is expected to start with a `[CLS]` "This is a classification problem" token, and each sentence should end with a `[SEP]` "Separator" token.

Therefore, the `tfm.nlp.layers.BertPackInputs` layer's constructor takes the `tokenizer`'s special tokens as an argument. It also needs to know the indices of the tokenizer's special tokens.

In [13]:
special = tokenizer.get_special_tokens_dict()
special

{'vocab_size': 30522,
 'start_of_sequence_id': 101,
 'end_of_segment_id': 102,
 'padding_id': 0,
 'mask_id': 103}

In [14]:
max_seq_length = 128

packer = tfm.nlp.layers.BertPackInputs(
    seq_length=max_seq_length,
    special_tokens_dict = tokenizer.get_special_tokens_dict())

The `packer` takes a list of tokenized sentences as input. For example:

In [15]:
sentences1 = ["hello tensorflow"]
tok1 = tokenizer(sentences1)
tok1

<tf.RaggedTensor [[[7592], [23435, 12314]]]>

In [16]:
sentences2 = ["goodbye tensorflow"]
tok2 = tokenizer(sentences2)
tok2

<tf.RaggedTensor [[[9119], [23435, 12314]]]>

Then, it returns a dictionary containing three outputs:

- `input_word_ids`: The tokenized sentences packed together.
- `input_mask`: The mask indicating which locations are valid in the other outputs.
- `input_type_ids`: Indicating which sentence each token belongs to.

In [17]:
packed = packer([tok1, tok2])

for key, tensor in packed.items():
  print(f"{key:15s}: {tensor[:, :12]}")

input_word_ids : [[  101  7592 23435 12314   102  9119 23435 12314   102     0     0     0]]
input_mask     : [[1 1 1 1 1 1 1 1 1 0 0 0]]
input_type_ids : [[0 0 0 0 0 1 1 1 1 0 0 0]]


#### Put it all together

Combine these two parts into a `keras.layers.Layer` that can be attached to your model:

In [18]:
class BertInputProcessor(tf.keras.layers.Layer):
  def __init__(self, tokenizer, packer):
    super().__init__()
    self.tokenizer = tokenizer
    self.packer = packer

  def call(self, inputs):
    tok1 = self.tokenizer(inputs['sentence1'])
    tok2 = self.tokenizer(inputs['sentence2'])

    packed = self.packer([tok1, tok2])

    if 'label' in inputs:
      return packed, inputs['label']
    else:
      return packed

But for now just apply it to the dataset using `Dataset.map`, since the dataset you loaded from TFDS is a `tf.data.Dataset` object:

In [19]:
bert_inputs_processor = BertInputProcessor(tokenizer, packer)

In [20]:
glue_train = glue['train'].map(bert_inputs_processor).prefetch(1)

Here is an example batch from the processed dataset:

In [21]:
example_inputs, example_labels = next(iter(glue_train))

2023-11-28 15:11:25.313238: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


In [22]:
example_inputs

{'input_word_ids': <tf.Tensor: shape=(32, 128), dtype=int32, numpy=
 array([[ 101, 1996, 7235, ...,    0,    0,    0],
        [ 101, 2625, 2084, ...,    0,    0,    0],
        [ 101, 6804, 1011, ...,    0,    0,    0],
        ...,
        [ 101, 2021, 2049, ...,    0,    0,    0],
        [ 101, 2274, 2062, ...,    0,    0,    0],
        [ 101, 2043, 1037, ...,    0,    0,    0]], dtype=int32)>,
 'input_mask': <tf.Tensor: shape=(32, 128), dtype=int32, numpy=
 array([[1, 1, 1, ..., 0, 0, 0],
        [1, 1, 1, ..., 0, 0, 0],
        [1, 1, 1, ..., 0, 0, 0],
        ...,
        [1, 1, 1, ..., 0, 0, 0],
        [1, 1, 1, ..., 0, 0, 0],
        [1, 1, 1, ..., 0, 0, 0]], dtype=int32)>,
 'input_type_ids': <tf.Tensor: shape=(32, 128), dtype=int32, numpy=
 array([[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]], dtype=int32)>}

In [23]:
example_labels

<tf.Tensor: shape=(32,), dtype=int64, numpy=
array([0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1,
       1, 1, 1, 1, 1, 0, 0, 1, 0, 1])>

In [24]:
for key, value in example_inputs.items():
  print(f'{key:15s} shape: {value.shape}')

print(f'{"labels":15s} shape: {example_labels.shape}')

input_word_ids  shape: (32, 128)
input_mask      shape: (32, 128)
input_type_ids  shape: (32, 128)
labels          shape: (32,)


The `input_word_ids` contain the token IDs:

In [25]:
plt.pcolormesh(example_inputs['input_word_ids'])

<matplotlib.collections.QuadMesh at 0x7f49041fb4c0>

The mask allows the model to cleanly differentiate between the content and the padding. The mask has the same shape as the `input_word_ids`, and contains a `1` anywhere the `input_word_ids` is not padding.

In [26]:
plt.pcolormesh(example_inputs['input_mask'])

<matplotlib.collections.QuadMesh at 0x7f4904126d90>

The "input type" also has the same shape, but inside the non-padded region, contains a `0` or a `1` indicating which sentence the token is a part of.

In [27]:
plt.pcolormesh(example_inputs['input_type_ids'])

<matplotlib.collections.QuadMesh at 0x7f49043a8520>

Apply the same preprocessing to the validation and test subsets of the GLUE MRPC dataset:

In [28]:
glue_validation = glue['validation'].map(bert_inputs_processor).prefetch(1)
glue_test = glue['test'].map(bert_inputs_processor).prefetch(1)

## Build, train and export the model

Now that you have formatted the data as expected, you can start working on building and training the model.

### Build the model


The first step is to download the configuration file—`config_dict`—for the pre-trained BERT model:


In [29]:
import json

bert_config_file = os.path.join(gs_folder_bert, "bert_config.json")
config_dict = json.loads(tf.io.gfile.GFile(bert_config_file).read())
config_dict

{'attention_probs_dropout_prob': 0.1,
 'hidden_act': 'gelu',
 'hidden_dropout_prob': 0.1,
 'hidden_size': 768,
 'initializer_range': 0.02,
 'intermediate_size': 3072,
 'max_position_embeddings': 512,
 'num_attention_heads': 12,
 'num_hidden_layers': 12,
 'type_vocab_size': 2,
 'vocab_size': 30522}

In [30]:
encoder_config = tfm.nlp.encoders.EncoderConfig({
    'type':'bert',
    'bert': config_dict
})

In [31]:
bert_encoder = tfm.nlp.encoders.build_encoder(encoder_config)
bert_encoder

<official.nlp.modeling.networks.bert_encoder.BertEncoder at 0x7f48f02ac370>

The configuration file defines the core BERT model from the Model Garden, which is a Keras model that predicts the outputs of `num_classes` from the inputs with maximum sequence length `max_seq_length`.

In [32]:
bert_classifier = tfm.nlp.models.BertClassifier(network=bert_encoder, num_classes=2)

Run it on a test batch of data 10 examples from the training set. The output is the logits for the two classes:

In [33]:
bert_classifier(
    example_inputs, training=True).numpy()[:10]

array([[-0.43970585, -0.03515399],
       [ 0.08846277, -0.52144897],
       [-0.6429028 , -0.3931797 ],
       [-0.60580117, -0.63702095],
       [-0.3618883 , -0.05291632],
       [-0.6457471 ,  0.09466767],
       [-0.18039268,  0.22603166],
       [-0.60778785, -0.58083296],
       [ 0.5504377 , -0.3834118 ],
       [ 0.3804643 ,  0.48679772]], dtype=float32)

In [39]:
!pip install pydot



The `TransformerEncoder` in the center of the classifier above **is** the `bert_encoder`.

If you inspect the encoder, notice the stack of `Transformer` layers connected to those same three inputs:

In [40]:
tf.keras.utils.plot_model(bert_encoder, show_shapes=True, dpi=48)

You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.


### Restore the encoder weights

When built, the encoder is randomly initialized. Restore the encoder's weights from the checkpoint:

In [41]:
checkpoint = tf.train.Checkpoint(encoder=bert_encoder)
checkpoint.read(
    os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7f488441dc70>

Note: The pretrained `TransformerEncoder` is also available on [TensorFlow Hub](https://tensorflow.org/hub). Go to the [TF Hub appendix](#hub_bert) for details.

### Set up the optimizer

BERT typically uses the Adam optimizer with weight decay—[AdamW](https://arxiv.org/abs/1711.05101) (`tf.keras.optimizers.experimental.AdamW`).
It also employs a learning rate schedule that first warms up from 0 and then decays to 0:

In [42]:
# Set up epochs and steps
epochs = 5
batch_size = 32
eval_batch_size = 32

train_data_size = info.splits['train'].num_examples
steps_per_epoch = int(train_data_size / batch_size)
num_train_steps = steps_per_epoch * epochs
warmup_steps = int(0.1 * num_train_steps)
initial_learning_rate=2e-5

Linear decay from `initial_learning_rate` to zero over `num_train_steps`.

In [43]:
linear_decay = tf.keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate=initial_learning_rate,
    end_learning_rate=0,
    decay_steps=num_train_steps)

Warmup to that value over `warmup_steps`:

In [44]:
warmup_schedule = tfm.optimization.lr_schedule.LinearWarmup(
    warmup_learning_rate = 0,
    after_warmup_lr_sched = linear_decay,
    warmup_steps = warmup_steps
)

The overall schedule looks like this:

In [45]:
x = tf.linspace(0, num_train_steps, 1001)
y = [warmup_schedule(xi) for xi in x]
plt.plot(x,y)
plt.xlabel('Train step')
plt.ylabel('Learning rate')

Text(0, 0.5, 'Learning rate')

Use `tf.keras.optimizers.experimental.AdamW` to instantiate the optimizer with that schedule:

In [46]:
optimizer = tf.keras.optimizers.experimental.Adam(
    learning_rate = warmup_schedule)

### Train the model

Set the metric as accuracy and the loss as sparse categorical cross-entropy. Then, compile and train the BERT classifier:

In [47]:
metrics = [tf.keras.metrics.SparseCategoricalAccuracy('accuracy', dtype=tf.float32)]
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

bert_classifier.compile(
    optimizer=optimizer,
    loss=loss,
    metrics=metrics)

In [48]:
bert_classifier.evaluate(glue_validation)



[0.7043040990829468, 0.38235294818878174]





















[0.7620382308959961, 0.31617647409439087]

In [None]:
bert_classifier.fit(
      glue_train,
      validation_data=(glue_validation),
      batch_size=32,
      epochs=epochs)

Epoch 1/5
  8/115 [=>............................] - ETA: 9:01 - loss: 0.7399 - accuracy: 0.5391

  4/115 [>.............................] - ETA: 8:29 - loss: 0.8523 - accuracy: 0.4688

  5/115 [>.............................] - ETA: 8:21 - loss: 0.8298 - accuracy: 0.5000

  6/115 [>.............................] - ETA: 8:13 - loss: 0.8307 - accuracy: 0.4948

  7/115 [>.............................] - ETA: 8:12 - loss: 0.8276 - accuracy: 0.4911

  8/115 [=>............................] - ETA: 8:07 - loss: 0.8064 - accuracy: 0.5117

  9/115 [=>............................] - ETA: 8:03 - loss: 0.8327 - accuracy: 0.4965

 10/115 [=>............................] - ETA: 7:58 - loss: 0.8326 - accuracy: 0.4875

 11/115 [=>............................] - ETA: 7:53 - loss: 0.8280 - accuracy: 0.4886

 12/115 [==>...........................] - ETA: 7:49 - loss: 0.8345 - accuracy: 0.4766

 13/115 [==>...........................] - ETA: 7:43 - loss: 0.8241 - accuracy: 0.4856

 14/115 [==>...........................] - ETA: 7:38 - loss: 0.8279 - accuracy: 0.4799

 15/115 [==>...........................] - ETA: 7:32 - loss: 0.8261 - accuracy: 0.4812

 16/115 [===>..........................] - ETA: 7:27 - loss: 0.8268 - accuracy: 0.4805

 17/115 [===>..........................] - ETA: 7:21 - loss: 0.8263 - accuracy: 0.4798

 18/115 [===>..........................] - ETA: 7:16 - loss: 0.8292 - accuracy: 0.4826

 19/115 [===>..........................] - ETA: 7:11 - loss: 0.8301 - accuracy: 0.4852

 20/115 [====>.........................] - ETA: 7:06 - loss: 0.8272 - accuracy: 0.4859

 21/115 [====>.........................] - ETA: 7:00 - loss: 0.8208 - accuracy: 0.4866

 22/115 [====>.........................] - ETA: 6:55 - loss: 0.8220 - accuracy: 0.4872

 23/115 [=====>........................] - ETA: 6:51 - loss: 0.8151 - accuracy: 0.4905

 24/115 [=====>........................] - ETA: 6:45 - loss: 0.8123 - accuracy: 0.4922

 25/115 [=====>........................] - ETA: 6:40 - loss: 0.8074 - accuracy: 0.5000

 26/115 [=====>........................] - ETA: 6:35 - loss: 0.8039 - accuracy: 0.5012





















































































































































































Epoch 2/5


  1/115 [..............................] - ETA: 8:02 - loss: 0.5938 - accuracy: 0.7188

  2/115 [..............................] - ETA: 7:46 - loss: 0.5683 - accuracy: 0.7031

  3/115 [..............................] - ETA: 7:51 - loss: 0.5639 - accuracy: 0.7188

  4/115 [>.............................] - ETA: 7:46 - loss: 0.5822 - accuracy: 0.7109

  5/115 [>.............................] - ETA: 7:41 - loss: 0.5795 - accuracy: 0.7125

  6/115 [>.............................] - ETA: 7:40 - loss: 0.5732 - accuracy: 0.7083

  7/115 [>.............................] - ETA: 7:36 - loss: 0.5677 - accuracy: 0.7098

  8/115 [=>............................] - ETA: 7:34 - loss: 0.5644 - accuracy: 0.7109

  9/115 [=>............................] - ETA: 7:29 - loss: 0.5666 - accuracy: 0.7118

 10/115 [=>............................] - ETA: 7:25 - loss: 0.5632 - accuracy: 0.7094

 11/115 [=>............................] - ETA: 7:20 - loss: 0.5556 - accuracy: 0.7102

 12/115 [==>...........................] - ETA: 7:16 - loss: 0.5600 - accuracy: 0.7135

 13/115 [==>...........................] - ETA: 7:11 - loss: 0.5661 - accuracy: 0.7067

 14/115 [==>...........................] - ETA: 7:06 - loss: 0.5602 - accuracy: 0.7121

 15/115 [==>...........................] - ETA: 7:02 - loss: 0.5681 - accuracy: 0.7042

 16/115 [===>..........................] - ETA: 6:58 - loss: 0.5759 - accuracy: 0.6973

 17/115 [===>..........................] - ETA: 6:54 - loss: 0.5732 - accuracy: 0.7022

 18/115 [===>..........................] - ETA: 6:50 - loss: 0.5705 - accuracy: 0.7014

 19/115 [===>..........................] - ETA: 6:45 - loss: 0.5683 - accuracy: 0.7039

 20/115 [====>.........................] - ETA: 6:41 - loss: 0.5687 - accuracy: 0.7016

 21/115 [====>.........................] - ETA: 6:37 - loss: 0.5720 - accuracy: 0.7009

 22/115 [====>.........................] - ETA: 6:33 - loss: 0.5781 - accuracy: 0.6932

 23/115 [=====>........................] - ETA: 6:29 - loss: 0.5752 - accuracy: 0.6970

 24/115 [=====>........................] - ETA: 6:24 - loss: 0.5678 - accuracy: 0.7057

 25/115 [=====>........................] - ETA: 6:20 - loss: 0.5644 - accuracy: 0.7100

 26/115 [=====>........................] - ETA: 6:17 - loss: 0.5628 - accuracy: 0.7091





















































































































































































Epoch 3/5


  1/115 [..............................] - ETA: 8:09 - loss: 0.3812 - accuracy: 0.8750

  2/115 [..............................] - ETA: 7:50 - loss: 0.3735 - accuracy: 0.8438

  3/115 [..............................] - ETA: 7:46 - loss: 0.4260 - accuracy: 0.8125

  4/115 [>.............................] - ETA: 7:43 - loss: 0.4509 - accuracy: 0.8047

  5/115 [>.............................] - ETA: 7:39 - loss: 0.4470 - accuracy: 0.8188

  6/115 [>.............................] - ETA: 7:39 - loss: 0.4466 - accuracy: 0.8177

  7/115 [>.............................] - ETA: 7:35 - loss: 0.4388 - accuracy: 0.8170

  8/115 [=>............................] - ETA: 7:32 - loss: 0.4301 - accuracy: 0.8242

  9/115 [=>............................] - ETA: 7:28 - loss: 0.4199 - accuracy: 0.8299

 10/115 [=>............................] - ETA: 7:24 - loss: 0.4179 - accuracy: 0.8281

 11/115 [=>............................] - ETA: 7:19 - loss: 0.4153 - accuracy: 0.8267

 12/115 [==>...........................] - ETA: 7:14 - loss: 0.4151 - accuracy: 0.8255

 13/115 [==>...........................] - ETA: 7:10 - loss: 0.4149 - accuracy: 0.8245

 14/115 [==>...........................] - ETA: 7:05 - loss: 0.4034 - accuracy: 0.8326

 15/115 [==>...........................] - ETA: 7:01 - loss: 0.4054 - accuracy: 0.8292

 16/115 [===>..........................] - ETA: 6:57 - loss: 0.4151 - accuracy: 0.8242

 17/115 [===>..........................] - ETA: 6:53 - loss: 0.4094 - accuracy: 0.8254

 18/115 [===>..........................] - ETA: 6:49 - loss: 0.4055 - accuracy: 0.8247

 19/115 [===>..........................] - ETA: 6:45 - loss: 0.4030 - accuracy: 0.8273

 20/115 [====>.........................] - ETA: 6:40 - loss: 0.4035 - accuracy: 0.8250

 21/115 [====>.........................] - ETA: 6:35 - loss: 0.4055 - accuracy: 0.8229

 22/115 [====>.........................] - ETA: 6:32 - loss: 0.4131 - accuracy: 0.8210

 23/115 [=====>........................] - ETA: 6:27 - loss: 0.4117 - accuracy: 0.8193

 24/115 [=====>........................] - ETA: 6:23 - loss: 0.4070 - accuracy: 0.8216

 25/115 [=====>........................] - ETA: 6:19 - loss: 0.4049 - accuracy: 0.8225

 26/115 [=====>........................] - ETA: 6:15 - loss: 0.4032 - accuracy: 0.8209





















































































































































































Epoch 4/5


  1/115 [..............................] - ETA: 8:00 - loss: 0.1989 - accuracy: 0.9375

  2/115 [..............................] - ETA: 7:54 - loss: 0.2917 - accuracy: 0.9062

  3/115 [..............................] - ETA: 7:50 - loss: 0.3592 - accuracy: 0.8542

  4/115 [>.............................] - ETA: 7:44 - loss: 0.3896 - accuracy: 0.8125

  5/115 [>.............................] - ETA: 7:39 - loss: 0.3875 - accuracy: 0.8125

  6/115 [>.............................] - ETA: 7:38 - loss: 0.3687 - accuracy: 0.8333

  7/115 [>.............................] - ETA: 7:33 - loss: 0.3629 - accuracy: 0.8304

  8/115 [=>............................] - ETA: 7:30 - loss: 0.3451 - accuracy: 0.8398

  9/115 [=>............................] - ETA: 7:26 - loss: 0.3213 - accuracy: 0.8576

 10/115 [=>............................] - ETA: 7:22 - loss: 0.3093 - accuracy: 0.8625

 11/115 [=>............................] - ETA: 7:18 - loss: 0.3073 - accuracy: 0.8608

 12/115 [==>...........................] - ETA: 7:13 - loss: 0.3007 - accuracy: 0.8646

 13/115 [==>...........................] - ETA: 7:10 - loss: 0.3050 - accuracy: 0.8630

 14/115 [==>...........................] - ETA: 7:06 - loss: 0.2986 - accuracy: 0.8683

 15/115 [==>...........................] - ETA: 7:01 - loss: 0.3080 - accuracy: 0.8646

 16/115 [===>..........................] - ETA: 6:58 - loss: 0.3069 - accuracy: 0.8652

 17/115 [===>..........................] - ETA: 6:53 - loss: 0.2967 - accuracy: 0.8695

 18/115 [===>..........................] - ETA: 6:49 - loss: 0.2963 - accuracy: 0.8663

 19/115 [===>..........................] - ETA: 6:45 - loss: 0.2952 - accuracy: 0.8668

 20/115 [====>.........................] - ETA: 6:40 - loss: 0.3022 - accuracy: 0.8672

 21/115 [====>.........................] - ETA: 6:36 - loss: 0.2986 - accuracy: 0.8690

 22/115 [====>.........................] - ETA: 6:32 - loss: 0.3013 - accuracy: 0.8679

 23/115 [=====>........................] - ETA: 6:27 - loss: 0.2989 - accuracy: 0.8709

 24/115 [=====>........................] - ETA: 6:23 - loss: 0.2949 - accuracy: 0.8737

 25/115 [=====>........................] - ETA: 6:19 - loss: 0.2927 - accuracy: 0.8775

 26/115 [=====>........................] - ETA: 6:14 - loss: 0.2976 - accuracy: 0.8762





















































































































































































Epoch 5/5


  1/115 [..............................] - ETA: 7:50 - loss: 0.3780 - accuracy: 0.9062

  2/115 [..............................] - ETA: 8:15 - loss: 0.2541 - accuracy: 0.9375

  3/115 [..............................] - ETA: 8:05 - loss: 0.2536 - accuracy: 0.9271

  4/115 [>.............................] - ETA: 8:00 - loss: 0.2598 - accuracy: 0.9219

  5/115 [>.............................] - ETA: 7:53 - loss: 0.2467 - accuracy: 0.9187

  6/115 [>.............................] - ETA: 7:48 - loss: 0.2275 - accuracy: 0.9167

  7/115 [>.............................] - ETA: 7:44 - loss: 0.2229 - accuracy: 0.9152

  8/115 [=>............................] - ETA: 7:40 - loss: 0.2091 - accuracy: 0.9180

  9/115 [=>............................] - ETA: 7:34 - loss: 0.1902 - accuracy: 0.9271

 10/115 [=>............................] - ETA: 7:28 - loss: 0.1832 - accuracy: 0.9250

 11/115 [=>............................] - ETA: 7:25 - loss: 0.1854 - accuracy: 0.9233

 12/115 [==>...........................] - ETA: 7:19 - loss: 0.1795 - accuracy: 0.9245

 13/115 [==>...........................] - ETA: 7:15 - loss: 0.1773 - accuracy: 0.9279

 14/115 [==>...........................] - ETA: 7:10 - loss: 0.1740 - accuracy: 0.9308

 15/115 [==>...........................] - ETA: 7:04 - loss: 0.1874 - accuracy: 0.9208

 16/115 [===>..........................] - ETA: 7:00 - loss: 0.1890 - accuracy: 0.9199

 17/115 [===>..........................] - ETA: 6:56 - loss: 0.1811 - accuracy: 0.9246

 18/115 [===>..........................] - ETA: 6:52 - loss: 0.1826 - accuracy: 0.9236

 19/115 [===>..........................] - ETA: 6:48 - loss: 0.1805 - accuracy: 0.9243

 20/115 [====>.........................] - ETA: 6:44 - loss: 0.1763 - accuracy: 0.9266

 21/115 [====>.........................] - ETA: 6:39 - loss: 0.1773 - accuracy: 0.9256

 22/115 [====>.........................] - ETA: 6:35 - loss: 0.1923 - accuracy: 0.9219

 23/115 [=====>........................] - ETA: 6:31 - loss: 0.1895 - accuracy: 0.9239

 24/115 [=====>........................] - ETA: 6:26 - loss: 0.1873 - accuracy: 0.9258

 25/115 [=====>........................] - ETA: 6:22 - loss: 0.1829 - accuracy: 0.9287

 26/115 [=====>........................] - ETA: 6:17 - loss: 0.1803 - accuracy: 0.9303





















































































































































































<keras.src.callbacks.History at 0x7fe64c5aa4f0>

Now run the fine-tuned model on a custom example to see that it works.

Start by encoding some sentence pairs:

In [45]:
my_examples = {
        'sentence1':[
            'The rain in Spain falls mainly on the plain.',
            'Look I fine tuned BERT.'],
        'sentence2':[
            'It mostly rains on the flat lands of Spain.',
            'Is it working? This does not match.']
    }

The model should report class `1` "match" for the first example and class `0` "no-match" for the second:

In [46]:
ex_packed = bert_inputs_processor(my_examples)
my_logits = bert_classifier(ex_packed, training=False)

result_cls_ids = tf.argmax(my_logits)
result_cls_ids

<tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>

In [47]:
tf.gather(tf.constant(info.features['label'].names), result_cls_ids)

<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'equivalent', b'not_equivalent'], dtype=object)>

### Export the model

Often the goal of training a model is to _use_ it for something outside of the Python process that created it. You can do this by exporting the model using `tf.saved_model`. (Learn more in the [Using the SavedModel format](https://www.tensorflow.org/guide/saved_model) guide and the [Save and load a model using a distribution strategy](https://www.tensorflow.org/tutorials/distribute/save_and_load) tutorial.)

First, build a wrapper class to export the model. This wrapper does two things:

- First it packages `bert_inputs_processor` and `bert_classifier` together into a single `tf.Module`, so you can export all the functionalities.
- Second it defines a `tf.function` that implements the end-to-end execution of the model.

Setting the `input_signature` argument of `tf.function` lets you define a fixed signature for the `tf.function`. This can be less surprising than the default automatic retracing behavior.

In [48]:
class ExportModel(tf.Module):
  def __init__(self, input_processor, classifier):
    self.input_processor = input_processor
    self.classifier = classifier

  @tf.function(input_signature=[{
      'sentence1': tf.TensorSpec(shape=[None], dtype=tf.string),
      'sentence2': tf.TensorSpec(shape=[None], dtype=tf.string)}])
  def __call__(self, inputs):
    packed = self.input_processor(inputs)
    logits =  self.classifier(packed, training=False)
    result_cls_ids = tf.argmax(logits)
    return {
        'logits': logits,
        'class_id': result_cls_ids,
        'class': tf.gather(
            tf.constant(info.features['label'].names),
            result_cls_ids)
    }

Create an instance of this export-model and save it:

In [49]:
export_model = ExportModel(bert_inputs_processor, bert_classifier)

In [50]:
import tempfile
export_dir=tempfile.mkdtemp(suffix='_saved_model')
tf.saved_model.save(export_model, export_dir=export_dir,
                    signatures={'serving_default': export_model.__call__})

INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpad4uwdme_saved_model/assets


INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpad4uwdme_saved_model/assets


Reload the model and compare the results to the original:

In [51]:
original_logits = export_model(my_examples)['logits']

In [52]:
reloaded = tf.saved_model.load(export_dir)
reloaded_logits = reloaded(my_examples)['logits']

In [53]:
# The results are identical:
print(original_logits.numpy())
print()
print(reloaded_logits.numpy())

[[-3.7263763  3.4041796]
 [ 1.8957428 -0.5120322]]

[[-3.7263763  3.4041796]
 [ 1.8957428 -0.5120322]]


In [54]:
print(np.mean(abs(original_logits - reloaded_logits)))

0.0


Congratulations! You've used `tensorflow_models` to build a BERT-classifier, train it, and export for later use.

## Optional: BERT on TF Hub

<a id="hub_bert"></a>


You can get the BERT model off the shelf from [TF Hub](https://tfhub.dev/). There are [many versions available along with their input preprocessors](https://tfhub.dev/google/collections/bert/1).

This example uses [a small version of BERT from TF Hub](https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/2) that was pre-trained using the English Wikipedia and BooksCorpus datasets, similar to the [original implementation](https://arxiv.org/abs/1908.08962) (Turc et al., 2019).

Start by importing TF Hub:

In [55]:
import tensorflow_hub as hub

Select the input preprocessor and the model from TF Hub and wrap them as `hub.KerasLayer` layers:

In [56]:
# Always make sure you use the right preprocessor.
hub_preprocessor = hub.KerasLayer(
    "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")

# This is a really small BERT.
hub_encoder = hub.KerasLayer(f"https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/2",
                             trainable=True)

print(f"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables")

The Hub encoder has 39 trainable variables


Test run the preprocessor on a batch of data:

In [57]:
hub_inputs = hub_preprocessor(['Hello TensorFlow!'])
{key: value[0, :10].numpy() for key, value in hub_inputs.items()} 

{'input_type_ids': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32),
 'input_word_ids': array([  101,  7592, 23435, 12314,   999,   102,     0,     0,     0,
            0], dtype=int32),
 'input_mask': array([1, 1, 1, 1, 1, 1, 0, 0, 0, 0], dtype=int32)}

In [58]:
result = hub_encoder(
    inputs=hub_inputs,
    training=False,
)

print("Pooled output shape:", result['pooled_output'].shape)
print("Sequence output shape:", result['sequence_output'].shape)

Pooled output shape: (1, 128)
Sequence output shape: (1, 128, 128)


At this point it would be simple to add a classification head yourself.

The Model Garden `tfm.nlp.models.BertClassifier` class can also build a classifier onto the TF Hub encoder:

In [59]:
hub_classifier = tfm.nlp.models.BertClassifier(
    bert_encoder,
    num_classes=2,
    dropout_rate=0.1,
    initializer=tf.keras.initializers.TruncatedNormal(
        stddev=0.02))

The one downside to loading this model from TF Hub is that the structure of internal Keras layers is not restored. This makes it more difficult to inspect or modify the model.

The BERT encoder model—`hub_classifier`—is now a single layer.

For concrete examples of this approach, refer to [Solve Glue tasks using BERT](https://www.tensorflow.org/text/tutorials/bert_glue).

## Optional: Optimizer `config`s

The `tensorflow_models` package defines serializable `config` classes that describe how to build the live objects. Earlier in this tutorial, you built the optimizer manually.

The configuration below describes an (almost) identical optimizer built by the `optimizer_factory.OptimizerFactory`:

In [60]:
optimization_config = tfm.optimization.OptimizationConfig(
    optimizer=tfm.optimization.OptimizerConfig(
        type = "adam"),
    learning_rate = tfm.optimization.LrConfig(
        type='polynomial',
        polynomial=tfm.optimization.PolynomialLrConfig(
            initial_learning_rate=2e-5,
            end_learning_rate=0.0,
            decay_steps=num_train_steps)),
    warmup = tfm.optimization.WarmupConfig(
        type='linear',
        linear=tfm.optimization.LinearWarmupConfig(warmup_steps=warmup_steps)
    ))


fac = tfm.optimization.optimizer_factory.OptimizerFactory(optimization_config)
lr = fac.build_learning_rate()
optimizer = fac.build_optimizer(lr=lr)

In [61]:
x = tf.linspace(0, num_train_steps, 1001).numpy()
y = [lr(xi) for xi in x]
plt.plot(x,y)
plt.xlabel('Train step')
plt.ylabel('Learning rate')

Text(0, 0.5, 'Learning rate')

The advantage to using `config` objects is that they don't contain any complicated TensorFlow objects, and can be easily serialized to JSON, and rebuilt. Here's the JSON for the above `tfm.optimization.OptimizationConfig`:

In [62]:
optimization_config = optimization_config.as_dict()
optimization_config

{'optimizer': {'type': 'adam',
  'adam': {'clipnorm': None,
   'clipvalue': None,
   'global_clipnorm': None,
   'name': 'Adam',
   'beta_1': 0.9,
   'beta_2': 0.999,
   'epsilon': 1e-07,
   'amsgrad': False}},
 'ema': None,
 'learning_rate': {'type': 'polynomial',
  'polynomial': {'name': 'PolynomialDecay',
   'initial_learning_rate': 2e-05,
   'decay_steps': 570,
   'end_learning_rate': 0.0,
   'power': 1.0,
   'cycle': False,
   'offset': 0}},
 'warmup': {'type': 'linear',
  'linear': {'name': 'linear', 'warmup_learning_rate': 0, 'warmup_steps': 57}}}

The `tfm.optimization.optimizer_factory.OptimizerFactory` can just as easily build the optimizer from the JSON dictionary:

In [63]:
fac = tfm.optimization.optimizer_factory.OptimizerFactory(
    tfm.optimization.OptimizationConfig(optimization_config))
lr = fac.build_learning_rate()
optimizer = fac.build_optimizer(lr=lr)