In [24]:
import os 
import random
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import warnings
warnings.filterwarnings('ignore')

from data import TranslationDataset
from rnn import RNN, RNNTools
from transformers import Transformers, TransformersTools

from IPython.display import Markdown

In [25]:
# configurable parameters, change as needed

# set to true if loading existing model file, false if training a new model
skip_training = True
data_dir = 'data'
rnn_model_save_path = 'models/rnn.pth'
tra_model_save_path = 'models/transformers.pth'


In [26]:
# create dirs if not existing
os.makedirs(data_dir, exist_ok=True)
os.makedirs('models', exist_ok=True)
os.makedirs('logs', exist_ok=True)

In [27]:
# additional settings, automatically selects cuda if available
if skip_training:
    device_type = 'cpu'
elif torch.cuda.is_available():
    device_type = 'cuda:0'
else:
    device_type = 'cpu'

# set manually if needed e.g. device_type = 'cpu'
print("Using device type:", device_type)
device = torch.device(device_type)

Using device type: cpu


In [28]:
trainset = TranslationDataset(data_dir, train=True)
testset = TranslationDataset(data_dir, train=False)
print('Number of sentence pairs in the training set: ', len(trainset))
print('Number of sentence pairs in the test set: ', len(testset))

Number of sentence pairs in the training set:  8682
Number of sentence pairs in the test set:  2171


## RNN

In [29]:
trainloader = DataLoader(dataset=trainset, batch_size=64, shuffle=True, collate_fn=RNNTools.collate, pin_memory=True)
testloader = DataLoader(dataset=testset, batch_size=64, shuffle=False, collate_fn=RNNTools.collate)

In [30]:
rnn = RNN(trainset.input_lang.n_words, trainset.output_lang.n_words, embed_size=256, hidden_size=256)
rnn.to(device)

RNN(
  (encoder): Encoder(
    (embedding): Embedding(4489, 256)
    (gru): GRU(256, 256)
  )
  (decoder): Decoder(
    (embedding): Embedding(2925, 256)
    (gru): GRU(256, 256)
    (out): Linear(in_features=256, out_features=2925, bias=True)
  )
)

In [31]:
if not skip_training:
    PADDING_VALUE = 0 
    teacher_forcing_ratio = 0.5
    num_epochs = 2

    optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001)    
    criterion = nn.NLLLoss(ignore_index=PADDING_VALUE)
    
    rnn.train()
    
    for epoch in range(num_epochs):
        total_loss = 0
        total_data = 0
        for src_seqs, src_seq_lengths, tgt_seqs in trainloader:
            src_seqs, tgt_seqs = src_seqs.to(device), tgt_seqs.to(device)
            
            if torch.rand(1) < teacher_forcing_ratio:
                teacher_forcing=True
            else:
                teacher_forcing=False
            
            # forward pass
            outputs = rnn(src_seqs, tgt_seqs, src_seq_lengths, teacher_forcing)
            loss = criterion(outputs.permute(0, 2, 1).to(device), tgt_seqs)
            
            # compute loss metric
            total_loss += (loss.item() * src_seqs.shape[1])
            total_data += src_seqs.shape[1]

            # backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print("epoch: {0} training loss: {1:.3f}".format(epoch, total_loss/total_data))

In [32]:
if not skip_training:
    torch.save(rnn.state_dict(), rnn_model_save_path)

In [33]:
if skip_training:
    rnn.load_state_dict(torch.load(rnn_model_save_path, map_location=lambda storage, loc: storage))
    print('RNN model loaded from: {}'.format(rnn_model_save_path))
    rnn.to(device)
    rnn.eval()

RNN model loaded from: models/rnn.pth


In [34]:
rnntools = RNNTools(device)

In [35]:
results_df = pd.DataFrame(index=range(20), columns=['batch_i', 'Source', 'Actual Translation', 'RNN Translation', 'Transformer Translation'])

In [36]:
print('Test data:')

