# NEURAL MACHINE TRANSLATION - LSTM with Attention

## Required Module & Config files

In [1]:
import src.LSTMAttention as lstmANMT
from src.Tokenizer import Corpus, LangData, dataLoader
from src.utils import load_config, get_device, train_model, sentence_bleu, corpus_bleu
from src.TranslatorAtt import TranslatorAtt
from torch.nn import CrossEntropyLoss
from torch.optim import NAdam
import evaluate
import numpy as np
from torchinfo import summary

# Loading config file
config = load_config()
# Get device : GPU/MPS Back-End/CPU
device = get_device()
print(f"Using device: {device}")

Using device: mps


## Load the dataset

In [2]:
# Encoder-Source
english_data = Corpus(f"{config.TRAIN_DATA}/english.txt", "English")
afrikaans_data = Corpus(f"{config.TRAIN_DATA}/afrikaans.txt", "Afrikaans")

## Set Hyperparameters

In [3]:
# Encoder - source
IN_ENCODER = english_data.vocab_size
ENCODER_EMB = 256

# Decoder - target
IN_DECODER = afrikaans_data.vocab_size
OUT_DECODER = afrikaans_data.vocab_size
DECODER_EMB = 256

# Shared
HIDDEN_SIZE = 1024
NUM_LAYERS = 2

LR = 1e-3
BATCH_SIZE = 128

## Set the model

In [4]:
encoder_net = lstmANMT.Encoder(IN_ENCODER, ENCODER_EMB, HIDDEN_SIZE, NUM_LAYERS).to(device)
decoder_net = lstmANMT.Decoder(IN_DECODER, DECODER_EMB, HIDDEN_SIZE, NUM_LAYERS).to(device)
model = lstmANMT.LSTMANMT(encoder_net, decoder_net, OUT_DECODER)

summary(model)

Layer (type:depth-idx)                   Param #
LSTMANMT                                 --
├─Encoder: 1-1                           --
│    └─Embedding: 2-1                    744,448
│    └─LSTM: 2-2                         13,647,872
├─Decoder: 1-2                           --
│    └─Embedding: 2-3                    737,280
│    └─LSTM: 2-4                         13,647,872
│    └─Linear: 2-5                       5,901,120
Total params: 34,678,592
Trainable params: 34,678,592
Non-trainable params: 0

In [5]:
train_data = LangData(english_data, afrikaans_data)
train_loader = dataLoader(train_data, BATCH_SIZE)

pad_idx = afrikaans_data.stoi['<pad>']
criterion = CrossEntropyLoss(ignore_index=0)

optimizer = NAdam(model.parameters(), LR)
translator = TranslatorAtt(model, english_data, afrikaans_data, device, lstm=True)

In [6]:
# Data used for follow-up durring training
mytext = "<sos> given that we represent the target output as $y\in\{0,1\}$ and we have $n$ training points , we can write the negative log likelihood of the parameters as follows : <eos>"
ground = "<sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>"

predicted = translator.translate_sentence(mytext)
bleu = sentence_bleu(prediction=[predicted], reference=[ground])
print(f"Pred: {predicted}")
print(f"Refe: {ground}")
print(f"BLEU SCORES: {bleu}")

Pred: <sos> skrif. trafiek besta besta besta hierso toiletsitplek frase frase frase frase swart gebyt wonderlike gebyt wonderlike waarnemings rooi." frans frans frans hoofstad (fir) frans frans hoofstad (fir) pretparke pretparke pretparke huis
Refe: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU SCORES: [0.073, 0.06, 0.045, 0.0]


## Train the data

In [7]:
EPOCHS = 20
params = {
    "model": model,
    "train_loader": train_loader,
    "optimizer": optimizer,
    "criterion": criterion,
    "device": device,
    "epochs": EPOCHS,
    "source_test": mytext,
    "reference": ground,
	"translator":translator
}

train_loss = train_model(**params)
np.save('lstm_att_train_loss.npy', np.array(train_loss))

