https://machinetalk.org/2019/03/29/neural-machine-translation-with-attention-mechanism/?unapproved=1376&moderation-hash=5feb059786bee61a003bab66b69c7c18#comment-1376

In [1]:
import tensorflow as tf
import numpy as np
import unicodedata
import re

In [2]:
tf.__version__

'2.1.0'

In [3]:
#English to French
raw_data = (
    ('What a ridiculous concept!', 'Quel concept ridicule !'),
    ('Your idea is not entirely crazy.', "Votre idée n'est pas complètement folle."),
    ("A man's worth lies in what he is.", "La valeur d'un homme réside dans ce qu'il est."),
    ('What he did is very wrong.', "Ce qu'il a fait est très mal."),
    ("All three of you need to do that.", "Vous avez besoin de faire cela, tous les trois."),
    ("Are you giving me another chance?", "Me donnez-vous une autre chance ?"),
    ("Both Tom and Mary work as models.", "Tom et Mary travaillent tous les deux comme mannequins."),
    ("Can I have a few minutes, please?", "Puis-je avoir quelques minutes, je vous prie ?"),
    ("Could you close the door, please?", "Pourriez-vous fermer la porte, s'il vous plaît ?"),
    ("Did you plant pumpkins this year?", "Cette année, avez-vous planté des citrouilles ?"),
    ("Do you ever study in the library?", "Est-ce que vous étudiez à la bibliothèque des fois ?"),
    ("Don't be deceived by appearances.", "Ne vous laissez pas abuser par les apparences."),
    ("Excuse me. Can you speak English?", "Je vous prie de m'excuser ! Savez-vous parler anglais ?"),
    ("Few people know the true meaning.", "Peu de gens savent ce que cela veut réellement dire."),
    ("Germany produced many scientists.", "L'Allemagne a produit beaucoup de scientifiques."),
    ("Guess whose birthday it is today.", "Devine de qui c'est l'anniversaire, aujourd'hui !"),
    ("He acted like he owned the place.", "Il s'est comporté comme s'il possédait l'endroit."),
    ("Honesty will pay in the long run.", "L'honnêteté paye à la longue."),
    ("How do we know this isn't a trap?", "Comment savez-vous qu'il ne s'agit pas d'un piège ?"),
    ("I can't believe you're giving up.", "Je n'arrive pas à croire que vous abandonniez."),
)

