# NEURAL MACHINE TRANSLATION - Vanilla RNN with Attention

## Required Module & Config files

In [1]:
import src.RNN_GRUAttention as rnnANMT
from src.Tokenizer import Corpus, LangData, dataLoader
from src.utils import load_config, get_device, train_model, sentence_bleu, corpus_bleu
from src.TranslatorAtt import TranslatorAtt
from torch.nn import CrossEntropyLoss
from torch.optim import NAdam
from torchinfo import summary

# Loading config file
config = load_config()
# Get device : GPU/MPS Back-End/CPU
device = get_device()
print(f"Using device: {device}")

Using the latest cached version of the module from /Users/lucien/.cache/huggingface/modules/evaluate_modules/metrics/evaluate-metric--bleu/9e0985c1200e367cce45605ce0ecb5ede079894e0f24f54613fca08eeb8aff76 (last modified on Thu Jul 18 16:29:52 2024) since it couldn't be found locally at evaluate-metric--bleu, or remotely on the Hugging Face Hub.


Using device: mps


## Load the dataset

In [2]:
# Encoder-Source
english_data = Corpus(f"{config.TRAIN_DATA}/english.txt", "English")
afrikaans_data = Corpus(f"{config.TRAIN_DATA}/afrikaans.txt", "Afrikaans")

## Set Hyperparameters

In [3]:
# Encoder - source
IN_ENCODER = english_data.vocab_size
ENCODER_EMB = 256

# Decoder - target
IN_DECODER = afrikaans_data.vocab_size
OUT_DECODER = afrikaans_data.vocab_size
DECODER_EMB = 256

# Shared
HIDDEN_SIZE = 1024
NUM_LAYERS = 2

LR = 1e-4
BATCH_SIZE = 128

## Set the model

In [4]:
encoder_net = rnnANMT.Encoder(IN_ENCODER, ENCODER_EMB, HIDDEN_SIZE, NUM_LAYERS).to(device)
decoder_net = rnnANMT.Decoder(IN_DECODER, DECODER_EMB, HIDDEN_SIZE, NUM_LAYERS).to(device)
model = rnnANMT.RNNAtt(encoder_net, decoder_net, OUT_DECODER)

summary(model)

Layer (type:depth-idx)                   Param #
RNNAtt                                   --
├─Encoder: 1-1                           --
│    └─RNN: 2-1                          3,411,968
│    └─Embedding: 2-2                    743,936
├─Decoder: 1-2                           --
│    └─RNN: 2-3                          3,411,968
│    └─Embedding: 2-4                    737,024
│    └─Linear: 2-5                       5,899,071
Total params: 14,203,967
Trainable params: 14,203,967
Non-trainable params: 0

In [5]:
train_data = LangData(english_data, afrikaans_data)
train_loader = dataLoader(train_data, BATCH_SIZE)

pad_idx = afrikaans_data.stoi['<pad>']
criterion = CrossEntropyLoss(ignore_index=0)

optimizer = NAdam(model.parameters(), LR)
translator = TranslatorAtt(model, english_data, afrikaans_data, device)

In [6]:
# Data used for follow-up durring training
mytext = "<sos> given that we represent the target output as $y\in\{0,1\}$ and we have $n$ training points , we can write the negative log likelihood of the parameters as follows: <eos>"
ground = "<sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as: <eos>"

predicted = translator.translate_sentence(mytext)
bleu = sentence_bleu(prediction=[predicted], reference=[ground])
print(f"Pred: {predicted}")
print(f"Refe: {ground}")
print(f"BLEU SCORES: {bleu}")

Pred: <sos> weller." persent (all-pass) lewer verlies ophou trek \textit{iris kousale prototipe verdubbel $w_2=0.07$ algebra\"ies waag kringe standardiseer (all-pass) hoogwater aanvaar handskoene
Refe: <sos> as ons die teikenuittree voorstel as $y\in\{0,1\}$ en ons $n$ afrigpunte het , dan kan ons die negatiewe log-waarskynlikheidskostefunksie skryf as: <eos>
BLEU SCORES: [0.195, 0.099, 0.063, 0.0]


## Train the data

In [7]:
EPOCHS = 20
params = {
    "model": model,
    "train_loader": train_loader,
    "optimizer": optimizer,
    "criterion": criterion,
    "device": device,
    "epochs": EPOCHS,
    "source_test": mytext,
    "reference": ground,
	"translator":translator
}