Epoch 1/20: 100%|██████████| 20/20 [00:11<00:00,  1.80batch/s, loss=1.727]


Predicted: <sos> die die die <eos>
BLEU Score: [0.032, 0.024, 0.018, 0.0]


Epoch 2/20: 100%|██████████| 20/20 [00:10<00:00,  1.92batch/s, loss=1.467]


Predicted: <sos> die volgende van die volgende van die volgende <eos>
BLEU Score: [0.096, 0.07, 0.052, 0.0]


Epoch 3/20: 100%|██████████| 20/20 [00:10<00:00,  1.90batch/s, loss=1.391]


Predicted: <sos> die volgende bloklengte van die volgende bladsy <eos>
BLEU Score: [0.083, 0.061, 0.045, 0.0]


Epoch 4/20: 100%|██████████| 20/20 [00:10<00:00,  1.88batch/s, loss=1.259]


Predicted: <sos> die volgende sein wat wat die berekening benodig word : <eos>
BLEU Score: [0.134, 0.103, 0.081, 0.056]


Epoch 5/20: 100%|██████████| 20/20 [00:11<00:00,  1.81batch/s, loss=1.132]


Predicted: <sos> ons het die data van die datastel , en ons en die gemiddeld , en ons en die gemiddeld , en ons en die gemiddeld , en ons die gemiddeld van
BLEU Score: [0.279, 0.171, 0.092, 0.0]


Epoch 6/20: 100%|██████████| 20/20 [00:10<00:00,  1.89batch/s, loss=0.951]


Predicted: <sos> die die waardes van die model , ons die \% van die model , en ons die \% in die model , en ons die \% in die model , en
BLEU Score: [0.358, 0.217, 0.108, 0.0]


Epoch 7/20: 100%|██████████| 20/20 [00:10<00:00,  1.87batch/s, loss=0.791]


Predicted: <sos> as ons die teikenuittree voorstel as ons $y\in\{0,1\}$ as ons afrigpunte as ons wil as ons dan as ons dan as ons dan as ons dan as ons dan as ons
BLEU Score: [0.558, 0.489, 0.444, 0.411]


Epoch 8/20: 100%|██████████| 20/20 [00:10<00:00,  1.87batch/s, loss=0.651]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons afrigpunte afrigpunte van die afrigpunte van die datastel , en $d$ in die geval dat die kenmerk van die duitse mark as
BLEU Score: [0.6, 0.522, 0.494, 0.477]


Epoch 9/20: 100%|██████████| 20/20 [00:11<00:00,  1.82batch/s, loss=0.541]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU Score: [0.974, 0.961, 0.947, 0.933]


Epoch 10/20: 100%|██████████| 20/20 [00:10<00:00,  1.93batch/s, loss=0.420]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU Score: [0.974, 0.961, 0.947, 0.933]


Epoch 11/20: 100%|██████████| 20/20 [00:10<00:00,  1.92batch/s, loss=0.353]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU Score: [1.0, 1.0, 1.0, 1.0]


Epoch 12/20: 100%|██████████| 20/20 [00:10<00:00,  1.88batch/s, loss=0.317]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan dan die negatiewe log-waarskynlikheidskostefunksie as ons skryf <eos>
BLEU Score: [0.948, 0.88, 0.837, 0.799]


Epoch 13/20: 100%|██████████| 20/20 [00:10<00:00,  1.82batch/s, loss=0.291]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU Score: [1.0, 1.0, 1.0, 1.0]


Epoch 14/20: 100%|██████████| 20/20 [00:10<00:00,  1.87batch/s, loss=0.269]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan dan die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU Score: [0.948, 0.935, 0.921, 0.906]


Epoch 15/20: 100%|██████████| 20/20 [00:10<00:00,  1.85batch/s, loss=0.256]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU Score: [1.0, 1.0, 1.0, 1.0]


