# 20 Training the Transformer Model

In [1]:
from pickle import load
from time import time

from keras.losses import sparse_categorical_crossentropy
from numpy.random import shuffle
from tensorflow import (
    GradientTape,
    TensorSpec,
    argmax,
    cast,
    convert_to_tensor,
    data,
    equal,
    float32,
    function,
    int64,
    math,
    reduce_sum,
    train,
)
from tensorflow.keras.metrics import Mean
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.optimizers.schedules import LearningRateSchedule
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.preprocessing.text import Tokenizer

from xformer.model import Xformer

## 20.1 Preparing the Training Dataset

The dataset is already standardized and clean (no punctuation, all lowercae, etc.) and it can be downloaded from [here](https://github.com/Rishav09/Neural-Machine-Translation-System/blob/master/english-german-both.pkl).  
The class below loads the data, selects a subset of it for demonstration purposes (because it's very large), appends special `<START>` and `<EOS>` tokens to the beginning and end of the sequences, splits them based on a pre-defined ratio (the train-test split), tokenizes the input and target sequences separately and uses these to deduce the maximum sequence length and vocabulary size for the encoder and decoder respectively.  
Let's go!

In [2]:
class PrepareDataset:
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.n_sentences = (
            10_000  # Number of sentences to include in the dataset
        )
        self.train_split = 0.9  # Proportion of the data to use for training

    # Fit a tokenizer
    def create_tokenizer(self, dataset):
        tokenizer = Tokenizer()
        tokenizer.fit_on_texts(dataset)

        return tokenizer

    def find_seq_length(self, dataset):
        return max(len(seq.split()) for seq in dataset)

    def find_vocab_size(self, tokenizer, dataset):
        tokenizer.fit_on_texts(dataset)

        return len(tokenizer.word_index) + 1

    def __call__(self, filename, **kwargs):
        # Load a clean dataset
        clean_dataset = load(open(filename, "rb"))

        # Reduce dataset size
        dataset = clean_dataset[: self.n_sentences, :]

        # Include start and end of string tokens
        # Note: The book uses <START> but that is no good since it will be
        # cleaned and lowercased to "start" and get mixed up with the actual
        # English word "start", which does appear in the training data.
        for i in range(dataset[:, 0].size):
            dataset[i, 0] = "<SEQSTART> " + dataset[i, 0] + " <EOS>"
            dataset[i, 1] = "<SEQSTART> " + dataset[i, 1] + " <EOS>"

        # Random shuffle the dataset
        shuffle(dataset)

        # Split the dataset
        train = dataset[: int(self.n_sentences * self.train_split)]

        # Prepare tokenizer for the encoder input
        enc_tokenizer = self.create_tokenizer(train[:, 0])
        enc_seq_length = self.find_seq_length(train[:, 0])
        enc_vocab_size = self.find_vocab_size(enc_tokenizer, train[:, 0])

        # Encode and pad the input sequences
        trainX = enc_tokenizer.texts_to_sequences(train[:, 0])
        trainX = pad_sequences(trainX, maxlen=enc_seq_length, padding="post")
        trainX = convert_to_tensor(trainX, dtype=int64)

        # Prepare tokenizer for the decoder input
        dec_tokenizer = self.create_tokenizer(train[:, 1])
        dec_seq_length = self.find_seq_length(train[:, 1])
        dec_vocab_size = self.find_vocab_size(dec_tokenizer, train[:, 1])

        # Encode and pad the input sequences
        trainY = dec_tokenizer.texts_to_sequences(train[:, 1])
        trainY = pad_sequences(trainY, maxlen=dec_seq_length, padding="post")
        trainY = convert_to_tensor(trainY, dtype=int64)

        return (
            trainX,
            trainY,
            train,
            enc_seq_length,
            dec_seq_length,
            enc_vocab_size,
            dec_vocab_size,
        )

Let's test it and take a look at some sample sentence pairs.

In [3]:
# Prepare the training data
dataset = PrepareDataset()
(
    trainX,
    trainY,
    train_orig,
    enc_seq_length,
    dec_seq_length,
    enc_vocab_size,
    dec_vocab_size,
) = dataset("data/english-german-both.pkl")

In [4]:
print(
    train_orig[0, 0],
    trainX[0, :],
    "\n",
    train_orig[0, 1],
    trainY[0, :],
    sep="\n",
)

<SEQSTART> toms conscious <EOS>
tf.Tensor([   1   44 1442    2    0    0    0], shape=(7,), dtype=int64)


<SEQSTART> tom ist bei bewusstsein <EOS>
tf.Tensor([  1   5   4 196 552   2   0   0   0   0   0   0], shape=(12,), dtype=int64)


It's a dataset of very short English and German sentence pairs.

In [5]:
print("Encoder sequence length:", enc_seq_length)
print("Decoder sequence length:", dec_seq_length)

Encoder sequence length: 7
Decoder sequence length: 12


## 20.2 Applying a Padding Mask
### (And Introducing the Loss Function and Accuracy Metric)

So, here's the thing: Just masking the input and target sequences was not enough. We also need to exclude the masked tokens from being used in the calculation of our loss function and our accuracy metric.  
We will be using a sparse categorical cross-entropy loss function. Here's the implementation.

In [6]:
def loss_fn(target, prediction):
    # Create mask so that the zero padding values are not included
    # in the computation of loss
    mask = math.logical_not(equal(target, 0))
    mask = cast(mask, float32)

    # Compute a sparse categorical cross-entropy loss on the unmasked values
    loss = (
        sparse_categorical_crossentropy(target, prediction, from_logits=True)
        * mask
    )

    # Compute the mean loss over the unmasked values
    return reduce_sum(loss) / reduce_sum(mask)

Note that the output of the decoder is a tensor of shape `(batch_size, dec_seq_length, dec_vocab_size)` and its values represent the probabilities for each vocabulary token at each position in the output sequence. In order to compare the output to the target sequence, we will pick only the highest probability token at each position (and retrieve its corresponding token/word using `argmax`) and calculate the average accuracy (which is 0 or 1 for an individual token) over all unmasked values:

In [7]:
def accuracy_fn(target, prediction):
    # Create mask so that the zero padding values are not included in the
    # computation of accuracy
    mask = math.logical_not(math.equal(target, 0))

    # Find equal prediction and target values, and apply the padding mask
    accuracy = equal(target, argmax(prediction, axis=2)) # Should this be `argmax(prediction, axis=2) + 1` ??
    accuracy = math.logical_and(mask, accuracy)

    # Cast the True/False values to 32-bit-precision floating-point numbers
    mask = cast(mask, float32)
    accuracy = cast(accuracy, float32)

    # Compute the mean accuracy over the unmasked values
    return reduce_sum(accuracy) / reduce_sum(mask)

## 20.3 Training the Transformer Model

As always, we will use the parameters used in the AIAYN paper.

In [8]:
# Define the model parameters
h = 8  # Number of self-attention heads
d_model = 512  # Dimensionality of model layers' outputs
d_ff = 2048  # Dimensionality of the inner fully connected layer
n = 6  # Number of layers in the encoder stack

# Define the training parameters
epochs = 2
batch_size = 64
beta_1 = 0.9
beta_2 = 0.98
epsilon = 1e-9
dropout_rate = 0.1

And we'll use a learning rate scheduler which was specified in the same paper as follows:  
$$\text { lrate }=d_{\mathrm{model}}^{-0.5} \cdot \min \left(step\_num^{-0.5}, \text { step_num } \cdot \text { warmup_steps }{ }^{-1.5}\right)$$

In [9]:
class LRScheduler(LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000, **kwargs):
        super().__init__(**kwargs)
        self.d_model = cast(d_model, float32)
        self.warmup_steps = cast(warmup_steps, float32)

    def __call__(self, step_num):
        # Linearly increasing the learning rate for the first warmup_steps, and
        # decreasing it thereafter
        step_num = cast(step_num, float32)
        arg1 = step_num**-0.5
        arg2 = step_num * (self.warmup_steps**-1.5)

        return (self.d_model**-0.5) * math.minimum(arg1, arg2)

Let's prepare our batches for training and instantiate our model and optimizer:

In [10]:
train_dataset = data.Dataset.from_tensor_slices((trainX, trainY))
train_dataset = train_dataset.batch(batch_size)

In [11]:
optimizer = Adam(LRScheduler(d_model), beta_1, beta_2, epsilon)

In [12]:
training_model = Xformer(
    enc_vocab_size,
    dec_vocab_size,
    enc_seq_length,
    dec_seq_length,
    h,
    d_model,
    d_ff,
    n,
    dropout_rate,
)

Next, we write our own training loop, taking advantage of the loss and accuracy functions we coded earlier.  
**Note:** The default execution mode in TensorFlow 2 is *eager execution*. However, for a fairly large model such as this, we want to leverage the optimizations provided by *graph execution* (at the cost of some overhead). In order to do so, we need to use the `@function` decorator below.

In [13]:
@function
def train_step(encoder_input, decoder_input, decoder_output):
    with GradientTape() as tape:
        # Run the forward pass of the model to generate a prediction
        prediction = training_model(encoder_input, decoder_input, training=True)

        # Compute the training loss
        loss = loss_fn(decoder_output, prediction)

        # Compute the training accuracy
        accuracy = accuracy_fn(decoder_output, prediction)

    # Retrieve gradients of the trainable variables with respect to the training loss
    gradients = tape.gradient(loss, training_model.trainable_weights)

    # Update the values of the trainable variables by gradient descent
    optimizer.apply_gradients(zip(gradients, training_model.trainable_weights))

    train_loss(loss)
    train_accuracy(accuracy)

In [14]:
train_loss = Mean(name="train_loss")
train_accuracy = Mean(name="train_accuracy")

# Create a checkpoint object and manager to manage multiple checkpoints
ckpt = train.Checkpoint(model=training_model, optimizer=optimizer)
ckpt_manager = train.CheckpointManager(ckpt, "./checkpoints", max_to_keep=3)

for epoch in range(epochs):
    train_loss.reset_states()
    train_accuracy.reset_states()

    print("\nStart of epoch %d" % (epoch + 1))
    
    start_time = time()

    # Iterate over the dataset batches
    for step, (train_batchX, train_batchY) in enumerate(train_dataset):
        # Define the encoder and decoder inputs, and the decoder output
        encoder_input = train_batchX[:, 1:]
        decoder_input = train_batchY[:, :-1]
        decoder_output = train_batchY[:, 1:]

        train_step(encoder_input, decoder_input, decoder_output)

        if step % 50 == 0:
            print(
                f"Epoch {epoch + 1} Step {step} Loss {train_loss.result():.4f} "
                + f"Accuracy {train_accuracy.result():.4f}"
            )

    # Print epoch number and loss value at the end of every epoch
    print(
        f"Epoch {epoch +1}: Training Loss {train_loss.result():.4f}, "
        + f"Training Accuracy {train_accuracy.result():.4f}"
    )

    # Save a checkpoint after every five epochs
    if (epoch + 1) % 5 == 0:
        save_path = ckpt_manager.save()
        print("Saved checkpoint at epoch %d" % (epoch + 1))
        
print("Total time taken: %.2fs" % (time() - start_time))


Start of epoch 1
Epoch 1 Step 0 Loss 8.2269 Accuracy 0.0000
Epoch 1 Step 50 Loss 7.3872 Accuracy 0.1554
Epoch 1 Step 100 Loss 6.8584 Accuracy 0.1876
Epoch 1: Training Loss 6.6163, Training Accuracy 0.1966

Start of epoch 2
Epoch 2 Step 0 Loss 5.7815 Accuracy 0.2140
Epoch 2 Step 50 Loss 5.5019 Accuracy 0.2682
Epoch 2 Step 100 Loss 5.3345 Accuracy 0.2718
Epoch 2: Training Loss 5.2267, Training Accuracy 0.2740
Total time taken: 57.24s