In [4]:
#Clean text
def strip_accents(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')

def normalize_string(s):
    s = strip_accents(s)
    s = re.sub(r'([!.?])', r' \1', s)
    s = re.sub(r'[^a-zA-Z.!?]+', r' ', s)
    s = re.sub(r'\s+', r' ', s)
    return s
                   

In [5]:
raw_data_en, raw_data_fr = list(zip(*raw_data))
raw_data_en, raw_data_fr = list(raw_data_en), list(raw_data_fr)


In [6]:
raw_data_en = [normalize_string(data) for data in raw_data_en]
raw_data_fr_in = ['<start> ' + normalize_string(data) for data in raw_data_fr]
raw_data_fr_out = [normalize_string(data) + ' <end>' for data in raw_data_fr]


In [7]:
raw_data_fr_out

['Quel concept ridicule ! <end>',
 'Votre idee n est pas completement folle . <end>',
 'La valeur d un homme reside dans ce qu il est . <end>',
 'Ce qu il a fait est tres mal . <end>',
 'Vous avez besoin de faire cela tous les trois . <end>',
 'Me donnez vous une autre chance ? <end>',
 'Tom et Mary travaillent tous les deux comme mannequins . <end>',
 'Puis je avoir quelques minutes je vous prie ? <end>',
 'Pourriez vous fermer la porte s il vous plait ? <end>',
 'Cette annee avez vous plante des citrouilles ? <end>',
 'Est ce que vous etudiez a la bibliotheque des fois ? <end>',
 'Ne vous laissez pas abuser par les apparences . <end>',
 'Je vous prie de m excuser ! Savez vous parler anglais ? <end>',
 'Peu de gens savent ce que cela veut reellement dire . <end>',
 'L Allemagne a produit beaucoup de scientifiques . <end>',
 'Devine de qui c est l anniversaire aujourd hui ! <end>',
 'Il s est comporte comme s il possedait l endroit . <end>',
 'L honnetete paye a la longue . <end>',
 'C

In [8]:
raw_data_fr_in

['<start> Quel concept ridicule !',
 '<start> Votre idee n est pas completement folle .',
 '<start> La valeur d un homme reside dans ce qu il est .',
 '<start> Ce qu il a fait est tres mal .',
 '<start> Vous avez besoin de faire cela tous les trois .',
 '<start> Me donnez vous une autre chance ?',
 '<start> Tom et Mary travaillent tous les deux comme mannequins .',
 '<start> Puis je avoir quelques minutes je vous prie ?',
 '<start> Pourriez vous fermer la porte s il vous plait ?',
 '<start> Cette annee avez vous plante des citrouilles ?',
 '<start> Est ce que vous etudiez a la bibliotheque des fois ?',
 '<start> Ne vous laissez pas abuser par les apparences .',
 '<start> Je vous prie de m excuser ! Savez vous parler anglais ?',
 '<start> Peu de gens savent ce que cela veut reellement dire .',
 '<start> L Allemagne a produit beaucoup de scientifiques .',
 '<start> Devine de qui c est l anniversaire aujourd hui !',
 '<start> Il s est comporte comme s il possedait l endroit .',
 '<start> 

In [9]:
#convert raw strings to integer sequences
#set filters to blank as we already took care of punctuation
en_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
en_tokenizer.fit_on_texts(raw_data_en)


In [10]:
print(en_tokenizer.word_index)

{'.': 1, 'you': 2, '?': 3, 'the': 4, 'a': 5, 'is': 6, 'he': 7, 'what': 8, 'in': 9, 'do': 10, 'can': 11, 't': 12, 'did': 13, 'giving': 14, 'me': 15, 'i': 16, 'few': 17, 'please': 18, 'this': 19, 'know': 20, 'ridiculous': 21, 'concept': 22, '!': 23, 'your': 24, 'idea': 25, 'not': 26, 'entirely': 27, 'crazy': 28, 'man': 29, 's': 30, 'worth': 31, 'lies': 32, 'very': 33, 'wrong': 34, 'all': 35, 'three': 36, 'of': 37, 'need': 38, 'to': 39, 'that': 40, 'are': 41, 'another': 42, 'chance': 43, 'both': 44, 'tom': 45, 'and': 46, 'mary': 47, 'work': 48, 'as': 49, 'models': 50, 'have': 51, 'minutes': 52, 'could': 53, 'close': 54, 'door': 55, 'plant': 56, 'pumpkins': 57, 'year': 58, 'ever': 59, 'study': 60, 'library': 61, 'don': 62, 'be': 63, 'deceived': 64, 'by': 65, 'appearances': 66, 'excuse': 67, 'speak': 68, 'english': 69, 'people': 70, 'true': 71, 'meaning': 72, 'germany': 73, 'produced': 74, 'many': 75, 'scientists': 76, 'guess': 77, 'whose': 78, 'birthday': 79, 'it': 80, 'today': 81, 'acted'

In [11]:
#convert raw eng sentences to int sequences
data_en = en_tokenizer.texts_to_sequences(raw_data_en)


In [12]:
data_en

[[8, 5, 21, 22, 23],
 [24, 25, 6, 26, 27, 28, 1],
 [5, 29, 30, 31, 32, 9, 8, 7, 6, 1],
 [8, 7, 13, 6, 33, 34, 1],
 [35, 36, 37, 2, 38, 39, 10, 40, 1],
 [41, 2, 14, 15, 42, 43, 3],
 [44, 45, 46, 47, 48, 49, 50, 1],
 [11, 16, 51, 5, 17, 52, 18, 3],
 [53, 2, 54, 4, 55, 18, 3],
 [13, 2, 56, 57, 19, 58, 3],
 [10, 2, 59, 60, 9, 4, 61, 3],
 [62, 12, 63, 64, 65, 66, 1],
 [67, 15, 1, 11, 2, 68, 69, 3],
 [17, 70, 20, 4, 71, 72, 1],
 [73, 74, 75, 76, 1],
 [77, 78, 79, 80, 6, 81, 1],
 [7, 82, 83, 7, 84, 4, 85, 1],
 [86, 87, 88, 9, 4, 89, 90, 1],
 [91, 10, 92, 20, 19, 93, 12, 5, 94, 3],
 [16, 11, 12, 95, 2, 96, 14, 97, 1]]

In [13]:
#add padding to have the same input length
data_en = tf.keras.preprocessing.sequence.pad_sequences(data_en, padding='post')
data_en[:3]

array([[ 8,  5, 21, 22, 23,  0,  0,  0,  0,  0],
       [24, 25,  6, 26, 27, 28,  1,  0,  0,  0],
       [ 5, 29, 30, 31, 32,  9,  8,  7,  6,  1]], dtype=int32)

In [14]:
#Repeat the same steps for French sentences
fr_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
fr_tokenizer.fit_on_texts(raw_data_fr_in)
fr_tokenizer.fit_on_texts(raw_data_fr_out)

data_fr_in = fr_tokenizer.texts_to_sequences(raw_data_fr_in)
data_fr_in = tf.keras.preprocessing.sequence.pad_sequences(data_fr_in,
                                                           padding='post')

data_fr_out = fr_tokenizer.texts_to_sequences(raw_data_fr_out)
data_fr_out = tf.keras.preprocessing.sequence.pad_sequences(data_fr_out,
                                                            padding='post')

In [15]:
#Create an instanceof tf dataset
dataset = tf.data.Dataset.from_tensor_slices(
    (data_en, data_fr_in, data_fr_out))
dataset = dataset.shuffle(20).batch(5)

In [16]:
class Encoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_size, lstm_size):
        super(Encoder, self).__init__()
        self.lstm_size = lstm_size
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_size)
        self.lstm = tf.keras.layers.LSTM(
            lstm_size, return_sequences=True, return_state=True)

    def call(self, sequence, states):
        embed = self.embedding(sequence)
        output, state_h, state_c = self.lstm(embed, initial_state=states)

        return output, state_h, state_c

    def init_states(self, batch_size):
        return (tf.zeros([batch_size, self.lstm_size]),
                tf.zeros([batch_size, self.lstm_size]))

In [17]:
class Decoder_without_attention(tf.keras.Model):
    def __init__(self, vocab_size, embedding_size, lstm_size):
        super(Decoder_without_attention, self).__init__()
        self.lstm_size = lstm_size
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_size)
        self.lstm = tf.keras.layers.LSTM(
            lstm_size, return_sequences=True, return_state=True)
        self.dense = tf.keras.layers.Dense(vocab_size)

    def call(self, sequence, state):
        embed = self.embedding(sequence)
        lstm_out, state_h, state_c = self.lstm(embed, state)
        logits = self.dense(lstm_out)

        return logits, state_h, state_c

In [18]:
# Test our model
EMBEDDING_SIZE = 32
LSTM_SIZE = 64

en_vocab_size = len(en_tokenizer.word_index) + 1
encoder = Encoder(en_vocab_size, EMBEDDING_SIZE, LSTM_SIZE)

fr_vocab_size = len(fr_tokenizer.word_index) + 1
decoder = Decoder_without_attention(fr_vocab_size, EMBEDDING_SIZE, LSTM_SIZE)

source_input = tf.constant([[1, 3, 5, 7, 2, 0, 0, 0]])
initial_state = encoder.init_states(1)
encoder_output, en_state_h, en_state_c = encoder(source_input, initial_state)

target_input = tf.constant([[1, 4, 6, 9, 2, 0, 0]])
decoder_output, de_state_h, de_state_c = decoder(target_input, (en_state_h, en_state_c))

print('Source sequences', source_input.shape)
print('Encoder outputs', encoder_output.shape)
print('Encoder state_h', en_state_h.shape)
print('Encoder state_c', en_state_c.shape)

print('\nDestination vocab size', fr_vocab_size)
print('Destination sequences', target_input.shape)
print('Decoder outputs', decoder_output.shape)
print('Decoder state_h', de_state_h.shape)
print('Decoder state_c', de_state_c.shape)

Source sequences (1, 8)
Encoder outputs (1, 8, 64)
Encoder state_h (1, 64)
Encoder state_c (1, 64)

Destination vocab size 110
Destination sequences (1, 7)
Decoder outputs (1, 7, 110)
Decoder state_h (1, 64)
Decoder state_c (1, 64)


In [25]:
#Define a loss function
#Since we padded zeros into the sequences, let’s not take those zeros into account when computing the loss:
def loss_func(targets, logits):
    crossentropy = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True)
    mask = tf.math.logical_not(tf.math.equal(targets, 0))
    mask = tf.cast(mask, dtype=tf.int64)
    loss = crossentropy(targets, logits, sample_weight=mask)

    return loss

