# English-to-German translation with a sequence-to-sequence Transformer


**Disclaimer**: This code has been adapted from an original notebook with the following details:

**Notebook:** https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/nlp/ipynb/neural_machine_translation_with_transformer.ipynb

## Introduction

In this example, we'll build a sequence-to-sequence Transformer model, which
we'll train on an English-to-German machine translation task.

The code featured here is adapted from the book
[Deep Learning with Python, Second Edition](https://www.manning.com/books/deep-learning-with-python-second-edition)
(chapter 11: Deep learning for text).
The present example is fairly barebones, so for detailed explanations of
how each building block works, as well as the theory behind Transformers,
I recommend reading the book.

In [1]:
!pip install -q tqdm
!pip install -q datasets
!pip install -q torchmetrics

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.3/519.3 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m761.3/761.3 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25h

## Setup

In [2]:
import re
import random
import string
import pathlib
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import TextVectorization

In [3]:
# for reproducible code
import numpy as np
import tensorflow as tf
import random as python_random

seed = 1

def reset_seeds():
   np.random.seed(seed)
   python_random.seed(seed)
   tf.random.set_seed(seed)

reset_seeds()
keras.utils.set_random_seed(seed)
tf.config.experimental.enable_op_determinism()

## Downloading the data

We'll be working with an English-to-German translation dataset called Multi30k (https://arxiv.org/abs/1605.00459)

## Parsing the data

Each line contains an English sentence and its corresponding German sentence.
The English sentence is the *source sequence* and German one is the *target sequence*.
We prepend the token `"[start]"` and we append the token `"[end]"` to the German sentence.

In [4]:
import datasets

dataset = datasets.load_dataset('bentrevett/multi30k')
dataset

Downloading readme:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/4.60M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/164k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/156k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['en', 'de'],
        num_rows: 29000
    })
    validation: Dataset({
        features: ['en', 'de'],
        num_rows: 1014
    })
    test: Dataset({
        features: ['en', 'de'],
        num_rows: 1000
    })
})

performing consonants transform

In [5]:
def mask_vowels(text, mask="a"): # we could've used # but this is a puctuation char that will be deleted in preprocessing steps
    text_with_no_vowels = re.sub(
        r"[AEIOU]",
        mask,
        text,
        flags=re.IGNORECASE,
    )
    return text_with_no_vowels

In [6]:
dataset['train'] = dataset['train'].map(
    lambda example:{
        'en_masked':mask_vowels(example['en']),
        **example,
      }
  )
dataset['validation'] = dataset['validation'].map(
    lambda example:{
        'en_masked':mask_vowels(example['en']),
        **example,
      }
    )
dataset['test'] = dataset['test'].map(
    lambda example:{
        'en_masked':mask_vowels(example['en']),
        **example,
      }
  )
dataset['train']['en'][:1],dataset['train']['en_masked'][:1]

Map:   0%|          | 0/29000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1014 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

(['Two young, White males are outside near many bushes.'],
 ['Twa yaang, Whata malas ara aatsada naar many bashas.'])

In [7]:
train_pairs = list(
    [de_document,f'[start] {en_document} [end]']
    for de_document,en_document in zip(
        dataset['train']['de'],
        dataset['train']['en_masked'],
      )
  )

val_pairs = list(
    [de_document,f'[start] {en_document} [end]']
    for de_document,en_document in zip(
        dataset['validation']['de'],
        dataset['validation']['en_masked'],
      )
  )

test_pairs = list(
    [de_document,f'[start] {en_document} [end]']
    for de_document,en_document in zip(
        dataset['test']['de'],
        dataset['test']['en_masked'],
      )
  )

print(f"{len(train_pairs)} training pairs")
print(f"{len(val_pairs)} validation pairs")
print(f"{len(test_pairs)} test pairs")

29000 training pairs
1014 validation pairs
1000 test pairs


## Vectorizing the text data

We'll use two instances of the `TextVectorization` layer to vectorize the text
data (one for English and one for German),
that is to say, to turn the original strings into integer sequences
where each integer represents the index of a word in a vocabulary.

