# Implementing the Transformer Decoder in Keras

Here again what the decoder does is that it takes the input (target) sequence's "raw" embeddings and, through multiple self-attention + cross-attention layers, gradually _transforms_ them to projections that are more representative and meaningful based on overall context. This is, of course, done in parallel over all tokens all at once. (At inference we generate one token at a time, but we'll get to inference in future chapters).

In [1]:
from numpy import random
from tensorflow import shape
from tensorflow.keras.layers import Dropout, Layer

from xformer.common import AddAndNorm, FeedForward
from xformer.multihead_attention import MultiHeadAttention
from xformer.positional_encoding import CustomEmbeddingWithFixedPosnWts

## 18.1 Recap of the Transformer Decoder

In NLP, typically a sequence-to-sequence model such as translation would have be an encoder+decoder transformer. Once again, the architecture of the decoder has _a lot_ in common with the encoder. Let's focus on its few differences and important details.  

It has _three_ sub-layers instead of two:
1. One multi-head self-attention with queries, keys and values coming from the embedded and positionally-encoded input sequences. This is architecturally identical to the multi-head self-attention layer in the encoder. (Just remember that the inputs come from _target_ sentences).
2. It has an additional multi-head attention sub-layer which is _not_ self-attending. It gets its keys and values coming from the output of the transformer's encoder and it gets is queries from its own multi-head self-attention sub-layer (#1 above). We can call this one multi-head _cross_-attention if you want.
3. As with the encoder, it has a fully-connected feed-forward sub-layer after that.  

As with the encoder, each sub-layer above is followed by an "Add-and-Norm" layer normalization sub-layer. And just as before, regularization is performed by applying a dropout layer to the outputs of each of the above 3 sub-layers right before the normalization step, as well as to the positionally-encoded embeddings right before they are fed into the decoder.

## 18.2 Implementing the Transformer Decoder from Scratch

Let's just jump right into implementing it, starting by defining the `DecoderLayer` and `Decoder` classes, simililarly to how we did things for the encoder and reusing a lot of the code from before.

In [2]:
# Implementing the Decoder Layer
class DecoderLayer(Layer):
    def __init__(self, n_heads, d_model, d_ff, dropout_rate, **kwargs):
        super().__init__(**kwargs)
        self.multihead_attention1 = MultiHeadAttention(n_heads, d_model)
        self.dropout1 = Dropout(dropout_rate)
        self.add_norm1 = AddAndNorm()
        self.multihead_attention2 = MultiHeadAttention(n_heads, d_model)
        self.dropout2 = Dropout(dropout_rate)
        self.add_norm2 = AddAndNorm()
        self.feed_forward = FeedForward(d_ff, d_model)
        self.dropout3 = Dropout(dropout_rate)
        self.add_norm3 = AddAndNorm()

    def call(self, x, lookahead_mask, encoder_output, encoder_padding_mask, training):
        # Multi-head self-attention layer
        multihead_output1 = self.multihead_attention1(x, x, x, lookahead_mask)
        # Expected output shape = (batch_size, sequence_length, d_model)

        # Add in a dropout layer
        multihead_output1 = self.dropout1(multihead_output1, training=training)

        # Followed by an Add & Norm layer
        addnorm_output1 = self.add_norm1(x, multihead_output1)
        # Expected output shape = (batch_size, sequence_length, d_model)

        # Followed by another multi-head (cross-)attention layer
        multihead_output2 = self.multihead_attention2(
            addnorm_output1, encoder_output, encoder_output, encoder_padding_mask
        )

        # Add in another dropout layer
        multihead_output2 = self.dropout2(multihead_output2, training=training)

        # Followed by another Add & Norm layer
        addnorm_output2 = self.add_norm2(addnorm_output1, multihead_output2)

        # Followed by a fully connected layer
        feedforward_output = self.feed_forward(addnorm_output2)
        # Expected output shape = (batch_size, sequence_length, d_model)
        # Add in another dropout layer
        feedforward_output = self.dropout3(
            feedforward_output, training=training
        )
        # Followed by another Add & Norm layer
        return self.add_norm3(addnorm_output2, feedforward_output)

