# Imports

In [1]:
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import pathlib
import random
import string
import re
import numpy as np

import tensorflow.data as tf_data
import tensorflow.strings as tf_strings

import keras
from keras import layers
from keras import ops
from keras.layers import TextVectorization

# Parse text file and prepare the data

In [6]:
with open("spa.txt", encoding="utf-8") as f:
    lines = f.read().split("\n")[:-1]
text_pairs = []
for line in lines:
    eng, spa = line.split("\t")
    spa = "[start] " + spa + " [end]"
    text_pairs.append((eng, spa))

## What does it look like

In [7]:
for _ in range(5):
    print(random.choice(text_pairs))

('What about next Sunday?', '[start] ¿Qué te parece el próximo domingo? [end]')
("He's very proud of his custom motorcycle.", '[start] Está muy orgulloso de su motocicleta personalizada. [end]')
('I have a Vietnamese friend. Her name is Tiên.', '[start] Tengo una amigo Vietnamita. Su nombre es Tiên. [end]')
('What is the secret of success?', '[start] ¿Cuál es el secreto del éxito? [end]')
('When she was young, she was very popular.', '[start] Cuando ella era joven, era muy popular. [end]')


## Train test split

In [8]:
random.shuffle(text_pairs)
num_val_samples = int(0.15 * len(text_pairs))
num_train_samples = len(text_pairs) - 2 * num_val_samples
train_pairs = text_pairs[:num_train_samples]
val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]
test_pairs = text_pairs[num_train_samples + num_val_samples :]

print(f"{len(text_pairs)} total pairs")
print(f"{len(train_pairs)} training pairs")
print(f"{len(val_pairs)} validation pairs")
print(f"{len(test_pairs)} test pairs")

118964 total pairs
83276 training pairs
17844 validation pairs
17844 test pairs


## Vectorization

In [None]:
# Define which characters to strip (punctuation + inverted question mark from Spanish)
strip_chars = string.punctuation + "¿"

# Keep [ and ] tokens, since we use them for [start] and [end]
strip_chars = strip_chars.replace("[", "")
strip_chars = strip_chars.replace("]", "")

# Define preprocessing / training parameters
vocab_size = 15000        # Limit vocabulary size (only keep most frequent words)
sequence_length = 20      # Max length for English input sequences
batch_size = 64           # Batch size for training

# Custom standardization function for Spanish text
def custom_standardization(input_string):
    # Convert to lowercase
    lowercase = tf_strings.lower(input_string)
    # Remove punctuation and extra characters defined above
    return tf_strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")

# Vectorizer for English input text
eng_vectorization = TextVectorization(
    max_tokens=vocab_size,                 # Limit vocabulary size
    output_mode="int",                     # Output integer-encoded sequences
    output_sequence_length=sequence_length # Pad/truncate to fixed length
)

# Vectorizer for Spanish target text
spa_vectorization = TextVectorization(
    max_tokens=vocab_size,                     # Limit vocabulary size
    output_mode="int",                         # Output integer-encoded sequences
    output_sequence_length=sequence_length + 1,# Targets are shifted by 1 (teacher forcing)
    standardize=custom_standardization         # Use custom cleanup for Spanish text
)

# Extract training English and Spanish sentences
train_eng_texts = [pair[0] for pair in train_pairs]
train_spa_texts = [pair[1] for pair in train_pairs]

# Build vocabularies and adapt vectorizers to the training data
eng_vectorization.adapt(train_eng_texts)
spa_vectorization.adapt(train_spa_texts)


## Create datasets

In [None]:
# Function to format a batch of English-Spanish pairs for seq2seq training
def format_dataset(eng, spa):
    # Convert raw text into integer sequences using the vectorizers
    eng = eng_vectorization(eng)    # Shape: (batch, sequence_length)
    spa = spa_vectorization(spa)    # Shape: (batch, sequence_length + 1)

    # Return inputs and labels in the format required by seq2seq models:
    # - Encoder gets the full English sentence
    # - Decoder gets the Spanish sentence shifted by 1 (teacher forcing)
    # - Labels are the Spanish sentence shifted by 1 (expected next tokens)
    return (
        {
            "encoder_inputs": eng,        # Input for encoder
            "decoder_inputs": spa[:, :-1] # Decoder sees all tokens except the last one
        },
        spa[:, 1:]                        # Labels are all tokens except the first one
    )


# Function to build a tf.data.Dataset from text pairs
def make_dataset(pairs):
    # Unzip list of (eng, spa) tuples into two lists
    eng_texts, spa_texts = zip(*pairs)
    eng_texts = list(eng_texts)
    spa_texts = list(spa_texts)

    # Create a dataset from raw text
    dataset = tf_data.Dataset.from_tensor_slices((eng_texts, spa_texts))

    # Batch the dataset
    dataset = dataset.batch(batch_size)

    # Apply format_dataset to convert texts into model-ready tensors
    dataset = dataset.map(format_dataset)

    # Improve training performance:
    # - cache(): keep dataset in memory after first epoch
    # - shuffle(): randomize order to avoid learning order bias
    # - prefetch(): pipeline optimization for faster GPU feeding
    return dataset.cache().shuffle(2048).prefetch(16)