train_model(**params)

Using the latest cached version of the module from /Users/lucien/.cache/huggingface/modules/evaluate_modules/metrics/evaluate-metric--bleu/9e0985c1200e367cce45605ce0ecb5ede079894e0f24f54613fca08eeb8aff76 (last modified on Thu Jul 18 16:29:52 2024) since it couldn't be found locally at evaluate-metric--bleu, or remotely on the Hugging Face Hub.
Epoch 1/20: 100%|██████████| 20/20 [00:12<00:00,  1.61batch/s, loss=1.891]


Predicted: <sos> ek het nie <eos>
BLEU Score: [0.028, 0.022, 0.017, 0.0]


Epoch 2/20: 100%|██████████| 20/20 [00:07<00:00,  2.75batch/s, loss=1.599]


Predicted: <sos> ek is 'n die die die <eos>
BLEU Score: [0.07, 0.052, 0.038, 0.0]


Epoch 3/20: 100%|██████████| 20/20 [00:07<00:00,  2.83batch/s, loss=1.554]


Predicted: <sos> die die die die die die die die die die die die die die die die die die die die
BLEU Score: [0.108, 0.07, 0.049, 0.0]


Epoch 4/20: 100%|██████████| 20/20 [00:07<00:00,  2.72batch/s, loss=1.581]


Predicted: <sos> die die die die die die die die die die die die die die die die die die die die
BLEU Score: [0.108, 0.07, 0.049, 0.0]


Epoch 5/20: 100%|██████████| 20/20 [00:07<00:00,  2.75batch/s, loss=1.502]


Predicted: <sos> die die die die die die die die die die die die die die die die die die die die
BLEU Score: [0.108, 0.07, 0.049, 0.0]


Epoch 6/20: 100%|██████████| 20/20 [00:07<00:00,  2.81batch/s, loss=1.456]


Predicted: <sos> die die die van die die die die die die die die die die die die die die die die
BLEU Score: [0.108, 0.07, 0.049, 0.0]


Epoch 7/20: 100%|██████████| 20/20 [00:06<00:00,  2.90batch/s, loss=1.404]


Predicted: <sos> die die die die die die die die die die die die die die die die die die die die
BLEU Score: [0.108, 0.07, 0.049, 0.0]


Epoch 8/20: 100%|██████████| 20/20 [00:07<00:00,  2.84batch/s, loss=1.386]


Predicted: <sos> die die die die die die die die die die die die die die die die die die die die
BLEU Score: [0.108, 0.07, 0.049, 0.0]


Epoch 9/20: 100%|██████████| 20/20 [00:07<00:00,  2.70batch/s, loss=1.430]


Predicted: <sos> die die die van die die die die die van die , van die , van die , van die
BLEU Score: [0.13, 0.077, 0.052, 0.0]


Epoch 10/20: 100%|██████████| 20/20 [00:07<00:00,  2.72batch/s, loss=1.387]


Predicted: <sos> die die die die van die en die die die <eos>
BLEU Score: [0.134, 0.092, 0.066, 0.0]


Epoch 11/20: 100%|██████████| 20/20 [00:07<00:00,  2.72batch/s, loss=1.390]


Predicted: <sos> ons is die die die die die die die die , , die die die <eos>
BLEU Score: [0.202, 0.131, 0.092, 0.0]


Epoch 12/20: 100%|██████████| 20/20 [00:07<00:00,  2.66batch/s, loss=1.343]


Predicted: <sos> ons is die die die van die , , , , , , , , , die die <eos>
BLEU Score: [0.223, 0.144, 0.1, 0.0]


Epoch 13/20: 100%|██████████| 20/20 [00:07<00:00,  2.71batch/s, loss=1.320]


Predicted: <sos> ons is die die die van die , , , , die die die , , , die die die
BLEU Score: [0.152, 0.083, 0.055, 0.0]


Epoch 14/20: 100%|██████████| 20/20 [00:07<00:00,  2.68batch/s, loss=1.284]


Predicted: <sos> ons is die die van van die , , , , , die die van die , van die ,
BLEU Score: [0.152, 0.083, 0.055, 0.0]


Epoch 15/20: 100%|██████████| 20/20 [00:07<00:00,  2.71batch/s, loss=1.247]


Predicted: <sos> ons is die die van die van die , van die , van die , van die , van die
BLEU Score: [0.152, 0.083, 0.055, 0.0]


