A simple MRC example according to https://arxiv.org/abs/2001.09415.  
The data is from https://dataset.org/dream/.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install transformers
!pip install sentencepiece

In [1]:
%tensorflow_version 2.x

import os
import warnings
import time
import math
import json
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.keras as keras
from transformers import TFAlbertModel

os.chdir('./drive/My Drive/Python/Research')
warnings.filterwarnings('ignore')
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

In [None]:
MODEL = 'albert-base-v2'  # model name on huggingface
VOCAB = 'models/albert_base_en/30k-clean.vocab'  # download vocab file of albert
SPM = 'models/albert_base_en/30k-clean.model'  # download spm file of albert
FILEPATH = 'tasks/datasets/dream'  # download dream datasets
LOWER = True  # whether to process text in lower case
LEFTLEN = 460  # max length of passages
RIGHTLEN = 52  # max total length of question-answer pairs
NCLASS = 3  # num of options
INDIM = 768  # hidden dim of bert model
HEAD = 12  # num of co-attention heads
SIZE = 64  # dim of a co-attention head
KLAYER = 2  # num of co-attention layers
BATCH = 16  # batch size
EPOCH = 3  # num of epochs
LRATE = 1e-5  # learning rate
DROP = 0.1  # drop rate of co-attention layers

resolver_1 = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://'+os.environ['COLAB_TPU_ADDR'])
tf.config.experimental_connect_to_cluster(resolver_1)
tf.tpu.experimental.initialize_tpu_system(resolver_1)
strategy_1 = tf.distribute.TPUStrategy(resolver_1)
tf.config.list_logical_devices('TPU')

Data processing refers to https://github.com/nlpdata/mrc_bert_baseline.  
Also, need to download the source code of ALBERT tokenizer. 

In [3]:
from utils.albert import tokenization


class InputExample(object):
  def __init__(self, guid, text_a, text_b=None, label=None, text_c=None):
    self.guid = guid
    self.text_a = text_a
    self.text_b = text_b
    self.text_c = text_c
    self.label = label


class InputFeatures(object):
  def __init__(self, input_ids, input_mask, segment_ids, label_id):
    self.input_ids = input_ids
    self.input_mask = input_mask
    self.segment_ids = segment_ids
    self.label_id = label_id


class DataProcessor(object):
  def get_train_examples(self, data_dir):
    raise NotImplementedError()

  def get_dev_examples(self, data_dir):
    raise NotImplementedError()

  def get_labels(self):
    raise NotImplementedError()

  @classmethod
  def _read_tsv(cls, input_file, quotechar=None):
    with open(input_file, "r") as f:
      reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
      lines = []

      for line in reader:
        lines.append(line)
      
      return lines


class dreamProcessor(DataProcessor):
  def __init__(self, fpath):
    self.D = [[], [], []]
    self.fpath = fpath

    for sid in range(3):
      with open([self.fpath+'/train.json', self.fpath+'/dev.json', self.fpath+'/test.json'][sid], "r") as f:
        data = json.load(f)

        for i in range(len(data)):
          for j in range(len(data[i][1])):
            d = ['\n'.join(data[i][0]).lower(), data[i][1][j]["question"].lower()]

            for k in range(len(data[i][1][j]["choice"])):
              d += [data[i][1][j]["choice"][k].lower()]
            
            d += [data[i][1][j]["answer"].lower()] 
            self.D[sid] += [d]

  def get_train_examples(self):
    return self._create_examples(self.D[0], "train")

  def get_test_examples(self):
    return self._create_examples(self.D[2], "test")

  def get_dev_examples(self):
    return self._create_examples(self.D[1], "dev")

  def get_labels(self):
    return ["0", "1", "2"]

  def _create_examples(self, data, set_type):
    examples = []

    for (i, d) in enumerate(data):
      for k in range(3):
        if data[i][2+k] == data[i][5]:
          answer = str(k)
              
      label = tokenization.convert_to_unicode(answer)

      for k in range(3):
        guid = "%s-%s-%s" % (set_type, i, k)
        text_a = tokenization.convert_to_unicode(data[i][0])
        text_b = tokenization.convert_to_unicode(data[i][k+2])
        text_c = tokenization.convert_to_unicode(data[i][1])
        examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, text_c=text_c))
        
    return examples