# Build training and validation datasets
train_ds = make_dataset(train_pairs)
val_ds = make_dataset(val_pairs)


# Model

In [None]:
# Transformer Encoder block implementation
class TransformerEncoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        # Embedding dimension for inputs
        self.embed_dim = embed_dim
        # Dimension for feedforward projection (intermediate hidden size)
        self.dense_dim = dense_dim
        # Number of attention heads
        self.num_heads = num_heads

        # Multi-head self-attention layer
        self.attention = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )

        # Feedforward network applied after attention
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(dense_dim, activation="relu"),  # First dense layer
                layers.Dense(embed_dim),  # Project back to embedding dimension
            ]
        )

        # Layer normalization layers (pre-norm style)
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()

        # Support masking to ignore padding tokens
        self.supports_masking = True

    def call(self, inputs, mask=None):
        # If a mask is provided, create padding mask for attention
        if mask is not None:
            padding_mask = ops.cast(mask[:, None, :], dtype="int32")
        else:
            padding_mask = None

        # Apply multi-head self-attention
        attention_output = self.attention(
            query=inputs, value=inputs, key=inputs, attention_mask=padding_mask
        )

        # Add & Norm: residual connection + normalization
        proj_input = self.layernorm_1(inputs + attention_output)

        # Apply feedforward network
        proj_output = self.dense_proj(proj_input)

        # Add & Norm again
        return self.layernorm_2(proj_input + proj_output)

    def get_config(self):
        # Ensure model serialization compatibility
        config = super().get_config()
        config.update(
            {
                "embed_dim": self.embed_dim,
                "dense_dim": self.dense_dim,
                "num_heads": self.num_heads,
            }
        )
        return config


# Positional Embedding Layer: adds token embeddings + positional embeddings
class PositionalEmbedding(layers.Layer):
    def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
        super().__init__(**kwargs)
        # Embedding for tokens (word embeddings)
        self.token_embeddings = layers.Embedding(
            input_dim=vocab_size, output_dim=embed_dim
        )
        # Embedding for positions (positional encoding)
        self.position_embeddings = layers.Embedding(
            input_dim=sequence_length, output_dim=embed_dim
        )
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim

    def call(self, inputs):
        # Determine sequence length of inputs
        length = ops.shape(inputs)[-1]
        # Create position indices (0, 1, 2, ..., length-1)
        positions = ops.arange(0, length, 1)

        # Embed tokens and positions
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)

        # Sum token embeddings and positional embeddings
        return embedded_tokens + embedded_positions

    def compute_mask(self, inputs, mask=None):
        # Create a mask for padding tokens (assuming 0 = padding token)
        return ops.not_equal(inputs, 0)

    def get_config(self):
        # For saving/loading the model
        config = super().get_config()
        config.update(
            {
                "sequence_length": self.sequence_length,
                "vocab_size": self.vocab_size,
                "embed_dim": self.embed_dim,
            }
        )
        return config


# Transformer Decoder block implementation
class TransformerDecoder(layers.Layer):
    def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        self.num_heads = num_heads

        # Masked self-attention (causal mask ensures autoregressive behavior)
        self.attention_1 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )

        # Cross-attention (decoder attends to encoder outputs)
        self.attention_2 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )

        # Feedforward network
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(latent_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )

        # Layer normalization layers
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.layernorm_3 = layers.LayerNormalization()

        self.supports_masking = True

    def call(self, inputs, mask=None):
        # Inputs are a tuple: (decoder_inputs, encoder_outputs)
        inputs, encoder_outputs = inputs

        # Causal mask ensures tokens cannot attend to future tokens
        causal_mask = self.get_causal_attention_mask(inputs)

        # Handle optional padding masks
        if mask is None:
            inputs_padding_mask, encoder_outputs_padding_mask = None, None
        else:
            inputs_padding_mask, encoder_outputs_padding_mask = mask

        # 1st block: masked self-attention
        attention_output_1 = self.attention_1(
            query=inputs,
            value=inputs,
            key=inputs,
            attention_mask=causal_mask,
            query_mask=inputs_padding_mask,
        )
        out_1 = self.layernorm_1(inputs + attention_output_1)

        # 2nd block: cross-attention with encoder outputs
        attention_output_2 = self.attention_2(
            query=out_1,
            value=encoder_outputs,
            key=encoder_outputs,
            query_mask=inputs_padding_mask,
            key_mask=encoder_outputs_padding_mask,
        )
        out_2 = self.layernorm_2(out_1 + attention_output_2)

        # Feedforward projection
        proj_output = self.dense_proj(out_2)

        # Final residual connection + normalization
        return self.layernorm_3(out_2 + proj_output)

    def get_causal_attention_mask(self, inputs):
        # Generate causal (look-ahead) mask for autoregressive decoding
        input_shape = ops.shape(inputs)
        batch_size, sequence_length = input_shape[0], input_shape[1]

        # Create lower triangular matrix (i >= j)
        i = ops.arange(sequence_length)[:, None]
        j = ops.arange(sequence_length)
        mask = ops.cast(i >= j, dtype="int32")

        # Reshape to (1, seq_len, seq_len) and tile for batch size
        mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))
        mult = ops.concatenate(
            [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])],
            axis=0,
        )
        return ops.tile(mask, mult)

    def get_config(self):
        # For saving/loading the model
        config = super().get_config()
        config.update(
            {
                "embed_dim": self.embed_dim,
                "latent_dim": self.latent_dim,
                "num_heads": self.num_heads,
            }
        )
        return config


