A text similarity classification example for ALBERT using TPU.  
The data is from LCQMC.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%tensorflow_version 2.x

import os
import warnings
import time
import math
import json
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.keras as keras

os.chdir('./drive/My Drive/Python/Research/bert')
warnings.filterwarnings('ignore')
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

import mymodels as mm

In [None]:
MODEL = 'albert'
MODE = 'cls'
MAXLEN = 128
CATE = 2
DROP = 0.5
LRATE = 1e-4
BATCH = 64
EPOCH = 5
VOCAB = 'models/albert_small_ch/vocab.txt'
CONFIG = 'models/albert_small_ch/albert_config.json'
CKPT = 'models/albert_small_ch/albert_model.ckpt'

In [None]:
resolver_1 = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://'+os.environ['COLAB_TPU_ADDR'])
tf.config.experimental_connect_to_cluster(resolver_1)
tf.tpu.experimental.initialize_tpu_system(resolver_1)
strategy_1 = tf.distribute.TPUStrategy(resolver_1)

for i_1 in tf.config.list_logical_devices('TPU'):
  print(i_1)

LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')
LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU')
LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU')
LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU')
LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU')
LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU')
LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU')
LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU')


In [None]:
def data_processing(data, tokenizer, strategy, maxlen, batch, training):
  text1, seg1, mask1, label1 = [], [], [], []
  
  for i1 in range(len(data)):
    text2, seg2, mask2 = tokenizer.encoding(data['sentence1'][i1], data['sentence2'][i1], maxlen)
    text1.append(text2)
    seg1.append(seg2)
    mask1.append(mask2)
    label1.append(data['label'][i1])

  text1, seg1, mask1, label1 = np.array(text1), np.array(seg1), np.array(mask1), np.array(label1)
  data1 = tf.data.Dataset.from_tensor_slices((text1, seg1, mask1, label1))
  data1 = data1.shuffle(len(text1)).batch(batch) if training else data1.batch(batch)
  return strategy.experimental_distribute_datasets_from_function(lambda _: data1)


training_1 = pd.read_csv('tasks/datasets/lcqmc/train.csv')
dev_1 = pd.read_csv('tasks/datasets/lcqmc/dev.csv')
test_1 = pd.read_csv('tasks/datasets/lcqmc/test.csv')

tokenizer_1 = mm.Tokenizer()
tokenizer_1.loading(VOCAB)
batch_1 = BATCH//strategy_1.num_replicas_in_sync
training_2 = data_processing(training_1, tokenizer_1, strategy_1, MAXLEN, batch_1, True)
dev_2 = data_processing(dev_1, tokenizer_1, strategy_1, MAXLEN, batch_1, False)
test_2 = data_processing(test_1, tokenizer_1, strategy_1, MAXLEN, batch_1, False)

print(training_1.head())

            sentence1        sentence2  label
0    喜欢打篮球的男生喜欢什么样的女生  爱打篮球的男生喜欢什么样的女生      1
1        我手机丢了，我想换个手机      我想买个新手机，求推荐      1
2            大家觉得她好看吗       大家觉得跑男好看吗？      0
3           求秋色之空漫画全集        求秋色之空全集漫画      1
4  晚上睡觉带着耳机听音乐有什么害处吗？     孕妇可以戴耳机听音乐吗?      0


In [None]:
class ModelBERT(keras.Model):
  def __init__(self, model, mode, config, drop, category):
    super(ModelBERT, self).__init__()
    self.bert = mm.BERT(config, model, mode)
    self.drop = keras.layers.Dropout(drop)
    self.dense = keras.layers.Dense(category, 'softmax')

  def propagating(self, text, segment, mask, training=False):
    x1 = self.bert.propagating(text, segment, mask, training)
    return self.dense(self.drop(x1, training=training))


