In [1]:
import tensorflow as tf
import numpy as np

from tensorflow import keras

In [2]:
class PositionEmbedding(tf.keras.layers.Layer):
    """Creates a positional embedding.
    Example:
    ```python
    position_embedding = PositionEmbedding(max_length=100)
    inputs = tf.keras.Input((100, 32), dtype=tf.float32)
    outputs = position_embedding(inputs)
    ```
    Args:
    max_length: The maximum size of the dynamic sequence.
    initializer: The initializer to use for the embedding weights. Defaults to
      "glorot_uniform".
    seq_axis: The axis of the input tensor where we add the embeddings.
    Reference: This layer creates a positional embedding as described in
    [BERT: Pre-training of Deep Bidirectional Transformers for Language
    Understanding](https://arxiv.org/abs/1810.04805).
    """

    def __init__(self, max_length, initializer="glorot_uniform", seq_axis=1,  **kwargs):

        super(PositionEmbedding, self).__init__(**kwargs)
        
        if max_length is None:
            raise ValueError("`max_length` must be an Integer, not `None`.")
        
        self._max_length = max_length
        self._initializer = tf.keras.initializers.get(initializer)
        self._seq_axis = seq_axis

        
    def get_config(self):
        config = {
            "max_length": self._max_length,
            "initializer": tf.keras.initializers.serialize(self._initializer),
            "seq_axis": self._seq_axis,
        }
        base_config = super(PositionEmbedding, self).get_config()
    
        return dict(list(base_config.items()) + list(config.items()))

    
    def build(self, input_shape):
        dimension_list = input_shape.as_list()

        seq_length = dimension_list[self._seq_axis]
        width = dimension_list[-1]

        if self._max_length is not None:
            weight_sequence_length = self._max_length
        else:
            weight_sequence_length = seq_length

        self._position_embeddings = self.add_weight("PositionEmbeddings", shape=[weight_sequence_length, width], initializer=self._initializer)

        super(PositionEmbedding, self).build(input_shape)

        
    def call(self, inputs):
        input_shape = tf.shape(inputs)
        actual_seq_len = input_shape[self._seq_axis]
        position_embeddings = self._position_embeddings[:actual_seq_len, :]
        
        new_shape = [1 for _ in inputs.get_shape().as_list()]
        new_shape[self._seq_axis] = actual_seq_len
        new_shape[-1] = position_embeddings.get_shape().as_list()[-1]
        
        position_embeddings = tf.reshape(position_embeddings, new_shape)
        
        return tf.broadcast_to(position_embeddings, input_shape)

