## I referred to the following webpages for the implementation.
- Implementation of Transformer<br>
https://qiita.com/halhorn/items/c91497522be27bde17ce<br>
https://github.com/kpot/keras-transformer/tree/master/keras_transformer<br>
https://github.com/Lsdefine/attention-is-all-you-need-keras<br>
- Usage of "\_\_call\_\_" method<br>
https://qiita.com/kyo-bad/items/439d8cc3a0424c45214a

In [None]:
import warnings

warnings.filterwarnings('ignore')

In [12]:
import numpy as np
import math

import tensorflow as tf

from keras.models import Model
from keras.layers import Dense, Dropout, Activation, Layer, Embedding, Input, Reshape, Lambda, Add
from keras import backend as K
from keras.initializers import RandomNormal
from keras.utils import plot_model
from keras.optimizers import Adam
from keras.callbacks import Callback

In [8]:
vocab_size = 8000
d_model = 512
MAX_LEN = 716
class_num = 9
PAD_ID = 0
warmup_steps = 4000
NUM_TRAIN = 5893
NUM_TEST = 1474
batch_size = 16
epochs = 600
negative_inf = -1e9

In [10]:
class MultiheadAttention():
    ## hidden_dim has to be multiples of head_num
    def __init__(self, max_len, hidden_dim=512, head_num=8, dropout_rate=0.1, *args, **kwargs):
        self.max_len = max_len
        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.dropout_rate = dropout_rate
        
        self.q_dense_layer = Dense(hidden_dim, use_bias=False)
        self.k_dense_layer = Dense(hidden_dim, use_bias=False)
        self.v_dense_layer = Dense(hidden_dim, use_bias=False)
        self.output_dense_layer = Dense(hidden_dim, use_bias=False)
        self.attention_dropout_layer = Dropout(dropout_rate)
        
    def split_heads(self, x):
        def reshape(x):
            x = tf.reshape(x, [-1, self.max_len, self.head_num, self.hidden_dim // self.head_num])
            return tf.transpose(x, [0, 2, 1, 3])
        
        out = Lambda(reshape)(x)
        return out
    
    def combine_heads(self, heads):
        def reshape(x):
            heads = tf.transpose(x, [0, 2, 1, 3])
            return tf.reshape(x, [-1, self.max_len, self.hidden_dim])
        
        out = Lambda(reshape)(heads)
        return out
        
    def __call__(self, query, memory, attention_mask):
        #two arguments of query and memory are already encoded as embedded vectors for all words
        q = self.q_dense_layer(query)
        k = self.k_dense_layer(memory)
        v = self.v_dense_layer(memory)
        
        q = self.split_heads(q)
        k = self.split_heads(k)
        v = self.split_heads(v)
        
        #for scaled dot-product
        depth_inside_each_head = self.hidden_dim // self.head_num
        q = Lambda(lambda x: x * (depth_inside_each_head ** -0.5))(q)
        
        #q.shape = (batch_size, head_num, query_len, emb_dim)
        #k.shape = (batch_size, head_num, memory_len, emb_dim)
        #batch_dot(q, k).shape = (batch_size, head_num, query_len, memory_len)
        score = Lambda(lambda x: K.batch_dot(x[0], x[1], axes=[3, 3]))([q, k])
        neg_inf_for_pads = Lambda(lambda x: K.cast_to_floatx(x) * negative_inf)(attention_mask)
        masked_score = Add()([score, neg_inf_for_pads])
        
        normalized_score = Activation("softmax")(masked_score)
        normalized_score = self.attention_dropout_layer(normalized_score)
        
        #normalized_score.shape = (batch_size, head_num, query_length, memory_length)
        #v.shape = (batch_size, head_num, memory_length, depth)
        #attention_weighted_output.shape = (batch_size, head_num, query_length, depth)
        attention_weighted_output = Lambda(lambda x: K.batch_dot(x[0], x[1], axes=[3, 2]))([normalized_score, v])
        attention_weighted_output = self.combine_heads(attention_weighted_output)
        return self.output_dense_layer(attention_weighted_output)

In [11]:
# SlefAttention class inherits MultiheadAttention class so that it can make query and memory come from the same source.
class SelfAttention(MultiheadAttention):
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def __call__(self, query, attention_mask):
        return super().__call__(query, query, attention_mask=attention_mask)

In [5]:
class PositionwiseFeedForwardNetwork():
    
    def __init__(self, hidden_dim, dropout_rate, *args, **kwargs):
        self.hidden_dim = hidden_dim
        self.dropout_rate = dropout_rate
        
        self.first_dense_layer = Dense(hidden_dim*4, use_bias=True, activation="relu")
        self.second_dense_layer = Dense(hidden_dim, use_bias=True, activation="linear")
        self.dropout_layer = Dropout(dropout_rate)
        
    def __call__(self, inputs):
        # make the network more flexible to learn for the first dense layer(non-linear transformation is used),
        # and put the network back into the same hidden dim as original(linear transformation is used)
        x = self.first_dense_layer(inputs)
        x = self.dropout_layer(x)
        return self.second_dense_layer(x)

In [6]:
class LayerNormalization(Layer):
    def __init__(self, axis=-1, **kwargs):
        self.axis = axis
        super(LayerNormalization, self).__init__(**kwargs)
        
    def get_config(self):
        config = super().get_config()
        config["axis"] = self.axis
        return config
    
    def build(self, input_shape):
        hidden_dim = input_shape[-1]
        self.scale = self.add_weight("layer_norm_scale", shape=[hidden_dim],
                                    initializer="ones")
        self.shift = self.add_weight("layer_norm_shift", shape=[hidden_dim],
                                    initializer="zeros")
        super(LayerNormalization, self).build(input_shape)
        
    def call(self, inputs, epsilon=1e-6):
        mean = K.mean(inputs, axis=[-1], keepdims=True)
        variance = K.var(inputs, axis=[-1], keepdims=True)
        normalized_inputs = (inputs - mean) / (K.sqrt(variance) + epsilon)
        return normalized_inputs * self.scale + self.shift
    
    def compute_output_shape(self, input_shape):
        return input_shape

In [7]:
class PreLayerNormPostResidualConnectionWrapper():
    def __init__(self, layer, dropout_rate, *args, **kwargs):
        self.layer = layer
        self.layer_norm = LayerNormalization()
        self.dropout_layer = Dropout(dropout_rate)
        
    def __call__(self, inputs, *args, **kwargs):
        x = self.layer_norm(inputs)
        x = self.layer(x)
        outputs = self.dropout_layer(x)
        results = Add()([inputs, outputs])
        return results

In [8]:
class AddPositionalEncoding(Layer): 
    def call(self, inputs):
        data_type = inputs.dtype
        batch_size, max_len, emb_dim = tf.unstack(tf.shape(inputs))
        # i is from 0 to 255 when emb_dim is 512
        #so the doubled_i is from 0 to 510
        doubled_i = K.arange(emb_dim) // 2 * 2
        exponent = K.tile(K.expand_dims(doubled_i, 0), [max_len, 1])
        denominator_matrix = K.pow(10000.0, K.cast(exponent / emb_dim, data_type))
        
        # since cos(x) = sin(x + π/2), we convert the series of [sin, cos, sin, cos, ...]
        # into [sin, sin, sin, sin, ...]
        to_convert = K.cast(K.arange(emb_dim) % 2, data_type) * math.pi / 2
        convert_matrix = K.tile(tf.expand_dims(to_convert, 0), [max_len, 1])
        
        seq_pos = K.arange(max_len)
        numerator_matrix = K.cast(K.tile(K.expand_dims(seq_pos, 1), [1, emb_dim]), data_type)
        
        positinal_encoding = K.sin(numerator_matrix / denominator_matrix + convert_matrix)
        batched_positional_encoding = K.tile(K.expand_dims(positinal_encoding, 0), [batch_size, 1, 1])
        return inputs + batched_positional_encoding
    
    def compute_output_shape(self, input_shape):
        return input_shape

In [9]:
class MakeZeroPads(Layer):
    def __init__(self, seq_len, vocab_size, emb_dim, data_type="float32", *args, **kwargs):
        self.emb_dim = emb_dim
        super(MakeZeroPads, self).__init__(*args, **kwargs)
        
    def call(self, inputs):
        mask_for_pads = tf.to_float(tf.not_equal(inputs, PAD_ID))
        pads_masked_embedding = inputs * mask_for_pads
        return pads_masked_embedding * (self.emb_dim ** 0.5)
    
    def compute_output_shape(self, input_shape):
        return input_shape

In [10]:
class Encoder():
    def __init__(self, vocab_size, max_len, stack_num, head_num, emb_dim, dropout_rate, *args, **kwargs):
        self.vocab_size = vocab_size
        self.max_len = max_len
        self.stack_num = stack_num
        self.head_num = head_num
        self.emb_dim = emb_dim
        self.dropout_rate = dropout_rate
        
        self.embedding_layer = Embedding(self.vocab_size,
                           self.emb_dim,
                           embeddings_initializer=RandomNormal(mean=0.0, stddev=self.emb_dim**-0.5)
                          )
        self.make_zero_pads_layer = MakeZeroPads(self.max_len, vocab_size, emb_dim)
        self.add_pos_enc_layer = AddPositionalEncoding()
        self.input_dropout_layer = Dropout(dropout_rate)
        
        self.attention_block_list = []
        for _ in range(stack_num):
            self_attention_layer = SelfAttention(self.max_len, self.emb_dim, self.head_num, self.dropout_rate)
            pffn_layer = PositionwiseFeedForwardNetwork(self.emb_dim, self.dropout_rate)
            self.attention_block_list.append([
                PreLayerNormPostResidualConnectionWrapper(self_attention_layer, dropout_rate),
                PreLayerNormPostResidualConnectionWrapper(pffn_layer, dropout_rate)
            ])
        self.output_layer_norm = LayerNormalization()
        
    def __call__(self, inputs, self_attention_mask):
        x = self.embedding_layer(inputs)
        x = self.make_zero_pads_layer(x)
        x = self.add_pos_enc_layer(x)
        x = self.input_dropout_layer(x)
        
        for i, set_of_layers_list in enumerate(self.attention_block_list):
            self_attention_layer, pffn_layer = tuple(set_of_layers_list)
            x = self_attention_layer(query=x, attention_mask=self_attention_mask)
            x = pffn_layer(x)
            
        return self.output_layer_norm(x)

In [13]:
class Decoder():
    def __init__(self, vocab_size, stack_num, head_num, emb_dim, dropout_rate, max_len, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.stack_num = stack_num
        self.head_num = head_num
        self.emb_dim = emb_dim
        self.dropout_rate = dropout_rate
        
        self.embedding_layer = Embedding(self.vocab_size,
                   self.emb_dim,
                   embeddings_initializer=RandomNormal(mean=0.0, stddev=self.emb_dim**-0.5)
                  )
        self.make_zero_pads_layer = MakeZeroPads(self.max_len, vocab_size, emb_dim)
        self.add_pos_enc_layer = AddPositionalEncoding()
        self.input_dropout_layer = Dropout(dropout_rate)
        
        self.attention_block_list = []
        for _ in range(stack_num):
            self_attention_layer = SelfAttention(self.max_len, self.emb_dim, self.head_num, self.dropout_rate)
            sou_tar_attn_layer = MultiheadAttention(self.max_len, self.emb_dim, self.head_num, self.dropout_rate)
            pffn_layer = PositionwiseFeedForwardNetwork(self.emb_dim, self.dropout_rate)
            self.attention_block_list.append([
                PreLayerNormPostResidualConnectionWrapper(self_attention_layer, dropout_rate),
                PreLayerNormPostResidualConnectionWrapper(sou_tar_attn_layer, dropout_rate),
                PreLayerNormPostResidualConnectionWrapper(pffn_layer, dropout_rate)
            ])
        self.output_layer_norm = LayerNormalization()
        
        ## need to change output dense into shared emb weight
        self.output_dense_layer = Dense(vocab_size, use_bias=False)
        
    def __call__(self, inputs, encoder_outputs, self_attention_mask, sou_tar_attn_mask, train_flag):
        x = self.embedding_layer(inputs)
        x = self.make_zero_pads_layer(x)
        x = self.add_pos_enc_layer(x)
        x = self.input_dropout_layer(x)
        
        for i, set_of_layers_list in enumerate(self.attention_block_list):
            self_attention_layer, sou_tar_attn_layer, pffn_layer = tuple(set_of_layers_list)
            x = self_attention_layer(query=x, attention_mask=self_attention_mask)
            x = sou_tar_attn_layer(query=x, memory=encoder_output, attention_mask=sou_tar_attn_mask)
            x = pffn_layer(x)
            
        x = self.output_layer_norm(x)
        return self.output_dense_layer(x)

In [None]:
class Transformer():
    def __init__(self, vocab_size, stack_num, head_num, emb_dim, dropout_rate, max_len, *args, **kwargs):
        self.vocab_size = vocab_size
        self.stack_num = stack_num
        self.head_num = head_num
        self.emb_dim = emb_dim
        self.dropout_rate = dropout_rate
        self.max_len = max_len
        
        self.encoder = Encoder(vocab_size=vocab_size,
                    stack_num=stack_num,
                    head_num=head_num,
                    emb_dim=emb_dim,
                    dropout_rate=dropout_rate,
                    max_len=max_len
                              )
        self.decoder = Decoder(vocab_size=vocab_size,
                              stack_num=stack_num,
                              head_num=head_num,
                              emb_dim=emb_dim,
                              dropout_rate=dropout_rate,
                              max_len=max_len
                              )
    # Mask for pads
    def create_enc_attention_mask(self, encoder_input):
        batch_size, length = tf.unstack(tf.shape(encoder_input))
        pad_array = tf.equal(encoder_input, PAD_ID)
        pad_array = tf.reshape(pad_array, [batch_size, 1, 1, length])
        
    # Mask for pads and tokens that are located at greater time step than current time step
    def create_dec_self_attention_mask(self, decoder_input):
        batch_size, length = tf.unstack(tf.shape(decpder_input))
        pad_array = tf.equal(decoder_input, PAD_ID)
        pad_array = tf.reshape(pad_array, [batch_size, 1, 1, length])
        
        autoregression_array = tf.logical_not(
            tf.matrix_band_part(tf.ones([length, length], dtype=tf.bool), -1, 0)
        )
        autoregression_array = tf.reshape(autpregression_array, [1, 1, length, length])
        return tf.logical_or(pad_array, autoregression_array)
        
    def __call__(self, encoder_input, decoder_input):
        enc_attention_mask = self.create_enc_attention_mask(encoder_input)
        dec_self_attention_mask = self.create_dec_self_attention_mask(decoder_input)
        
        encoder_output = self.encoder(
            encoder_input,
            self_attention_mask=enc_attention_mask
        )
        
        decoder_output = self.decoder(
            decoder_input,
            self_encoder_output,
            self_attention_mask=dec_self_attention_mask,
            sou_tar_attn_mask=enc_attention_mask
        )
        
        return decoder_output