In [20]:
optimizer = tf.keras.optimizers.Adam()

In [21]:
# create the training function in which we perform a forward pass followed by a backward pass.
# use the @tf.function decorator to take advance of static graph computation (remove it when you want to debug)
#Network’s computations need to be put under tf.GradientTape() to keep track of gradients
@tf.function
def train_step_without_attention(source_seq, target_seq_in, target_seq_out, en_initial_states):
    #print("TRAIN_STEP START")
    with tf.GradientTape() as tape:
        en_outputs = encoder(source_seq, en_initial_states)
        #print("EN_OUTPUTS OBJECT: ", en_outputs[1:])
        en_states = en_outputs[1:]
        de_states = en_states

        de_outputs = decoder(target_seq_in, de_states)
        logits = de_outputs[0]
        loss = loss_func(target_seq_out, logits)

    variables = encoder.trainable_variables + decoder.trainable_variables
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))
    #print("TRAIN_STEP END")
    return loss

In [22]:
#let’s define a method for inference purpose. 
#What it does is basically a forward pass, but instead of target sequences, we will feed in the <start> token. 
#Every next time step will take the output of the last time step as input until 
#we hit the <end> token or the output sequence has exceed a specific length:
def predict_without_attention():
    #print("PREDICT START ")
    test_source_text = raw_data_en[np.random.choice(len(raw_data_en))]
    print(test_source_text)
    test_source_seq = en_tokenizer.texts_to_sequences([test_source_text])
    print(test_source_seq)

    en_initial_states = encoder.init_states(1)
    en_outputs = encoder(tf.constant(test_source_seq), en_initial_states)

    de_input = tf.constant([[fr_tokenizer.word_index['<start>']]])
    de_state_h, de_state_c = en_outputs[1:]
    out_words = []

    while True:
        de_output, de_state_h, de_state_c = decoder(
            de_input, (de_state_h, de_state_c))
        de_input = tf.argmax(de_output, -1)
        out_words.append(fr_tokenizer.index_word[de_input.numpy()[0][0]])

        if out_words[-1] == '<end>' or len(out_words) >= 20:
            break

    print(' '.join(out_words))