Both layers will use the default string standardization (strip punctuation characters)
and splitting scheme (split on whitespace)

In [8]:
from collections import defaultdict

In [9]:
# calculate vocab sizes
en_vocab = defaultdict(int)
de_vocab = defaultdict(int)

for (de_item,en_item) in train_pairs:
  for token in de_item.split():
    de_vocab[token] += 1
  for token in en_item.split():
    en_vocab[token] += 1

de_vocab = dict(de_vocab)
en_vocab = dict(en_vocab)

len(en_vocab),len(de_vocab)

(14396, 24889)

In [10]:
de_vocab = {vocab:freq for vocab,freq in de_vocab.items() if freq > 1}
len(de_vocab)

9758

In [11]:
strip_chars = string.punctuation
strip_chars = strip_chars.replace("[", "")
strip_chars = strip_chars.replace("]", "")

sequence_length = 35
batch_size = 64


def custom_standardization(input_string):
    lowercase = tf.strings.lower(input_string)
    return tf.strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")


de_vectorization = TextVectorization(
    max_tokens=len(de_vocab),
    output_mode="int",
    output_sequence_length=sequence_length,
)
eng_vectorization = TextVectorization(
    max_tokens=len(en_vocab),
    output_mode="int",
    output_sequence_length=sequence_length + 1,
    standardize=custom_standardization,
)
train_de_texts = [pair[0] for pair in train_pairs]
train_eng_texts = [pair[1] for pair in train_pairs]
eng_vectorization.adapt(train_eng_texts)
de_vectorization.adapt(train_de_texts)

In [12]:
input_vocab_size = len(de_vectorization.get_vocabulary())
output_vocab_size = len(eng_vectorization.get_vocabulary())
# numbers here might be different than above because of the transformations applied
input_vocab_size,output_vocab_size

(9758, 9374)

Next, we'll format our datasets.

At each training step, the model will seek to predict target words N+1 (and beyond)
using the source sentence and the target words 0 to N.

As such, the training dataset will yield a tuple `(inputs, targets)`, where:

- `inputs` is a dictionary with the keys `encoder_inputs` and `decoder_inputs`.
`encoder_inputs` is the vectorized source sentence and `encoder_inputs` is the target sentence "so far",
that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target sentence.
- `target` is the target sentence offset by one step:
it provides the next words in the target sentence -- what the model will try to predict.

In [13]:
def format_dataset(de, eng):
    eng = eng_vectorization(eng)
    de = de_vectorization(de)
    return (
        {
            "encoder_inputs": de,
            "decoder_inputs": eng[:, :-1],
        },
        eng[:, 1:],
    )


def make_dataset(pairs):
    # eng_texts, de_texts = zip(*pairs)
    de_texts, eng_texts = zip(*pairs)
    eng_texts = list(eng_texts)
    de_texts = list(de_texts)
    dataset = tf.data.Dataset.from_tensor_slices((de_texts, eng_texts))
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(format_dataset)
    return dataset.shuffle(2048).prefetch(16).cache()


train_ds = make_dataset(train_pairs)
val_ds = make_dataset(val_pairs)

Let's take a quick look at the sequence shapes
(we have batches of 64 pairs, and all sequences are 20 steps long):

In [14]:
for inputs, targets in train_ds.take(1):
  print(f'inputs["encoder_inputs"].shape: {inputs["encoder_inputs"].shape}')
  print(f'inputs["decoder_inputs"].shape: {inputs["decoder_inputs"].shape}')
  print(f"targets.shape: {targets.shape}")

inputs["encoder_inputs"].shape: (64, 35)
inputs["decoder_inputs"].shape: (64, 35)
targets.shape: (64, 35)


## Building the model

Our sequence-to-sequence Transformer consists of a `TransformerEncoder`
and a `TransformerDecoder` chained together. To make the model aware of word order,
we also use a `PositionalEmbedding` layer.

The source sequence will be pass to the `TransformerEncoder`,
which will produce a new representation of it.
This new representation will then be passed
to the `TransformerDecoder`, together with the target sequence so far (target words 0 to N).
The `TransformerDecoder` will then seek to predict the next words in the target sequence (N+1 and beyond).