i = 0
for src_seqs, src_mask, tgt_seqs in testloader:
    if i >= 20:
        break

    out_seqs = rnntools.translate(rnn, src_seqs, src_mask)

    for r in random.sample(range(0, 64), 1):
        results_df.loc[i, 'batch_i'] = r
        results_df.loc[i, 'Source'] = rnntools.seq_to_string(src_seqs[:,r], testset.input_lang)
        results_df.loc[i, 'Actual Translation'] = rnntools.seq_to_string(tgt_seqs[:,r], testset.output_lang)
        results_df.loc[i, 'RNN Translation'] = rnntools.seq_to_string(out_seqs[:,r], testset.output_lang)

    i += 1

display(Markdown(results_df[['Source', 'Actual Translation', 'RNN Translation']].to_markdown()))

Test data:


|    | Source                                | Actual Translation               | RNN Translation                        |
|---:|:--------------------------------------|:---------------------------------|:---------------------------------------|
|  0 | vous vous etes trompe d avion .       | you are on the wrong plane .     | you re always anticipating of others . |
|  1 | je suis interesse par l anglais .     | i am interested in english .     | i m interested in english .            |
|  2 | nous sommes tellement fiers de vous ! | we re so proud of you !          | we re so proud of you !                |
|  3 | tu mens n est ce pas ?                | you re lying aren t you ?        | you re staying aren t you ?            |
|  4 | il est un peu emeche .                | he s a bit tipsy .               | he s a bit rough .                     |
|  5 | elle sembla indifferente .            | she seemed uninterested .        | she seemed uninterested .              |
|  6 | tu es tout seul .                     | you re all alone .               | you re all alone .                     |
|  7 | je ne suis pas une menteuse .         | i m not a liar .                 | i m not a liar .                       |
|  8 | je suis hongrois .                    | i am hungarian .                 | i m a .                                |
|  9 | nous avons juste peur .               | we re just scared .              | i m still waiting .                    |
| 10 | desole d etre si stupide .            | i m sorry i m so stupid .        | i m so happy for your loss . .         |
| 11 | nous sommes une grande famille .      | we re a big family .             | we re a big boy .                      |
| 12 | je ne vais pas citer de noms .        | i m not going to name names .    | i m not going to lose .                |
| 13 | elles sont toutes parties .           | they re all gone .               | they re all dead .                     |
| 14 | je suis amoureuse de toi .            | i m in love with you .           | i m on your side .                     |
| 15 | elle l a sermonne .                   | she scolded him .                | she studies him .                      |
| 16 | je suis dans le jardin .              | i am in the garden .             | i am in the bathtub .                  |
| 17 | il profite de sa vie d ecolier .      | he is enjoying his school life . | he is afraid to have his hair .        |
| 18 | vous n etes pas si vieux tom .        | you re not that old tom .        | you re not that old tom .              |
| 19 | vous n etes pas millionnaire .        | you re not a millionaire .       | you re not a millionaire .             |

In [37]:
score = rnntools.compute_bleu_score(rnn, trainloader, trainset.output_lang)
print(f'BLEU score on training data: {score*100}')
score = rnntools.compute_bleu_score(rnn, testloader, trainset.output_lang)
print(f'BLEU score on test data: {score*100}')

BLEU score on training data: 96.69817090034485
BLEU score on test data: 47.73730933666229


## Transformers

In [38]:
# skip_training = True

In [39]:
trainloader = DataLoader(dataset=trainset, batch_size=64, shuffle=True, collate_fn=TransformersTools.collate, pin_memory=True)
testloader = DataLoader(dataset=testset, batch_size=64, shuffle=False, collate_fn=TransformersTools.collate)

In [40]:
tra = Transformers(trainset.input_lang.n_words, trainset.output_lang.n_words, n_blocks=3, n_features=256, n_heads=16, n_hidden=1024)
tra.to(device)

