In [2]:
import tensorflow as tf
import numpy as np

from tensorflow import keras

In [3]:
class PositionEmbedding(tf.keras.layers.Layer):
    """Creates a positional embedding.
    Example:
    ```python
    position_embedding = PositionEmbedding(max_length=100)
    inputs = tf.keras.Input((100, 32), dtype=tf.float32)
    outputs = position_embedding(inputs)
    ```
    Args:
    max_length: The maximum size of the dynamic sequence.
    initializer: The initializer to use for the embedding weights. Defaults to
      "glorot_uniform".
    seq_axis: The axis of the input tensor where we add the embeddings.
    Reference: This layer creates a positional embedding as described in
    [BERT: Pre-training of Deep Bidirectional Transformers for Language
    Understanding](https://arxiv.org/abs/1810.04805).
    """

    def __init__(self, max_length, initializer="glorot_uniform", seq_axis=1,  **kwargs):

        super(PositionEmbedding, self).__init__(**kwargs)
        
        if max_length is None:
            raise ValueError("`max_length` must be an Integer, not `None`.")
        
        self._max_length = max_length
        self._initializer = tf.keras.initializers.get(initializer)
        self._seq_axis = seq_axis

        
    def get_config(self):
        config = {
            "max_length": self._max_length,
            "initializer": tf.keras.initializers.serialize(self._initializer),
            "seq_axis": self._seq_axis,
        }
        base_config = super(PositionEmbedding, self).get_config()
    
        return dict(list(base_config.items()) + list(config.items()))

    
    def build(self, input_shape):
        dimension_list = input_shape.as_list()

        seq_length = dimension_list[self._seq_axis]
        width = dimension_list[-1]

        if self._max_length is not None:
            weight_sequence_length = self._max_length
        else:
            weight_sequence_length = seq_length

        self._position_embeddings = self.add_weight("embeddings", shape=[weight_sequence_length, width], initializer=self._initializer)

        super(PositionEmbedding, self).build(input_shape)

        
    def call(self, inputs):
        input_shape = tf.shape(inputs)
        actual_seq_len = input_shape[self._seq_axis]
        position_embeddings = self._position_embeddings[:actual_seq_len, :]
        
        new_shape = [1 for _ in inputs.get_shape().as_list()]
        new_shape[self._seq_axis] = actual_seq_len
        new_shape[-1] = position_embeddings.get_shape().as_list()[-1]
        
        position_embeddings = tf.reshape(position_embeddings, new_shape)
        
        return tf.broadcast_to(position_embeddings, input_shape)

In [4]:
class MultiHeadAttention(tf.keras.layers.Layer) :
    
    def __init__(self, d_model, num_heads, name = 'multi_head_attention') :
        
        super(MultiHeadAttention, self).__init__(name = name)
        self.d_model = d_model
        self.num_heads = num_heads
        
        self.depth = d_model // num_heads
        assert d_model == (num_heads * self.depth)
        
        self.w_q = tf.keras.layers.Dense(self.d_model)
        self.w_k = tf.keras.layers.Dense(self.d_model)
        self.w_v = tf.keras.layers.Dense(self.d_model)
        
        self.dense = tf.keras.layers.Dense(d_model)
        
    def split_head(self, l, batch_size) :
        outputs = tf.reshape(l, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(outputs, perm = [0, 2, 1, 3])
    
    def scaled_dot_product(self, query, key, value, mask) :
        
        d_k = tf.cast(self.depth, dtype = tf.float32)
        
        dot_score = tf.matmul(query, key, transpose_b = True) / tf.math.sqrt(d_k / self.num_heads)
        
        if mask is not None :
            dot_score += mask * -1e9
        
        attention_score = tf.nn.softmax(dot_score)
        outputs = tf.matmul(attention_score, value)
        
        return outputs, attention_score
    
    
    def call(self, inputs) :
        
        query, key, value, mask = inputs['query'], inputs['key'], inputs['value'], inputs['mask']
        
        batch_size = tf.shape(query)[0]
        
        
        # inputs : (batch, seq_len, d_model)
        query = self.w_q(query)
        key   = self.w_k(key)
        value = self.w_v(value)
        
        # q, k, v
        # (batch, seq_len, d_model) -> (batch, num_heads, seq_len, depth)
        query = self.split_head(query, batch_size)
        key = self.split_head(key, batch_size)
        value = self.split_head(value, batch_size)
        
        # scaled_dot_product
        outputs, _ = self.scaled_dot_product(query, key, value, mask)
        
        outputs = tf.reshape(outputs, (batch_size, -1, self.d_model))
        outputs = self.dense(outputs)
        
        return outputs

In [5]:
def custom_gelu(x):
    return 0.5 * x * (1 + tf.tanh(tf.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))

In [6]:
def encoder_layer(dff, d_model, num_heads, dropout, name = 'encoder_layer') :

    inputs = tf.keras.Input(shape=(None, d_model), name = 'inputs')

    padding_mask = tf.keras.Input(shape = (1, 1, None), name = 'padding_mask')

    attention = MultiHeadAttention(d_model = d_model, num_heads = num_heads)({
        'query' : inputs, 'key' : inputs, 'value' : inputs,
        'mask': padding_mask
    })

    attention = tf.keras.layers.Dropout(rate=dropout)(attention)
    attention = tf.keras.layers.LayerNormalization(epsilon=1e-6)(inputs + attention)

    outputs = tf.keras.layers.Dense(units = dff, activation = custom_gelu)(attention)
    outputs = tf.keras.layers.Dense(units = d_model)(outputs)

    outputs = tf.keras.layers.Dropout(rate=dropout)(outputs)
    outputs = tf.keras.layers.LayerNormalization(epsilon = 1e-6)(attention + outputs)

    return tf.keras.Model(inputs = [inputs, padding_mask], outputs = outputs, name = name)

In [44]:
def qbert_model(vocab_size, max_seq_len, num_layers, dff, d_model, num_heads, dropout, name = 'qbert'):

    input = tf.keras.Input(shape = (None, ), name = 'inputs')
    padding_mask = tf.keras.Input(shape = (1, 1, None), name = 'padding_mask')
    segments = tf.keras.Input(shape = (None, ), name = 'segments')
    
    outputs = {}
    
    embeddings = tf.keras.layers.Embedding(vocab_size, d_model)(input)
    embeddings += PositionEmbedding(max_seq_len)(embeddings)
    embeddings += tf.keras.layers.Embedding(3, d_model)(segments) # sentence A or sentence B
    
    output = tf.keras.layers.Dropout(rate = dropout)(embeddings)
    # LayerNorm +
    
    encode_outputs = []
    for i in range(num_layers) :

        output = encoder_layer(dff = dff
                                , d_model=d_model
                                , num_heads = num_heads
                                , dropout = dropout
                                , name = 'encoding_layer_{}'.format(i))([output, padding_mask])
        
        # pooler_layer 
        if i == 0 :
            pooler_output = tf.keras.layers.Dense(d_model)(output)
            outputs['pooled_output'] = pooler_output
            
        encode_outputs.append(output)
    
    outputs['sequence_output'] = encode_outputs[-1]
    outputs['hidden_states'] = encode_outputs
    
    return tf.keras.Model(inputs = [input, padding_mask, segments], outputs = outputs, name = name)


## TEST CODE