## I referred to the following webpages for the implementation.
- Implementation of Transformer<br>
https://qiita.com/halhorn/items/c91497522be27bde17ce<br>
https://github.com/kpot/keras-transformer/tree/master/keras_transformer<br>
https://github.com/Lsdefine/attention-is-all-you-need-keras<br>
- Usage of "__call__" method<br>
https://qiita.com/kyo-bad/items/439d8cc3a0424c45214a

In [84]:
import numpy as np
import math

import tensorflow as tf

from keras.models import Model
from keras.layers import Dense, Dropout, Activation, Layer, Embedding, Input, Reshape, Lambda
from keras import backend as K
from keras.initializers import RandomNormal

In [85]:
class MultiheadAttention():
    ## hidden_dim has to be multiples of head_num
    def __init__(self, hidden_dim=512, head_num=8, dropout_rate=0.1, *args, **kwargs):
        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.dropout_rate = dropout_rate
        
        self.q_dense_layer = Dense(hidden_dim, use_bias=False, name="q_dense_layer")
        self.k_dense_layer = Dense(hidden_dim, use_bias=False, name="k_dense_layer")
        self.v_dense_layer = Dense(hidden_dim, use_bias=False, name="v_dense_layer")
        self.output_dense_layer = Dense(hidden_dim, use_bias=False, name="output_dense_layer")
        self.attention_dropout_layer = Dropout(dropout_rate, name="attention_dropout_layer")
        
    def split_heads(self, x):
        def reshape(x):
            batch_size, max_len, hidden_dim = tf.unstack(tf.shape(x))
            x = tf.reshape(x, [batch_size, max_len, self.head_num, self.hidden_dim // self.head_num])
            return tf.transpose(x, [0, 2, 1, 3])
        
        out = Lambda(reshape)(x)
        return out
    
    def combine_heads(self, heads):
        def reshape(x):
            batch_size, _, max_len, _ = tf.unstack(tf.shape(x))
            heads = tf.transpose(x, [0, 2, 1, 3])
            return tf.reshape(x, [batch_size, max_len, self.hidden_dim])
        
        out = Lambda(reshape)(heads)
        return out
        
    def __call__(self, query, memory):
        #two arguments of query and memory are already encoded as embedded vectors for all words
        q = self.q_dense_layer(query)
        k = self.k_dense_layer(memory)
        v = self.v_dense_layer(memory)
        
        q = self.split_heads(q)
        k = self.split_heads(k)
        v = self.split_heads(v)
        
        #for scaled dot-product
        depth_inside_each_head = self.hidden_dim // self.head_num
        q *= depth_inside_each_head ** -0.5
        #q.shape = (batch_size, query_len, emb_dim)
        #k.shape = (batch_size, memory_len, emb_dim)
        #batch_dot(q, k).shape = (batch_size, query_len, memory_len)
        score = Lambda(lambda x: K.batch_dot(x[0], x[1], axes=[2, 2]))([q, k])
        normalized_score = Activation("softmax")(score)
        normalized_score = self.attention_dropout_layer(normalized_score)
        #attention_weighted_output = tf.matmul(normalized_score, v)
        
        #normalized_score.shape = (batch_size, query_length, memory_length)
        #v.shape = (batch_size, memory_length, depth)
        #attention_weighted_output.shape = (batch_size, query_length, depth)
        attention_weighted_output = Lambda(lambda x: K.batch_dot(x[0], x[1], axes=[2, 1]))([normalized_score, v])
        attention_weighted_output = self.combine_heads(attention_weighted_output)
        return self.output_dense_layer(attention_weighted_output)

In [86]:
# SlefAttention class inherits MultiheadAttention class so that it can make query and memory come from the same source.
class SelfAttention(MultiheadAttention):
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def __call__(self, query):
        return super().__call__(query, query)

In [87]:
class PositionwiseFeedForwardNetwork():
    
    def __init__(self, hidden_dim, dropout_rate, *args, **kwargs):
        self.hidden_dim = hidden_dim
        self.dropout_rate = dropout_rate
        
        self.first_dense_layer = Dense(hidden_dim*4, use_bias=True, activation="relu", name="first_dense_layer")
        self.second_dense_layer = Dense(hidden_dim, use_bias=True, activation="linear", name="second_dense_layer")
        self.dropout_layer = Dropout(dropout_rate, name="PFFN_dropout")
        
    def __call__(self, inputs):
        # make the network more flexible to learn for the first dense layer(non-linear transformation is used),
        # and put the network back into the same hidden dim as original(linear transformation is used)
        x = self.first_dense_layer(inputs)
        x = self.dropout_layer(x)
        return self.second_dense_layer(x)

In [88]:
class LayerNormalization(Layer):
    def __init__(self, axis=-1, **kwargs):
        self.axis = axis
        super(LayerNormalization, self).__init__(**kwargs)
        
    def get_config(self):
        config = super().get_config()
        config["axis"] = self.axis
        return config
    
    def build(self, input_shape):
        hidden_dim = input_shape[-1]
        self.scale = self.add_weight("layer_norm_scale", shape=[hidden_dim],
                                    initializer="ones")
        self.shift = self.add_weight("layer_norm_shift", shape=[hidden_dim],
                                    initializer="zeros")
        super(LayerNormalization, self).build(input_shape)
        
    def call(self, inputs, epsilon=1e-6):
        mean = K.mean(inputs, axis=[-1], keepdims=True)
        variance = K.var(inputs, axis=[-1], keepdims=True)
        normalized_inputs = (inputs - mean) / (K.sqrt(variance) + epsilon)
        return normalized_inputs * self.scale + self.shift
    
    def compute_output_shape(self, input_shape):
        return input_shape

In [89]:
class PreLayerNormPostResidualConnectionWrapper():
    def __init__(self, layer, dropout_rate, *args, **kwargs):
        self.layer = layer
        self.layer_norm = LayerNormalization()
        self.dropout_layer = Dropout(dropout_rate)
        
    def __call__(self, inputs, *args, **kwargs):
        x = self.layer_norm(inputs)
        x = self.layer(x, *args, **kwargs)
        output = self.dropout_layer(x)
        return inputs + output

In [90]:
class AddPositionalEncoding(Layer):
    def call(self, inputs):
        data_type = inputs.dtype
        print("add pos inputs", inputs)
        batch_size, max_len, emb_dim = tf.unstack(tf.shape(inputs))
        # i is from 0 to 255 when emb_dim is 512
        #so the doubled_i is from 0 to 510
        doubled_i = K.arange(emb_dim) // 2 * 2
        exponent = K.tile(K.expand_dims(doubled_i, 0), [max_len, 1])
        denominator_matrix = K.pow(10000.0, K.cast(exponent / emb_dim, data_type))
        
        # since cos(x) = sin(x + π/2), we convert the series of [sin, cos, sin, cos, ...]
        # into [sin, sin, sin, sin, ...]
        to_convert = K.cast(K.arange(emb_dim) % 2, data_type) * math.pi / 2
        convert_matrix = K.tile(tf.expand_dims(to_convert, 0), [max_len, 1])
        
        seq_pos = K.arange(max_len)
        numerator_matrix = K.cast(K.tile(K.expand_dims(seq_pos, 1), [1, emb_dim]), data_type)
        
        positinal_encoding = K.sin(numerator_matrix / denominator_matrix + convert_matrix)
        batched_positional_encoding = K.tile(K.expand_dims(positinal_encoding, 0), [batch_size, 1, 1])
        return inputs + batched_positional_encoding
    
    def compute_output_shape(self, input_shape):
        print("input_shape", input_shape)
        return input_shape

In [94]:
PAD_ID = 0

class TokenEmbedding(Layer):
    def __init__(self, seq_len, vocab_size, emb_dim, data_type="float32", *args, **kwargs):
        self.seq_len = seq_len
        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.data_type = data_type
        super(TokenEmbedding, self).__init__(*args, **kwargs)
        
    def build(self, input_shape):
        self.embedding_layer = Embedding(self.vocab_size,
                                   self.emb_dim,
                                   embeddings_initializer=RandomNormal(mean=0.0, stddev=self.emb_dim**-0.5)
                                  )
        super(TokenEmbedding, self).build(input_shape)
        
    def call(self, inputs):
        mask_for_pads = tf.to_float(tf.not_equal(inputs, PAD_ID))
        embedding = self.embedding_layer(inputs)
        pads_masked_embedding = embedding * tf.expand_dims(mask_for_pads, -1)
        return pads_masked_embedding * (self.emb_dim ** 0.5)
    
    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.seq_len, self.emb_dim)

In [95]:
class Encoder():
    def __init__(self, vocab_size, stack_num, head_num, emb_dim, dropout_rate, max_len, *args, **kwargs):
        self.stack_num = stack_num
        self.head_num = head_num
        self.emb_dim = emb_dim
        self.dropout_rate = dropout_rate
        
        self.token_emb_layer = TokenEmbedding(max_len, vocab_size, emb_dim)
        self.add_pos_enc_layer = AddPositionalEncoding()
        self.input_dropout_layer = Dropout(dropout_rate)
        
        self.attention_block_list = []
        for _ in range(stack_num):
            self_attention_layer = SelfAttention(emb_dim, head_num, dropout_rate, name="self_attention_layer")
            pffn_layer = PositionwiseFeedForwardNetwork(emb_dim, dropout_rate, "pffn_layer")
            self.attention_block_list.append([
                PreLayerNormPostResidualConnectionWrapper(self_attention_layer, dropout_rate, name="prepos_self_attention_wrapper"),
                PreLayerNormPostResidualConnectionWrapper(pffn_layer, dropout_rate, name="prepos_pffn_wrapper")
            ])
        self.output_layer_norm = LayerNormalization()
        
    def __call__(self, inputs):
        x = self.token_emb_layer(inputs)
        print("first x", x)
        x = self.add_pos_enc_layer(x)
        print("second x", x)
        x = self.input_dropout_layer(x)
        print("third x", x)
        
        for i, set_of_layers_list in enumerate(self.attention_block_list):
            self_attention_layer, pffn_layer = tuple(set_of_layers_list)
            print("self_attention_layer", self_attention_layer)
            print("fourth x", x)
            x = self_attention_layer(x)
            print("fifth x", x)
            x = pffn_layer(x)
            
        return self.output_layer_norm(x)

In [96]:
# Transformer classification model
MAX_LEN = 717

inputs = Input(shape=(MAX_LEN,))
transformer_encoder = Encoder(vocab_size=8000, stack_num=6, head_num=8, emb_dim=512, dropout_rate=0.1, max_len=MAX_LEN)
encoder_output = transformer_encoder(inputs)
print("encoder output", encoder_output)
summarized_vecs = encoder_output[:, 0, :]
outputs = Dense(MAX_LEN, activation="softmax")(summarized_vecs)
print("outputs", outputs)
model = Model(inputs, outputs)
model.summary()

first x Tensor("token_embedding_18/mul_1:0", shape=(?, 717, 512), dtype=float32)
add pos inputs Tensor("token_embedding_18/mul_1:0", shape=(?, 717, 512), dtype=float32)
input_shape (None, 717, 512)
second x Tensor("add_positional_encoding_18/add_1:0", shape=(?, 717, 512), dtype=float32)
third x Tensor("dropout_186/cond/Merge:0", shape=(?, 717, 512), dtype=float32)
self_attention_layer <__main__.PreLayerNormPostResidualConnectionWrapper object at 0x7fff19f1ab00>
fourth x Tensor("dropout_186/cond/Merge:0", shape=(?, 717, 512), dtype=float32)
fifth x Tensor("add_120:0", shape=(?, 717, 512), dtype=float32)
self_attention_layer <__main__.PreLayerNormPostResidualConnectionWrapper object at 0x7fff1ce433c8>
fourth x Tensor("add_121:0", shape=(?, 717, 512), dtype=float32)
fifth x Tensor("add_122:0", shape=(?, 717, 512), dtype=float32)
self_attention_layer <__main__.PreLayerNormPostResidualConnectionWrapper object at 0x7fff19f170f0>
fourth x Tensor("add_123:0", shape=(?, 717, 512), dtype=float32

AttributeError: 'NoneType' object has no attribute '_inbound_nodes'