In [None]:
from translation_machine import dataset_utils

import torchtext,torch
from torchtext.datasets import Multi30k
train_dataset = Multi30k(language_pair=("en", "de"), split=('train'))
train_dataset = dataset_utils.DatasetSlicer(train_dataset,start_index=0,stop_index=3)

In [None]:
from torchtext.data.utils import get_tokenizer

de_tokenizer = get_tokenizer('spacy', language='de')
en_tokenizer = get_tokenizer('spacy', language='en')

In [None]:
from collections import Counter, OrderedDict
from itertools import islice

english_counter = Counter()
german_counter = Counter()


from tqdm import tqdm
for el in tqdm(train_dataset):
    english_sentence = en_tokenizer(el[0])
    german_sentence = de_tokenizer(el[1])
    english_counter.update(english_sentence)
    german_counter.update(german_sentence)


vocab_english = torchtext.vocab.vocab(english_counter,specials=['<unk>'])
vocab_german = torchtext.vocab.vocab(german_counter,specials=['<unk>','<sos>','<eos>'])

vocab_english.set_default_index(vocab_english['<unk>'])
vocab_german.set_default_index(vocab_german['<unk>'])

In [None]:
#getting the tallest sequences for each language
length_en_sentences = []
length_de_sentrences = []
for el in train_dataset:
    length_en_sentences.append(len(el[0]))
    length_de_sentrences.append(len(el[1]))
max_length_english = max(length_de_sentrences)
max_length_german = max(length_en_sentences)
max_length_german_extended = max_length_german+2

In [None]:
import torch

def collate_fn(batch):
    """transform batch of pairs of english and german sentences into batch tensor"""
    en_id_tokens_batchs = []
    ge_id_tokens_batchs = []
    en_lengths = []
    ge_lengths = []
    for el in batch:
        english_sentence,german_sentence = el
        id_token_en = [vocab_english[el] for el in en_tokenizer(english_sentence)]
        id_token_ge = [vocab_german[el] for el in de_tokenizer(german_sentence)]
        en_length = len(id_token_en)
        ge_length = len(id_token_ge)+2
        
        
        id_token_en += [vocab_english['<unk>']]*(max_length_english-len(id_token_en))
        id_token_ge = [vocab_german['<sos>']]+id_token_ge+[vocab_german['<eos>']]
        id_token_ge += [vocab_german['<unk>']]*(max_length_german_extended-len(id_token_ge))
        
        #we add the start and en end of sequence token to each spanish sentence
        
        en_id_tokens_batchs.append(id_token_en)
        ge_id_tokens_batchs.append(id_token_ge)
        en_lengths.append(en_length)
        ge_lengths.append(ge_length)

    #convert to tensors
    res =  en_id_tokens_batchs,ge_id_tokens_batchs,en_lengths,ge_lengths
    res = [torch.tensor(el) for el in res]
    
    return res



In [None]:
from torch.utils.data import DataLoader
import numpy as np

shuffle = True
batch_size= 20


train_data_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=shuffle,collate_fn=collate_fn)

In [None]:
from translation_machine import model

en_embeddings_size = 128
ge_embeddings_size = 128

hidden_size_encoder = 256
hidden_size_decoder = 256




model_inputs = {
    "en_embeddings_size":en_embeddings_size,
    "ge_embeddings_size":ge_embeddings_size,
    "hidden_size_encoder":hidden_size_encoder,
    "hidden_size_decoder":hidden_size_decoder,
    "vocab_english":vocab_english,
    "vocab_german":vocab_german,
    "max_length_german_extended":max_length_german_extended
}

sequence_translator = model.SequenceTranslator(**model_inputs)


In [None]:
from torch import nn
baseline_loss = nn.CrossEntropyLoss()


In [None]:
from torch import optim
from translation_machine import model_trainer
optimizer = optim.NAdam(params=sequence_translator.parameters(),lr=0.1)


scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

model_trainer = model_trainer.ModelTrainer(sequence_translator,optimizer,scheduler,train_data_loader,None,baseline_loss)


restart = True

if not(restart):
    tmp = torch.load("./sequence_translator_extended.pth")
    model_params = tmp["model_params"]
    model_inputs = tmp["model_inputs"]
    model_trainer.model.load_state_dict(model_params)


In [None]:
from tqdm import tqdm
sequence_translator.train()
losses_train = []
losses_val = []
bleu_scores_train = []
bleu_scores_val = []

In [None]:
nb_epochs = 40

for epoch in tqdm(range(nb_epochs)):
    print("optimizing for epoch")    
    loss_train,bleu_score_train = model_trainer.train_on_epoch()
    #loss_val,bleu_score_val = model_trainer.validate_on_epoch()

    loss_train = [float(el) for el in loss_train]
    #loss_val = [float(el) for el in loss_val]

    losses_train.append(loss_train)
    #losses_val.append(loss_val)
    
    #bleu_scores_train.append(bleu_score_train)
    #bleu_scores_val.append(bleu_score_val)