<a href="https://colab.research.google.com/github/VickkiMars/NLP_Mastery/blob/main/nmt_with_transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import pathlib
import random
import string
import re
import numpy as np

import tensorflow.data as tf_data
import tensorflow.strings as tf_strings
import keras
from keras import layers
from keras import ops
from keras.layers import TextVectorization

In [3]:
with open("dataset_2.txt", "r") as f:
  lines = f.read().split("\n")[:-1]
text_pairs = []
for line in lines:
  inp, targ = line.split("\t")
  targ = "[start] " + targ + " [end]"
  text_pairs.append((inp, targ))

In [5]:
for _ in range(5):
  print(random.choice(text_pairs))

('Please make a transfer of 8h to Peace Microfinance Bank at Loucas Domhnall, account number 9793807269.', '[start] 800 Peace Microfinance Bank 9793807269 Loucas Domhnall [end]')
('Please send 131k to 1146429305, Abbey Mortgage Bank, at Chiaka Uju.', '[start] 131000 Abbey Mortgage Bank 1146429305 Chiaka Uju [end]')
('please transfer 76k to 2036822496', '[start] 76000 2036822496 [end]')
('Deposit 54k to Amucha MFB, account number 4104402882, for Angels Mukolu', '[start] 54000 Amucha MFB 4104402882 Angels Mukolu [end]')
('abeg send 34k to 7234499301', '[start] 34000 7234499301 [end]')


In [6]:
random.shuffle(text_pairs)
num_val_samples = int(0.2 * len(text_pairs))
num_train_samples = len(text_pairs) - 2 * num_val_samples

train_pairs = text_pairs[:num_train_samples]
val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]
test_pairs = text_pairs[num_train_samples + num_val_samples :]

print(f"{len(text_pairs)} total pairs")
print(f"{len(train_pairs)} training pairs")
print(f"{len(val_pairs)} validation pairs")
print(f"{len(test_pairs)} test pairs")

38985 total pairs
23391 training pairs
7797 validation pairs
7797 test pairs


In [7]:
strip_chars = string.punctuation + "¿"
strip_chars = strip_chars.replace("[", "")
strip_chars = strip_chars.replace("]", "")

In [36]:
vocab_size = 15000
sequence_length = 140
batch_size = 16

def custom_standardization(input_string):
  lowercase = tf_strings.lower(input_string)
  return tf_strings.regex_replace(
      lowercase, f"[{re.escape(strip_chars)}]", "")

In [37]:
inp_vectorization = TextVectorization(
    max_tokens=vocab_size,
    output_mode="int",
    output_sequence_length=sequence_length,
)
targ_vectorization = TextVectorization(
    max_tokens=vocab_size,
    output_mode="int",
    output_sequence_length=sequence_length + 1,
    standardize=custom_standardization,
)

In [38]:
train_inp_texts = [pair[0] for pair in train_pairs]
train_targ_texts = [pair[1] for pair in train_pairs]
inp_vectorization.adapt(train_inp_texts)
targ_vectorization.adapt(train_targ_texts)

In [39]:
def format_dataset(inp, targ):
  inp = inp_vectorization(inp)
  targ = targ_vectorization(targ)
  return (
    {
      "encoder_inputs": inp,
      "decoder_inputs": targ[:, :-1],
    },
    targ[:, 1:],
  )

In [40]:
def make_dataset(pairs):
  inp_texts, targ_texts = zip(*pairs)
  inp_texts = list(inp_texts)
  targ_texts = list(targ_texts)
  dataset = tf_data.Dataset.from_tensor_slices((inp_texts, targ_texts))
  dataset = dataset.batch(batch_size)
  dataset = dataset.map(format_dataset, num_parallel_calls=4)
  return dataset.shuffle(2048).prefetch(16)

In [41]:
train_ds = make_dataset(train_pairs)
val_ds = make_dataset(val_pairs)

In [42]:
for inputs, targets in train_ds.take(1):
  print(f'inputs["encoder_inputs"].shape: {inputs["encoder_inputs"].shape}')
  print(f'inputs["decoder_inputs"].shape: {inputs["decoder_inputs"].shape}')
  print(f"targets.shape: {targets.shape}")

inputs["encoder_inputs"].shape: (16, 140)
inputs["decoder_inputs"].shape: (16, 140)
targets.shape: (16, 140)


