In [4]:
import tensorflow as tf
from pathlib import Path
url="https://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip"
path=tf.keras.utils.get_file("spa.zip",origin=url,extract=True,cache_dir="datasets")
data=(Path(path).with_name("spa_extracted")/"spa-eng"/"spa.txt").read_text(encoding="UTF-8")

In [5]:
import numpy as np
text=data.replace("¡","").replace("¿","")
pairs=[line.split("\t") for line in text.splitlines()]
np.random.shuffle(pairs)
sentences_en,sentences_esp=zip(*pairs)

In [6]:
x_train_enc=tf.constant(sentences_en[:100000])
x_val_enc=tf.constant(sentences_en[100000:])

x_train_dec=tf.constant([f"startofseq {sen}" for sen in sentences_esp[:100000]])
x_val_dec=tf.constant([f"startofseq {sen}" for sen in sentences_esp[100000:]])

y_train=[f"{sen} endofseq" for sen in sentences_esp[:100000]]
y_val=[f"{sen} endofseq" for sen in sentences_esp[100000:]]

In [7]:
vocab_size=8000
output_len=50
enc_vec_layer=tf.keras.layers.TextVectorization(vocab_size,output_sequence_length=output_len)
dec_vec_layer=tf.keras.layers.TextVectorization(vocab_size,output_sequence_length=output_len)
enc_vec_layer.adapt(sentences_en)
dec_vec_layer.adapt([f"startofseq {s} endofseq" for s in sentences_esp])

This is the learnable positions method:

In [8]:
#max_training_len=50
#n_dims=128
#pos_emb_layer=tf.keras.layers.Embedding(max_training_len,n_dims)
#batch_max_len_enc=tf.shape(encoder_embeddings)[1]
#encoder_input=encoder_embeddings+pos_emb_layer(tf.range(batch_max_len_enc))
#batch_max_len_dec=tf.shape(decoder_embeddings)[1]
#decoder_input=decoder_embeddings+pos_emb_layer(tf.range(batch_max_len_dec))

I'll be using the fixed positions method using maths functions:

