# I referred to [**the webpage**](https://qiita.com/halhorn/items/c91497522be27bde17ce) for the implementation.

In [2]:
import numpy as np

import tensorflow as tf

from keras.models import Model
from keras.layers import Dense, Dropout, Activation, Layer
from keras import backend as K

In [2]:
class MultiheadAttention(Model):
    ## hidden_dim has to be multiples of head_num
    def __init__(self, hidden_dim=512, head_num=8, dropout_rate=0.1, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.dropout_rate = dropout_rate
        
        self.q_dense_layer = Dense(hidden_dim, use_bias=False, name="q_dense_layer")
        self.k_dense_layer = Dense(hidden_dim, use_bias=False, name="k_dense_layer")
        self.v_dense_layer = Dense(hidden_dim, use_bias=False, name="v_dense_layer")
        self.output_dense_layer = Dense(hidden_dim, use_bias=False, name="output_dense_layer")
        self.attention_dropout_layer = Dropout(dropout_rate, name="attention_dropout_layer")
        
    def split_heads(self, x):
        batch_size, max_len, hidden_dim = x.shape
        x = x.reshape(batch_size, max_len, self.head_num, self.hidden_dim//self.head_num)
        return x.transpose(0, 2, 1, 3)
    
    def combine_heads(self, heads):
        batch_size, _, max_len, _ = heads.shape
        heads = heads.transpose(0, 2, 1, 3)
        return heads.reshape(batch_size, max_len, self.hidden_dim)
        
    def call(self, query, memory, attention_mask, train_flag):
        #two arguments of query and memory are already encoded as embedded vectors for all words
        q = self.q_dense_layer(query)
        k = self.k_dense_layer(memory)
        v = self.v_dense_layer(memory)
        
        q = self.split_heads(q)
        k = self.split_heads(k)
        v = self.split_heads(v)
        
        #for scaled dot-product
        depth_inside_each_head = self.hidden_dim // self.head_num
        q *= depth_inside_each_head ** -0.5
        
        score = tf.matmul(q, k, transpose_b=True)
        # query.dtype.min ≈ -∞
        score += tf.to_float(attention_mask) * query.dtype.min
        normalized_score = Activation("softmax")(score, name="attention_weight")
        normalized_score = self.attention_dropout_layer(normalized_score, training=train_flag)
        attention_weighted_output = tf.matmul(normalized_score, v)
        attention_weighted_output = self.combine_head(attention_weighted_output)
        return self.output_dense_layer(attention_weighted_output)

In [3]:
# SlefAttention class inherits MultiheadAttention class so that it can make query and memory come from the same source.
class SelfAttention(MultiheadAttention):
    
    def call(self, query, attention_mask, train_flag):
        return super().call(query, query, attention_mask, train_flag)

In [4]:
class PositionwiseFeedforwardNetwork(Model):
    
    def __init__(self, hidden_dim, dropout_rate, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.hidden_dim = hidden_dim
        self.dropout_rate = dropout_rate
        
        self.first_dense_layer = Dense(hidden_dim*4, use_bias=True, activation="relu", name="first_dense_layer")
        self.second_dense_layer = Dense(hidden_dim, use_bias=True, activation="linear", name="second_dense_layer")
        self.dropout_layer = Dropout(dropout_rate, name="PFFN_dropout")
        
    def call(self, input, train_flag):
        # make the network more flexible to learn for the first dense layer(non-linear transformation is used),
        # and put the network back into the same hidden dim as original(linear transformation is used)
        x = self.first_dense_layer(input)
        x = self.dropout_layer(x, training=train_flag)
        return self.second_dense_layer(x)

In [5]:
class LayerNormalization(Layer):
    def build(self, input_shape):
        hidden_dim = input_shape[-1]
        self.scale = self.add_weight("layer_norm_scale", shape=[hidden_dim],
                                    initializer="ones")
        self.shift = self.add_weight("layer_norm_shift", shape=[hidden_dim],
                                    initializer="zeros")
        super().build(input_shape)
        
    def call(self, input, epsilon=1e-6):
        mean = K.mean(input, axis=[-1], keepdims=True)
        variance = K.var(input, axis=[-1], keepdims=True)
        normalized_input = (input - mean) / (K.sqrt(variance) + epsilon)
        return normalized_input * self.scale + self.shift

In [6]:
class PreLayerNormPostResidualConnectionWrapper(Model):
    def __init__(self, layer, dropout_rate, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.layer = layer
        self.layer_norm = LayerNormalization()
        self.dropout_layer = Dropout(dropout_rate)
        
    def call(self, input, train_flag, *args, **kwargs):
        x = self.layer_norm(input)
        x = self.layer(x, training=train_flag, *args, **kwargs)
        output = self.dropout_layer(x, training=train_flag)
        return input + output

In [None]:
class AddPositionalEncoding(Layer):
    def call(self, input):
        data_type = input.dtype
        batch_size, max_len, emb_dim = input.shape
        emb_dim_counter = list(range(emb_dim)) // 2
        tmp1 = K.tile(K.expand_dims(emb_dim_counter, 0), [max_len, 1])
        tmp2 = K.pow(10000.0, K.cast(tmp1 / emb_dim, data_type))

In [15]:
test = list(range(5))
test

[0, 1, 2, 3, 4]

In [7]:
test = np.arange(40)
print(test.shape)
test = test.reshape(5,4,2)
test.dtype

(40,)


dtype('int64')

In [22]:
test.reshape(5, 2, 4, 1)

array([[[[ 0],
         [ 1],
         [ 2],
         [ 3]],

        [[ 4],
         [ 5],
         [ 6],
         [ 7]]],


       [[[ 8],
         [ 9],
         [10],
         [11]],

        [[12],
         [13],
         [14],
         [15]]],


       [[[16],
         [17],
         [18],
         [19]],

        [[20],
         [21],
         [22],
         [23]]],


       [[[24],
         [25],
         [26],
         [27]],

        [[28],
         [29],
         [30],
         [31]]],


       [[[32],
         [33],
         [34],
         [35]],

        [[36],
         [37],
         [38],
         [39]]]])

In [20]:
test.transpose(2, 0, 1)

array([[[ 0,  2,  4,  6],
        [ 8, 10, 12, 14],
        [16, 18, 20, 22],
        [24, 26, 28, 30],
        [32, 34, 36, 38]],

       [[ 1,  3,  5,  7],
        [ 9, 11, 13, 15],
        [17, 19, 21, 23],
        [25, 27, 29, 31],
        [33, 35, 37, 39]]])

In [10]:
a, b, c = test.shape

In [12]:
a, b, c

(5, 4, 2)