# Translation with cross-attention

## Michał Gromadzki

In [1]:
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import tensorflow_text as tf_text
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pathlib
tf.random.set_seed(1337)
np.random.seed(1337)

2023-05-17 01:34:16.974545: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
path_to_zip = tf.keras.utils.get_file(
    'spa-eng.zip', origin='http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip',
    extract=True)

path_to_file = pathlib.Path(path_to_zip).parent/'spa-eng/spa.txt'

In [3]:
text = path_to_file.read_text(encoding="utf-8")

In [4]:
lines = text.splitlines()
pairs = [line.split('\t') for line in lines]
pairs[:10]

[['Go.', 'Ve.'],
 ['Go.', 'Vete.'],
 ['Go.', 'Vaya.'],
 ['Go.', 'Váyase.'],
 ['Hi.', 'Hola.'],
 ['Run!', '¡Corre!'],
 ['Run.', 'Corred.'],
 ['Who?', '¿Quién?'],
 ['Fire!', '¡Fuego!'],
 ['Fire!', '¡Incendio!']]

In [5]:
len(pairs), len(pairs[0])

(118964, 2)

In [6]:
pairs = np.array(pairs)

In [7]:
np.random.shuffle(pairs)

In [8]:
target = pairs[:,0]
context = pairs[:,1]

In [9]:
target.shape, context.shape

((118964,), (118964,))

In [10]:
example_context = context[1337]
example_target = target[1337]
example_context, example_target

('¿Dirías que es verdad?', "Would you say that's true?")

In [11]:
batch_size = 64
n = int(0.8*len(context))

train_raw = (
    tf.data.Dataset
    .from_tensor_slices((context[:n], target[:n]))
    .shuffle(10000)
    .batch(batch_size))
val_raw = (
    tf.data.Dataset
    .from_tensor_slices((context[n:], target[n:]))
    .shuffle(10000)
    .batch(batch_size))

2023-05-17 01:34:22.398635: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:982] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2023-05-17 01:34:22.471414: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:982] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2023-05-17 01:34:22.471531: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:982] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2023-05-17 01:34:22.476404: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:982] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2023-05-17 01:34:22.476493: I tensorflow/compile

In [12]:
def preprocess(text):
    # Split accented characters.
    text = tf_text.normalize_utf8(text, 'NFKD')
    text = tf.strings.lower(text)
    # Keep space, a to z, and select punctuation.
    text = tf.strings.regex_replace(text, '[^ a-z.?!,¿]', '')
    # Add spaces around punctuation.
    text = tf.strings.regex_replace(text, '[.?!,¿]', r' \0 ')
    # Strip whitespace.
    text = tf.strings.strip(text)

    text = tf.strings.join(['[START]', text, '[END]'], separator=' ')
    return text

In [13]:
example_context = preprocess(example_context)
example_target = preprocess(example_target)
example_context.numpy(), example_target.numpy()

(b'[START] \xc2\xbf dirias que es verdad ? [END]',
 b'[START] would you say thats true ? [END]')

In [14]:
vocab_size = 8192
processor_spa = tf.keras.layers.TextVectorization(
    standardize=preprocess,
    max_tokens=vocab_size,
    ragged=True)

processor_spa.adapt(train_raw.map(lambda context, target: context))

2023-05-17 01:34:24.543289: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_1' with dtype string and shape [95171]
	 [[{{node Placeholder/_1}}]]


In [15]:
processor_spa.get_vocabulary()[:10]

['', '[UNK]', '[START]', '[END]', '.', 'que', 'de', 'el', 'a', 'no']

In [16]:
processor_eng = tf.keras.layers.TextVectorization(
    standardize=preprocess,
    max_tokens=vocab_size,
    ragged=True)

processor_eng.adapt(train_raw.map(lambda context, target: target))

2023-05-17 01:34:30.099665: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_1' with dtype string and shape [95171]
	 [[{{node Placeholder/_1}}]]


In [17]:
processor_eng.get_vocabulary()[:10]

['', '[UNK]', '[START]', '[END]', '.', 'the', 'i', 'to', 'you', 'tom']

In [18]:
example_context = processor_spa(example_context)
example_target = processor_eng(example_target)
example_context.numpy(), example_target.numpy()

(array([   2,    1,   13, 5577,    5,   15,  109,   12,    1,    3]),
 array([  2, 397,  78,   8, 135, 146, 289,  11, 535,   3]))

In [19]:
def prepare_text(context, target):
    context = processor_spa(context).to_tensor()
    target = processor_eng(target)
    targ_in = target[:,:-1].to_tensor()
    targ_out = target[:,1:].to_tensor()
    return (context, targ_in), targ_out

In [20]:
train_ds = train_raw.map(prepare_text)
val_ds = val_raw.map(prepare_text)

In [21]:
for (ex_context_tok, ex_tar_in), ex_tar_out in train_ds.take(1):
    print(ex_context_tok[0, :10].numpy()) 
    print()
    print(ex_tar_in[0, :10].numpy()) 
    print(ex_tar_out[0, :10].numpy())