In [9]:
class PositionalEncoding(tf.keras.layers.Layer):
    def __init__(self,max_input_len,n_dims,dtype=tf.float32,**kwargs):
        super().__init__(dtype=dtype,**kwargs)
        assert n_dims%2==0,'n_dims must be even for sin/cos distribution'
        p,i=np.meshgrid(np.arange(max_input_len),2*(np.arange(n_dims//2)))
        PE=np.empty((1,max_input_len,n_dims))
        PE[0,:,::2]=np.sin(p/10000**(i/n_dims)).T
        PE[0,:,1::2]=np.cos(p/10000**(i/n_dims)).T
        self.embedding_table=tf.constant(PE.astype(self.dtype))
        self.supports_masking=True
    def call(self,inputs):
        inputs=tf.cast(inputs,tf.float32)
        batch_max_len=tf.shape(inputs)[1]
        return inputs+self.embedding_table[:,:batch_max_len]

IF MEMORY IS SMALL, IT WILL THROW AN OOM ERROR!!! (100k sentences having 50 tokens/words each, where each token is represented by 128 dims: 100000x50x128 values storing!)

In [10]:
max_training_len=50
n_dims=128

vec_enc_inputs=enc_vec_layer(x_train_enc)
vec_dec_inputs=dec_vec_layer(x_train_dec)
vec_y=dec_vec_layer(y_train)

vec_enc_inputs_val=enc_vec_layer(x_val_enc)
vec_dec_inputs_val=dec_vec_layer(x_val_dec)
vec_y_val=dec_vec_layer(y_val)

Embedding=tf.keras.layers.Embedding(vocab_size,n_dims,dtype=tf.float32)

enc_inputs = tf.keras.Input(shape=(output_len,), dtype=tf.int64)
dec_inputs = tf.keras.Input(shape=(output_len,), dtype=tf.int64)

enc_emb = Embedding(enc_inputs)
dec_emb = Embedding(dec_inputs)

pos_enc = PositionalEncoding(512, n_dims)

final_encoder_inputs = pos_enc(enc_emb)
final_decoder_inputs = pos_enc(dec_emb)


In [11]:
enc_stack_repititions=2
Heads=8
n_units=128
dropout=0.1
Z=final_encoder_inputs
encoder_pad_mask = tf.keras.layers.Lambda(
    lambda x: tf.cast(tf.math.not_equal(x, 0), tf.bool)[:, tf.newaxis, tf.newaxis, :]
)(enc_inputs)
for _ in range(enc_stack_repititions):
    skip=Z
    attn_layer=tf.keras.layers.MultiHeadAttention(num_heads=Heads,key_dim=16,dropout=dropout)
    Z=attn_layer(Z,value=Z,attention_mask=encoder_pad_mask)
    Z=tf.keras.layers.LayerNormalization(epsilon=1e-7)(tf.keras.layers.Add()([Z,skip]))
    skip=Z
    Z=tf.keras.layers.Dense(n_units,activation="relu")(Z)
    Z=tf.keras.layers.Dense(n_dims)(Z)
    Z=tf.keras.layers.LayerNormalization(epsilon=1e-7)(tf.keras.layers.Add()([Z,skip]))

encoder_output=Z

In [12]:
batch_max_len_dec = tf.keras.layers.Lambda(
    lambda x: tf.shape(x)[1]
)(final_decoder_inputs)
dec_pad_mask = tf.keras.layers.Lambda(
    lambda x: tf.cast(tf.math.not_equal(x, 0), tf.bool)[:, tf.newaxis, tf.newaxis, :],
    output_shape=lambda s: (s[0], 1, 1, s[1])
)(dec_inputs)
causal_mask = tf.keras.layers.Lambda(
    lambda x: tf.cast(
        tf.linalg.band_part(tf.ones((tf.shape(x)[1], tf.shape(x)[1])), -1, 0),
        tf.bool
    )[tf.newaxis, tf.newaxis, :, :],
    output_shape=(1, 1, None, None)
)(dec_inputs)
combined_mask = tf.keras.layers.Lambda(
    lambda inputs: tf.logical_and(inputs[0], inputs[1]),
    output_shape=lambda s: s[0]
)([dec_pad_mask, causal_mask])

In [13]:
Z=final_decoder_inputs
dec_stack_repititions=2
for _ in range(dec_stack_repititions):
    skip=Z
    attn_layer=tf.keras.layers.MultiHeadAttention(dropout=dropout,num_heads=Heads,key_dim=16)
    Z=attn_layer(Z,value=Z,attention_mask=combined_mask)
    Z=tf.keras.layers.LayerNormalization(epsilon=1e-7)(tf.keras.layers.Add()([Z,skip]))
    skip=Z
    attn_layer=tf.keras.layers.MultiHeadAttention(dropout=dropout,num_heads=Heads,key_dim=16)
    Z=attn_layer(query=Z,key=encoder_output,value=encoder_output,attention_mask=encoder_pad_mask)
    Z=tf.keras.layers.LayerNormalization(epsilon=1e-7)(tf.keras.layers.Add()([Z,skip]))
    skip=Z
    Z=tf.keras.layers.Dense(units=n_units,activation="relu")(Z)
    Z=tf.keras.layers.Dense(units=n_dims)(Z)
    Z=tf.keras.layers.LayerNormalization(epsilon=1e-7)(tf.keras.layers.Add()([skip,Z]))

In [14]:
probas=tf.keras.layers.Dense(vocab_size,activation="softmax")(Z)
model=tf.keras.Model(inputs=[enc_inputs,dec_inputs],outputs=probas)
model.compile(loss="sparse_categorical_crossentropy",optimizer="nadam",metrics=["Accuracy"])
model.fit([vec_enc_inputs,vec_dec_inputs],vec_y,epochs=7,batch_size=64,validation_data=((vec_enc_inputs_val,vec_dec_inputs_val),vec_y_val))

Epoch 1/7
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m112s[0m 38ms/step - Accuracy: 0.8864 - loss: 1.1710 - val_Accuracy: 0.9367 - val_loss: 0.3568
Epoch 2/7
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 10ms/step - Accuracy: 0.9417 - loss: 0.3137 - val_Accuracy: 0.9490 - val_loss: 0.2598
Epoch 3/7
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 10ms/step - Accuracy: 0.9540 - loss: 0.2175 - val_Accuracy: 0.9545 - val_loss: 0.2196
Epoch 4/7
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 10ms/step - Accuracy: 0.9592 - loss: 0.1786 - val_Accuracy: 0.9562 - val_loss: 0.2089
Epoch 5/7
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 10ms/step - Accuracy: 0.9625 - loss: 0.1569 - val_Accuracy: 0.9580 - val_loss: 0.2002
Epoch 6/7
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 10ms/step - Accuracy: 0.9652 - loss: 0.1411 - val_Accuracy: 0.9585 - val_loss: 0.1961
Epoch 7/7

<keras.src.callbacks.history.History at 0x796d6834f470>

In [17]:
def translate(input, model, encoder_vec_layer, decoder_vec_layer,
              start_token="startofseq", end_token="endofseq", max_len=50):
    input_enc = encoder_vec_layer(tf.constant([input]))
    dec_input = decoder_vec_layer(tf.constant([start_token]))

    for _ in range(max_len):
        probas = model.predict([input_enc, dec_input], verbose=0)
        next_token = tf.argmax(probas[:, -1, :], axis=-1, output_type=tf.int32)
        next_token = tf.cast(next_token, dtype=dec_input.dtype)
        dec_input = tf.concat([dec_input, tf.expand_dims(next_token, axis=1)], axis=1)

        end_token_id = decoder_vec_layer([end_token]).numpy()[0][0]
        if next_token.numpy()[0] == end_token_id:
            break

    vocab = decoder_vec_layer.get_vocabulary()
    decoded = [vocab[t] for t in dec_input.numpy()[0]]
    sentence = " ".join(decoded)
    return sentence.replace(start_token, "").replace(end_token, "").strip()

In [18]:
translate("This is the translation by the magic of attention",model,enc_vec_layer,dec_vec_layer)

''