In [23]:
#Training loop
NUM_EPOCHS = 250
BATCH_SIZE = 5

for e in range(NUM_EPOCHS):
    en_initial_states = encoder.init_states(BATCH_SIZE)

    for batch, (source_seq, target_seq_in, target_seq_out) in enumerate(dataset.take(-1)):
        loss = train_step_without_attention(source_seq, target_seq_in,
                          target_seq_out, en_initial_states)

    print('Epoch {} Loss {:.4f}'.format(e + 1, loss.numpy()))
    
    try:
        predict_without_attention()
    except Exception:
        continue

Epoch 1 Loss 3.4235
What a ridiculous concept !
[[8, 5, 21, 22, 23]]
s il il vous vous <end>
Epoch 2 Loss 3.6857
Are you giving me another chance ?
[[41, 2, 14, 15, 42, 43, 3]]
vous vous vous vous vous <end>
Epoch 3 Loss 3.5402
Did you plant pumpkins this year ?
[[13, 2, 56, 57, 19, 58, 3]]
vous vous vous vous <end>
Epoch 4 Loss 3.3937
Guess whose birthday it is today .
[[77, 78, 79, 80, 6, 81, 1]]
vous vous vous <end>
Epoch 5 Loss 3.8518
Do you ever study in the library ?
[[10, 2, 59, 60, 9, 4, 61, 3]]
vous vous vous vous <end>
Epoch 6 Loss 3.5430
What he did is very wrong .
[[8, 7, 13, 6, 33, 34, 1]]
vous vous vous <end>
Epoch 7 Loss 3.4282
All three of you need to do that .
[[35, 36, 37, 2, 38, 39, 10, 40, 1]]
vous vous vous <end>
Epoch 8 Loss 3.2189
Germany produced many scientists .
[[73, 74, 75, 76, 1]]
vous vous vous <end>
Epoch 9 Loss 2.7267
Do you ever study in the library ?
[[10, 2, 59, 60, 9, 4, 61, 3]]
vous vous vous <end>
Epoch 10 Loss 3.0996
Excuse me . Can you speak Engl

