In [53]:
import tensorflow as tf
import numpy as np

# Sample parallel corpus (tiny for demo)

In [54]:
english_sentences = ["hello", "how are you", "thank you", "good night"]
french_sentences = ["bonjour", "comment ça va", "merci", "bonne nuit"]
french_sentences = [f"<start> {s} <end>" for s in french_sentences]

# Tokenize source (English)

In [55]:
src_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
src_tokenizer.fit_on_texts(english_sentences)
src_sequences = src_tokenizer.texts_to_sequences(english_sentences)
src_word_index = src_tokenizer.word_index
src_vocab_size = len(src_word_index) + 1

# Tokenize target (French) and add <start>, <end> tokens

In [None]:
tgt_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
tgt_tokenizer.fit_on_texts(french_sentences)
tgt_sequences = tgt_tokenizer.texts_to_sequences(french_sentences)
tgt_word_index = tgt_tokenizer.word_index
tgt_index_word = {v: k for k, v in tgt_word_index.items()}
tgt_vocab_size = len(tgt_word_index) + 1

# Pad sequences

In [57]:
src_padded = tf.keras.preprocessing.sequence.pad_sequences(src_sequences, padding='post')
tgt_padded = tf.keras.preprocessing.sequence.pad_sequences(tgt_sequences, padding='post')

# Split target into decoder input and output

In [58]:
decoder_input = tgt_padded[:, :-1]
decoder_target = tf.keras.utils.to_categorical(tgt_padded[:, 1:], num_classes=tgt_vocab_size)

# Define the Seq2Seq model

In [59]:
embedding_dim = 64
latent_dim = 64

encoder_inputs = tf.keras.Input(shape=(None,))
enc_emb = tf.keras.layers.Embedding(src_vocab_size, embedding_dim)(encoder_inputs)
encoder_outputs, state_h, state_c = tf.keras.layers.LSTM(latent_dim, return_state=True)(enc_emb)
 
decoder_inputs = tf.keras.Input(shape=(None,))
dec_emb_layer = tf.keras.layers.Embedding(tgt_vocab_size, embedding_dim)
dec_emb = dec_emb_layer(decoder_inputs)
decoder_lstm = tf.keras.layers.LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(dec_emb, initial_state=[state_h, state_c])
decoder_dense = tf.keras.layers.Dense(tgt_vocab_size, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)

# Compile and train

In [60]:
model = tf.keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit([src_padded, decoder_input], decoder_target, epochs=300, verbose=0)

<keras.src.callbacks.history.History at 0x26dbed07080>

# Inference models

In [61]:
# Encoder model
encoder_model = tf.keras.Model(encoder_inputs, [state_h, state_c])

# Decoder model (for inference)
decoder_state_input_h = tf.keras.Input(shape=(latent_dim,))
decoder_state_input_c = tf.keras.Input(shape=(latent_dim,))
decoder_inputs_inf = tf.keras.Input(shape=(None,))
dec_emb2 = dec_emb_layer(decoder_inputs_inf)
decoder_outputs2, h, c = decoder_lstm(dec_emb2, initial_state=[decoder_state_input_h, decoder_state_input_c])
decoder_outputs2 = decoder_dense(decoder_outputs2)

decoder_model = tf.keras.Model(
    [decoder_inputs_inf, decoder_state_input_h, decoder_state_input_c],
    [decoder_outputs2, h, c]
)


# Inference: simple translation (just index decoding here)

In [62]:
def translate(input_text):
    seq = src_tokenizer.texts_to_sequences([input_text])
    seq = tf.keras.preprocessing.sequence.pad_sequences(seq, maxlen=src_padded.shape[1], padding='post')

    h, c = encoder_model.predict(seq)

    target_seq = np.array([[tgt_word_index['<start>']]])
    stop_condition = False
    decoded_sentence = []

    while not stop_condition:
        output_tokens, h, c = decoder_model.predict([target_seq, h, c])
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_word = tgt_index_word.get(sampled_token_index, '')

        if sampled_word == '<end>' or len(decoded_sentence) > 10:
            stop_condition = True
        else:
            decoded_sentence.append(sampled_word)
            target_seq = np.array([[sampled_token_index]])

    return ' '.join(decoded_sentence)


# Test translation

In [63]:
print("Translate 'thank you':", translate("thank you"))
print("Translate 'good night':", translate("good night"))
print("Translate 'hello':", translate("hello"))

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 130ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 115ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 118ms/step
Translate 'thank you': merci
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step
Translate 'good night': bonne nuit
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 41ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step
Translate 'hello': bonjour