In [43]:
class TransformerEncoder(layers.Layer):
  def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
    super().__init__(**kwargs)
    self.embed_dim = embed_dim
    self.dense_dim = dense_dim
    self.num_heads = num_heads
    self.attention = layers.MultiHeadAttention(
        num_heads=num_heads, key_dim=embed_dim
    )
    self.dense_proj = keras.Sequential(
        [
          layers.Dense(dense_dim, activation="relu"),
          layers.Dense(embed_dim),
        ]
    )
    self.layernorm_1 = layers.LayerNormalization()
    self.layernorm_2 = layers.LayerNormalization()
    self.supports_masking = True

  def call(self, inputs, mask=None):
    if mask is not None:
      padding_mask = ops.cast(mask[:, None, :], dtype='int32')
    else:
      padding_mask = None

    attention_output = self.attention(
        query=inputs, value=inputs, key=inputs, attention_mask=padding_mask
    )
    proj_input = self.layernorm_1(inputs + attention_output)
    proj_output = self.dense_proj(proj_input)
    return self.layernorm_2(proj_input + proj_output)

  def get_config(self):
    config = super().get_config()
    config.update(
        {
            "embed_dim": self.embed_dim,
            "num_heads": self.num_heads,
            "dense_dim": self.dense_dim,
        }
    )
    return config

In [44]:
class PositionalEmbedding(layers.Layer):
  def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
    super().__init__(**kwargs)
    self.token_embeddings = layers.Embedding(
        input_dim=vocab_size, output_dim=embed_dim
    )
    self.position_embeddings = layers.Embedding(
        input_dim = sequence_length, output_dim = embed_dim
    )
    self.sequence_length = sequence_length
    self.vocab_size = vocab_size
    self.embed_dim = embed_dim

  def call(self, inputs):
    length = ops.shape(inputs)[-1]
    positions = ops.arange(0, length, 1)
    embedded_tokens = self.token_embeddings(inputs)
    embedded_positions = self.position_embeddings(positions)
    return embedded_tokens + embedded_positions

  def compute_mask(self, inputs, mask=None):
    if mask is None:
      return None
    else:
      return ops.not_equal(inputs, 0)

  def get_config(self):
    config = super().get_config()
    config.update(
        {
            "sequence_length": self.sequence_length,
            "vocab_size": self.vocab_size,
            "embed_dim": self.embed_dim,
        }
    )
    return config

In [45]:
class TransformerDecoder(layers.Layer):
  def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):
    super().__init__(**kwargs)
    self.embed_dim = embed_dim
    self.latent_dim = latent_dim
    self.num_heads = num_heads
    self.attention_1 = layers.MultiHeadAttention(
        num_heads=num_heads, key_dim=embed_dim
    )
    self.attention_2 = layers.MultiHeadAttention(
        num_heads=num_heads, key_dim=embed_dim
    )
    self.dense_proj = keras.Sequential(
        [
          layers.Dense(latent_dim, activation="relu"),
          layers.Dense(embed_dim),
        ]
    )
    self.layernorm_1 = layers.LayerNormalization()
    self.layernorm_2 = layers.LayerNormalization()
    self.layernorm_3 = layers.LayerNormalization()
    self.supports_masking = True

  def call(self, inputs, encoder_outputs, mask=None):
    causal_mask = self.get_causal_attention_mask(inputs)
    if mask is not None:
      padding_mask = ops.cast(mask[:, None, :], dtype="int32")
      padding_mask = ops.minimum(
          padding_mask, causal_mask
      )
    else:
      padding_mask = None

    attention_output_1 = self.attention_1(
        query=inputs, value=inputs, key=inputs, attention_mask=causal_mask
    )
    out_1 = self.layernorm_1(inputs + attention_output_1)

    attention_output_2 = self.attention_2(
        query=out_1,
        value=encoder_outputs,
        key=encoder_outputs,
        attention_mask=padding_mask,
    )
    out_2 = self.layernorm_2(out_1 + attention_output_2)

    proj_output = self.dense_proj(out_2)
    return self.layernorm_3(out_2 + proj_output)

  def get_causal_attention_mask(self, inputs):
    input_shape = ops.shape(inputs)
    batch_size, sequence_length = input_shape[0], input_shape[1]
    i = ops.arange(sequence_length)[:, None]
    j = ops.arange(sequence_length)
    mask = ops.cast(i >= j, dtype="int32")
    mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))
    mult = ops.concatenate(
        [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])],
        axis=0,
    )
    return ops.tile(mask, mult)

  def get_config(self):
    config = super().get_config()
    config.update(
        {
            "embed_dim": self.embed_dim,
            "latent_dim": self.latent_dim,
            "num_heads": self.num_heads,
        }
    )
    return config