Epoch 74 Loss 1.3244
Could you close the door please ?
[[53, 2, 54, 4, 55, 18, 3]]
vous vous de a des les . <end>
Epoch 75 Loss 1.6243
Are you giving me another chance ?
[[41, 2, 14, 15, 42, 43, 3]]
l a de de scientifiques . <end>
Epoch 76 Loss 1.5545
How do we know this isn t a trap ?
[[91, 10, 92, 20, 19, 93, 12, 5, 94, 3]]
vous vous vous vous vous la vous vous ? ? <end>
Epoch 77 Loss 1.7872
Did you plant pumpkins this year ?
[[13, 2, 56, 57, 19, 58, 3]]
l a a de longue . <end>
Epoch 78 Loss 1.5357
Honesty will pay in the long run .
[[86, 87, 88, 9, 4, 89, 90, 1]]
l a a de longue . <end>
Epoch 79 Loss 1.7710
Guess whose birthday it is today .
[[77, 78, 79, 80, 6, 81, 1]]
l paye de longue . <end>
Epoch 80 Loss 1.5736
Excuse me . Can you speak English ?
[[67, 15, 1, 11, 2, 68, 69, 3]]
vous vous vous de la la vous ? <end>
Epoch 81 Loss 1.5384
How do we know this isn t a trap ?
[[91, 10, 92, 20, 19, 93, 12, 5, 94, 3]]
vous vous vous vous a la la des ? <end>
Epoch 82 Loss 1.4676
Can I hav