with strategy_1.scope():
  model_1 = ModelBERT(MODEL, MODE, CONFIG, DROP, CATE)
  model_1.bert.loading(CKPT)
  function_1 = keras.losses.SparseCategoricalCrossentropy(reduction=keras.losses.Reduction.NONE)
  optimizer_1 = mm.AdamWV2(EPOCH*(int(len(training_1)/BATCH)+1), LRATE)

  loss_1 = keras.metrics.Mean(name='training_loss')
  acc_1 = keras.metrics.SparseCategoricalAccuracy(name='training_accuracy')
  acc_2 = keras.metrics.SparseCategoricalAccuracy(name='dev_accuracy')

In [None]:
@tf.function
def step_training(iterator):
  def training(data):
    text_1, seg_1, mask_1, label_1 = data

    with tf.GradientTape() as tape_1:
      pred_1 = model_1.propagating(text_1, seg_1, mask_1, True)
      value_1 = function_1(label_1, pred_1)
      value_1 = tf.nn.compute_average_loss(value_1, global_batch_size=BATCH)

    grad_1 = tape_1.gradient(value_1, model_1.trainable_variables)
    grad_1, _ = tf.clip_by_global_norm(grad_1, 1.0)
    optimizer_1.apply_gradients(list(zip(grad_1, model_1.trainable_variables)))
    loss_1.update_state(value_1*strategy_1.num_replicas_in_sync)
    acc_1.update_state(label_1, pred_1)
    
  strategy_1.run(training, args=(next(iterator),))


@tf.function
def step_evaluating(iterator):
  def evaluating(data):
    text_1, seg_1, mask_1, label_1 = data
    pred_1 = model_1.propagating(text_1, seg_1, mask_1, False)
    acc_2.update_state(label_1, pred_1)

  strategy_1.run(evaluating, args=(next(iterator),))

In [None]:
print_1 = 'Training loss is {:.4f}, and accuracy is {:.4f}.'
print_2 = 'Dev accuracy is {:.4f}, and epoch cost is {:.4f}.'


for e_1 in range(EPOCH):
  print('Epoch {} running.'.format(e_1+1))
  time_0, training_3, dev_3 = time.time(), iter(training_2), iter(dev_2)

  for s_1 in range(math.floor(len(training_1)/BATCH)):
    step_training(training_3)

    if (s_1+1) % 1000 == 0:
      print(print_1.format(float(loss_1.result()), float(acc_1.result())))

  for s_1 in range(math.ceil(len(dev_1)/BATCH)):
    step_evaluating(dev_3)

  print(print_2.format(float(acc_2.result()), time.time()-time_0))
  print('**********')
  acc_1.reset_states()
  acc_2.reset_states()

Epoch 1 running.
Training loss is 0.3683, and accuracy is 0.8363.
Training loss is 0.3146, and accuracy is 0.8650.
Training loss is 0.2903, and accuracy is 0.8772.
Dev accuracy is 0.8299, and epoch cost is 100.3739.
**********
Epoch 2 running.
Training loss is 0.2622, and accuracy is 0.9191.
Training loss is 0.2525, and accuracy is 0.9181.
Training loss is 0.2448, and accuracy is 0.9187.
Dev accuracy is 0.8432, and epoch cost is 77.7369.
**********
Epoch 3 running.
Training loss is 0.2316, and accuracy is 0.9378.
Training loss is 0.2246, and accuracy is 0.9374.
Training loss is 0.2189, and accuracy is 0.9374.
Dev accuracy is 0.8522, and epoch cost is 76.7962.
**********
Epoch 4 running.
Training loss is 0.2075, and accuracy is 0.9582.
Training loss is 0.2011, and accuracy is 0.9575.
Training loss is 0.1955, and accuracy is 0.9575.
Dev accuracy is 0.8601, and epoch cost is 76.9570.
**********
Epoch 5 running.
Training loss is 0.1847, and accuracy is 0.9759.
Training loss is 0.1786, and 

In [None]:
test_3 = iter(test_2)

for s_1 in range(math.ceil(len(test_1)/BATCH)):
  step_evaluating(test_3)

print('Test accuracy is {:.4f}.'.format(float(acc_2.result())))

Test accuracy is 0.8539.