Epoch 16/20: 100%|██████████| 20/20 [00:10<00:00,  1.89batch/s, loss=0.256]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU Score: [1.0, 1.0, 1.0, 1.0]


Epoch 17/20: 100%|██████████| 20/20 [00:11<00:00,  1.82batch/s, loss=0.252]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU Score: [1.0, 1.0, 1.0, 1.0]


Epoch 18/20: 100%|██████████| 20/20 [00:10<00:00,  1.91batch/s, loss=0.240]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU Score: [1.0, 1.0, 1.0, 1.0]


Epoch 19/20: 100%|██████████| 20/20 [00:10<00:00,  1.95batch/s, loss=0.235]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU Score: [1.0, 1.0, 1.0, 1.0]


Epoch 20/20: 100%|██████████| 20/20 [00:10<00:00,  1.84batch/s, loss=0.248]

Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU Score: [1.0, 1.0, 1.0, 1.0]





## Evauate on the Training set

In [8]:
EN_SRC = [' '.join(sent) for sent in english_data.data_str]
AF_REF = [[' '.join(sent)] for sent in afrikaans_data.data_str]
TRANSLATED = [translator.translate_sentence(sent) for sent in EN_SRC]
corpus_bleu(TRANSLATED, AF_REF)

                                     BLEU-1                                     
------------------------------------------------------------------------------------------
bleu                : 0.8220112003895788
precisions          : [0.8220112003895788]
brevity_penalty     : 1.0
length_ratio        : 1.1124654640013
translation_length  : 41070
reference_length    : 36918
******************************************************************************************
                                     BLEU-2                                     
------------------------------------------------------------------------------------------
bleu                : 0.7863161354826563
precisions          : [0.8220112003895788, 0.7521710952689565]
brevity_penalty     : 1.0
length_ratio        : 1.1124654640013
translation_length  : 41070
reference_length    : 36918
******************************************************************************************
                                     BLEU-3   

## Evaluate on the Test set

In [9]:
with open(f"{config.VAL_DATA}/english.txt") as data:
    english_val = data.read().strip().split("\n")
with open(f"{config.VAL_DATA}/afrikaans.txt") as data:
    afrikaans_val = data.read().strip().split("\n")

### Greedy Search

In [10]:
VAL_AF_REF = [[sent] for sent in afrikaans_val]
VAL_TRANSLATED = [translator.translate_sentence(sent) for sent in english_val]

corpus_bleu(VAL_TRANSLATED, VAL_AF_REF)

                                     BLEU-1                                     
------------------------------------------------------------------------------------------
bleu                : 0.540899433427762
precisions          : [0.540899433427762]
brevity_penalty     : 1.0
length_ratio        : 1.160388987809889
translation_length  : 16944
reference_length    : 14602
******************************************************************************************
                                     BLEU-2                                     
------------------------------------------------------------------------------------------
bleu                : 0.4479252410157097
precisions          : [0.540899433427762, 0.37093220872411403]
brevity_penalty     : 1.0
length_ratio        : 1.160388987809889
translation_length  : 16944
reference_length    : 14602
******************************************************************************************
                                     BLEU-3 

### Beam search

In [11]:
VAL_TRANSLATED = [translator.translate_sentence(sent, method="beam", beam_width=3) for sent in english_val]
corpus_bleu(VAL_TRANSLATED, VAL_AF_REF)

                                     BLEU-1                                     
------------------------------------------------------------------------------------------
bleu                : 0.5834628460379737
precisions          : [0.6963810751994841]
brevity_penalty     : 0.8378499456936506
length_ratio        : 0.8496781262840707
translation_length  : 12407
reference_length    : 14602
******************************************************************************************
                                     BLEU-2                                     
------------------------------------------------------------------------------------------
bleu                : 0.4943086234370929
precisions          : [0.6963810751994841, 0.49982378854625553]
brevity_penalty     : 0.8378499456936506
length_ratio        : 0.8496781262840707
translation_length  : 12407
reference_length    : 14602
******************************************************************************************
         