Epoch 142 Loss 0.6077
What he did is very wrong .
[[8, 7, 13, 6, 33, 34, 1]]
qu il est tres . <end>
Epoch 143 Loss 0.6462
Your idea is not entirely crazy .
[[24, 25, 6, 26, 27, 28, 1]]
idee n est pas completement . <end>
Epoch 144 Loss 0.5403
Honesty will pay in the long run .
[[86, 87, 88, 9, 4, 89, 90, 1]]
l paye a la longue . <end>
Epoch 145 Loss 0.5560
Honesty will pay in the long run .
[[86, 87, 88, 9, 4, 89, 90, 1]]
l paye a la longue . <end>
Epoch 146 Loss 0.5654
What a ridiculous concept !
[[8, 5, 21, 22, 23]]
quel <end>
Epoch 147 Loss 0.7049
Can I have a few minutes please ?
[[11, 16, 51, 5, 17, 52, 18, 3]]
je vous avez vous plante des citrouilles ? <end>
Epoch 148 Loss 0.7146
Honesty will pay in the long run .
[[86, 87, 88, 9, 4, 89, 90, 1]]
l paye a la longue . <end>
Epoch 149 Loss 0.5567
Your idea is not entirely crazy .
[[24, 25, 6, 26, 27, 28, 1]]
idee n est pas completement . <end>
Epoch 150 Loss 0.6018
I can t believe you re giving up .
[[16, 11, 12, 95, 2, 96, 14, 97, 

de qui c est l anniversaire aujourd hui ! <end>
Epoch 209 Loss 0.2531
Your idea is not entirely crazy .
[[24, 25, 6, 26, 27, 28, 1]]
votre n est pas completement folle . <end>
Epoch 210 Loss 0.2040
What he did is very wrong .
[[8, 7, 13, 6, 33, 34, 1]]
qu il est tres un . <end>
Epoch 211 Loss 0.2953
All three of you need to do that .
[[35, 36, 37, 2, 38, 39, 10, 40, 1]]
vous avez besoin de faire cela tous les trois . <end>
Epoch 212 Loss 0.2453
Do you ever study in the library ?
[[10, 2, 59, 60, 9, 4, 61, 3]]
comment savez vous qu il ne s agit pas d piege ? <end>
Epoch 213 Loss 0.1715
Could you close the door please ?
[[53, 2, 54, 4, 55, 18, 3]]
pourriez vous la porte s il vous ? <end>
Epoch 214 Loss 0.1991
Germany produced many scientists .
[[73, 74, 75, 76, 1]]
de qui est anniversaire aujourd hui ! <end>
Epoch 215 Loss 0.2318
Do you ever study in the library ?
[[10, 2, 59, 60, 9, 4, 61, 3]]
comment savez vous qu il ne s agit pas d piege ? <end>
Epoch 216 Loss 0.2148
A man s worth lie

### Luong Attention

In [29]:
class LuongAttention(tf.keras.Model):
    def __init__(self, rnn_size, attention_func):
        super(LuongAttention, self).__init__()
        self.attention_func = attention_func

        if attention_func not in ['dot', 'general', 'concat']:
            raise ValueError(
                'Unknown attention score function! Must be either dot, general or concat.')

        if attention_func == 'general':
            # General score function
            self.wa = tf.keras.layers.Dense(rnn_size)
        elif attention_func == 'concat':
            # Concat score function
            self.wa = tf.keras.layers.Dense(rnn_size, activation='tanh')
            self.va = tf.keras.layers.Dense(1)

    def call(self, decoder_output, encoder_output):
        if self.attention_func == 'dot':
            # Dot score function: decoder_output (dot) encoder_output
            # decoder_output has shape: (batch_size, 1, rnn_size)
            # encoder_output has shape: (batch_size, max_len, rnn_size)
            # => score has shape: (batch_size, 1, max_len)
            score = tf.matmul(decoder_output, encoder_output, transpose_b=True)
        elif self.attention_func == 'general':
            # General score function: decoder_output (dot) (Wa (dot) encoder_output)
            # decoder_output has shape: (batch_size, 1, rnn_size)
            # encoder_output has shape: (batch_size, max_len, rnn_size)
            # => score has shape: (batch_size, 1, max_len)
            score = tf.matmul(decoder_output, self.wa(
                encoder_output), transpose_b=True)
        elif self.attention_func == 'concat':
            # Concat score function: va (dot) tanh(Wa (dot) concat(decoder_output + encoder_output))
            # Decoder output must be broadcasted to encoder output's shape first
            decoder_output = tf.tile(
                decoder_output, [1, encoder_output.shape[1], 1])

            # Concat => Wa => va
            # (batch_size, max_len, 2 * rnn_size) => (batch_size, max_len, rnn_size) => (batch_size, max_len, 1)
            score = self.va(
                self.wa(tf.concat((decoder_output, encoder_output), axis=-1)))

            # Transpose score vector to have the same shape as other two above
            # (batch_size, max_len, 1) => (batch_size, 1, max_len)
            score = tf.transpose(score, [0, 2, 1])

        # alignment a_t = softmax(score)
        alignment = tf.nn.softmax(score, axis=2)

        # context vector c_t is the weighted average sum of encoder output
        context = tf.matmul(alignment, encoder_output)

        return context, alignment

In [30]:
#Rewrite the decoder
#At each time step t, we will concatenate the context vector and the current output (of the RNN unit) to 
#form a new output vector. We then continue as normal: convert that vector to vocabulary space for the final
#output.
class Decoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_size, rnn_size, attention_func):
        super(Decoder, self).__init__()
        self.attention = LuongAttention(rnn_size, attention_func)
        self.rnn_size = rnn_size
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_size)
        self.lstm = tf.keras.layers.LSTM(
            rnn_size, return_sequences=True, return_state=True)
        self.wc = tf.keras.layers.Dense(rnn_size, activation='tanh')
        self.ws = tf.keras.layers.Dense(vocab_size)

    def call(self, sequence, state, encoder_output):
        # Remember that the input to the decoder
        # is now a batch of one-word sequences,
        # which means that its shape is (batch_size, 1)
        embed = self.embedding(sequence)

        # Therefore, the lstm_out has shape (batch_size, 1, rnn_size)
        lstm_out, state_h, state_c = self.lstm(embed, initial_state=state)

        # Use self.attention to compute the context and alignment vectors
        # context vector's shape: (batch_size, 1, rnn_size)
        # alignment vector's shape: (batch_size, 1, source_length)
        context, alignment = self.attention(lstm_out, encoder_output)

        # Combine the context vector and the LSTM output
        # Before combined, both have shape of (batch_size, 1, rnn_size),
        # so let's squeeze the axis 1 first
        # After combined, it will have shape of (batch_size, 2 * rnn_size)
        lstm_out = tf.concat(
            [tf.squeeze(context, 1), tf.squeeze(lstm_out, 1)], 1)

        # lstm_out now has shape (batch_size, rnn_size)
        lstm_out = self.wc(lstm_out)

        # Finally, it is converted back to vocabulary space: (batch_size, vocab_size)
        logits = self.ws(lstm_out)

        return logits, state_h, state_c, alignment