A key detail that makes this possible is causal masking
(`use_causal_mask=True` in the first attention layer of the `TransformerDecoder`).
The `TransformerDecoder` sees the entire sequences at once, and thus we must make
sure that it only uses information from target tokens 0 to N when predicting token N+1
(otherwise, it could use information from the future, which would
result in a model that cannot be used at inference time).

In [15]:
class TransformerEncoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(dense_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.supports_masking = True

    def call(self, inputs, mask=None):
        attention_output = self.attention(query=inputs, value=inputs, key=inputs)
        proj_input = self.layernorm_1(inputs + attention_output)
        proj_output = self.dense_proj(proj_input)
        return self.layernorm_2(proj_input + proj_output)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "embed_dim": self.embed_dim,
                "dense_dim": self.dense_dim,
                "num_heads": self.num_heads,
            }
        )
        return config


class PositionalEmbedding(layers.Layer):
    def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.token_embeddings = layers.Embedding(
            input_dim=vocab_size, output_dim=embed_dim
        )
        self.position_embeddings = layers.Embedding(
            input_dim=sequence_length, output_dim=embed_dim
        )
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim

    def call(self, inputs):
        length = tf.shape(inputs)[-1]
        positions = tf.range(start=0, limit=length, delta=1)
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)
        return embedded_tokens + embedded_positions

    def compute_mask(self, inputs, mask=None):
        return tf.math.not_equal(inputs, 0)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "sequence_length": self.sequence_length,
                "vocab_size": self.vocab_size,
                "embed_dim": self.embed_dim,
            }
        )
        return config


class TransformerDecoder(layers.Layer):
    def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        self.num_heads = num_heads
        self.attention_1 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.attention_2 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(latent_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.layernorm_3 = layers.LayerNormalization()
        self.add = layers.Add()  # instead of `+` to preserve mask
        self.supports_masking = True

    def call(self, inputs, encoder_outputs, mask=None):
        attention_output_1 = self.attention_1(
            query=inputs, value=inputs, key=inputs, use_causal_mask=True
        )
        out_1 = self.layernorm_1(self.add([inputs, attention_output_1]))

        attention_output_2 = self.attention_2(
            query=out_1,
            value=encoder_outputs,
            key=encoder_outputs,
        )
        out_2 = self.layernorm_2(self.add([out_1, attention_output_2]))

        proj_output = self.dense_proj(out_2)
        return self.layernorm_3(self.add([out_2, proj_output]))

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "embed_dim": self.embed_dim,
                "latent_dim": self.latent_dim,
                "num_heads": self.num_heads,
            }
        )
        return config


Next, we assemble the end-to-end model.

In [16]:
embed_dim = 256
latent_dim = 2048
num_heads = 8

encoder_inputs = keras.Input(shape=(None,), dtype="int64", name="encoder_inputs")
x = PositionalEmbedding(sequence_length, input_vocab_size, embed_dim)(encoder_inputs)
encoder_outputs = TransformerEncoder(embed_dim, latent_dim, num_heads)(x)
encoder = keras.Model(encoder_inputs, encoder_outputs)

decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs")
encoded_seq_inputs = keras.Input(shape=(None, embed_dim), name="decoder_state_inputs")
x = PositionalEmbedding(sequence_length, output_vocab_size, embed_dim)(decoder_inputs)
x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, encoded_seq_inputs)
x = layers.Dropout(0.5)(x)
decoder_outputs = layers.Dense(output_vocab_size, activation="softmax")(x)
decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs)

decoder_outputs = decoder([decoder_inputs, encoder_outputs])
transformer = keras.Model(
    [encoder_inputs, decoder_inputs], decoder_outputs, name="transformer"
)

## Training our model

We'll use accuracy as a quick way to monitor training progress on the validation data.
Note that machine translation typically uses BLEU scores as well as other metrics, rather than accuracy.

Here, we are training the model for 10 epochs