def convert_examples_to_features(examples, label_list, left_len, right_len, tokenizer, n_class):
  label_map, features, stat = {}, [[]], []

  for (i, label) in enumerate(label_list):
    label_map[label] = i

  for (ex_index, example) in enumerate(examples):
    tokens_a = tokenizer.tokenize(example.text_a)
    tokens_b = None
    tokens_c = None
    
    if example.text_b:
      tokens_b = tokenizer.tokenize(example.text_b)

    if example.text_c:
      tokens_c = tokenizer.tokenize(example.text_c)

    stat.append([len(tokens_a), len(tokens_b), len(tokens_c)])
    tokens_a = tokens_a[0:(left_len-2)]
    _truncate_seq_pair(tokens_b, tokens_c, right_len-2)
    tokens_b = tokens_c+["[SEP]"]+tokens_b
    tokens, segment_ids = [], []
    tokens.append("[CLS]")
    segment_ids.append(0)

    for token in tokens_a:
      tokens.append(token)
      segment_ids.append(0)

    tokens.append("[SEP]")
    segment_ids.append(0)
    input_mask = [1]*len(tokens)

    while len(tokens) < left_len:
      tokens.append("[PAD]")
      input_mask.append(0)
      segment_ids.append(0)

    if tokens_b:
      for token in tokens_b:
        tokens.append(token)
        segment_ids.append(1)
        input_mask.append(1)
      
      tokens.append("[SEP]")
      segment_ids.append(1)
      input_mask.append(1)

    input_ids = tokenizer.convert_tokens_to_ids(tokens)

    while len(input_ids) < left_len+right_len:
      input_ids.append(0)
      input_mask.append(0)
      segment_ids.append(0)

    label_id = label_map[example.label]
    features[-1].append(InputFeatures(
      input_ids=input_ids,
      input_mask=input_mask,
      segment_ids=segment_ids,
      label_id=label_id))

    if len(features[-1]) == n_class:
      features.append([])

  if len(features[-1]) == 0:
    features = features[:-1]

  return features, stat


def _truncate_seq_pair(tokens_a, tokens_b, max_length):
  while True:
    total_length = len(tokens_a)+len(tokens_b)

    if total_length <= max_length:
      break
    
    if len(tokens_a) > len(tokens_b):
      tokens_a.pop()
    else:
      tokens_b.pop()


def _truncate_seq_tuple(tokens_a, tokens_b, tokens_c, max_length):
  while True:
    total_length = len(tokens_a)+len(tokens_b)+len(tokens_c)

    if total_length <= max_length:
      break
    
    if len(tokens_a) >= len(tokens_b) and len(tokens_a) >= len(tokens_c):
      tokens_a.pop()
    
    elif len(tokens_b) >= len(tokens_a) and len(tokens_b) >= len(tokens_c):
      tokens_b.pop()
    else:
      tokens_c.pop()            


def accuracy(out, labels):
  outputs = np.argmax(out, axis=1)
  return np.sum(outputs==labels)


def data_processing(data, label, leftlen, rightlen, tokenizer, nclass):
  train_features, stat = convert_examples_to_features(data, label, leftlen, rightlen, tokenizer, nclass)
  input_ids, input_mask, segment_ids, label_id = [], [], [], []

  for f in train_features:
    input_ids.append([])
    input_mask.append([])
    segment_ids.append([])

    for i in range(nclass):
      input_ids[-1].append(f[i].input_ids)
      input_mask[-1].append(f[i].input_mask)
      segment_ids[-1].append(f[i].segment_ids)
    
    label_id.append([f[0].label_id])

  return input_ids, input_mask, segment_ids, label_id, stat


def data_preparing(text, seg, mask, label, batch, training, strategy):
  text1, seg1, mask1, label1 = np.array(text), np.array(seg), np.array(mask), np.array(label)
  data1 = tf.data.Dataset.from_tensor_slices((text1, seg1, mask1, label1))
  data1 = data1.shuffle(len(text1)).batch(batch) if training else data1.batch(batch)
  return strategy.experimental_distribute_datasets_from_function(lambda _: data1)


