# NEURAL MACHINE TRANSLATION - GRU with Attention

## Required Module & Config files

In [1]:
import src.RNN_GRUAttention as gruANMT
from src.Tokenizer import Corpus, LangData, dataLoader
from src.utils import load_config, get_device, train_model, sentence_bleu, corpus_bleu
from src.TranslatorAtt import TranslatorAtt
from torch.nn import CrossEntropyLoss
from torch.optim import NAdam
import evaluate
import numpy as np
from torchinfo import summary

# Loading config file
config = load_config()
# Get device : GPU/MPS Back-End/CPU
device = get_device()
print(f"Using device: {device}")

Using device: mps


## Load the dataset

In [2]:
# Encoder-Source
english_data = Corpus(f"{config.TRAIN_DATA}/english.txt", "English")
afrikaans_data = Corpus(f"{config.TRAIN_DATA}/afrikaans.txt", "Afrikaans")

## Set Hyperparameters

In [3]:
# Encoder - source
IN_ENCODER = english_data.vocab_size
ENCODER_EMB = 256

# Decoder - target
IN_DECODER = afrikaans_data.vocab_size
OUT_DECODER = afrikaans_data.vocab_size
DECODER_EMB = 256

# Shared
HIDDEN_SIZE = 1024
NUM_LAYERS = 2

LR = 1e-3
BATCH_SIZE = 128

## Set the model

In [4]:
encoder_net = gruANMT.Encoder(IN_ENCODER, ENCODER_EMB, HIDDEN_SIZE, NUM_LAYERS, type='GRU').to(device)
decoder_net = gruANMT.Decoder(IN_DECODER, DECODER_EMB, HIDDEN_SIZE, NUM_LAYERS, type='GRU').to(device)
model = gruANMT.RNNAtt(encoder_net, decoder_net, OUT_DECODER)

summary(model)

Layer (type:depth-idx)                   Param #
RNNAtt                                   --
├─Encoder: 1-1                           --
│    └─GRU: 2-1                          10,235,904
│    └─Embedding: 2-2                    744,448
├─Decoder: 1-2                           --
│    └─GRU: 2-3                          10,235,904
│    └─Embedding: 2-4                    737,280
│    └─Linear: 2-5                       5,901,120
Total params: 27,854,656
Trainable params: 27,854,656
Non-trainable params: 0

In [5]:
train_data = LangData(english_data, afrikaans_data)
train_loader = dataLoader(train_data, BATCH_SIZE)

pad_idx = afrikaans_data.stoi['<pad>']
criterion = CrossEntropyLoss(ignore_index=0)

optimizer = NAdam(model.parameters(), LR)
translator = TranslatorAtt(model, english_data, afrikaans_data, device)

In [6]:
# Data used for follow-up durring training
mytext = "<sos> given that we represent the target output as $y\in\{0,1\}$ and we have $n$ training points , we can write the negative log likelihood of the parameters as follows : <eos>"
ground = "<sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>"

predicted = translator.translate_sentence(mytext)
bleu = sentence_bleu(prediction=[predicted], reference=[ground])
print(f"Pred: {predicted}")
print(f"Refe: {ground}")
print(f"BLEU SCORES: {bleu}")

Pred: <sos> vrees stelsels roomys hou waarskynlik hande herhaal beursie beursie minstens stout stout faseweergawes ter neem. neem. aanhou aanhou aangebied individualis neem. neem. aangebied stasie japan burgerskap. sleutel roomys jou jou validasiedata
Refe: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU SCORES: [0.077, 0.064, 0.048, 0.0]


## Train the data

In [7]:
EPOCHS = 15
params = {
    "model": model,
    "train_loader": train_loader,
    "optimizer": optimizer,
    "criterion": criterion,
    "device": device,
    "epochs": EPOCHS,
    "source_test": mytext,
    "reference": ground,
	"translator":translator
}

train_loss = train_model(**params)
np.save('gru_att_train_loss.npy', np.array(train_loss))

Epoch 1/15: 100%|██████████| 20/20 [00:11<00:00,  1.78batch/s, loss=1.704]


Predicted: <sos> die die van die : <eos>
BLEU Score: [0.064, 0.05, 0.04, 0.028]


Epoch 2/15: 100%|██████████| 20/20 [00:10<00:00,  1.84batch/s, loss=1.464]


Predicted: <sos> die volgende van die volgende van die volgende en en die gemiddeld , en en die gemiddeld , en die gemiddeld en die gemiddeld , en die gemiddeld en die gemiddeld
BLEU Score: [0.178, 0.096, 0.063, 0.0]


Epoch 3/15: 100%|██████████| 20/20 [00:10<00:00,  1.88batch/s, loss=1.174]


Predicted: <sos> die volgende figuur as die volgende : <eos>
BLEU Score: [0.104, 0.077, 0.06, 0.041]


Epoch 4/15: 100%|██████████| 20/20 [00:10<00:00,  1.90batch/s, loss=0.985]


