# Sequence Classification with Transformers using TensorFlow's Strategies

This colab notebook will guide you through using the Transformers library to obtain state-of-the-art results on the sequence classification task. It uses strategies which can be used on TPU via the TPUStrategy. It is attached to the following tutorial TODO.

We will be using HuggingFace's own model: DistilBERT.

The models from the `Transformers` repository need TF2 installed. You may be prompted to restart the runtime before the changes take effect.

This colab notebook was inspired by several other notebooks/tutorials:

[The official TensorFlow input pipeline for BERT](https://github.com/tensorflow/models/blob/master/official/nlp/bert/input_pipeline.py)

[The official TensorFlow modeling utils for BERT](https://github.com/tensorflow/models/blob/master/official/modeling/model_training_utils.py)

[The official TFRecord documentation](https://www.tensorflow.org/tutorials/load_data/tfrecord)

[The official custom training with Strategy tutorial by TensorFlow](https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/distribute/custom_training.ipynb)

In [0]:
%tensorflow_version 2.x
!pip install transformers



### Imports

We'll only be importing the components that we'll use during this tutorial: the TensorFlow model alongside the model specific tokenizer. The last two imports will manage the pre-processing of our data.

In [0]:
import tensorflow as tf
print(tf.__version__)

import os
from transformers import (TFDistilBertForSequenceClassification, 
                          DistilBertTokenizer, 
                          glue_convert_examples_to_features, 
                          glue_processors)

2.0.0


# Pre-processing

## Dataset

We'll be using the MRPC dataset that was used in the previous example, as a means of comparison. The way the dataset is handled by the `glue_convert_examples_to_features` was detailed in the [previous notebook](https://colab.research.google.com/drive/1l39vWjZ5jRUimSQDoUcuWGIoNjLjA2zu#scrollTo=ipMamgbw6bjL) so you should refer to it if you do not understand how this works.

The difference in this example is that we'll be building an input pipeline beforehand. This input pipeline will be used to feed the data to tf.Example.

### Importing the data

We'll use the handy `tensorflow_datasets` package to import our data. As we are using a TPU we do not have access to our local filesystem, we therefore use a Google Cloud Platform bucket to save our data.

**You will not be able to use our bucket for this notebook. Please create your own and replace the string corresponding to the bucket.**

The data is handled exactly the same way as in the previous tutorial.

In [0]:
bucket = "gs://huggingface-bucket/public"

In [0]:
import tensorflow_datasets

IS_COLAB_BACKEND = 'COLAB_GPU' in os.environ  # this is always set on Colab, the value is 0 or 1 depending on GPU presence
if IS_COLAB_BACKEND:
    from google.colab import auth
    # Authenticates the Colab machine and also the TPU using your
    # credentials so that they can access your private GCS buckets.
    auth.authenticate_user()

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
data, info = tensorflow_datasets.load('glue/mrpc', with_info=True, data_dir=bucket)
train_examples = info.splits['train'].num_examples
validation_examples = info.splits['validation'].num_examples

MAX_SEQUENCE_LENGTH = 128

# Prepare dataset for GLUE as a tf.data.Dataset instance
train_dataset_w_features = glue_convert_examples_to_features(data['train'], tokenizer, MAX_SEQUENCE_LENGTH, 'mrpc')
validation_dataset_w_features = glue_convert_examples_to_features(data['validation'], tokenizer, MAX_SEQUENCE_LENGTH, 'mrpc')

100%|██████████| 231508/231508 [00:00<00:00, 2653183.99B/s]
INFO:absl:Overwrite dataset info from restored data version.
INFO:absl:Reusing dataset glue (gs://huggingface-bucket/public/glue/mrpc/0.0.2)
INFO:absl:Constructing tf.data.Dataset for split None, from gs://huggingface-bucket/public/glue/mrpc/0.0.2


### Serialization

Here we are using [TFRecord alongside tf.Example](https://www.tensorflow.org/tutorials/load_data/tfrecord) as a way to read data efficiently. Feeding data to a TPU can very easily be a bottleneck, we therefore store our data in a file that can be used during training.

**Unless you change the bucket to your own, you will not be able to run this cell as we have not given public access to write on our public folder. If you change this cell to your own bucket in order to run it, you will have to change the URL from which you download the TFRecord to your bucket URL.**

In [0]:
skip = True

if not skip:
    # Prepare tf.Examples and tf.Features and write them as TFRecords
    def save_tfrecord_to_bucket(features_dataset, bucket_url, file_name):
        with tf.compat.v1.python_io.TFRecordWriter(f"{bucket_url}/{file_name}.tfrecord") as tfwriter:
            for train_feature in features_dataset:
                example, label = train_feature
                feature_key_value_pair = {
                    'input_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=example['input_ids'])),
                    'attention_mask': tf.train.Feature(int64_list=tf.train.Int64List(value=example['attention_mask'])),
                    'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
                }
                features = tf.train.Features(feature=feature_key_value_pair)
                example = tf.train.Example(features=features)

                tfwriter.write(example.SerializeToString())
        print(f"Saved {file_name}.")

    save_tfrecord_to_bucket(train_dataset_w_features, bucket, "glue_mnli_train")
    save_tfrecord_to_bucket(validation_dataset_w_features, bucket, "glue_mnli_valid")

# Building the training system

## Strategy

We make use of TensorFlow's strategies, which handle the data distribution as well as the distributed training that happens on the devices available. In this example we'll be using a `MirroredStrategy` which can be used to train on a multiple GPUs in a distributed manner. 

In [0]:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)

INFO:absl:Entering into master device scope: /job:worker/replica:0/task:0


INFO:tensorflow:Initializing the TPU system: 10.45.234.42:8470


INFO:tensorflow:Initializing the TPU system: 10.45.234.42:8470


INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Finished initializing TPU system.


<tensorflow.python.tpu.topology.Topology at 0x7f546f000f28>

In [0]:
strategy = tf.distribute.experimental.TPUStrategy(tpu)
print("Number of accelerators: ", strategy.num_replicas_in_sync)

INFO:tensorflow:Found TPU system:


INFO:tensorflow:Found TPU system:


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


Number of accelerators:  8


## Loading the Dataset with the strategy

Here we define a batch size for each replica. We set it to be a multiple of 8 to best leverage the systolic array as defined in the [Google TPU performance guide](https://cloud.google.com/tpu/docs/performance-guide#rule_of_thumb_pick_efficient_values_for_batch_and_feature_dimensions).

In [0]:
BATCH_SIZE_PER_REPLICA = 8
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
EPOCHS = 3

### Retrieving the TFRecord dataset

The TFRecord dataset is now entirely processed and ready to be used as input by our training loop. We load it, shuffle it and batch it.

In [0]:
def get_tfrecord_dataset(bucket_url, file_name):
    features = {
        'input_ids': tf.io.FixedLenFeature([MAX_SEQUENCE_LENGTH], tf.int64),
        'attention_mask': tf.io.FixedLenFeature([MAX_SEQUENCE_LENGTH], tf.int64),
        'label': tf.io.FixedLenFeature([], tf.int64),
    }

    dataset = tf.data.TFRecordDataset(f"{bucket_url}/{file_name}.tfrecord")

    # Taken from the TensorFlow models repository: https://github.com/tensorflow/models/blob/befbe0f9fe02d6bc1efb1c462689d069dae23af1/official/nlp/bert/input_pipeline.py#L24
    def decode_record(record, features):
        """Decodes a record to a TensorFlow example."""
        example = tf.io.parse_single_example(record, features)

        # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
        # So cast all int64 to int32.
        for name in list(example.keys()):
            t = example[name]
            if t.dtype == tf.int64:
                t = tf.cast(t, tf.int32)
            example[name] = t
        return example


    def select_data_from_record(record):
        x = {
            'input_ids': record['input_ids'],
            'attention_mask': record['attention_mask'],
        }
        y = record['label']
        return (x, y)


    dataset = dataset.map(lambda record: decode_record(record, features))
    dataset = dataset.map(select_data_from_record)
    dataset = dataset.shuffle(100)
    return dataset.batch(GLOBAL_BATCH_SIZE)

train_dataset = get_tfrecord_dataset(bucket, "glue_mnli_train")
train_dataset.prefetch(1024)

validation_dataset = get_tfrecord_dataset(bucket, "glue_mnli_valid")

There is an additional step here to distribute the dataset among the different TPU cores. We make use of a strategy method to do so.

Every item held in the dataset (which is a batched dataset) will now be split over the TPU workers. As the TPU we're using has 8 workers and our batch is of size 64, every example will be evenly split in batches of (64 / 8 =) 8 and distributed across workers.

In [0]:
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
validation_dist_dataset = strategy.experimental_distribute_dataset(validation_dataset)

## Model creation

We create a function that will instantiate a new model when called.

In [0]:
def model_fn():
    return TFDistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")

## Hyperparameters initialization

While in the strategy's scope, we define a sparse categorical crossentropy loss. We define a method `compute_loss` which will be called to compute the loss between the model's prediction and the expected result (or label).

In order to measure the accuracy during training and evaluation, we define two metrics which are both sparse categorical accuracy.

Finally, we initialize a model and create an optimizer object using the Adam optimizer.


In [0]:
with strategy.scope():
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE, from_logits=True)

    def compute_loss(labels, predictions):
        per_example_loss = loss_object(labels, predictions)
        return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

    test_loss_metric = tf.keras.metrics.Mean(name='test_loss')
    test_accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

    train_loss_metric = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
    train_accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy('training_accuracy')
    
    model = model_fn()
    optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)

100%|██████████| 492/492 [00:00<00:00, 104703.31B/s]
100%|██████████| 363423424/363423424 [00:10<00:00, 33184074.84B/s]


## Steps

We create two functions that will be called during the training and test steps. These two methods are very similar but differ regarding the gradient computing: there is no need to compute the gradient during evaluation, and the optimizer does not need to adjust the weights.

We make sure to take the first item returned by our model. The `TFDistilBertForSequenceClassification`, like all our models, return the output values as well as the values computed by each layer (called hidden states). Those values are helpful in am myriad of settings, but not in this case.

In [0]:
with strategy.scope():
    def train_step(inputs):
        features, labels = inputs

        with tf.GradientTape() as tape:
            predictions = model(features, training=True)[0]  # Gather only the outputs of the text-classification head
            loss = compute_loss(labels, predictions)

        gradients = tape.gradient(loss, model.trainable_variables)

        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        train_loss_metric.update_state(loss)
        train_accuracy_metric.update_state(labels, predictions)

    def test_step(inputs):
        features, labels = inputs

        predictions = model(features, training=False)[0]  # Gather only the outputs of the text-classification head
        t_loss = compute_loss(labels, predictions)

        test_loss_metric.update_state(t_loss)
        test_accuracy_metric.update_state(labels, predictions)

## Training & Evaluation

Finally, using all the previously defined attributes, we create two traced tf.function which will execute the training and test steps in a distributed manner. There is no need for them to return anything as the metrics will directly be updated in the steps described beforehand.

We loop over the number of epochs, training the model and evaluating it at the end of each epoch.

In [0]:
from tqdm import tqdm

with strategy.scope():
    @tf.function
    def distributed_train_step(dataset):
        strategy.experimental_run_v2(train_step, args=(dataset,))
 

    @tf.function
    def distributed_test_step(dataset):
        strategy.experimental_run_v2(test_step, args=(dataset,))


    global_step = 0
    for epoch in range(EPOCHS):
        total_loss = 0.0
        training_steps = 10
        epoch_step = 0
        print_every = 10

        ### Training loop ###
        for tensor in tqdm(train_dist_dataset, desc="Training"):
            distributed_train_step(tensor)  

            train_loss = train_loss_metric.result().numpy().astype(float)
            train_accuracy = train_accuracy_metric.result().numpy()

            global_step += 1
            epoch_step += 1

            if epoch_step % print_every == 0:
                print(f"Training step {epoch_step} Accuracy: {train_accuracy}, Training loss: {train_loss}")


        ### Test loop ###
        for tensor in tqdm(validation_dist_dataset, desc="Evaluating"):
            distributed_test_step(tensor)
            
        
        ### Output results ###
        test_accuracy = test_accuracy_metric.result().numpy()
        test_loss = test_loss_metric.result().numpy()
        print(f'Epoch: [{epoch}] Validation accuracy = {test_accuracy}')

        ### Reset metrics ###
        test_loss_metric.reset_states()
        train_accuracy_metric.reset_states()
        train_loss_metric.reset_states()
        test_accuracy_metric.reset_states()
        epoch_step = 0

        

Training: 11it [00:46,  1.43s/it]

Training step 10 Accuracy: 0.6499999761581421, Training loss: 0.08133892714977264


Training: 21it [00:48,  5.02it/s]

Training step 20 Accuracy: 0.66796875, Training loss: 0.07936723530292511


Training: 31it [00:49,  6.23it/s]

Training step 30 Accuracy: 0.667187511920929, Training loss: 0.07885921001434326


Training: 41it [00:51,  6.19it/s]

Training step 40 Accuracy: 0.6675781011581421, Training loss: 0.07801492512226105


Training: 51it [00:53,  6.10it/s]

Training step 50 Accuracy: 0.6775000095367432, Training loss: 0.07582151144742966


Training: 58it [01:10,  5.07s/it]
Evaluating: 7it [00:09,  2.55s/it]
Training: 0it [00:00, ?it/s]

Epoch: [0] Validation accuracy = 0.718137264251709


Training: 11it [00:02,  5.64it/s]

Training step 10 Accuracy: 0.753125011920929, Training loss: 0.05832377076148987


Training: 21it [00:03,  6.36it/s]

Training step 20 Accuracy: 0.7789062261581421, Training loss: 0.056860826909542084


Training: 31it [00:05,  6.21it/s]

Training step 30 Accuracy: 0.78125, Training loss: 0.056746941059827805


Training: 41it [00:07,  6.04it/s]

Training step 40 Accuracy: 0.78515625, Training loss: 0.05594119429588318


Training: 51it [00:08,  6.18it/s]

Training step 50 Accuracy: 0.7990624904632568, Training loss: 0.053302664309740067


Training: 58it [00:09,  6.29it/s]
Evaluating: 7it [00:01,  3.69it/s]
Training: 0it [00:00, ?it/s]

Epoch: [1] Validation accuracy = 0.8284313678741455


Training: 11it [00:02,  5.57it/s]

Training step 10 Accuracy: 0.8828125, Training loss: 0.03771297633647919


Training: 21it [00:03,  6.04it/s]

Training step 20 Accuracy: 0.883593738079071, Training loss: 0.035818397998809814


Training: 31it [00:05,  6.11it/s]

Training step 30 Accuracy: 0.8880208134651184, Training loss: 0.034640006721019745


Training: 41it [00:07,  6.14it/s]

Training step 40 Accuracy: 0.8863281011581421, Training loss: 0.0344291552901268


Training: 51it [00:08,  6.36it/s]

Training step 50 Accuracy: 0.8928124904632568, Training loss: 0.032859668135643005


Training: 58it [00:09,  5.94it/s]
Evaluating: 7it [00:01,  4.99it/s]

Epoch: [2] Validation accuracy = 0.8504902124404907





# Congrats !

You've successfully fine-tuned DistilBERT on MRPC, on a TPU. Feel free to tune the hyper-parameters or change models: simply switch the model, tokenizer and checkpoint.