batch_1 = BATCH//strategy_1.num_replicas_in_sync
processor_1 = dreamProcessor(FILEPATH)
label_1 = processor_1.get_labels()
tokenizer_1 = tokenization.FullTokenizer.from_scratch(vocab_file=VOCAB, do_lower_case=LOWER, spm_model_file=SPM)
training_1 = data_processing(processor_1.get_train_examples(), label_1, LEFTLEN, RIGHTLEN, tokenizer_1, NCLASS)
training_2 = data_preparing(training_1[0], training_1[2], training_1[1], training_1[3], batch_1, True, strategy_1)
dev_1 = data_processing(processor_1.get_dev_examples(), label_1, LEFTLEN, RIGHTLEN, tokenizer_1, NCLASS)
dev_2 = data_preparing(dev_1[0], dev_1[2], dev_1[1], dev_1[3], batch_1, False, strategy_1)
test_1 = data_processing(processor_1.get_test_examples(), label_1, LEFTLEN, RIGHTLEN, tokenizer_1, NCLASS)
test_2 = data_preparing(test_1[0], test_1[2], test_1[1], test_1[3], batch_1, False, strategy_1)

Modeling with ALBERT and Co-Attention layer.

In [None]:
def w_initializing(param=0.02):
  return keras.initializers.TruncatedNormal(stddev=param)


class AdamW(keras.optimizers.Adam):
  def __init__(self, step, lrate=1e-3, drate=1e-2, name='AdamW', **kwargs):
    super(AdamW, self).__init__(learning_rate=lrate, name=name, **kwargs)
    self.step, self.drate, self.spec = step, drate, ['bias', 'normalization', 'lnorm', 'layernorm']

  @staticmethod
  def _rate_sch(rate, step, total):
    warm1 = total*0.1
    return tf.where(step < warm1, rate*step/warm1, rate*(total-step)/(total-warm1))

  def _prepare_local(self, var_device, var_dtype, apply_state):
    super(AdamW, self)._prepare_local(var_device, var_dtype, apply_state)
    rate1 = self._rate_sch(1., tf.cast(self.iterations+1, var_dtype), self.step+1)
    apply_state[(var_device, var_dtype)]['lr_t'] *= rate1
    apply_state[(var_device, var_dtype)]['lr'] *= rate1

  def _resource_apply_base(self, var, apply_state=None):
    devi1, type1, spec1 = var.device, var.dtype.base_dtype, any(c1 in var.name.lower() for c1 in self.spec)
    coef1 = ((apply_state or {}).get((devi1, type1)) or self._fallback_apply_state(devi1, type1))
    return tf.no_op if spec1 else var.assign_sub(coef1['lr_t']*var*self.drate, use_locking=self._use_locking)

  def _resource_apply_dense(self, grad, var, apply_state=None):
    deca1 = self._resource_apply_base(var, apply_state)

    with tf.control_dependencies([deca1]):
      return super(AdamW, self)._resource_apply_dense(grad, var, apply_state)

  def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
    deca1 = self._resource_apply_base(var, apply_state)

    with tf.control_dependencies([deca1]):
      return super(AdamW, self)._resource_apply_sparse(grad, var, indices, apply_state)

  def get_config(self):
    conf1 = super(AdamW, self).get_config()
    conf1.update({'decaying_rate': self.drate, 'step': self.step})
    return conf1


class CoAttention(keras.layers.Layer):
  def __init__(self, bname, head, size, attdrop=0., drop=0., eps=1e-6, ninf=-1e4, **kwargs):
    super(CoAttention, self).__init__(**kwargs)
    self.head, self.size, self.dim, self.ninf = head, size, head*size, ninf
    self.wq = keras.layers.Dense(self.dim, None, True, w_initializing(), name=bname+'/attention/query')
    self.wk = keras.layers.Dense(self.dim, None, True, w_initializing(), name=bname+'/attention/key')
    self.wv = keras.layers.Dense(self.dim, None, True, w_initializing(), name=bname+'/attention/value')
    self.dense = keras.layers.Dense(self.dim, None, True, w_initializing(), name=bname+'/attention/dense')
    self.norm = keras.layers.LayerNormalization(-1, eps, name=bname+'/attention/LayerNorm')
    self.attdrop = keras.layers.Dropout(attdrop)
    self.drop = keras.layers.Dropout(drop)

  def transposing(self, x):
    return tf.transpose(tf.reshape(x, [-1, tf.shape(x)[1], self.head, self.size]), [0, 2, 1, 3])

  def masking(self, mask):
    m1 = mask[:, tf.newaxis, tf.newaxis, :]
    return tf.cast(m1, tf.float32)*self.ninf

  def calculating(self, q, k, v, mask, training):
    a1 = tf.matmul(self.transposing(q), self.transposing(k), transpose_b=True)
    a1 = a1/tf.math.sqrt(tf.cast(self.size, tf.float32))
    a1 = tf.nn.softmax(a1+self.masking(1-mask) if mask is not None else a1, axis=-1)
    return tf.matmul(self.attdrop(a1, training=training), self.transposing(v)), a1

  def propagating(self, left, right, mask=None, training=False):
    x1, a1 = self.calculating(self.wq(left), self.wk(right), self.wv(right), mask, training)
    x1 = tf.reshape(tf.transpose(x1, [0, 2, 1, 3]), [-1, tf.shape(x1)[2], self.dim])
    return self.norm(left+self.drop(self.dense(x1), training=training))