In [3]:
# Implementing the Decoder
class Decoder(Layer):
    def __init__(
        self,
        vocab_size,
        sequence_length,
        n_heads,
        d_model,
        d_ff,
        n_dec_layers,
        dropout_rate,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.wrd_emb_posn_enc = CustomEmbeddingWithFixedPosnWts(
            sequence_length, vocab_size, d_model
        )
        self.dropout = Dropout(dropout_rate)
        self.decoder_layer = [
            DecoderLayer(n_heads, d_model, d_ff, dropout_rate)
            for _ in range(n_dec_layers)
        ]

    def call(
        self,
        target_sequence,
        lookahead_mask,
        encoder_output,
        encoder_padding_mask,
        training,
    ):
        # Generate the positional encoding
        pos_encoding_output = self.wrd_emb_posn_enc(target_sequence)
        # Expected output shape = (number of sentences, sequence_length, d_model)
        # Add in a dropout layer
        x = self.dropout(pos_encoding_output, training=training)
        # Pass on the positional encoded values to each encoder layer
        for i, layer in enumerate(self.decoder_layer):
            x = layer(x, lookahead_mask, encoder_output, encoder_padding_mask, training)
        return x

## 18.3 Testing Out the Code

As before, let's test it out with parameter values from AIAYN. We'll use dummy data for the target sequences _and_ for our encoder output. Also we won't be using masks yet.

In [5]:
h = 8  # Number of self-attention heads
d_ff = 2048  # Dimensionality of the inner fully-connected layer
d_model = 512  # Dimensionality of the model
n = 6  # Number of layers in the encoder stack

batch_size = 64  # Batch size from the training process
dropout_rate = 0.1  # Frequency of dropping the input units in dropout layers

dec_vocab_size = 20  # Vocabulary size for the decoder
input_seq_length = 5  # Maximum length of the input sequence
input_seq = random.random((batch_size, input_seq_length))
enc_output = random.random((batch_size, input_seq_length, d_model))

decoder = Decoder(
    dec_vocab_size, input_seq_length, h, d_model, d_ff, n, dropout_rate
)

print(decoder(input_seq, None, enc_output, None, True))

tf.Tensor(
[[[-8.2598048e-01 -1.6173053e-01  3.3324018e+00 ...  2.3628049e+00
    4.7588381e-01  4.0246236e-01]
  [ 1.2710889e-01 -4.0063667e-01  3.0413105e+00 ...  2.6443102e+00
    1.1692477e-01  9.2432606e-01]
  [ 6.5477896e-01 -4.0979478e-01  2.6779525e+00 ...  2.3136461e+00
    9.8803163e-01  6.7036331e-01]
  [-2.9772779e-01 -9.3412715e-01  2.9222012e+00 ...  1.8688526e+00
    7.0343709e-01  8.1980890e-01]
  [-1.1378692e-01 -2.8452662e-01  1.7573909e+00 ...  1.8726753e+00
    8.8464886e-01  1.7555189e-01]]

 [[-9.1874218e-03 -6.5855700e-01  3.6245153e+00 ...  1.7825265e+00
    6.5002859e-01  2.3037803e-01]
  [ 6.8538249e-01 -6.6915894e-01  3.7835851e+00 ...  1.7818812e+00
    8.2844245e-01  5.7019848e-01]
  [ 8.2658291e-01  4.1048513e-03  3.9944384e+00 ...  1.0317283e+00
    8.5510725e-01  3.9327469e-02]
  [ 3.9483979e-01 -6.3777500e-01  3.0753524e+00 ...  1.7436596e+00
    5.2387822e-01  5.7222742e-01]
  [ 2.6328373e-01 -2.5644457e-01  3.3943105e+00 ...  1.9001243e+00
    1.08589