# 22 Inference with the Transformer Model

In [1]:
from pickle import load

from tensorflow import (
    Module,
    TensorArray,
    argmax,
    convert_to_tensor,
    int64,
    newaxis,
    transpose,
)
from tensorflow.keras.preprocessing.sequence import pad_sequences

from xformer.model import Xformer

## 22.1 Performing Inference the Transformer Model

In [2]:
# Define the model parameters
h = 8  # Number of self-attention heads
d_model = 512  # Dimensionality of model layers' outputs
d_ff = 2048  # Dimensionality of the inner fully connected layer
n = 6  # Number of layers in the encoder stack

# Define the dataset parameters
enc_seq_length = 7  # Encoder sequence length
dec_seq_length = 12  # Decoder sequence length
enc_vocab_size = 2405  # Encoder vocabulary size
dec_vocab_size = 3864  # Decoder vocabulary size

# Create model
trained_model = Xformer(
    enc_vocab_size,
    dec_vocab_size,
    enc_seq_length,
    dec_seq_length,
    h,
    d_model,
    d_ff,
    n,
    0,
)  # Note that dropout_rate is zero for inference

We need to use the same tokenizer that we used for the encoder for training to tokenize our test sequences. We use the decoder tokenizer to tokenize the special START and END of sequence tokens for the output.  
When preparing the output array, we don't know what its size is going to be, so we initialize it with size zero, but set the `dynamic_size` parameter to `True` to allow it it to grow dynamically. Then we add the START token to it.  
Notice how at inference time, the transformer model works iteratively. It's called 'autoregressive generation'. There is no 'teacher forcing' as with during training. So the decoder needs to generate output one token at a time (using some strategy; we've chosen 'greedy' decoding. But do read up on 'beam search', which is another important alternative). We iterate until the maximum decoder output length is reached, or the special EOS token is predicted and generated.

In [3]:
class Translate(Module):
    def __init__(self, inferencing_model, **kwargs):
        super().__init__(**kwargs)
        self.transformer = trained_model

    def load_tokenizer(self, name):
        with open(name, "rb") as handle:
            return load(handle)

    def __call__(self, sentence):
        sentence[0] = "<SEQSTART> " + sentence[0] + " <EOS>"

        enc_tokenizer = self.load_tokenizer("data/enc_tokenizer.pkl")
        dec_tokenizer = self.load_tokenizer("data/dec_tokenizer.pkl")

        # Prepare the input sentence by tokenizing, padding and converting to tensor
        encoder_input = enc_tokenizer.texts_to_sequences(sentence)
        encoder_input = pad_sequences(
            encoder_input, maxlen=enc_seq_length, padding="post"
        )
        encoder_input = convert_to_tensor(encoder_input, dtype=int64)
        
        # Prepare the output <SEQSTART> token by tokenizing, and converting to tensor
        output_start = dec_tokenizer.texts_to_sequences(["<SEQSTART>"])
        output_start = convert_to_tensor(output_start[0], dtype=int64)
        
        # Prepare the output <EOS> token by tokenizing, and converting to tensor
        output_end = dec_tokenizer.texts_to_sequences(["<EOS>"])
        output_end = convert_to_tensor(output_end[0], dtype=int64)
        
        # Prepare the output array of dynamic size
        decoder_output = TensorArray(dtype=int64, size=0, dynamic_size=True)
        decoder_output = decoder_output.write(0, output_start)
        
        for i in range(dec_seq_length):
            # Predict an output token
            prediction = self.transformer(encoder_input,transpose(decoder_output.stack()), training=False)
            prediction = prediction[:, -1, :]
            
            # Select the prediction with the highest score
            predicted_id = argmax(prediction, axis=-1)
            predicted_id = predicted_id[0][newaxis]
            
            # Write the selected prediction to the output array at the next
            # available index
            decoder_output = decoder_output.write(i + 1, predicted_id)
            
            # Break if an <EOS> token is predicted
            if predicted_id == output_end:
                break
            
        output = transpose(decoder_output.stack())[0]
        output = output.numpy()

        output_str = []

        # Decode the predicted tokens into an output string
        for i in range(output.shape[0]):
            key = output[i]
            output_str.append(dec_tokenizer.index_word[key])

        return output_str

## 22.2 Testing Out the Code

My validation loss kind of plateaued after the 8th epoch, so we will use the weights from that epoch for inference. (See learning curves in the previous chapter).  
**Note:** In the book, their learning curves look better and validation loss only starts to plateau after the 16th epoch. But mine actually gets _worse_ starting around the 10th-12th epoch. I have some nagging suspicions that the book's implementation has some bugs and indeed it is even more likley that my own code has introduced bugs as I tried to correct some of the mistakes and shortcomings of the code from the book. But we will let it all slide. This was a learning project anyway, and I have learned a great deal, in a very in-depth and hands-on manner, about the inner workings of attention mechanisms and transformer models. \*pat myself on the back\*  
Remember: Transformers are notoriously data-hungry. The more training data we feed it, the better it's going to perform (to a point, supposedly).  
And now for the last hurrah...

In [9]:
# Sentence to translate
sentence = ['im thirsty']
# Ideally we should get “<SEQSTART> ich bin durstig <EOS>"

# Load the trained model's weights at the specified epoch
trained_model.load_weights('weights/wghts8.ckpt')

# Create a new instance of the 'Translate' class
translator = Translate(trained_model)

# Translate the input sentence
print(translator(sentence))

['seqstart', 'es', 'ist', 'nicht', 'eos']


...  
It's... something!! 😅🤷‍♂️