In [13]:
########################################################################################################################
## -- libraries and packages -- ########################################################################################
########################################################################################################################
import os
import sys
sys.path.append(os.path.abspath(".."))
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformer import DataHandler, TransformerDataset, Transformer, TransformerTrainingModule

########################################################################################################################
## -- testing data handling, helper functions, and tokenizer -- ########################################################
########################################################################################################################
src_vocab_path = "../data/vocabs/en_vocab.json"
tgt_vocab_path = "../data/vocabs/fa_vocab.json"
src_path, src_name = "../data/dataset/Tatoeba.zip", "en.txt"
tgt_path, tgt_name = "../data/dataset/Tatoeba.zip", "fa.txt"
SOS_TOKEN, PAD_TOKEN, EOS_TOKEN = '<SOS>', '<PAD>', '<EOS>'

batch_size = 64
model_emb = 64
hidden = 256
num_heads = 8
dropout_p = 0.1
num_layers_enc = 2
num_layers_dec = 2
max_sequence_length = 128
max_sentences = 1
device = 'cpu'

## -- creating the dataset and dataloader -- ##
torch.manual_seed(42)
data_handler = DataHandler(src_path, src_name, src_vocab_path, tgt_path, tgt_name, tgt_vocab_path, 
                           SOS_TOKEN, PAD_TOKEN, EOS_TOKEN, max_sequence_length = max_sequence_length, 
                           max_sentences = max_sentences)

data = data_handler.data()
dataset = TransformerDataset(data.src_sentences, data.tgt_sentences)
tr_data = DataLoader(dataset, batch_size = batch_size, shuffle = True)

## -- creating mode, loss function and optimizer -- ##
learning_rate = 2e-5
criterion = nn.CrossEntropyLoss(ignore_index = data.tgt_stoi[PAD_TOKEN], reduction = 'none')
transformer_model = Transformer(model_emb, hidden, num_heads, dropout_p, num_layers_enc, 
                                num_layers_dec, max_sequence_length, data.src_stoi, data.tgt_stoi, 
                                SOS_TOKEN, PAD_TOKEN, EOS_TOKEN, device = device).to(device)

optimizer = torch.optim.Adam(transformer_model.parameters(), lr = learning_rate)

## -- training loop -- ##
trainer = TransformerTrainingModule(transformer_model, criterion, optimizer, data,
                                    max_sequence_length, SOS_TOKEN, PAD_TOKEN, EOS_TOKEN, 
                                    device = device, save_path = '../weights/model_weights.pth')

for _ in range(10):
  trainer.fit(data_loader = tr_data, epochs = 250, verbose = True)
  translation = trainer.translate(data.src_sentences[0], repetition_penalty = 1.2, temperature = 0.7, top_k = 12)
  print()
  print(f"src: ", data.src_sentences[0])
  print(f"tgt: ", data.tgt_sentences[0])
  print(f"trn: ", translation)
  print()

Epoch 250 / 250, Batch: 1 / 1, Loss: 3.14794
src:  i just don't know what to say.
tgt:  من فقط نمی دانم چه بگویم.
trn:  م   م ه  ۸ط ننم یرغینومغنم

Epoch 250 / 250, Batch: 1 / 1, Loss: 2.43668
src:  i just don't know what to say.
tgt:  من فقط نمی دانم چه بگویم.
trn:  م طم. ن نممان چگوط ینو.م عمیم م.ظ

Epoch 250 / 250, Batch: 1 / 1, Loss: 1.80733
src:  i just don't know what to say.
tgt:  من فقط نمی دانم چه بگویم.
trn:   فمینمی دچووی ه ع یغ چهم.

Epoch 250 / 250, Batch: 1 / 1, Loss: 1.30678
src:  i just don't know what to say.
tgt:  من فقط نمی دانم چه بگویم.
trn:  من انمگط فذ دانم.

Epoch 250 / 250, Batch: 1 / 1, Loss: 0.93974
src:  i just don't know what to say.
tgt:  من فقط نمی دانم چه بگویم.
trn:  من فق^ دانگویم چه چویم.

Epoch 250 / 250, Batch: 1 / 1, Loss: 0.69524
src:  i just don't know what to say.
tgt:  من فقط نمی دانم چه بگویم.
trn:  من فقط دانم دانم بگوی چم.

Epoch 250 / 250, Batch: 1 / 1, Loss: 0.50033
src:  i just don't know what to say.
tgt:  من فقط نمی دانم چه بگویم.
trn: 