In [None]:
from tensorflow import keras
from transformers import AutoTokenizer, TFAutoModel

class CustomBERTModel(keras.Model):
    def __init__(self):
          super(CustomBERTModel, self).__init__()
          self.bert = TFAutoModel.from_pretrained("bert-base-uncased")
          ### New layers:
          self.linear1 = keras.layers.Dense(256)
          self.linear2 = keras.layers.Dense(2) ## 2 is the number of classes in this example

    def call(self, inputs, training=False):
          # call expects only one positional argument, so you have to pass in a tuple and unpack. The next parameter is a special reserved training parameter.
          ids, mask = inputs
          sequence_output = self.bert(ids, mask, training=training).last_hidden_state

          # sequence_output has the following shape: (batch_size, sequence_length, 768)
          linear1_output = self.linear1(sequence_output[:,0,:]) ## extract the 1st token's embeddings

          linear2_output = self.linear2(linear1_output)

          return linear2_output


In [None]:
model = CustomBERTModel()

In [None]:
train_acc_metric.reset_states()
val_acc_metric.reset_states()

In [None]:
def train_step(model, tf_train_dataset, tf_test_dataset, epochs=2):
  # train the model by using GradientTape
  optimizer = keras.optimizers.Adam(learning_rate=5e-5)
  loss_fn = SparseCategoricalCrossentropy(from_logits=True)
  train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
  val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
  for epoch in range(epochs):
      print(f"\nStart of Training Epoch {epoch}")
      for step, batch in enumerate(tf_train_dataset):
          # print(step)
          # print(batch)
          ids = batch[0]['input_ids']
          mask = batch[0]['attention_mask']
          y = batch[1]
          with tf.GradientTape() as tape:
              logits = model((ids, mask), training=True)
              loss_value = loss_fn(y, logits)
              # print(f"Loss at step {step}: {loss_value}")
          grads = tape.gradient(loss_value, model.trainable_weights)
          # Filter trainable weights that have gradients
          trainable_vars = [var for var, grad in zip(model.trainable_weights, grads) if grad is not None]

          # optimizer.apply_gradients(zip(grads, model.trainable_weights))
          optimizer.apply_gradients(
                        (grad, var)
                        for (grad, var) in zip(grads, model.trainable_variables)
                        if grad is not None
                      )
          # Update training metric.
          train_acc_metric(y, logits)

          # Log every 200 batches.
          if step % 10 == 0:
              print(
                  "Training loss at step %d: %.4f"
                  % (step, float(loss_value))
              )
              #print accuracy on the training set
              train_acc = train_acc_metric.result()
              print("Training acc over epoch: %.4f" % (float(train_acc),))
          # Display metrics at the end of each epoch.

      train_acc_metric.reset_states()
      # perform validation on test data
      for step, batch in enumerate(tf_test_dataset):
          ids = batch[0]['input_ids']
          mask = batch[0]['attention_mask']
          y = batch[1]
          logits = model([ids, mask], training=False)
          # Update val metrics
          val_acc_metric(y, logits)
      val_acc = val_acc_metric.result()
      # print accuracy on the test set
      print("Test acc: %.4f" % (float(val_acc),))
      # Reset val metrics at the end of each epoch
      val_acc_metric.reset_states()

  model.save_weights('my_model', save_format='tf')

In [None]:
train_step(new_model, tf_train_dataset, tf_test_dataset, epochs=2)

In [None]:
# The recommended way to save a subclassed model is to use save_weights to create a TensorFlow SavedModel checkpoint
model.save_weights('tape_model', save_format='tf')

In [None]:
# restore the loaded model
new_model = CustomBERTModel()

In [None]:
new_model.compile(optimizer=optimizer, loss=loss_fn, metrics=[train_acc_metric])
# call the model on part of the training set to build the model