[   2   86    5   15   42  106 1243    5   10  500]

[   2    6   65   51 1673 1427   16    9   49  406]
[   6   65   51 1673 1427   16    9   49  406    4]


2023-05-17 01:34:34.965510: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_18' with dtype int64
	 [[{{node Placeholder/_18}}]]


In [22]:
word_to_id_eng = tf.keras.layers.StringLookup(
        vocabulary=processor_eng.get_vocabulary(),
        mask_token='', oov_token='[UNK]')

In [23]:
word_to_id_spa = tf.keras.layers.StringLookup(
        vocabulary=processor_spa.get_vocabulary(),
        mask_token='', oov_token='[UNK]')

In [24]:
id_to_word_eng = tf.keras.layers.StringLookup(
        vocabulary=processor_eng.get_vocabulary(),
        mask_token='', oov_token='[UNK]',
        invert=True)

In [25]:
id_to_word_spa = tf.keras.layers.StringLookup(
        vocabulary=processor_spa.get_vocabulary(),
        mask_token='', oov_token='[UNK]',
        invert=True)

In [26]:
for (ex_context_tok, ex_tar_in), ex_tar_out in train_ds.take(1):
    print(id_to_word_spa(ex_context_tok[0, :10]).numpy()) 
    print()
    print(id_to_word_eng(ex_tar_in[0, :10]).numpy()) 
    print(id_to_word_eng(ex_tar_out[0, :10]).numpy())

[b'[START]' b'este' b'es' b'el' b'museo' b'mas' b'grande' b'de' b'la'
 b'ciudad']

[b'[START]' b'this' b'is' b'the' b'largest' b'museum' b'in' b'the' b'city'
 b'.']
[b'this' b'is' b'the' b'largest' b'museum' b'in' b'the' b'city' b'.'
 b'[END]']


2023-05-17 01:34:35.516698: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_13' with dtype string
	 [[{{node Placeholder/_13}}]]


# Model

In [27]:
n_units = 256
n_embed = 16

In [28]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self, units, n_embed):
        super(Encoder, self).__init__()
        self.vocab_size = vocab_size
        self.units = units
        self.embedding = tf.keras.layers.Embedding(self.vocab_size, n_embed, mask_zero=True)
        self.rnn = tf.keras.layers.Bidirectional(merge_mode='sum',
            layer=tf.keras.layers.GRU(units, return_sequences=True))      

    def call(self, x):
        x = self.embedding(x)
        x = self.rnn(x)
        return x

In [29]:
encoder = Encoder(n_units, n_embed)
ex_context = encoder(ex_context_tok)

ex_context_tok.shape, ex_context.shape

2023-05-17 01:34:36.960581: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:424] Loaded cuDNN version 8600


(TensorShape([64, 18]), TensorShape([64, 18, 256]))

In [30]:
n_heads = 4