Predicted: <sos> ons het die -gemiddelde -gemiddelde -gemiddelde bondeling , en die derde het ons eerder eerder die gulsige -gemiddelde (soos (soos (soos (soos (soos (soos (soos ons dan dan -gemiddelde die -gemiddelde
BLEU Score: [0.293, 0.121, 0.072, 0.0]


Epoch 5/15: 100%|██████████| 20/20 [00:10<00:00,  1.87batch/s, loss=0.769]


Predicted: <sos> ons het die kaggel en ons die \% verandering , dan ons dan die $k$ -gemiddelde algoritme <eos>
BLEU Score: [0.42, 0.267, 0.153, 0.0]


Epoch 6/15: 100%|██████████| 20/20 [00:11<00:00,  1.80batch/s, loss=0.611]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons het die teikenuittree gehad en dan $y\in\{0,1\}$ die $k$ -gemiddelde algoritme hardloop kan as ons slegs die log-waarskynlikheidskostefunksie gebruik as ons slegs
BLEU Score: [0.556, 0.458, 0.425, 0.405]


Epoch 7/15: 100%|██████████| 20/20 [00:11<00:00,  1.78batch/s, loss=0.476]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons afrigpunte afrigpunte het , dan ons die negatiewe log-waarskynlikheidskostefunksie as : <eos>
BLEU Score: [0.867, 0.826, 0.783, 0.736]


Epoch 8/15: 100%|██████████| 20/20 [00:10<00:00,  1.90batch/s, loss=0.361]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU Score: [1.0, 1.0, 1.0, 1.0]


Epoch 9/15: 100%|██████████| 20/20 [00:11<00:00,  1.76batch/s, loss=0.318]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU Score: [1.0, 1.0, 1.0, 1.0]


Epoch 10/15: 100%|██████████| 20/20 [00:10<00:00,  1.83batch/s, loss=0.278]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU Score: [1.0, 1.0, 1.0, 1.0]


Epoch 11/15: 100%|██████████| 20/20 [00:10<00:00,  1.90batch/s, loss=0.258]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU Score: [1.0, 1.0, 1.0, 1.0]


Epoch 12/15: 100%|██████████| 20/20 [00:10<00:00,  1.95batch/s, loss=0.242]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU Score: [1.0, 1.0, 1.0, 1.0]


Epoch 13/15: 100%|██████████| 20/20 [00:10<00:00,  1.96batch/s, loss=0.238]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU Score: [1.0, 1.0, 1.0, 1.0]


Epoch 14/15: 100%|██████████| 20/20 [00:10<00:00,  1.84batch/s, loss=0.244]


Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU Score: [1.0, 1.0, 1.0, 1.0]


Epoch 15/15: 100%|██████████| 20/20 [00:10<00:00,  1.89batch/s, loss=0.237]

Predicted: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as : <eos>
BLEU Score: [1.0, 1.0, 1.0, 1.0]





## Evauate on the Training set

In [8]:
EN_SRC = [' '.join(sent) for sent in english_data.data_str]
AF_REF = [[' '.join(sent)] for sent in afrikaans_data.data_str]
TRANSLATED = [translator.translate_sentence(sent) for sent in EN_SRC]
corpus_bleu(TRANSLATED, AF_REF)

                                     BLEU-1                                     
------------------------------------------------------------------------------------------
bleu                : 0.9149107367362334
precisions          : [0.9149107367362334]
brevity_penalty     : 1.0
length_ratio        : 1.0772522888563845
translation_length  : 39770
reference_length    : 36918
******************************************************************************************
                                     BLEU-2                                     
------------------------------------------------------------------------------------------
bleu                : 0.9033514952243483
precisions          : [0.9149107367362334, 0.8919382964453387]
brevity_penalty     : 1.0
length_ratio        : 1.0772522888563845
translation_length  : 39770
reference_length    : 36918
******************************************************************************************
                                     BLE

## Evaluate on the Test set

In [9]:
with open(f"{config.VAL_DATA}/english.txt") as data:
    english_val = data.read().strip().split("\n")
with open(f"{config.VAL_DATA}/afrikaans.txt") as data:
    afrikaans_val = data.read().strip().split("\n")

### Greedy Search

In [10]:
VAL_AF_REF = [[sent] for sent in afrikaans_val]

VAL_TRANSLATED = [translator.translate_sentence(sent) for sent in english_val]

corpus_bleu(VAL_TRANSLATED, VAL_AF_REF)

                                     BLEU-1                                     
------------------------------------------------------------------------------------------
bleu                : 0.546646897810219
precisions          : [0.546646897810219]
brevity_penalty     : 1.0
length_ratio        : 1.2009313792631147
translation_length  : 17536
reference_length    : 14602
******************************************************************************************
                                     BLEU-2                                     
------------------------------------------------------------------------------------------
bleu                : 0.4597544067217559
precisions          : [0.546646897810219, 0.3866739486619334]
brevity_penalty     : 1.0
length_ratio        : 1.2009313792631147
translation_length  : 17536
reference_length    : 14602
******************************************************************************************
                                     BLEU-3

### Beam search

In [11]:
VAL_TRANSLATED = [translator.translate_sentence(sent, method="beam", beam_width=3) for sent in english_val]
corpus_bleu(VAL_TRANSLATED, VAL_AF_REF)

                                     BLEU-1                                     
------------------------------------------------------------------------------------------
bleu                : 0.6045925523672416
precisions          : [0.7211213146447559]
brevity_penalty     : 0.8384061600856723
length_ratio        : 0.8501575126694974
translation_length  : 12414
reference_length    : 14602
******************************************************************************************
                                     BLEU-2                                     
------------------------------------------------------------------------------------------
bleu                : 0.5147232195663568
precisions          : [0.7211213146447559, 0.522673241172845]
brevity_penalty     : 0.8384061600856723
length_ratio        : 0.8501575126694974
translation_length  : 12414
reference_length    : 14602
******************************************************************************************
           

## Evaluate on the SUN validation set only

In [12]:
with open(f"{config.VAL_DATA}/sun_english.txt") as data:
    sun_english_val = data.read().strip().split("\n")
with open(f"{config.VAL_DATA}/sun_afrikaans.txt") as data:
    sun_afrikaans_val = data.read().strip().split("\n")

### Greedy Search

In [13]:
SUN_VAL_AF = [[sent] for sent in sun_afrikaans_val]
SUN_VAL_TRANSLATED = [translator.translate_sentence(sent) for sent in sun_english_val]
corpus_bleu(SUN_VAL_TRANSLATED, SUN_VAL_AF)

                                     BLEU-1                                     
------------------------------------------------------------------------------------------
bleu                : 0.3540898242568887
precisions          : [0.3540898242568887]
brevity_penalty     : 1.0
length_ratio        : 1.208442579968537
translation_length  : 4609
reference_length    : 3814
******************************************************************************************
                                     BLEU-2                                     
------------------------------------------------------------------------------------------
bleu                : 0.25874066568009585
precisions          : [0.3540898242568887, 0.18906708832166252]
brevity_penalty     : 1.0
length_ratio        : 1.208442579968537
translation_length  : 4609
reference_length    : 3814
******************************************************************************************
                                     BLEU-3 

### Beam Search

In [14]:
SUN_VAL_TRANSLATED = [translator.translate_sentence(sent, method="beam", beam_width=3) for sent in sun_english_val]
corpus_bleu(SUN_VAL_TRANSLATED, SUN_VAL_AF)

                                     BLEU-1                                     
------------------------------------------------------------------------------------------
bleu                : 0.402030567768737
precisions          : [0.5625218914185639]
brevity_penalty     : 0.7146931948815344
length_ratio        : 0.748557944415312
translation_length  : 2855
reference_length    : 3814
******************************************************************************************
                                     BLEU-2                                     
------------------------------------------------------------------------------------------
bleu                : 0.31396016562831097
precisions          : [0.5625218914185639, 0.34306023194912083]
brevity_penalty     : 0.7146931948815344
length_ratio        : 0.748557944415312
translation_length  : 2855
reference_length    : 3814
******************************************************************************************
               

In [15]:
metric = evaluate.load("bleu")
predictions = [translator.translate_sentence(sent, method="beam", beam_width=5) for sent in sun_english_val[10:20]]
labels = SUN_VAL_AF[10:20]
for source, pred, lab in zip(sun_english_val[10:20],predictions, labels):
    print(f"Source    : {source}")
    print(f"Prediction: {pred[:150]}")
    print(f"Label     : {lab[0][:150]}")
    print(f"BLEU      : {metric.compute(predictions=[pred], references=lab)['bleu']}")
    print()

Source    : <sos> component <eos>
Prediction: <sos> ontbyt <eos>
Label     : <sos> komponent <eos>
BLEU      : 0.0

Source    : <sos> architecture <eos>
Prediction: <sos> ontbyt <eos>
Label     : <sos> argitektuur <eos>
BLEU      : 0.0

Source    : <sos> specification <eos>
Prediction: <sos> ontbyt <eos>
Label     : <sos> spesifikasies <eos>
BLEU      : 0.0

Source    : <sos> at which stage of the design process would we choose the communication protocol between subsystems <eos>
Prediction: <sos> watter waarde van $x[n]$ : <eos>
Label     : <sos> by watter stap van die ontwerpsproses word die kommunikasie-kanaal tussen substelsels gekies <eos>
BLEU      : 0.0

Source    : <sos> motivate your answer <eos>
Prediction: <sos> motiveer jou antwoord <eos>
Label     : <sos> motiveer jou antwoord <eos>
BLEU      : 1.0

Source    : <sos> describe the meaning if a system is described as a cyber-physical system <eos>
Prediction: <sos> die stelsel met 'n monsterfrekwensie : <eos>
Label     : <sos>