In [3]:
class MultiHeadAttention(tf.keras.layers.Layer) :
    
    def __init__(self, d_model, num_heads, name = 'multi_head_attention') :
        
        super(MultiHeadAttention, self).__init__(name = name)
        self.d_model = d_model
        self.num_heads = num_heads
        
        self.depth = d_model // num_heads

        assert d_model == (num_heads * self.depth)
        
        self.w_q = tf.keras.layers.Dense(self.d_model, name = "{}/Q_WEIGHT".format(self.name))
        self.w_k = tf.keras.layers.Dense(self.d_model, name = "{}/K_WEIGHT".format(self.name))
        self.w_v = tf.keras.layers.Dense(self.d_model, name = "{}/V_WEIGHT".format(self.name))
        
        self.dense = tf.keras.layers.Dense(d_model, name = "{}/OUTPUT_DENSE".format(self.name))
        
    def split_head(self, l, batch_size) :
        outputs = tf.reshape(l, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(outputs, perm = [0, 2, 1, 3])
    
    def scaled_dot_product(self, query, key, value, mask) :
        
        d_k = tf.cast(self.depth, dtype = tf.float32)
        
        dot_score = tf.matmul(query, key, transpose_b = True) / tf.math.sqrt(d_k / self.num_heads)
        
        if mask is not None :
            dot_score += mask * -1e9
        
        attention_score = tf.nn.softmax(dot_score)
        outputs = tf.matmul(attention_score, value)
        
        return outputs, attention_score
    
    
    def call(self, inputs) :
        
        query, key, value, mask = inputs['query'], inputs['key'], inputs['value'], inputs['mask']
        
        batch_size = tf.shape(query)[0]
        
        
        # inputs : (batch, seq_len, d_model)
        query = self.w_q(query)
        key   = self.w_k(key)
        value = self.w_v(value)
        
        # q, k, v
        # (batch, seq_len, d_model) -> (batch, num_heads, seq_len, depth)
        query = self.split_head(query, batch_size)
        key = self.split_head(key, batch_size)
        value = self.split_head(value, batch_size)
        
        # scaled_dot_product
        outputs, _ = self.scaled_dot_product(query, key, value, mask)
        
        outputs = tf.reshape(outputs, (batch_size, -1, self.d_model))
        outputs = self.dense(outputs)
        
        return outputs

In [4]:
def custom_gelu(x):
    return 0.5 * x * (1 + tf.tanh(tf.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))

In [5]:
def encoder_layer(dff, d_model, num_heads, dropout, name = 'encoder_layer') :

    inputs = tf.keras.Input(shape=(None, d_model), name = '{}/Input'.format(name))
    
    padding_mask = tf.keras.Input(shape = (1, 1, None), name = '{}/padding_mask'.format(name))

    attention = MultiHeadAttention(d_model = d_model, num_heads = num_heads, name = "{}/MHA".format(name))({
        'query' : inputs, 'key' : inputs, 'value' : inputs,
        'mask': padding_mask
    })

    attention = tf.keras.layers.Dropout(rate=dropout, name = "{}/Dropout1".format(name))(attention)
    attention = tf.keras.layers.LayerNormalization(epsilon=1e-6, name = "{}/LM1".format(name))(inputs + attention)

    outputs = tf.keras.layers.Dense(units = dff, activation = custom_gelu, name = "{}/FFN1".format(name))(attention)
    outputs = tf.keras.layers.Dense(units = d_model, name = "{}/FFN2".format(name))(outputs)

    outputs = tf.keras.layers.Dropout(rate=dropout, name = "{}/Dropout2".format(name))(outputs)
    outputs = tf.keras.layers.LayerNormalization(epsilon = 1e-6, name = "{}/LM2".format(name))(attention + outputs)

    return tf.keras.Model(inputs = [inputs, padding_mask], outputs = outputs, name = name)

In [43]:
def qbert_encoder(vocab_size, max_seq_len, num_layers, dff, d_model, num_heads, dropout, name = 'qbert'):

    input_ids = tf.keras.Input(shape = (None, ), name = 'BertInput/inputs')
    padding_mask = tf.keras.Input(shape = (1, 1, None), name = 'BertInput/padding_mask')
    segments = tf.keras.Input(shape = (None, ), name = 'BertInput/segments')
    
    outputs = {}
    
    embeddings = tf.keras.layers.Embedding(vocab_size, d_model, name = 'Bert/Embedding')(input_ids)
    embeddings += PositionEmbedding(max_seq_len)(embeddings)
    embeddings += tf.keras.layers.Embedding(3, d_model, name = 'Bert/SegEmbedding')(segments) # sentence A or sentence B
    
    output = tf.keras.layers.Dropout(rate = dropout, name = 'Bert/Dropout_Embedding')(embeddings)
    output = tf.keras.layers.LayerNormalization(epsilon = 1e-6, name = 'Bert/LM_Embedding')(output)
    
    encode_outputs = []
    for i in range(num_layers) :

        output = encoder_layer(dff = dff
                                , d_model=d_model
                                , num_heads = num_heads
                                , dropout = dropout
                                , name = 'Encoding_Layer_{}'.format(i))([output, padding_mask])
        
        encode_outputs.append(output)
    
    outputs['sequence_output'] = encode_outputs[-1]
    outputs['hidden_states'] = encode_outputs
    
    pooler_output = tf.keras.layers.Dense(d_model, name = 'Bert/pooler_layer')(encode_outputs[-1][:, 0, :])
    outputs['pooled_output'] = pooler_output
    
    return tf.keras.Model(inputs = [input_ids, padding_mask, segments], outputs = outputs, name = name)


## ================================================================

In [49]:
def get_bert_models_fn(vocab_size
                       , hidden_size
                       , type_vocab_size
                       , num_layers
                       , num_attention_heads
                       , max_seq_length
                       , max_predictions_per_seq
                       , dropout_rate
                       , inner_dim 
                       , initializer) :
    
    input_word_ids = tf.keras.layers.Input(shape=(max_seq_length,)
                                           , name='input_word_ids', dtype=tf.int32)
    
    input_mask = tf.keras.layers.Input(shape=(max_seq_length,)
                                       , name='input_mask', dtype=tf.int32)
    
    input_type_ids = tf.keras.layers.Input(shape=(max_seq_length,)
                                           , name='input_type_ids', dtype=tf.int32)
    
    masked_lm_positions = tf.keras.layers.Input(shape=(max_predictions_per_seq,)
                                                ,  name='masked_lm_positions', dtype=tf.int32)
    
    masked_lm_ids = tf.keras.layers.Input(shape=(max_predictions_per_seq,)
                                          , name='masked_lm_ids', dtype=tf.int32)
    
    masked_lm_weights = tf.keras.layers.Input(shape=(max_predictions_per_seq,)
                                              , name='masked_lm_weights', dtype=tf.int32)
    
    next_sentence_labels = tf.keras.layers.Input(shape=(1, )
                                                 , name = 'next_sentence_labels', dtype = tf.int32)
    
    bert_encoder = qbert_encoder(
        vocab_size, max_seq_length, num_layers,
        inner_dim, hidden_size, num_attention_heads,
        dropout_rate, name = 'qbert_encoder')
    
    
    input_mask_r = input_mask[:, tf.newaxis, tf.newaxis, :]
    
    encoder_output = bert_encoder([input_word_ids, input_mask_r, input_type_ids])
    
    embedding = bert_encoder.layers[1].weights[0]
    
    seg_output = encoder_output['sequence_output']
    
    cls_output = encoder_output['pooled_output']
    
    lm_output = LmLayer(embedding = embedding,
                        output = 'logits')([seg_output, masked_lm_positions])
    
    sentence_output = Classification(input_width = hidden_size,
                                     num_classes = 2)(cls_output)
    
    loss_metric_layer = BertPretrainLossAndMetricLayer(vocab_size)
    
    losses = loss_metric_layer(lm_output_logits = lm_output
                               ,sentence_output_logits = sentence_output
                               ,lm_label_ids = masked_lm_ids 
                               ,lm_label_weights = masked_lm_weights
                               ,sentence_labels = next_sentence_labels)
    
    inputs = {'input_ids' : input_word_ids,
              'input_mask' : input_mask,
              'segment_ids' : input_type_ids,
              'masked_lm_ids' : masked_lm_ids,
              'masked_lm_positions': masked_lm_positions,
              'masked_lm_weights' : masked_lm_weights,
              'next_sentence_labels' : next_sentence_labels}
    
    pretrain_model = tf.keras.Model(inputs = inputs, outputs = losses)
    
    return pretrain_model, bert_encoder

In [50]:
class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
    """Returns layer that computes custom loss and metrics for pretraining."""

    def __init__(self, vocab_size, **kwargs):
        super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
        self._vocab_size = vocab_size
        self.config = {
            'vocab_size': vocab_size,
        }

    def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
                   lm_example_loss, sentence_output, sentence_labels,
                   next_sentence_loss):
        """Adds metrics."""
        masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
            lm_labels, lm_output)
        numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
        denominator = tf.reduce_sum(lm_label_weights) + 1e-5
        masked_lm_accuracy = numerator / denominator
        self.add_metric(
            masked_lm_accuracy, name='MLM_ACC', aggregation='mean')

        self.add_metric(lm_example_loss, name='MLM_LOSS', aggregation='mean')

        if sentence_labels is not None:
            next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
              sentence_labels, sentence_output)
            self.add_metric(
              next_sentence_accuracy,
              name='NSP_ACC',
              aggregation='mean')

        if next_sentence_loss is not None:
            self.add_metric(
              next_sentence_loss, name='NSP_LOSS', aggregation='mean')

    def call(self,
               lm_output_logits,
               sentence_output_logits,
               lm_label_ids,
               lm_label_weights,
               sentence_labels=None):

        """Implements call() for the layer."""
        lm_label_weights = tf.cast(lm_label_weights, tf.float32)
        lm_output_logits = tf.cast(lm_output_logits, tf.float32)

        lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
            lm_label_ids, lm_output_logits, from_logits=True)
        lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
        lm_denominator_loss = tf.reduce_sum(lm_label_weights)
        mask_label_loss = tf.math.divide_no_nan(lm_numerator_loss,
                                                lm_denominator_loss)

        if sentence_labels is not None:
            sentence_output_logits = tf.cast(sentence_output_logits, tf.float32)
            sentence_loss = tf.keras.losses.sparse_categorical_crossentropy(sentence_labels, sentence_output_logits, from_logits=True)
            sentence_loss = tf.reduce_mean(sentence_loss)
            loss = mask_label_loss + sentence_loss
        else:
            sentence_loss = None
            loss = mask_label_loss

        batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1])
        # TODO(hongkuny): Avoids the hack and switches add_loss.
        final_loss = tf.fill(batch_shape, loss)

        self._add_metrics(lm_output_logits, lm_label_ids, lm_label_weights,
                          mask_label_loss, sentence_output_logits, sentence_labels,
                          sentence_loss)
        
        return final_loss