In [31]:
class CrossAttention(tf.keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super().__init__()
        self.mha = tf.keras.layers.MultiHeadAttention(key_dim=units, num_heads=n_heads, **kwargs)
        self.layernorm = tf.keras.layers.LayerNormalization()
        self.add = tf.keras.layers.Add()

    def call(self, x, context):
        attn_output = self.mha(
            query=x,
            value=context,
            return_attention_scores=False)

        x = self.add([x, attn_output])
        x = self.layernorm(x)
        return x

In [32]:
class Decoder(tf.keras.layers.Layer):
    def __init__(self, units, n_embed):
        super(Decoder, self).__init__()
        self.vocab_size = vocab_size
        self.units = units
        self.embedding = tf.keras.layers.Embedding(self.vocab_size, n_embed, mask_zero=True)
        self.rnn = tf.keras.layers.GRU(units, return_sequences=True, return_state=True)
        self.attention = CrossAttention(units)
        self.output_layer = tf.keras.layers.Dense(self.vocab_size)

    def get_initial_state(self, context):
        batch_size = tf.shape(context)[0]
        start_tokens = tf.fill([batch_size, 1], word_to_id_spa('[START]'))
        done = tf.zeros([batch_size, 1], dtype=tf.bool)
        embed = self.embedding(start_tokens)
        return start_tokens, done, self.rnn.get_initial_state(embed)[0]
    
    def get_next_token(self, context, next_token, done, state):
        logits, state = self(context, next_token, state = state, return_state=True) 
        logits = logits[:, -1, :]
        next_token = tf.random.categorical(logits, num_samples=1)
        done = done | (next_token == word_to_id_eng('[END]'))
        next_token = tf.where(done, tf.constant(0, dtype=tf.int64), next_token)
        return next_token, done, state

    def call(self, context, x, state=None, return_state=False):  
        x = self.embedding(x)
        x, state = self.rnn(x, initial_state=state)
        x = self.attention(x, context)
        logits = self.output_layer(x)

        if return_state:
            return logits, state
        else:
            return logits

In [33]:
decoder = Decoder(n_units, n_embed)
logits = decoder(ex_context, ex_tar_in)
logits.shape

TensorShape([64, 17, 8192])

In [34]:
class Model(tf.keras.Model):
    def __init__(self, units, n_embed):
        super().__init__()
        self.encoder = Encoder(units, n_embed)
        self.decoder = Decoder(units, n_embed)

    def get_initial_state(self, context):
        return self.decoder.get_initial_state(context)
    
    def get_next_token(self, context, next_token, done, state):
        return self.decoder.get_next_token(context, next_token, done, state)

    def call(self, inputs):
        context, x = inputs
        context = self.encoder(context)
        logits = self.decoder(context, x)
        return logits

In [35]:
n_units = 1024
n_embed = 64
model = Model(n_units, n_embed)
logits = model((ex_context_tok, ex_tar_in))
logits.shape

TensorShape([64, 17, 8192])

In [36]:
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder_1 (Encoder)         multiple                  7221248   
                                                                 
 decoder_1 (Decoder)         multiple                  29062144  
                                                                 
Total params: 36,283,392
Trainable params: 36,283,392
Non-trainable params: 0
_________________________________________________________________


In [37]:
model.compile(optimizer='adam',
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
            metrics=["accuracy"])

In [87]:
model.fit(train_ds.repeat(), epochs=30, steps_per_epoch=100, validation_data=val_ds, validation_steps=20)

Epoch 1/30


2023-05-16 21:35:02.211257: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_16' with dtype int64
	 [[{{node Placeholder/_16}}]]
2023-05-16 21:35:09.069370: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/ReverseV2_grad/ReverseV2/ReverseV2/axis' with dtype int32 and shape [1]
	 [[{{node gradients/ReverseV2_grad/ReverseV2/ReverseV2/axis}}]]
2023-05-16 21:35:16.065551: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 



2023-05-16 21:36:00.911481: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_12' with dtype int64
	 [[{{node Placeholder/_12}}]]


Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


<keras.callbacks.History at 0x7fd7e22a6e90>

## Test the model

In [88]:
for (ex_context_tok, ex_tar_in), ex_tar_out in train_ds.take(1):
    pass

2023-05-16 21:52:24.706814: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_12' with dtype int64
	 [[{{node Placeholder/_12}}]]


In [89]:
ex_context = model.encoder(ex_context_tok)

In [90]:
next_token, done, state = model.get_initial_state(ex_context)
tokens = []

for n in range(10):
    next_token, done, state = model.get_next_token(
        ex_context, next_token, done, state)
    tokens.append(next_token)

tokens = tf.concat(tokens, axis=-1)

In [91]:
def tokens_to_text(tokens, language):
    words = id_to_word_eng(tokens) if language=="eng" else id_to_word_spa(tokens)
    result = tf.strings.reduce_join(words, axis=-1, separator=' ')
    result = tf.strings.regex_replace(result, '^ *\[START\] *', '')
    result = tf.strings.regex_replace(result, ' *\[END\] *$', '')
    return result.numpy()

In [92]:
tokens_to_text(tokens, language="eng")

array([b'tom would never be too surprised if mary behaves among',
       b'why are you eating vegetables ?    ',
       b'the child cant power tom watches for ten minutes .',
       b'the poor woman woman beyond money .   ',
       b'she took another book off the shelf .  ',
       b'by the time [UNK] is coming to japan . ',
       b'i cant play the piano .    ', b'dont be thankful .      ',
       b'she scolded a sweater for her father .  ',
       b'ill take this .      ', b'i no mind .      ',
       b'i need one more .     ', b'could i open the door ?    ',
       b'to translate it worse .     ',
       b'those ok went to set pouring yesterday .  ',
       b'that knife tastes good .     ', b'they are on ten oclock .    ',
       b'everyone and tom is to arrest tom .  ',
       b'in swimming . we have fun .   ', b'shes also vegetarian .      ',
       b'i had fun of them .    ',
       b'what are you doing this afternoon ?   ',
       b'how old do you have the children ?  ',
       

In [93]:
tokens_to_text(ex_context_tok,language="spa")

array([b'tom no estaria demasiado sorprendido si maria [UNK] aceptar la oferta de trabajo .',
       b'\xc2\xbf por que no comes vegetales ?',
       b'ese nino es incapaz de [UNK] quieto durante diez minutos .',
       b'a la pobre mujer mayor le robaron el dinero .',
       b'ella saco un libro de la repisa .',
       b'por fin la primavera ha llegado a esta parte de japon .',
       b'no puedo tocar el piano .', b'no seas [UNK] .',
       b'ella [UNK] un jersey para su padre .', b'lo tomare en cuenta .',
       b'no me importa .', b'necesito uno mas .',
       b'\xc2\xbf puedo abrir la puerta ?', b'un dia lo lamentaras .',
       b'esos prisioneros fueron liberados ayer .',
       b'ese cuchillo corta bien .', b'son las diez en punto .',
       b'todos [UNK] a tom .', b'las apariencias enganan .',
       b'tomas tambien es vegetariano .', b'nos divertimos con ellos .',
       b'\xc2\xbf que estas haciendo a estas horas ?',
       b'\xc2\xbf cuantos anos tienen los ninos ?',
       b

# Final model in model.py