Transformers(
  (encoder): Encoder(
    (embedding): Embedding(4489, 256, padding_idx=0)
    (positional_encoding): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder_blocks): ModuleList(
      (0-2): 3 x EncoderBlock(
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (dropout1): Dropout(p=0.1, inplace=False)
        (layer_norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=256, out_features=1024, bias=True)
          (1): Dropout(p=0.1, inplace=False)
          (2): ReLU()
          (3): Linear(in_features=1024, out_features=256, bias=True)
        )
        (dropout2): Dropout(p=0.1, inplace=False)
        (layer_norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      )
    )
  )
  (decoder): Decoder(
    (embedding): Embedding(2925, 256, padding_idx=0)
 

In [41]:
if not skip_training:
    PADDING_VALUE = 0
    num_epochs = 2

    optimizer = torch.optim.Adam(tra.parameters(), lr=0.001)
    criterion = nn.NLLLoss(ignore_index=PADDING_VALUE)

    for epoch in range(num_epochs):
        total_loss = 0
        total_data = 0
        for src_seqs, src_mask, tgt_seqs in trainloader:
            src_seqs, src_mask, tgt_seqs = src_seqs.to(device), src_mask.to(device), tgt_seqs.to(device)
            
            # forward
            outputs = tra(src_seqs, tgt_seqs, src_mask)
            
            # compute loss metric
            loss = criterion(outputs.permute(0, 2, 1).to(device), tgt_seqs[1:])
            total_loss += (loss.item() * src_seqs.shape[1])
            total_data += src_seqs.shape[1]

            # backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print("epoch: {0} training loss: {1:.3f}".format(epoch, total_loss/total_data))


In [42]:
if not skip_training:
    torch.save(tra.state_dict(), tra_model_save_path)

In [43]:
if skip_training:
    tra.load_state_dict(torch.load(tra_model_save_path, map_location=lambda storage, loc: storage))
    print('Transformers model loaded from: {}'.format(tra_model_save_path))
    tra.to(device)
    tra.eval()

Transformers model loaded from: models/transformers.pth


In [44]:
tratools = TransformersTools(device)

In [48]:
print('Test data:')

i = 0
for src_seqs, src_mask, tgt_seqs in testloader:
    if i >= 20:
        break

    out_seqs = tratools.translate(tra, src_seqs, src_mask)

    r = results_df.loc[i, 'batch_i']
    print(r)
    results_df.loc[i, 'Transformer Translation'] = tratools.seq_to_string(out_seqs[:,r], testset.output_lang)

    i += 1

display(Markdown(results_df[['Source', 'Actual Translation', 'Transformer Translation']].to_markdown()))

Test data:
20
19
25
20
34
61
41
13
49
39
43
33
9
50
26
38
34
4
9
34


|    | Source                                | Actual Translation               | Transformer Translation          |
|---:|:--------------------------------------|:---------------------------------|:---------------------------------|
|  0 | vous vous etes trompe d avion .       | you are on the wrong plane .     | she is devoted to her children . |
|  1 | je suis interesse par l anglais .     | i am interested in english .     | she s too loud .                 |
|  2 | nous sommes tellement fiers de vous ! | we re so proud of you !          | he s a fit .                     |
|  3 | tu mens n est ce pas ?                | you re lying aren t you ?        | you re very intelligent .        |
|  4 | il est un peu emeche .                | he s a bit tipsy .               | he is my hero .                  |
|  5 | elle sembla indifferente .            | she seemed uninterested .        | i m from croatia .               |
|  6 | tu es tout seul .                     | you re all alone .               | she s six years older than me .  |
|  7 | je ne suis pas une menteuse .         | i m not a liar .                 | i m not your husband anymore .   |
|  8 | je suis hongrois .                    | i am hungarian .                 | we re being attacked .           |
|  9 | nous avons juste peur .               | we re just scared .              | i m undressing .                 |
| 10 | desole d etre si stupide .            | i m sorry i m so stupid .        | you re alone aren t you ?        |
| 11 | nous sommes une grande famille .      | we re a big family .             | i m ready for them .             |
| 12 | je ne vais pas citer de noms .        | i m not going to name names .    | she is war .                     |
| 13 | elles sont toutes parties .           | they re all gone .               | i m not sure .                   |
| 14 | je suis amoureuse de toi .            | i m in love with you .           | we re all ears .                 |
| 15 | elle l a sermonne .                   | she scolded him .                | i m truly sorry .                |
| 16 | je suis dans le jardin .              | i am in the garden .             | they are christians .            |
| 17 | il profite de sa vie d ecolier .      | he is enjoying his school life . | he is eight .                    |
| 18 | vous n etes pas si vieux tom .        | you re not that old tom .        | he is able to play his entire .  |
| 19 | vous n etes pas millionnaire .        | you re not a millionaire .       | i m as tall as my grandfather    |

In [46]:
score = tratools.compute_bleu_score(tra, trainloader, trainset.output_lang)
print(f'BLEU score on training data: {score*100}')
score = tratools.compute_bleu_score(tra, testloader, trainset.output_lang)
print(f'BLEU score on test data: {score*100}')

BLEU score on training data: 95.63669647179238
BLEU score on test data: 58.79185315508608


In [47]:
display(Markdown(results_df[['Source', 'Actual Translation', 'RNN Translation', 'Transformer Translation']].to_markdown()))

|    | Source                                | Actual Translation               | RNN Translation                        | Transformer Translation          |
|---:|:--------------------------------------|:---------------------------------|:---------------------------------------|:---------------------------------|
|  0 | vous vous etes trompe d avion .       | you are on the wrong plane .     | you re always anticipating of others . | she is devoted to her children . |
|  1 | je suis interesse par l anglais .     | i am interested in english .     | i m interested in english .            | she s too loud .                 |
|  2 | nous sommes tellement fiers de vous ! | we re so proud of you !          | we re so proud of you !                | he s a fit .                     |
|  3 | tu mens n est ce pas ?                | you re lying aren t you ?        | you re staying aren t you ?            | you re very intelligent .        |
|  4 | il est un peu emeche .                | he s a bit tipsy .               | he s a bit rough .                     | he is my hero .                  |
|  5 | elle sembla indifferente .            | she seemed uninterested .        | she seemed uninterested .              | i m from croatia .               |
|  6 | tu es tout seul .                     | you re all alone .               | you re all alone .                     | she s six years older than me .  |
|  7 | je ne suis pas une menteuse .         | i m not a liar .                 | i m not a liar .                       | i m not your husband anymore .   |
|  8 | je suis hongrois .                    | i am hungarian .                 | i m a .                                | we re being attacked .           |
|  9 | nous avons juste peur .               | we re just scared .              | i m still waiting .                    | i m undressing .                 |
| 10 | desole d etre si stupide .            | i m sorry i m so stupid .        | i m so happy for your loss . .         | you re alone aren t you ?        |
| 11 | nous sommes une grande famille .      | we re a big family .             | we re a big boy .                      | i m ready for them .             |
| 12 | je ne vais pas citer de noms .        | i m not going to name names .    | i m not going to lose .                | she is war .                     |
| 13 | elles sont toutes parties .           | they re all gone .               | they re all dead .                     | i m not sure .                   |
| 14 | je suis amoureuse de toi .            | i m in love with you .           | i m on your side .                     | we re all ears .                 |
| 15 | elle l a sermonne .                   | she scolded him .                | she studies him .                      | i m truly sorry .                |
| 16 | je suis dans le jardin .              | i am in the garden .             | i am in the bathtub .                  | they are christians .            |
| 17 | il profite de sa vie d ecolier .      | he is enjoying his school life . | he is afraid to have his hair .        | he is eight .                    |
| 18 | vous n etes pas si vieux tom .        | you re not that old tom .        | you re not that old tom .              | he is able to play his entire .  |
| 19 | vous n etes pas millionnaire .        | you re not a millionaire .       | you re not a millionaire .             | i m as tall as my grandfather    |