In [None]:
# Model hyperparameters
embed_dim = 256      # Size of token/positional embeddings
latent_dim = 2048    # Size of hidden layer in feedforward network
num_heads = 8        # Number of attention heads

# ------------------------------
# ENCODER
# ------------------------------
# Encoder input: integer token IDs
encoder_inputs = keras.Input(shape=(None,), dtype="int64", name="encoder_inputs")

# Add token + positional embeddings
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(encoder_inputs)

# Apply Transformer Encoder (self-attention + feedforward)
encoder_outputs = TransformerEncoder(embed_dim, latent_dim, num_heads)(x)

# Define encoder model
encoder = keras.Model(encoder_inputs, encoder_outputs)

# ------------------------------
# DECODER
# ------------------------------
# Decoder input: integer token IDs
decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs")

# Input for encoder outputs (to be passed into cross-attention)
encoded_seq_inputs = keras.Input(shape=(None, embed_dim), name="decoder_state_inputs")

# Add token + positional embeddings for decoder inputs
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(decoder_inputs)

# Apply Transformer Decoder:
#   - masked self-attention on decoder inputs
#   - cross-attention on encoder outputs
x = TransformerDecoder(embed_dim, latent_dim, num_heads)([x, encoder_outputs])

# Dropout for regularization
x = layers.Dropout(0.5)(x)

# Final output layer: project to vocabulary size with softmax
decoder_outputs = layers.Dense(vocab_size, activation="softmax")(x)

# Define decoder model
decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs)

# ------------------------------
# FULL TRANSFORMER MODEL
# ------------------------------
# Tie encoder + decoder together
transformer = keras.Model(
    {"encoder_inputs": encoder_inputs, "decoder_inputs": decoder_inputs},
    decoder_outputs,
    name="transformer",
)





# Training

In [15]:
epochs = 1  # This should be at least 30 for convergence

transformer.summary()
transformer.compile(
    "rmsprop",
    loss=keras.losses.SparseCategoricalCrossentropy(ignore_class=0),
    metrics=["accuracy"],
)
transformer.fit(train_ds, epochs=epochs, validation_data=val_ds)

[1m1302/1302[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m973s[0m 744ms/step - accuracy: 0.1045 - loss: 5.0644 - val_accuracy: 0.1914 - val_loss: 2.8913


<keras.src.callbacks.history.History at 0x27497faa990>

In [17]:
spa_vocab = spa_vectorization.get_vocabulary()
spa_index_lookup = dict(zip(range(len(spa_vocab)), spa_vocab))
max_decoded_sentence_length = 20


def decode_sequence(input_sentence):
    tokenized_input_sentence = eng_vectorization([input_sentence])
    decoded_sentence = "[start]"
    for i in range(max_decoded_sentence_length):
        tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1]
        predictions = transformer(
            {
                "encoder_inputs": tokenized_input_sentence,
                "decoder_inputs": tokenized_target_sentence,
            }
        )

        # ops.argmax(predictions[0, i, :]) is not a concrete value for jax here
        sampled_token_index = ops.convert_to_numpy(
            ops.argmax(predictions[0, i, :])
        ).item(0)
        sampled_token = spa_index_lookup[sampled_token_index]
        decoded_sentence += " " + sampled_token

        if sampled_token == "[end]":
            break
    return decoded_sentence


test_eng_texts = [pair[0] for pair in test_pairs]
for _ in range(30):
    input_sentence = random.choice(test_eng_texts)
    translated = decode_sequence(input_sentence)
    print(input_sentence)
    print(translated)

When was the last time you played soccer?
[start] cuándo fue la última vez que tú [end]
What happens if I press this button?
[start] qué me [UNK] si esta noche [end]
Let's just forget you ever did this.
[start] solo te has hecho alguna vez [end]
Many of the students were tired.
[start] muchos estudiantes estaban cansado [end]
She didn't show up at the party yesterday.
[start] ella no se le [UNK] a la fiesta [end]
You've gained weight, haven't you?
[start] has [UNK] no has hecho [end]
Would you bring me another one, please?
[start] me [UNK] otra otra [end]
My shoulder hurts.
[start] mi [UNK] se ha ido [end]
The train went through a tunnel.
[start] el tren fue un [UNK] [end]
Tom was the only one who respected Mary.
[start] tom fue la único que mary que mary [end]
Take my car.
[start] toma mi coche [end]
I get anything I want.
[start] me [UNK] algo [end]
You need to look after your loved ones.
[start] tienes que [UNK] la [UNK] tu [UNK] [end]
Tom drops his kids off at school on his way to 