A text classification example for BERT using TPU.  
The data is CMNLI from https://github.com/CLUEbenchmark/CLUE.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [2]:
%tensorflow_version 2.x

import os
import warnings
import time
import math
import json
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras

os.chdir('./drive/My Drive/Python/Research/bert')
warnings.filterwarnings('ignore')
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

import mymodels as mm

In [3]:
MODEL = 'bert'
MODE = 'cls'
MAXLEN = 128
CATE = 3
DROP = 0.5
LRATE = 3e-5
BATCH = 64
EPOCH = 2
VOCAB = 'models/bert_base_ch/vocab.txt'
CONFIG = 'models/bert_base_ch/bert_config.json'
CKPT = 'models/bert_base_ch/bert_model.ckpt'

In [4]:
resolver_1 = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://'+os.environ['COLAB_TPU_ADDR'])
tf.config.experimental_connect_to_cluster(resolver_1)
tf.tpu.experimental.initialize_tpu_system(resolver_1)
strategy_1 = tf.distribute.TPUStrategy(resolver_1)

for i_1 in tf.config.list_logical_devices('TPU'):
  print(i_1)

LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')
LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU')
LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU')
LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU')
LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU')
LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU')
LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU')
LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU')


In [5]:
def file_loading(file):
  reader1 = open(file, 'r', encoding='utf-8').readlines()
  return [json.loads(i1.strip()) for i1 in reader1]


def data_processing(data, tokenizer, strategy, maxlen, batch, training):
  text1, seg1, mask1, label1 = [], [], [], []
  cate1 = {'neutral':0, 'entailment':1, 'contradiction':2}

  for i1 in data:
    text2, seg2, mask2 = tokenizer.encoding(i1['sentence1'], i1['sentence2'], maxlen)
    text1.append(text2)
    seg1.append(seg2)
    mask1.append(mask2)
    label1.append(cate1[i1['label']])
  
  text1, seg1, mask1, label1 = np.array(text1), np.array(seg1), np.array(mask1), np.array(label1)
  data1 = tf.data.Dataset.from_tensor_slices((text1, seg1, mask1, label1))
  data1 = data1.shuffle(len(text1)).batch(batch) if training else data1.batch(batch)
  return strategy.experimental_distribute_datasets_from_function(lambda _: data1)


tokenizer_1 = mm.Tokenizer()
tokenizer_1.loading(VOCAB)
training_1 = file_loading('tasks/datasets/cmnli/train.json')
dev_1 = file_loading('tasks/datasets/cmnli/dev.json')
dev_1 = [i_1 for i_1 in dev_1 if i_1['label'] != '-']
batch_1 = BATCH//strategy_1.num_replicas_in_sync
training_2 = data_processing(training_1, tokenizer_1, strategy_1, MAXLEN, batch_1, True)
dev_2 = data_processing(dev_1, tokenizer_1, strategy_1, MAXLEN, batch_1, False)
print(training_1[0])

{'sentence1': '从概念上讲，奶油略读有两个基本维度-产品和地理。', 'sentence2': '产品和地理位置是使奶油撇油起作用的原因。', 'label': 'neutral'}


In [6]:
class ModelBERT(keras.Model):
  def __init__(self, model, mode, config, drop, category):
    super(ModelBERT, self).__init__()
    self.bert = mm.BERT(config, model, mode)
    self.drop = keras.layers.Dropout(drop)
    self.dense = keras.layers.Dense(category, 'softmax')

  def propagating(self, text, segment, mask, training=False):
    x1 = self.bert.propagating(text, segment, mask, training)
    return self.dense(self.drop(x1, training=training))


with strategy_1.scope():
  model_1 = ModelBERT(MODEL, MODE, CONFIG, DROP, CATE)
  model_1.bert.loading(CKPT)
  function_1 = keras.losses.SparseCategoricalCrossentropy(reduction=keras.losses.Reduction.NONE)
  optimizer_1 = mm.AdamWV2(EPOCH*(int(len(training_1)/BATCH)+1), LRATE)

  loss_1 = keras.metrics.Mean(name='training_loss')
  acc_1 = keras.metrics.SparseCategoricalAccuracy(name='training_accuracy')
  acc_2 = keras.metrics.SparseCategoricalAccuracy(name='dev_accuracy')

In [7]:
@tf.function
def step_training(iterator):
  def training(data):
    text_1, seg_1, mask_1, label_1 = data

    with tf.GradientTape() as tape_1:
      pred_1 = model_1.propagating(text_1, seg_1, mask_1, True)
      value_1 = function_1(label_1, pred_1)
      value_1 = tf.nn.compute_average_loss(value_1, global_batch_size=BATCH)

    grad_1 = tape_1.gradient(value_1, model_1.trainable_variables)
    grad_1, _ = tf.clip_by_global_norm(grad_1, 1.0)
    optimizer_1.apply_gradients(list(zip(grad_1, model_1.trainable_variables)))
    loss_1.update_state(value_1*strategy_1.num_replicas_in_sync)
    acc_1.update_state(label_1, pred_1)
    
  strategy_1.run(training, args=(next(iterator),))


@tf.function
def step_evaluating(iterator):
  def evaluating(data):
    text_1, seg_1, mask_1, label_1 = data
    pred_1 = model_1.propagating(text_1, seg_1, mask_1, False)
    acc_2.update_state(label_1, pred_1)

  strategy_1.run(evaluating, args=(next(iterator),))

In [8]:
print_1 = 'Training loss is {:.4f}, and accuracy is {:.4f}.'
print_2 = 'Dev accuracy is {:.4f}, and epoch cost is {:.4f}.'


for e_1 in range(EPOCH):
  print('Epoch {} running.'.format(e_1+1))
  time_0, training_3, dev_3 = time.time(), iter(training_2), iter(dev_2)

  for s_1 in range(math.floor(len(training_1)/BATCH)):
    step_training(training_3)

    if (s_1+1) % 1000 == 0:
      print(print_1.format(float(loss_1.result()), float(acc_1.result())))

  for s_1 in range(math.ceil(len(dev_1)/BATCH)):
    step_evaluating(dev_3)

  print(print_2.format(float(acc_2.result()), time.time()-time_0))
  print('**********')
  acc_1.reset_states()
  acc_2.reset_states()

Epoch 1 running.
Training loss is 0.9800, and accuracy is 0.5716.
Training loss is 0.8362, and accuracy is 0.6451.
Training loss is 0.7688, and accuracy is 0.6789.
Training loss is 0.7287, and accuracy is 0.6984.
Training loss is 0.7004, and accuracy is 0.7119.
Training loss is 0.6784, and accuracy is 0.7223.
Dev accuracy is 0.7873, and epoch cost is 610.4352.
**********
Epoch 2 running.
Training loss is 0.6481, and accuracy is 0.8187.
Training loss is 0.6276, and accuracy is 0.8175.
Training loss is 0.6111, and accuracy is 0.8178.
Training loss is 0.5971, and accuracy is 0.8190.
Training loss is 0.5844, and accuracy is 0.8208.
Training loss is 0.5737, and accuracy is 0.8219.
Dev accuracy is 0.8015, and epoch cost is 520.2713.
**********