In [31]:
#modify the train_step function. Since we are dealing with each time step at a time on the decoder’s side, we will need to explicitly create a loop for that:
@tf.function
def train_step(source_seq, target_seq_in, target_seq_out, en_initial_states):
    loss = 0
    with tf.GradientTape() as tape:
        en_outputs = encoder(source_seq, en_initial_states)
        en_states = en_outputs[1:]
        de_state_h, de_state_c = en_states
        
        # We need to create a loop to iterate through the target sequences
        for i in range(target_seq_out.shape[1]):
            # Input to the decoder must have shape of (batch_size, length)
            # so we need to expand one dimension
            decoder_in = tf.expand_dims(target_seq_in[:, i], 1)
            logit, de_state_h, de_state_c, _ = decoder(
                decoder_in, *(de_state_h, de_state_c), en_outputs[0])
            
            # The loss is now accumulated through the whole batch
            loss += loss_func(target_seq_out[:, i], logit)

    variables = encoder.trainable_variables + decoder.trainable_variables
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))

    return loss / target_seq_out.shape[1]

In [32]:
#Modify predict function. We also need get the source sequence, the translated sequence and the alignment vector for visualization purpose:
def predict(test_source_text=None):
    if test_source_text is None:
        test_source_text = raw_data_en[np.random.choice(len(raw_data_en))]
    print(test_source_text)
    test_source_seq = en_tokenizer.texts_to_sequences([test_source_text])
    print(test_source_seq)

    en_initial_states = encoder.init_states(1)
    en_outputs = encoder(tf.constant(test_source_seq), en_initial_states)

    de_input = tf.constant([[fr_tokenizer.word_index['<start>']]])
    de_state_h, de_state_c = en_outputs[1:]
    de_state =[de_state_h, de_state_c]
    
    out_words = []
    alignments = []

    while True:
        de_output, de_state_h, de_state_c, alignment = decoder(de_input, de_state , en_outputs[0])
        de_input = tf.expand_dims(tf.argmax(de_output, -1), 0)
        out_words.append(fr_tokenizer.index_word[de_input.numpy()[0][0]])
        
        alignments.append(alignment.numpy())

        if out_words[-1] == '<end>' or len(out_words) >= 20:
            break

    print(' '.join(out_words))
    return np.array(alignments), test_source_text.split(' '), out_words


In [34]:
#Training loop
EMBEDDING_SIZE = 32
RNN_SIZE = 512
BATCH_SIZE = 5
ATTENTION_FUNC = 'concat'

en_vocab_size = len(en_tokenizer.word_index) + 1
fr_vocab_size = len(fr_tokenizer.word_index) + 1

encoder = Encoder(en_vocab_size, EMBEDDING_SIZE, RNN_SIZE)
decoder = Decoder(fr_vocab_size, EMBEDDING_SIZE, RNN_SIZE, ATTENTION_FUNC)
# These lines can be used for debugging purpose
# Or can be seen as a way to build the models
initial_state = encoder.init_states(1)
encoder_outputs = encoder(tf.constant([[1]]), initial_state)
decoder_outputs = decoder(tf.constant(
    [[1]]), encoder_outputs[1:], encoder_outputs[0])


def loss_func(targets, logits):
    crossentropy = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True)
    mask = tf.math.logical_not(tf.math.equal(targets, 0))
    mask = tf.cast(mask, dtype=tf.int64)
    loss = crossentropy(targets, logits, sample_weight=mask)

    return loss


optimizer = tf.keras.optimizers.Adam(clipnorm=5.0)


NUM_EPOCHS = 300

for e in range(NUM_EPOCHS):
    en_initial_states = encoder.init_states(BATCH_SIZE)
    
    predict()

    for batch, (source_seq, target_seq_in, target_seq_out) in enumerate(dataset.take(-1)):
        loss = train_step(source_seq, target_seq_in,
                          target_seq_out, en_initial_states)

    print('Epoch {} Loss {:.4f}'.format(e + 1, loss.numpy()))

How do we know this isn t a trap ?
[[91, 10, 92, 20, 19, 93, 12, 5, 94, 3]]
s donnez s donnez s donnez s donnez s donnez s donnez s donnez s donnez s donnez s donnez


TypeError: in converted code:

    <ipython-input-19-daac7785e3e5>:15 train_step  *
        logit, de_state_h, de_state_c, _ = decoder(
    /usr/local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py:778 __call__
        outputs = call_fn(cast_inputs, *args, **kwargs)

    TypeError: tf__call() takes 4 positional arguments but 5 were given
