# Task 3: NLP and Attention Mechanism

## Part 1

In [33]:
# import required libraries
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Flatten, Dense, LSTM, Input
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.preprocessing import Normalizer
import warnings
warnings.filterwarnings("ignore")

In [None]:
embedded_tokens = np.array([
    [0.1, 0.2, 0.3],
    [0.4, 0.5, 0.6],
    [0.7, 0.8, 0.9],
    [1.0, 1.1, 1.2],
    [0.1, 0.2, 0.3],
    [1.3, 1.4, 1.5]
])
Q = K = V = embedded_tokens

print(scaled_dot_product_attention(Q, K, V))

[[0.6703881  0.7703881  0.8703881 ]
 [0.77624039 0.87624039 0.97624039]
 [0.87532355 0.97532355 1.07532355]
 [0.96179317 1.06179317 1.16179317]
 [0.6703881  0.7703881  0.8703881 ]
 [1.03314235 1.13314235 1.23314235]]


In [34]:
# https://medium.com/@funcry/in-depth-understanding-of-attention-mechanism-part-ii-scaled-dot-product-attention-and-its-7743804e610e
# https://machinelearningmastery.com/how-to-implement-scaled-dot-product-attention-from-scratch-in-tensorflow-and-keras/

def scaled_dot_product_attention(queries, keys, values):
  dot_prod = np.dot(queries, keys.T) # QK^T
  dk = keys.shape[-1]
  scaled = dot_prod / np.sqrt(dk)  # / sqrt(dk)
  exp_x = np.exp(scaled - np.max(scaled, axis=-1, keepdims=True))  # softmax
  norm = exp_x / np.sum(exp_x, axis=-1, keepdims=True)
  return np.dot(norm, values) # *V

## Part 2

In [49]:
# https://www.tensorflow.org/text/tutorials/nmt_with_attention
# https://arxiv.org/abs/1508.04025v5
# https://machinelearningmastery.com/develop-encoder-decoder-model-sequence-sequence-prediction-keras/

# returns train, inference_encoder and inference_decoder models
def define_models(n_input, n_output, n_units):
  # encoder
  encoder_inputs = Input(shape=(None, n_input))
  encoder = LSTM(n_units, return_state=True)
  encoder_outputs, state_h, state_c = encoder(encoder_inputs)
  encoder_states = [state_h, state_c]
  # decoder
  decoder_inputs = Input(shape=(None, n_output))
  decoder_lstm = LSTM(n_units, return_sequences=True, return_state=True)

  # Attention helps decoder focus on relevant output
  decoder_hidden_state = Input(shape=(1, n_units))
  context_vector = tf.keras.layers.Lambda(lambda x: scaled_dot_product_attention(x[0], x[1], x[1]), output_shape=(1, n_units))([decoder_hidden_state, encoder_outputs])
  decoder_combined_inputs = keras.layers.Concatenate(axis=-1)([context_vector, decoder_inputs])

  decoder_outputs, _, _ = decoder_lstm(decoder_combined_inputs, initial_state=encoder_states)
  decoder_dense = Dense(n_output, activation='softmax')
  decoder_outputs = decoder_dense(decoder_outputs)
  model = keras.Model([encoder_inputs, decoder_inputs, decoder_hidden_state], decoder_outputs)

  # define inference encoder
  encoder_model = keras.Model(encoder_inputs, [encoder_outputs, state_h, state_c])
  # define inference decoder
  decoder_state_input_h = Input(shape=(n_units,))
  decoder_state_input_c = Input(shape=(n_units,))
  decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]

  decoder_hidden_state_input = Input(shape=(1, n_units))
  context_vector_inf = keras.layers.Lambda(lambda x: scaled_dot_product_attention(x[0], x[1], x[1]), output_shape=(1, n_units))([decoder_hidden_state_input, encoder_outputs])
  decoder_combined_inputs_inf = keras.layers.Concatenate(axis=-1)([context_vector_inf, decoder_inputs])
  decoder_outputs = decoder_dense(decoder_outputs)

  decoder_outputs, state_h, state_c = decoder_lstm(decoder_combined_inputs_inf, initial_state=decoder_states_inputs)
  # decoder_states = [state_h, state_c]
  # decoder_outputs = decoder_dense(decoder_outputs)
  decoder_model = keras.Model([decoder_inputs, decoder_hidden_state_input, encoder_outputs] + decoder_states_inputs, [decoder_outputs, state_h, state_c])
  # return all models
  return model, encoder_model, decoder_model

