# 17 Implementing the Transformer Encoder in Keras

In [1]:
from numpy import random
from tensorflow.keras.layers import (
    Dense,
    Dropout,
    Layer,
    LayerNormalization,
    ReLU,
)

from xformer.multihead_attention import MultiHeadAttention
from xformer.positional_encoding import CustomEmbeddingWithFixedPosnWts

## 17.1 Recap of the Transformer Encoder

Recall that the encoder block is a stack of N identical layers. Each layer consists of a multi-head self-attention layer which we expatiated on in Ch. 16. Now we will add some further important missing details.  

- The multi-head self-attention is one of _two_ sub-layers in each stack of the encoder. The _other_ sub-layer is a fully-connected feed-forward layer.
- After each of the aforementioned two sub-layers, there's a normalization layer which first adds the sublayer's output to its inputs (this forms what we call a "residual connection") and then normalizes the result.
- Regularization is performed by applying a dropout layer to the outputs of each of the above 2 "sub-layers" right before the normalization step, as well as to the positionally-encoded embeddings right before they are fed into the encoder.

## 17.2 Implementing the Transformer Encoder from Scratch

Note: We will reuse the multi-head attention and the positional embedding logic we implemented in previous chapters.

### The Feedforward Network and Layer Normalization

In AIAYN this is simply two fully-connected (AKA Linear) layers with a ReLU activation in between. The first FF layer's output has dims $d_{ff}=2048$ and the second one brings it back to $d_{model}=512$.

In [2]:
class FeedForward(Layer):
    def __init__(self, d_ff, d_model, **kwargs):
        super().__init__(**kwargs)
        self.fully_connected_1 = Dense(d_ff)  # First fully-connected layer
        self.fully_connected_2 = Dense(d_model)  # Second fully-connected layer
        self.activation = ReLU()  # ReLU activation layer to come in between

    def call(self, x):
        # The input is passed into the two fully-connected layers, with a ReLU in between
        fc1_output = self.fully_connected_1(x)
        fc2_output = self.fully_connected_2(self.activation(fc1_output))
        return fc2_output