## Evaluate on the SUN validation set only

In [12]:
with open(f"{config.VAL_DATA}/sun_english.txt") as data:
    sun_english_val = data.read().strip().split("\n")
with open(f"{config.VAL_DATA}/sun_afrikaans.txt") as data:
    sun_afrikaans_val = data.read().strip().split("\n")

### Greedy Search

In [13]:
SUN_VAL_AF = [[sent] for sent in sun_afrikaans_val]
SUN_VAL_TRANSLATED = [translator.translate_sentence(sent) for sent in sun_english_val]
corpus_bleu(SUN_VAL_TRANSLATED, SUN_VAL_AF)

                                     BLEU-1                                     
------------------------------------------------------------------------------------------
bleu                : 0.36072846135886055
precisions          : [0.3607284613588606]
brevity_penalty     : 1.0
length_ratio        : 1.1229680125852124
translation_length  : 4283
reference_length    : 3814
******************************************************************************************
                                     BLEU-2                                     
------------------------------------------------------------------------------------------
bleu                : 0.25855450913640177
precisions          : [0.3607284613588606, 0.18532065349914656]
brevity_penalty     : 1.0
length_ratio        : 1.1229680125852124
translation_length  : 4283
reference_length    : 3814
******************************************************************************************
                                     BLEU

### Beam Search

In [14]:
SUN_VAL_TRANSLATED = [translator.translate_sentence(sent, method="beam", beam_width=3) for sent in sun_english_val]
corpus_bleu(SUN_VAL_TRANSLATED, SUN_VAL_AF)

                                     BLEU-1                                     
------------------------------------------------------------------------------------------
bleu                : 0.3858280696088869
precisions          : [0.5504619758351101]
brevity_penalty     : 0.7009168417555892
length_ratio        : 0.7378080755112743
translation_length  : 2814
reference_length    : 3814
******************************************************************************************
                                     BLEU-2                                     
------------------------------------------------------------------------------------------
bleu                : 0.2967410480576267
precisions          : [0.5504619758351101, 0.3256079027355623]
brevity_penalty     : 0.7009168417555892
length_ratio        : 0.7378080755112743
translation_length  : 2814
reference_length    : 3814
******************************************************************************************
              

In [15]:
metric = evaluate.load("bleu")
predictions = [translator.translate_sentence(sent, method="beam", beam_width=5) for sent in sun_english_val[10:20]]
labels = SUN_VAL_AF[10:20]
for source, pred, lab in zip(sun_english_val[10:20],predictions, labels):
    print(f"Source    : {source}")
    print(f"Prediction: {pred[:150]}")
    print(f"Label     : {lab[0][:150]}")
    print(f"BLEU      : {metric.compute(predictions=[pred], references=lab)['bleu']}")
    print()

Source    : <sos> component <eos>
Prediction: <sos> vis geheime <eos>
Label     : <sos> komponent <eos>
BLEU      : 0.0

Source    : <sos> architecture <eos>
Prediction: <sos> vis geheime <eos>
Label     : <sos> argitektuur <eos>
BLEU      : 0.0

Source    : <sos> specification <eos>
Prediction: <sos> vis geheime <eos>
Label     : <sos> spesifikasies <eos>
BLEU      : 0.0

Source    : <sos> at which stage of the design process would we choose the communication protocol between subsystems <eos>
Prediction: <sos> as een van die data van hz van die data <eos>
Label     : <sos> by watter stap van die ontwerpsproses word die kommunikasie-kanaal tussen substelsels gekies <eos>
BLEU      : 0.0

Source    : <sos> motivate your answer <eos>
Prediction: <sos> motiveer jou <eos>
Label     : <sos> motiveer jou antwoord <eos>
BLEU      : 0.6101950432112578

Source    : <sos> describe the meaning if a system is described as a cyber-physical system <eos>
Prediction: <sos> die sein te maak van die laa