Epoch 16/20: 100%|██████████| 20/20 [00:07<00:00,  2.68batch/s, loss=1.188]


Predicted: <sos> ons ons die die van van die , , , , , , die die die van die , ,
BLEU Score: [0.173, 0.109, 0.065, 0.0]


Epoch 17/20: 100%|██████████| 20/20 [00:07<00:00,  2.82batch/s, loss=1.135]


Predicted: <sos> ons ons die die van van 'n , , , , , ons , die die die , , ,
BLEU Score: [0.195, 0.115, 0.068, 0.0]


Epoch 18/20: 100%|██████████| 20/20 [00:07<00:00,  2.78batch/s, loss=1.124]


Predicted: <sos> dit is die die van van 'n , , , ons , die die van die , van die ,
BLEU Score: [0.152, 0.083, 0.055, 0.0]


Epoch 19/20: 100%|██████████| 20/20 [00:07<00:00,  2.70batch/s, loss=1.119]


Predicted: <sos> dit is die die van van 'n , , , ons , die die van die , wat ons die
BLEU Score: [0.173, 0.109, 0.065, 0.0]


Epoch 20/20: 100%|██████████| 20/20 [00:07<00:00,  2.64batch/s, loss=1.101]


Predicted: <sos> dit is die die wat wat ons , , , , die die die , , , , die die
BLEU Score: [0.152, 0.083, 0.055, 0.0]


## Evauate on the Training set

In [8]:
EN_SRC = [' '.join(sent) for sent in english_data.data_str]
AF_REF = [[' '.join(sent)] for sent in afrikaans_data.data_str]
TRANSLATED = [translator.translate_sentence(sent) for sent in EN_SRC]
corpus_bleu(TRANSLATED, AF_REF)

Using the latest cached version of the module from /Users/lucien/.cache/huggingface/modules/evaluate_modules/metrics/evaluate-metric--bleu/9e0985c1200e367cce45605ce0ecb5ede079894e0f24f54613fca08eeb8aff76 (last modified on Thu Jul 18 16:29:52 2024) since it couldn't be found locally at evaluate-metric--bleu, or remotely on the Hugging Face Hub.


                                     BLEU-1                                     
------------------------------------------------------------------------------------------
bleu                : 0.554970097742182
precisions          : [0.554970097742182]
brevity_penalty     : 1.0
length_ratio        : 1.0282548626537358
translation_length  : 37957
reference_length    : 36914
******************************************************************************************
                                     BLEU-2                                     
------------------------------------------------------------------------------------------
bleu                : 0.45742393193969
precisions          : [0.554970097742182, 0.3770232925384919]
brevity_penalty     : 1.0
length_ratio        : 1.0282548626537358
translation_length  : 37957
reference_length    : 36914
******************************************************************************************
                                     BLEU-3  

## Evaluate on the Test set

In [13]:
with open(f"{config.VAL_DATA}/english.txt") as data:
    english_val = data.read().strip().split("\n")
with open(f"{config.VAL_DATA}/afrikaans.txt") as data:
    afrikaans_val = data.read().strip().split("\n")

### Greedy Search

In [14]:
VAL_AF_REF = [[sent] for sent in afrikaans_val]
VAL_TRANSLATED = [translator.translate_sentence(sent) for sent in english_val]

corpus_bleu(VAL_TRANSLATED, VAL_AF_REF)

Using the latest cached version of the module from /Users/lucien/.cache/huggingface/modules/evaluate_modules/metrics/evaluate-metric--bleu/9e0985c1200e367cce45605ce0ecb5ede079894e0f24f54613fca08eeb8aff76 (last modified on Thu Jul 18 16:29:52 2024) since it couldn't be found locally at evaluate-metric--bleu, or remotely on the Hugging Face Hub.


                                     BLEU-1                                     
------------------------------------------------------------------------------------------
bleu                : 0.5568948575171963
precisions          : [0.5568948575171963]
brevity_penalty     : 1.0
length_ratio        : 1.0451184444748733
translation_length  : 15265
reference_length    : 14606
******************************************************************************************
                                     BLEU-2                                     
------------------------------------------------------------------------------------------
bleu                : 0.45444803388550437
precisions          : [0.5568948575171963, 0.3708474099099099]
brevity_penalty     : 1.0
length_ratio        : 1.0451184444748733
translation_length  : 15265
reference_length    : 14606
******************************************************************************************
                                     BL