In [51]:
embed_dim = 128
latent_dim = 2048
num_heads = 8

encoder_inputs = keras.Input(shape=(None,), dtype="int64", name="encoder_inputs")
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(encoder_inputs)
encoder_outputs = TransformerEncoder(embed_dim, latent_dim, num_heads)(x)
encoder = keras.Model(encoder_inputs, encoder_outputs)

decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs")
encoded_seq_inputs = keras.Input(shape=(None, embed_dim), name="decoder_state_inputs")
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(decoder_inputs)
x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, encoded_seq_inputs)
x = layers.Dropout(0.2)(x)
decoder_outputs = layers.Dense(vocab_size, activation="softmax")(x)
decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs)

decoder_outputs = decoder([decoder_inputs, encoder_outputs])
transformer = keras.Model(
    [encoder_inputs, decoder_inputs], decoder_outputs, name="transformer"
)

In [52]:
epochs = 5  # This should be at least 30 for convergence

transformer.summary()
transformer.compile(
    "adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
transformer.fit(train_ds, epochs=epochs, validation_data=val_ds)

Epoch 1/5
[1m1462/1462[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m84s[0m 48ms/step - accuracy: 0.9608 - loss: 0.8210 - val_accuracy: 0.9747 - val_loss: 0.1440
Epoch 2/5
[1m1462/1462[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m63s[0m 43ms/step - accuracy: 0.9714 - loss: 0.1956 - val_accuracy: 0.9757 - val_loss: 0.1370
Epoch 3/5
[1m1462/1462[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m60s[0m 40ms/step - accuracy: 0.9722 - loss: 0.1849 - val_accuracy: 0.9762 - val_loss: 0.1352
Epoch 4/5
[1m1462/1462[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m60s[0m 41ms/step - accuracy: 0.9729 - loss: 0.1757 - val_accuracy: 0.9761 - val_loss: 0.1352
Epoch 5/5
[1m1462/1462[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m60s[0m 41ms/step - accuracy: 0.9735 - loss: 0.1668 - val_accuracy: 0.9762 - val_loss: 0.1364


<keras.src.callbacks.history.History at 0x7eba4ee54a60>

In [53]:
targ_vocab = targ_vectorization.get_vocabulary()
targ_index_lookup = dict(zip(range(len(targ_vocab)), targ_vocab))
max_decoded_sentence_length = 140

In [54]:
def decode_sequence(input_sentence):
    tokenized_input_sentence = inp_vectorization([input_sentence])
    decoded_sentence = "[start]"
    for i in range(max_decoded_sentence_length):
        tokenized_target_sentence = targ_vectorization([decoded_sentence])[:, :-1]
        predictions = transformer([tokenized_input_sentence, tokenized_target_sentence])

        # ops.argmax(predictions[0, i, :]) is not a concrete value for jax here
        sampled_token_index = ops.convert_to_numpy(
            ops.argmax(predictions[0, i, :])
        ).item(0)
        sampled_token = targ_index_lookup[sampled_token_index]
        decoded_sentence += " " + sampled_token

        if sampled_token == "[end]":
            break
    return decoded_sentence

In [55]:
test_eng_texts = [pair[0] for pair in test_pairs]
for _ in range(10):
    input_sentence = random.choice(test_eng_texts)
    translated = decode_sequence(input_sentence)
    print(input_sentence)
    print(translated)
    print("-"*10)

Please send 220k to 5570699237, at Waya Microfinance Bank
[start] 5000 mayfair finance bank limited [UNK] [UNK] [end]
----------
Kindly send 221m to ISUA MFB, 1105148297, account Okpoko Cris
[start] 5000 mayfair finance bank limited [UNK] [UNK] [end]
----------
Please send exactly 4h to Ugbong Alez at 8007028386, Fedeth MFB
[start] 5000 mayfair finance bank limited [UNK] [UNK] [end]
----------
Deposit 8k to 8473423119, account number First City Monument Bank, for Bonilla Usaman
[start] 5000 mayfair finance bank limited [UNK] [UNK] [end]
----------
I need you to transfer 57k to Lao Usman Ali, Sterling Bank, 2769729384 
[start] 5000 mayfair finance bank limited [UNK] [UNK] [end]
----------
LOMA MFB, account number 5177889316, Name: Afunwa Rhiannan, please send 798k.
[start] 5000 mayfair finance bank limited [UNK] [UNK] [end]
----------
please wire 326k at 8465917357 asap
[start] 5000 mayfair finance bank limited [UNK] [UNK] [end]
----------
Please make a transfer of 220k to Fairmoney Mic