Next, we define our "Layer Normalization" layer. [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf), not to be confused with but in many ways similar to [Batch Normalization](https://arxiv.org/pdf/1502.03167.pdf), is a way of ensuring better, more stable training.

In [3]:
class AddAndNorm(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.layer_norm = LayerNormalization()  # Layer normalization layer

    def call(self, x, sublayer_x):
        # Note: The sublayer's input and output need to be of the same shape to be summable
        add = x + sublayer_x
        # Apply layer normalization to the sum
        return self.layer_norm(add)

### The Encoder Layer

Next, we will define what an encoder layer looks like. **Note:** I may have used the word "encoder block" elsewhere. Going forward, I will try to stay consistent and use "encoder layer". Just picture AIAYN's block diagram and recall that they stack N=6 of for these to form their transformer's encoder "block". But we'll get to that in the next section.  
The `training` flag in the `call()` function is there so that we don't perform dropout regularization during testing and inference.  
The `padding_mask` argument, as explained in previous chapters, is to suppress zero padding tokens in input sequences from being processed along with valid input tokens.

In [4]:
class EncoderLayer(Layer):
    def __init__(self, n_heads, d_model, d_ff, dropout_rate, **kwargs):
        super().__init__(**kwargs)
        self.multihead_attention = MultiHeadAttention(n_heads, d_model)
        self.dropout1 = Dropout(dropout_rate)
        self.add_norm1 = AddAndNorm()
        self.feed_forward = FeedForward(d_ff, d_model)
        self.dropout2 = Dropout(dropout_rate)
        self.add_norm2 = AddAndNorm()

    def call(self, x, padding_mask, training):
        # Multi-head attention layer
        multihead_output = self.multihead_attention(x, x, x, padding_mask)
        # Expected output shape = (batch_size, sequence_length, d_model)
        # Add in a dropout layer
        multihead_output = self.dropout1(multihead_output, training=training)
        # Followed by an Add & Norm layer
        addnorm_output = self.add_norm1(x, multihead_output)
        # Expected output shape = (batch_size, sequence_length, d_model)
        # Followed by a fully connected layer
        feedforward_output = self.feed_forward(addnorm_output)
        # Expected output shape = (batch_size, sequence_length, d_model)
        # Add in another dropout layer
        feedforward_output = self.dropout2(
            feedforward_output, training=training
        )
        # Followed by another Add & Norm layer
        return self.add_norm2(addnorm_output, feedforward_output)

### The Transformer Encoder

We are now finally ready to stack these encoder layers to form our transformer encoder. It receives our input sequences, which have gone through tokenization, wod embedding and positional encoding. (We are re-using our `CustomEmbeddingWithFixedPosnWts` class from chapter 14 for that).

In [5]:
class Encoder(Layer):
    def __init__(
        self,
        vocab_size,
        sequence_length,
        n_heads,
        d_model,
        d_ff,
        n_enc_layers,
        dropout_rate,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.wrd_emb_posn_enc = CustomEmbeddingWithFixedPosnWts(
            sequence_length, vocab_size, d_model
        )
        self.dropout = Dropout(dropout_rate)
        self.encoder_layers = [
            EncoderLayer(n_heads, d_model, d_ff, dropout_rate)
            for _ in range(n_enc_layers)
        ]

    def call(self, input_sentence, padding_mask, training):
        # Generate the word embeddings & positional encodings
        emb_enc_output = self.wrd_emb_posn_enc(input_sentence)
        # Expected output shape = (batch_size, sequence_length, d_model)
        # Add in a dropout layer
        x = self.dropout(emb_enc_output, training=training)
        # Feed the result into the stack of encoder layers
        for i, layer in enumerate(self.encoder_layers):
            x = layer(x, padding_mask, training)
        return x

## 17.3 Testing Out the Code

As usual, we will use the parameter values specified in AIAYN and dummy data for our input sequences (until chapter 20).

In [7]:
h = 8  # Number of self-attention heads
d_ff = 2048  # Dimensionality of the inner fully-connected layer
d_model = 512  # Dimensionality of the model
n = 6  # Number of layers in the encoder stack

batch_size = 64  # Batch size from the training process
dropout_rate = 0.1  # Frequency of dropping the input units in dropout layers

enc_vocab_size = 20  # Vocabulary size for the encoder
input_seq_length = 5  # Maximum length of the input sequence
input_seq = random.random((batch_size, input_seq_length))

encoder = Encoder(
    enc_vocab_size, input_seq_length, h, d_model, d_ff, n, dropout_rate
)

print(encoder(input_seq, None, True))

tf.Tensor(
[[[-6.4046764e-01 -2.5869149e-01  2.9826114e-02 ...  1.7524600e+00
    8.3793515e-01  1.4727314e+00]
  [-2.5010207e-01  8.2378335e-02  7.0842355e-01 ...  2.1619058e+00
    4.8494333e-01  5.1776910e-01]
  [-4.0501606e-01 -7.2256976e-01  6.9958973e-01 ...  1.5095438e+00
    9.2207301e-01  2.0221012e+00]
  [-7.8895479e-01 -6.4195299e-01  3.8406304e-01 ...  1.4902123e+00
   -1.3452987e-01  9.5190138e-01]
  [-7.2948223e-01 -1.0197861e+00  4.0769660e-01 ...  2.4138000e+00
    2.5342479e-01  1.1871268e+00]]

 [[-1.2775974e+00 -6.0047591e-01  3.1905037e-01 ...  1.7848492e+00
   -5.4142863e-01  1.3197116e+00]
  [-8.6346459e-01 -4.9204442e-01  2.2519235e-01 ...  1.5220001e+00
    2.3462519e-02  1.4124132e+00]
  [-1.0557163e+00 -8.8454580e-01  9.2237324e-02 ...  1.4283720e+00
   -2.2392024e-01  1.4467469e+00]
  [-1.5382400e+00 -7.9968345e-01  2.0473975e-01 ...  1.2802440e+00
    1.0641914e-01  1.6147338e+00]
  [-1.8815150e+00  1.0320888e-02  8.3954060e-01 ...  9.6625865e-01
    1.83956