In [17]:
from keras.callbacks import ModelCheckpoint

In [18]:
epochs = 10

transformer.summary()
transformer.compile(
    "rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
checkpointing_callback = ModelCheckpoint('model-{epoch:03d}-{loss:03f}-{val_loss:03f}.h5',
    verbose=1,
    monitor='val_loss',
    save_best_only=True,
    mode='auto'
)
transformer.fit(train_ds, epochs=epochs, validation_data=val_ds, callbacks=[checkpointing_callback])

Model: "transformer"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 encoder_inputs (InputLayer)    [(None, None)]       0           []                               
                                                                                                  
 positional_embedding (Position  (None, None, 256)   2507008     ['encoder_inputs[0][0]']         
 alEmbedding)                                                                                     
                                                                                                  
 decoder_inputs (InputLayer)    [(None, None)]       0           []                               
                                                                                                  
 transformer_encoder (Transform  (None, None, 256)   3155456     ['positional_embedding[

<keras.callbacks.History at 0x7ae71019bf10>

## Decoding test sentences

Finally, let's demonstrate how to translate brand new English sentences.
We simply feed into the model the vectorized English sentence
as well as the target token `"[start]"`, then we repeatedly generated the next token, until
we hit the token `"[end]"`.

In [19]:
eng_vocab = eng_vectorization.get_vocabulary()
eng_index_lookup = dict(zip(range(len(eng_vocab)), eng_vocab))
max_decoded_sentence_length = sequence_length


def decode_sequence(input_sentence):
    tokenized_input_sentence = de_vectorization([input_sentence])
    decoded_sentence = "[start]"
    for i in range(max_decoded_sentence_length):
        tokenized_target_sentence = eng_vectorization([decoded_sentence])[:, :-1]
        predictions = transformer([tokenized_input_sentence, tokenized_target_sentence])

        sampled_token_index = np.argmax(predictions[0, i, :])
        sampled_token = eng_index_lookup[sampled_token_index]
        decoded_sentence += " " + sampled_token

        if sampled_token == "[end]":
            break
    return decoded_sentence


test_de_texts = [pair[0] for pair in test_pairs]
test_eng_texts = [pair[1] for pair in test_pairs]
for _ in range(3):
    input_sentence = random.choice(test_de_texts)
    translated = decode_sequence(input_sentence)
    true_sentence = test_eng_texts[test_de_texts.index(input_sentence)].lower()
    print('#'*80)
    print(input_sentence)
    print(translated)
    print(true_sentence)
    print('#'*80)

################################################################################
Ein paar Bier-Zapfhähne in einer Bar mit Weihnachtsdekoration an der Decke.
[start] a caapla drassad an a bar wath a bar wath blankat [end]
[start] a banch af baar pall tabs at a bar wath chrastmas laghts an tha caalang. [end]
################################################################################
################################################################################
Santa Claus wird bei einem Medienevent zu Weihnachten fotografiert.
[start] straat scana takas a pactara at a pactara takan by takas thaar pactara [end]
[start] santa claas baang phatagraphad at a haladay madaa avant. [end]
################################################################################
################################################################################
Eine Frau in einem roten Hemd reitet auf einem weißen Pferd, das an den Bäumen entlang galoppiert.
[start] a waman an a rad shart radas a whata 

## Calculate BLUE score

In [20]:
from torchmetrics.text import BLEUScore
from tqdm.auto import tqdm

In [21]:
target = [[sentence.lower()] for sentence in tqdm(test_eng_texts)]
preds = [decode_sequence(custom_standardization(sentence)) for sentence in tqdm(test_de_texts)]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [22]:
# show some examples by index
index = 10
test_de_texts[index],preds[index],target[index]

('Eine Mutter und ihr kleiner Sohn genießen einen schönen Tag im Freien.',
 '[start] a mathar and har san anjayang a baaatafal day [end]',
 ['[start] a mathar and har yaang sang anjayang a baaatafal day aatsada. [end]'])

In [23]:
bleu = BLEUScore(n_gram=4)
bleu(preds, target)

tensor(0.2980)