# Encoder 

#### Multi-Head Self-Attention

Q => Queries <br>
K => Keysz   <br>
V => Values  <br>

Attention (Q, K, V) = softmax( (Q* K**T) / (sqrt(dimension_of_K) ) ) * V

In [None]:
def scaled_dot_product_attention(query, key, value, mask=None):
    key_dimension = tf.cast(tf.shape(key)[-1], tf.float32)
    scaled_scores = tf.matmul(query, key, transpose_b=True) / np.sqrt(key_dimension)

    if mask is not None:
        scaled_scores = tf.where(mask==0, -np.inf, scaled_scores)
    
    softmax = tf.keras.layers.Softmax()
    weights = softmax(scaled_scores)

    return tf.matmul(weights, value), weights

In [None]:
class MultHeadSelfAttention(tf.keras.layers.Layer):
    def __init__(self, dimension_model, num_heads):
        super(MultHeadSelfAttention, self).__init__()
        self.dimension_model = dimension_model
        self.num_heads = num_heads

        self.dimension_head = self.dimension_model // self.num_heads

        self.query_weights = tf.keras.layers.Dense(self.dimension_model)
        self.key_weights = tf.keras.layers.Dense(self.dimension_model)
        self.value_weights = tf.keras.layers.Dense(self.dimension_model)

        self.dense = tf.keras.layers.Dense(self.dimension_model)
    
    def split_heads(self, x):
        batch_size = x.shape[0]

        split_inputs = tf.reshape(x, (batch_size, -1, self.num_heads, self.dimension_head))
        return tf.transpose(split_inputs, prem=[0, 2, 1, 3])
    
    def merge_heads(self, x):
        batch_size = x.shape[0]

        merge_inputs = tf.transpose(x, prem=[0, 2, 1 ,3])
        return tf.reshape(merge_inputs, (batch_size, -1, self.dimension_model))
    
    def call(self, q, k, v, mask):
        qs = self.query_weights(q)
        ks = self.key_weights(k)
        ws = self.value_weights(v)

        output, attention_weights = scaled_dot_product_attention(qs, ks, ws, mask)
        output = self.merge_heads(output)

        return self.dense(output), attention_weights

In [None]:
def feed_forward_network(dimension_model, hidden_dimension):
    return tf.keras.Sequential([
        tf.keras.layers.Dense(hidden_dimension, activation='relu'),
        tf.keras.layers.Dense(dimension_model)
    ])

In [None]:
class EncoderBlock(tf.keras.layers.Layer):
    def __init__(self, dimension_model, num_heads, hidden_dimension, dropout_rate=0.1):
        super(EncoderBlock, self).__init__()

        self.mhsa = MultHeadSelfAttention(dimension_model, num_heads)
        self.ffn = feed_forward_network(dimension_model, hidden_dimension)

        self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout2 = tf.keras.layers.Dropout(dropout_rate)

        self.layernorm1 = tf.keras.layers.LayerNormalization()
        self.layernorm2 = tf.keras.layers.LayerNormalization()
    
    def call(self, x, training, mask):
        mhsa_output, attention_weights = self.mhsa(x, x, x, mask)
        # drop out
        mhsa_output = self.dropout1(mhsa_output, training=training)
        # skip connection
        mhsa_output = self.layernorm1(x + mhsa_output)

        ffn_output = self.ffn(mhsa_output)
        ffn_output = self.dropout2(ffn_output, trainin=training)
        output = self.layernorm2(mhsa_output + ffn_output)

        return output, attention_weights

In [None]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self, num_blocks, dimension_model, num_heads, hidden_dimension, src_vocab_size, max_seq_len, dropout_rate):
        super(Encoder, self).__init__()

        self.dimension_model = dimension_model
        self.max_sql_len = max_seq_len

        self.token_embedding = tf.keras.layers.Embedding(src_vocab_size, self.dimension_model)
        self.positonal_embedding = tf.keras.layers.Embedding(max_seq_len, self.dimension_model)

        self.dropout = tf.keras.layers.Dropout(dropout_rate)

        self.blocks = [EncoderBlock(self.dimension_model, num_heads, hidden_dimension, dropout_rate)
                       for _ in range(num_blocks)]
    
    def call(self, input, training, mask):
        token_embeddings = self.token_embedding(input)

        num_pos = input.shape[1] * self.max_sql_len
        positional_index = np.resize(np.range(self.max_sql_len), num_pos)
        positional_index = np.reshape(positional_index)
        positional_embeddings = self.positonal_embedding(positional_index)

        x = self.dropout(token_embeddings + positional_embeddings, training=training)

        for block in self.blocks:
            x, weights = block(x, training, mask)
        
        return x, weights