class ModelBERT(keras.Model):
  def __init__(self, model, head, size, drop, nclass, leftlen, rightlen, klayer, indim=768):
    super(ModelBERT, self).__init__()
    self.nclass, self.size, self.left, self.klayer = nclass, [-1, leftlen+rightlen], leftlen, klayer
    self.bert = TFAlbertModel.from_pretrained(model)
    self.outdense = keras.layers.Dense(head*size) if head*size != indim else None
    self.coa = CoAttention('bert/coattention', head, size, drop, drop)
    self.poola = keras.layers.GlobalAveragePooling1D()
    self.poolb = keras.layers.GlobalAveragePooling1D()
    self.drop = keras.layers.Dropout(drop)
    self.dense = keras.layers.Dense(1)

  def propagating(self, text, seg, mask, training):
    x1, x2, x3 = tf.reshape(text, self.size), tf.reshape(seg, self.size), tf.reshape(mask, self.size)
    x1 = self.bert(input_ids=x1, attention_mask=x3, token_type_ids=x2, training=training)['last_hidden_state']

    if self.outdense is not None:
      x1 = self.outdense(x1)

    stat1, stat2 = x1[:, :self.left], x1[:, self.left:]
    mask1, mask2 = x3[:, :self.left], x3[:, self.left:]

    for i1 in range(self.klayer):
      stat1 = self.coa.propagating(stat1, stat2, mask2, training)
      stat2 = self.coa.propagating(stat2, stat1, mask1, training)

    stat1 = stat1*tf.cast(mask1, tf.float32)[:, :, tf.newaxis]
    stat2 = stat2*tf.cast(mask2, tf.float32)[:, :, tf.newaxis]
    stat1 = tf.concat([self.poola(stat1), self.poolb(stat2)], 1)
    stat1 = self.dense(self.drop(stat1, training=training))
    return tf.nn.softmax(tf.reshape(stat1, [-1, self.nclass]))


@tf.function
def step_training(iterator):
  def training(data):
    text_1, seg_1, mask_1, label_1 = data

    with tf.GradientTape() as tape_1:
      pred_1 = model_1.propagating(text_1, seg_1, mask_1, True)
      value_1 = function_1(label_1, pred_1)
      value_1 = tf.nn.compute_average_loss(value_1, global_batch_size=BATCH)

    grad_1 = tape_1.gradient(value_1, model_1.trainable_variables)
    grad_1, _ = tf.clip_by_global_norm(grad_1, 1.0)
    optimizer_1.apply_gradients(list(zip(grad_1, model_1.trainable_variables)))
    loss_1.update_state(value_1*strategy_1.num_replicas_in_sync)
    acc_1.update_state(label_1, pred_1)
    
  strategy_1.run(training, args=(next(iterator),))


@tf.function
def step_evaluating(iterator):
  def evaluating(data):
    text_1, seg_1, mask_1, label_1 = data
    pred_1 = model_1.propagating(text_1, seg_1, mask_1, False)
    acc_2.update_state(label_1, pred_1)

  strategy_1.run(evaluating, args=(next(iterator),))


with strategy_1.scope():
  model_1 = ModelBERT(MODEL, HEAD, SIZE, DROP, NCLASS, LEFTLEN, RIGHTLEN, KLAYER, INDIM)
  function_1 = keras.losses.SparseCategoricalCrossentropy(reduction=keras.losses.Reduction.NONE)
  optimizer_1 = AdamW(EPOCH*(int(len(training_1[0])/BATCH)+1), LRATE)
  loss_1 = keras.metrics.Mean(name='training_loss')
  acc_1 = keras.metrics.SparseCategoricalAccuracy(name='training_accuracy')
  acc_2 = keras.metrics.SparseCategoricalAccuracy(name='dev_accuracy')