### Beam search

In [15]:
VAL_TRANSLATED = [translator.translate_sentence(sent, method="beam", beam_width=2) for sent in english_val]
corpus_bleu(VAL_TRANSLATED, VAL_AF_REF)

Using the latest cached version of the module from /Users/lucien/.cache/huggingface/modules/evaluate_modules/metrics/evaluate-metric--bleu/9e0985c1200e367cce45605ce0ecb5ede079894e0f24f54613fca08eeb8aff76 (last modified on Thu Jul 18 16:29:52 2024) since it couldn't be found locally at evaluate-metric--bleu, or remotely on the Hugging Face Hub.


                                     BLEU-1                                     
------------------------------------------------------------------------------------------
bleu                : 0.5654433740314893
precisions          : [0.6959073997519636]
brevity_penalty     : 0.8125267445539816
length_ratio        : 0.8280843488977133
translation_length  : 12095
reference_length    : 14606
******************************************************************************************
                                     BLEU-2                                     
------------------------------------------------------------------------------------------
bleu                : 0.4711443517782686
precisions          : [0.6959073997519636, 0.4831491212176119]
brevity_penalty     : 0.8125267445539816
length_ratio        : 0.8280843488977133
translation_length  : 12095
reference_length    : 14606
******************************************************************************************
          

## Evaluate on the SUN validation set only

In [9]:
with open(f"{config.VAL_DATA}/sun_english.txt") as data:
    sun_english_val = data.read().strip().split("\n")
with open(f"{config.VAL_DATA}/sun_afrikaans.txt") as data:
    sun_afrikaans_val = data.read().strip().split("\n")

### Greedy Search

In [10]:
SUN_VAL_AF = [[sent] for sent in sun_afrikaans_val]
SUN_VAL_TRANSLATED = [translator.translate_sentence(sent) for sent in sun_english_val]
corpus_bleu(SUN_VAL_TRANSLATED, SUN_VAL_AF)

Using the latest cached version of the module from /Users/lucien/.cache/huggingface/modules/evaluate_modules/metrics/evaluate-metric--bleu/9e0985c1200e367cce45605ce0ecb5ede079894e0f24f54613fca08eeb8aff76 (last modified on Thu Jul 18 16:29:52 2024) since it couldn't be found locally at evaluate-metric--bleu, or remotely on the Hugging Face Hub.


                                     BLEU-1                                     
------------------------------------------------------------------------------------------
bleu                : 0.3795550652662408
precisions          : [0.42424242424242425]
brevity_penalty     : 0.8946655109847104
length_ratio        : 0.8998426848453067
translation_length  : 3432
reference_length    : 3814
******************************************************************************************
                                     BLEU-2                                     
------------------------------------------------------------------------------------------
bleu                : 0.28272016183778237
precisions          : [0.42424242424242425, 0.2353846153846154]
brevity_penalty     : 0.8946655109847104
length_ratio        : 0.8998426848453067
translation_length  : 3432
reference_length    : 3814
******************************************************************************************
           

### Beam Search

In [11]:
SUN_VAL_TRANSLATED = [translator.translate_sentence(sent, method="beam", beam_width=2) for sent in sun_english_val]
corpus_bleu(SUN_VAL_TRANSLATED, SUN_VAL_AF)

Using the latest cached version of the module from /Users/lucien/.cache/huggingface/modules/evaluate_modules/metrics/evaluate-metric--bleu/9e0985c1200e367cce45605ce0ecb5ede079894e0f24f54613fca08eeb8aff76 (last modified on Thu Jul 18 16:29:52 2024) since it couldn't be found locally at evaluate-metric--bleu, or remotely on the Hugging Face Hub.


                                     BLEU-1                                     
------------------------------------------------------------------------------------------
bleu                : 0.354415067780592
precisions          : [0.5483992467043315]
brevity_penalty     : 0.646271981426835
length_ratio        : 0.6961195595175669
translation_length  : 2655
reference_length    : 3814
******************************************************************************************
                                     BLEU-2                                     
------------------------------------------------------------------------------------------
bleu                : 0.27118267830255316
precisions          : [0.5483992467043315, 0.32106752931661947]
brevity_penalty     : 0.646271981426835
length_ratio        : 0.6961195595175669
translation_length  : 2655
reference_length    : 3814
******************************************************************************************
               