In [51]:
class LmLayer(tf.keras.layers.Layer) :
    
    def __init__(self, embedding, output) :
        
        super(LmLayer, self).__init__()
        
        self.embedding_table = embedding
        
        if output not in ('predictions', 'logits'):
            raise ValueError(
                ('Unknown `output` value "%s". `output` can be either "logits" or '
                 '"predictions"') % output)
        self._output_type = output
        
        self._vocab_size, hidden_size = self.embedding_table.shape
        self.dense = tf.keras.layers.Dense(hidden_size, activation = custom_gelu, name = 'transform/dense')
        self.layer_norm = tf.keras.layers.LayerNormalization(axis = -1, epsilon = 1e-12, name = 'transform/LayerNorm')
        self.bias = self.add_weight('transform/bias', shape = (self._vocab_size), initializer = 'zeros', trainable = True)
        
        
    def call(self, inputs) :
        
        seg_output, masked_lm_positions = inputs[0], inputs[1]
        
        sequence_shape = tf.shape(seg_output) # [batch_size, seq_length, dff]
        batch_size, seq_length = sequence_shape[0], sequence_shape[1] 
        width = seg_output.shape.as_list()[2] or sequence_shape[2]

        flat_offsets = tf.reshape(
            tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
        flat_positions = tf.reshape(masked_lm_positions + flat_offsets, [-1])
        flat_sequence_tensor = tf.reshape(seg_output,
                                          [batch_size * seq_length, width])
        gathered_tensor = tf.gather(flat_sequence_tensor, flat_positions)
        
        output = self.dense(gathered_tensor)
        output = self.layer_norm(output)
        output = tf.matmul(output, self.embedding_table, transpose_b = True)
        
        logits = tf.nn.bias_add(output, self.bias)
        masked_positions_length = masked_lm_positions.shape.as_list()[1] or tf.shape(masked_lm_positions)[1]
        
        logits = tf.reshape(logits, shape = [ -1, masked_positions_length, self._vocab_size])
        
        if self._output_type == 'predictions' :
            return logits
        
        return tf.nn.log_softmax(logits)
        


In [52]:
class Classification(tf.keras.Model):

    def __init__(self,
               input_width,
               num_classes,
               initializer='glorot_uniform',
               output='logits',
               **kwargs):

        cls_output = tf.keras.layers.Input(shape=(input_width,), name='cls_output', dtype=tf.float32)

        logits = tf.keras.layers.Dense( 
            num_classes,
            activation=None,
            kernel_initializer=initializer,
            name='predictions/transform/logits')(cls_output)

        if output == 'logits':
            output_tensors = logits
            
        elif output == 'predictions':
            policy = tf.keras.mixed_precision.global_policy()
            
            if policy.name == 'mixed_bfloat16':
                # b/158514794: bf16 is not stable with post-softmax cross-entropy.
                policy = tf.float32
                output_tensors = tf.keras.layers.Activation(
                  tf.nn.log_softmax, dtype=policy)(logits)
            else:
                raise ValueError(
                  ('Unknown `output` value "%s". `output` can be either "logits" or '
                   '"predictions"') % output)

        super(Classification, self).__init__(inputs=[cls_output], outputs=output_tensors, **kwargs)