In [5]:
print_1 = 'Training loss is {:.4f}, and accuracy is {:.4f}.'
print_2 = 'Dev accuracy is {:.4f}, and epoch cost is {:.4f}.'

for e_1 in range(EPOCH):
  print('Epoch {} running.'.format(e_1+1))
  time_0, training_3, dev_3 = time.time(), iter(training_2), iter(dev_2)

  for s_1 in range(math.floor(len(training_1[0])/BATCH)):
    step_training(training_3)

    if (s_1+1) % 50 == 0:
      print(print_1.format(float(loss_1.result()), float(acc_1.result())))

  for s_1 in range(math.ceil(len(dev_1[0])/BATCH)):
    step_evaluating(dev_3)

  print(print_2.format(float(acc_2.result()), time.time()-time_0))
  print('**********')
  acc_1.reset_states()
  acc_2.reset_states()

Epoch 1 running.
Training loss is 1.0971, and accuracy is 0.3450.
Training loss is 1.0748, and accuracy is 0.4106.
Training loss is 1.0479, and accuracy is 0.4437.
Training loss is 1.0200, and accuracy is 0.4756.
Training loss is 1.0058, and accuracy is 0.4970.
Training loss is 0.9856, and accuracy is 0.5131.
Training loss is 0.9727, and accuracy is 0.5334.
Dev accuracy is 0.6299, and epoch cost is 132.9115.
**********
Epoch 2 running.
Training loss is 0.9395, and accuracy is 0.7050.
Training loss is 0.9179, and accuracy is 0.7244.
Training loss is 0.9044, and accuracy is 0.7154.
Training loss is 0.8944, and accuracy is 0.7175.
Training loss is 0.8840, and accuracy is 0.7207.
Training loss is 0.8742, and accuracy is 0.7225.
Training loss is 0.8668, and accuracy is 0.7223.
Dev accuracy is 0.6647, and epoch cost is 67.4803.
**********
Epoch 3 running.
Training loss is 0.8440, and accuracy is 0.8338.
Training loss is 0.8254, and accuracy is 0.8275.
Training loss is 0.8055, and accuracy is

In [6]:
test_3 = iter(test_2)

for s_1 in range(math.ceil(len(test_1[0])/BATCH)):
  step_evaluating(test_3)

print('Test accuracy is {:.4f}.'.format(float(acc_2.result())))

Test accuracy is 0.6712.


Have a test!

In [7]:
class MRCBot(object):
  def __init__(self, processor, tokenizer, model):
    self.processor, self.tokenizer, self.model = processor, tokenizer, model

  def processing(self, examples):
    pair1 = []

    for i1, text1 in enumerate(examples):
      cont1 = '\n'.join(text1['passage'])
      ques1 = text1['question']
      cont1 = cont1.lower() if LOWER else cont1
      ques1 = ques1.lower() if LOWER else ques1
      
      for j1, answ1 in enumerate(text1['options']):
        answ1 = answ1.lower() if LOWER else answ1
        pair1.append(InputExample(guid=0, text_a=cont1, text_b=answ1, label='0', text_c=ques1))

    data1 = data_processing(pair1, self.processor.get_labels(), LEFTLEN, RIGHTLEN, self.tokenizer, NCLASS)
    return np.array(data1[0]), np.array(data1[2]), np.array(data1[1])

  def predicting(self, examples):
    x1, x2, x3 = self.processing(examples)
    pred1 = np.argmax(self.model.propagating(x1, x2, x3, False), -1).tolist()
    return [examples[i1]['options'][pred1[i1]] for i1 in range(len(examples))]


bot_1 = MRCBot(processor_1, tokenizer_1, model_1)

In [8]:
sample_1 = [{
  'passage': [
    'm: my name is andy tao, very nice to meet you.',
    'w: nice to meet you, my name is jennifer, and thank you for this expensive dinner.',
    'm: you are welcome. you know, i am quite rich, so it is a piece of cake.',
    'w: you are so humorous, i am looking forward to our next dating.',
    'm: me too. I can drive my luxurious tesla and take you to a nice park.'],
  'question': 'what is the man going to do for next dating?',
  'options': [
    'drive the woman to a park.',
    'show the woman his luxurious tesla.',
    'take the woman to a restaurant for dinner.']}]

bot_1.predicting(sample_1)

['drive the woman to a park.']