## Part 3

The encoder-decoder model will now be used for a machine translation task, using a subset of the Multi30k dataset.

https://github.com/multi30k/dataset


https://www.geeksforgeeks.org/nlp-bleu-score-for-evaluating-neural-machine-translation-python/


### Import Data

In [3]:
! git clone --recursive https://github.com/multi30k/dataset.git multi30k-dataset

Cloning into 'multi30k-dataset'...
remote: Enumerating objects: 313, done.[K
remote: Counting objects: 100% (32/32), done.[K
remote: Compressing objects: 100% (16/16), done.[K
remote: Total 313 (delta 17), reused 21 (delta 16), pack-reused 281 (from 1)[K
Receiving objects: 100% (313/313), 18.21 MiB | 18.24 MiB/s, done.
Resolving deltas: 100% (69/69), done.
Submodule 'scripts/subword-nmt' (https://github.com/rsennrich/subword-nmt.git) registered for path 'scripts/subword-nmt'
Cloning into '/content/multi30k-dataset/scripts/subword-nmt'...
remote: Enumerating objects: 622, done.        
remote: Counting objects: 100% (46/46), done.        
remote: Compressing objects: 100% (30/30), done.        
remote: Total 622 (delta 25), reused 31 (delta 16), pack-reused 576 (from 1)        
Receiving objects: 100% (622/622), 261.27 KiB | 1.34 MiB/s, done.
Resolving deltas: 100% (374/374), done.
Submodule path 'scripts/subword-nmt': checked out '80b7c1449e2e26673fb0b5cae993fe2d0dc23846'


In [24]:
def load_data(filepath):
    with open(filepath, "r", encoding="utf-8") as f:
        sentences = f.read().strip().split("\n")
    return sentences

path_en = 'multi30k-dataset/data/task1/tok/train.lc.norm.tok.en'
train_sentences = load_data(path_en)
path_de = 'multi30k-dataset/data/task1/tok/train.lc.norm.tok.de'
test_sentences = load_data(path_de)

x_train = train_sentences[:500]
y_train = test_sentences[:500]

print(train_sentences[:5])
print(test_sentences[:5])
print(len(train_sentences))

['two young , white males are outside near many bushes .', 'several men in hard hats are operating a giant pulley system .', 'a little girl climbing into a wooden playhouse .', 'a man in a blue shirt is standing on a ladder cleaning a window .', 'two men are at the stove preparing food .']
['zwei junge weiße männer sind im freien in der nähe vieler büsche .', 'mehrere männer mit schutzhelmen bedienen ein antriebsradsystem .', 'ein kleines mädchen klettert in ein spielhaus aus holz .', 'ein mann in einem blauen hemd steht auf einer leiter und putzt ein fenster .', 'zwei männer stehen am herd und bereiten essen zu .']
29000


In [33]:
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical

In [42]:
tokenizer = Tokenizer(filters="")
tokenizer.fit_on_texts(x_train)
x_sequences = tokenizer.texts_to_sequences(x_train)
x_vocab_size = len(tokenizer.word_index) + 1
x_max_length = max(len(seq) for seq in x_sequences)
x_padded = pad_sequences(x_sequences, maxlen=x_max_length, padding="post")
x_input = x_padded[:, :-1]
X1 = to_categorical(x_input, num_classes=x_vocab_size)

tokenizer.fit_on_texts(y_train)
y_sequences = tokenizer.texts_to_sequences(y_train)
y_vocab_size = len(tokenizer.word_index) + 1
y_max_length = max(len(seq) for seq in y_sequences)
y_padded = pad_sequences(y_sequences, maxlen=y_max_length, padding="post")
y_input = y_padded[:, :-1]
X2 = to_categorical(y_input, num_classes=y_vocab_size)

y_output = y_padded[:, 1:]
y_target = to_categorical(y_output, num_classes=y_vocab_size)

print(x_padded)
print(x_padded.shape)
print(x_vocab_size)
print(x_max_length)
print(y_padded.shape)

print(X1.shape)
print(X2.shape)
print(y_target.shape)

[[ 13  18  11 ...   0   0   0]
 [103  32   3 ...   0   0   0]
 [  1  33  26 ...   0   0   0]
 ...
 [  1 159  37 ...   0   0   0]
 [211  14 283 ...   0   0   0]
 [  1 180 215 ...   0   0   0]]
(500, 35)
1220
35
(500, 44)
(500, 34, 1220)
(500, 43, 2500)
(500, 43, 2500)


### Training

In [41]:
train, infenc, infdec = define_models(x_vocab_size, y_vocab_size, 128)
train.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [43]:
train.fit([X1, X2], y_target, epochs=1)

[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 884ms/step - accuracy: 0.4709 - loss: 7.6576


<keras.src.callbacks.history.History at 0x796c9c7f8210>

In [47]:
train.fit([X1, X2], y_target, epochs=5)

Epoch 1/5
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 538ms/step - accuracy: 0.7249 - loss: 4.6259
Epoch 2/5
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 459ms/step - accuracy: 0.7144 - loss: 2.1228
Epoch 3/5
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 428ms/step - accuracy: 0.7230 - loss: 1.8363
Epoch 4/5
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 527ms/step - accuracy: 0.7157 - loss: 1.8167
Epoch 5/5
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 496ms/step - accuracy: 0.7250 - loss: 1.7493


<keras.src.callbacks.history.History at 0x796c96d3ca50>

In [48]:
train.fit([X1, X2], y_target, epochs=10)

Epoch 1/10
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 491ms/step - accuracy: 0.7332 - loss: 1.6785
Epoch 2/10
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 500ms/step - accuracy: 0.7371 - loss: 1.6358
Epoch 3/10
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 532ms/step - accuracy: 0.7290 - loss: 1.6688
Epoch 4/10
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 464ms/step - accuracy: 0.7316 - loss: 1.6368
Epoch 5/10
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 465ms/step - accuracy: 0.7320 - loss: 1.6090
Epoch 6/10
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 526ms/step - accuracy: 0.7316 - loss: 1.6041
Epoch 7/10
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 525ms/step - accuracy: 0.7312 - loss: 1.5763
Epoch 8/10
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 472ms/step - accuracy: 0.7282 - loss: 1.5896
Epoch 9/10
[1m16/16[0m [32m━━━━━━━━━

<keras.src.callbacks.history.History at 0x796c96b41810>

In [54]:
# generate target given source sequence
def predict_sequence(infenc, infdec, source, n_steps, cardinality):
	# encode
	state = infenc.predict(source)
	# start of sequence input
	target_seq = array([0.0 for _ in range(cardinality)]).reshape(1, 1, cardinality)
	# collect predictions
	output = list()
	for t in range(n_steps):
		# predict next char
		yhat, h, c = infdec.predict([target_seq] + state)
		# store prediction
		output.append(yhat[0,0,:])
		# update state
		state = [h, c]
		# update target sequence
		target_seq = yhat
	return array(output)

In [62]:
test_input = np.expand_dims(X1[0], axis=0)
print(test_input.shape)

(1, 34, 1220)


In [66]:
print(X1.shape)
my_test = predict_sequence(infenc, infdec, np.expand_dims(X1[0], axis=0) , y_max_length, y_vocab_size)

(500, 34, 1220)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 78ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 71ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 66ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 62ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 61ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 58ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 60ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 59ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 75ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 74ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 64ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 90ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 74ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[

In [71]:
print(my_test)

[[2.7809513e-03 1.8604430e-03 7.4829304e-06 ... 3.5775869e-04
  6.1969746e-05 5.7597761e-05]
 [3.1104428e-03 3.0320552e-03 1.3676123e-05 ... 3.7945618e-04
  9.9735313e-05 1.0487767e-04]
 [4.3736571e-03 5.0791660e-03 1.4689128e-05 ... 2.8943416e-04
  1.1866018e-04 1.3413813e-04]
 ...
 [9.8439121e-01 3.6152357e-03 2.4909036e-07 ... 7.3339891e-07
  4.2242568e-06 4.0520817e-06]
 [9.8439121e-01 3.6152236e-03 2.4908988e-07 ... 7.3339675e-07
  4.2242486e-06 4.0520699e-06]
 [9.8439121e-01 3.6152201e-03 2.4908965e-07 ... 7.3339612e-07
  4.2242409e-06 4.0520658e-06]]


In [78]:
final_sentence = []

for pred in my_test:
  if isinstance(pred, np.ndarray):
    pred_index = np.argmax(pred)
  else:
    pred_index = pred
  # index to word
  word = tokenizer.index_word.get(pred_index, '')
  if word != '':
    final_sentence.append(word)
  else:
    break

result = ' '.join(final_sentence)
print(result)
print(y_train[0])
print(x_train[0])

mann in , , , , , , , , . . . .
zwei junge weiße männer sind im freien in der nähe vieler büsche .
two young , white males are outside near many bushes .


In [79]:
# https://www.geeksforgeeks.org/nlp-bleu-score-for-evaluating-neural-machine-translation-python/
from nltk.translate.bleu_score import sentence_bleu

# Calculate BLEU score with weights
score = sentence_bleu(x_train[0], result)
print(score)

1.2558634180836711e-231


The BLEU score is quite bad, which makes sense as this example generated sentence is not at all structured like an actual sentence. This is a bit surprising though as the training accuracy was around $73\%$, and the loss did decrease from $7.66$ to $1.54$. Something that is likely an issue with this model is that I am using a very small amount of input data that does not allow the network to learn and understand the languages and their relationships.

## Part 4

I made this Transformer for English to German translation based on this tutorial: https://www.tensorflow.org/text/tutorials/transformer. I also used many other sources that are linked below.

In [7]:
# https://keras.io/keras_hub/api/tokenizers/byte_pair_tokenizer/
# https://www.tensorflow.org/text/tutorials/transformer

def load_data(filepath):
    with open(filepath, "r", encoding="utf-8") as f:
        sentences = f.read().strip().split("\n")
    return sentences

path_en = 'multi30k-dataset/data/task1/tok/train.lc.norm.tok.en'
en_sentences = load_data(path_en)
path_de = 'multi30k-dataset/data/task1/tok/train.lc.norm.tok.de'
de_sentences = load_data(path_de)

x_train = en_sentences[:10000]
y_train = de_sentences[:10000]

print(x_train[:5])
print(y_train[:5])
print(len(x_train))

['two young , white males are outside near many bushes .', 'several men in hard hats are operating a giant pulley system .', 'a little girl climbing into a wooden playhouse .', 'a man in a blue shirt is standing on a ladder cleaning a window .', 'two men are at the stove preparing food .']
['zwei junge weiße männer sind im freien in der nähe vieler büsche .', 'mehrere männer mit schutzhelmen bedienen ein antriebsradsystem .', 'ein kleines mädchen klettert in ein spielhaus aus holz .', 'ein mann in einem blauen hemd steht auf einer leiter und putzt ein fenster .', 'zwei männer stehen am herd und bereiten essen zu .']
10000


### Word-level tokenization

In [50]:
# https://www.kaggle.com/code/shivanshuman/a-song-of-words-and-tokens

from collections import Counter

def get_vocab(sentences):
  vocab = Counter()
  for sentence in sentences:
    for word in sentence.split():
      vocab[word] += 1
  return vocab

en_vocab = get_vocab(x_train)
de_vocab = get_vocab(y_train)

print(1.most_common(10))
print(de_vocab.most_common(10))

[('a', 16897), ('.', 9473), ('in', 5015), ('the', 3644), ('on', 2734), ('man', 2606), ('is', 2505), ('and', 2457), ('of', 2233), ('with', 2015)]
[('.', 9883), ('ein', 6602), ('einem', 4270), ('in', 3717), (',', 3483), ('eine', 3063), ('mit', 3060), ('auf', 2952), ('und', 2941), ('mann', 2629)]


### Encoder Decoder

I did a lot of research into the structure, but due to the complexity and potential issues with mismatches shapes, decided to stick very close to this tutorial and alter it to the simplified specifications: https://www.tensorflow.org/text/tutorials/transformer

In [None]:
# https://arxiv.org/abs/1706.03762
# https://arxiv.org/pdf/1706.03762
# https://www.tensorflow.org/text/tutorials/transformer
# https://www.tensorflow.org/text/tutorials/nmt_with_attention
# https://arxiv.org/abs/1508.04025v5

In [1]:
# Install the most re version of TensorFlow to use the improved
# masking support for `tf.keras.layers.MultiHeadAttention`.
!apt install --allow-change-held-packages libcudnn8=8.1.0.77-1+cuda11.2
!pip uninstall -y -q tensorflow keras tensorflow-estimator tensorflow-text
!pip install protobuf~=3.20.3
!pip install -q tensorflow_datasets
!pip install -q -U tensorflow-text tensorflow

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
Package libcudnn8 is not available, but is referred to by another package.
This may mean that the package is missing, has been obsoleted, or
is only available from another source

[1;31mE: [0mVersion '8.1.0.77-1+cuda11.2' for 'libcudnn8' was not found[0m
[0mCollecting protobuf~=3.20.3
  Downloading protobuf-3.20.3-py2.py3-none-any.whl.metadata (720 bytes)
Downloading protobuf-3.20.3-py2.py3-none-any.whl (162 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.1/162.1 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: protobuf
  Attempting uninstall: protobuf
    Found existing installation: protobuf 4.25.6
    Uninstalling protobuf-4.25.6:
      Successfully uninstalled protobuf-4.25.6
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the follow

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m319.7/319.7 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m42.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m615.5/615.5 MB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m49.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import logging
import time

import numpy as np
import matplotlib.pyplot as plt

import tensorflow_datasets as tfds
import tensorflow as tf

import tensorflow_text

In [5]:
class BaseAttention(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__()
    self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
    self.layernorm = tf.keras.layers.LayerNormalization()
    self.add = tf.keras.layers.Add()

class CrossAttention(BaseAttention):
  def call(self, x, context):
    attn_output, attn_scores = self.mha(
        query=x,
        key=context,
        value=context,
        return_attention_scores=True)
    # Cache the attention scores for plotting later.
    self.last_attn_scores = attn_scores
    x = self.add([x, attn_output])
    x = self.layernorm(x)
    return x

class GlobalSelfAttention(BaseAttention):
  def call(self, x):
    attn_output = self.mha(
        query=x,
        value=x,
        key=x)
    x = self.add([x, attn_output])
    x = self.layernorm(x)
    return x

class CausalSelfAttention(BaseAttention):
  def call(self, x):
    attn_output = self.mha(
        query=x,
        value=x,
        key=x,
        use_causal_mask = True)
    x = self.add([x, attn_output])
    x = self.layernorm(x)
    return x

class FeedForward(tf.keras.layers.Layer):
  def __init__(self, d_model, dff, dropout_rate=0.1):
    super().__init__()
    self.seq = tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='relu'),
      tf.keras.layers.Dense(d_model),
      tf.keras.layers.Dropout(dropout_rate)
    ])
    self.add = tf.keras.layers.Add()
    self.layer_norm = tf.keras.layers.LayerNormalization()

  def call(self, x):
    x = self.add([x, self.seq(x)])
    x = self.layer_norm(x)
    return x

### Encoder

In [14]:
class EncoderLayer(tf.keras.layers.Layer):
  def __init__(self,*, d_model, num_heads, dff, dropout_rate=0.1):
    super().__init__()

    self.self_attention = GlobalSelfAttention(
        num_heads=num_heads,
        key_dim=d_model,
        dropout=dropout_rate)

    self.ffn = FeedForward(d_model, dff)

  def call(self, x):
    x = self.self_attention(x)
    x = self.ffn(x)
    return x

class Encoder(tf.keras.layers.Layer):
  def __init__(self, *, num_layers, d_model, num_heads,
               dff, vocab_size, dropout_rate=0.1):
    super().__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.pos_embedding = PositionalEmbedding(
        vocab_size=vocab_size, d_model=d_model)

    self.enc_layers = [
        EncoderLayer(d_model=d_model,
                     num_heads=num_heads,
                     dff=dff,
                     dropout_rate=dropout_rate)
        for _ in range(num_layers)]
    self.dropout = tf.keras.layers.Dropout(dropout_rate)

  def call(self, x):
    # `x` is token-IDs shape: (batch, seq_len)
    x = self.pos_embedding(x)  # Shape `(batch_size, seq_len, d_model)`.

    # Add dropout.
    x = self.dropout(x)

    for i in range(self.num_layers):
      x = self.enc_layers[i](x)

    return x  # Shape `(batch_size, seq_len, d_model)`.

### Decoder

In [7]:
class DecoderLayer(tf.keras.layers.Layer):
  def __init__(self,
               *,
               d_model,
               num_heads,
               dff,
               dropout_rate=0.1):
    super(DecoderLayer, self).__init__()

    self.causal_self_attention = CausalSelfAttention(
        num_heads=num_heads,
        key_dim=d_model,
        dropout=dropout_rate)

    self.cross_attention = CrossAttention(
        num_heads=num_heads,
        key_dim=d_model,
        dropout=dropout_rate)

    self.ffn = FeedForward(d_model, dff)

  def call(self, x, context):
    x = self.causal_self_attention(x=x)
    x = self.cross_attention(x=x, context=context)

    # Cache the last attention scores for plotting later
    self.last_attn_scores = self.cross_attention.last_attn_scores

    x = self.ffn(x)  # Shape `(batch_size, seq_len, d_model)`.
    return x

class Decoder(tf.keras.layers.Layer):
  def __init__(self, *, num_layers, d_model, num_heads, dff, vocab_size,
               dropout_rate=0.1):
    super(Decoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size,
                                             d_model=d_model)
    self.dropout = tf.keras.layers.Dropout(dropout_rate)
    self.dec_layers = [
        DecoderLayer(d_model=d_model, num_heads=num_heads,
                     dff=dff, dropout_rate=dropout_rate)
        for _ in range(num_layers)]

    self.last_attn_scores = None

  def call(self, x, context):
    # `x` is token-IDs shape (batch, target_seq_len)
    x = self.pos_embedding(x)  # (batch_size, target_seq_len, d_model)

    x = self.dropout(x)

    for i in range(self.num_layers):
      x  = self.dec_layers[i](x, context)

    self.last_attn_scores = self.dec_layers[-1].last_attn_scores

    # The shape of x is (batch_size, target_seq_len, d_model).
    return x

### Transformer

In [8]:
class Transformer(tf.keras.Model):
  def __init__(self, *, num_layers, d_model, num_heads, dff,
               input_vocab_size, target_vocab_size, dropout_rate=0.1):
    super().__init__()
    self.encoder = Encoder(num_layers=num_layers, d_model=d_model,
                           num_heads=num_heads, dff=dff,
                           vocab_size=input_vocab_size,
                           dropout_rate=dropout_rate)

    self.decoder = Decoder(num_layers=num_layers, d_model=d_model,
                           num_heads=num_heads, dff=dff,
                           vocab_size=target_vocab_size,
                           dropout_rate=dropout_rate)

    self.final_layer = tf.keras.layers.Dense(target_vocab_size)

  def call(self, inputs):
    # To use a Keras model with `.fit` you must pass all your inputs in the
    # first argument.
    context, x  = inputs

    context = self.encoder(context)  # (batch_size, context_len, d_model)

    x = self.decoder(x, context)  # (batch_size, target_len, d_model)

    # Final linear layer output.
    logits = self.final_layer(x)  # (batch_size, target_len, target_vocab_size)

    try:
      # Drop the keras mask, so it doesn't scale the losses/metrics.
      # b/250038731
      del logits._keras_mask
    except AttributeError:
      pass

    # Return the final output and the attention weights.
    return logits

In [12]:
# Smaller simplified values
num_layers = 2
d_model = 64
dff = 128
num_heads = 2
dropout_rate = 0.1
num_samples = 10000

In [None]:
transformer = Transformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    dff=dff,
    input_vocab_size=num_samples,
    target_vocab_size=num_samples,
    dropout_rate=dropout_rate)