# INM706 - Deep Learning for Sequence Analysis

Authors: Laerte Adami - Elisa Troschka

In [1]:
import time
from Utilities.lstmHandler import EncoderLSTM, DecoderLSTM
from Utilities.modelHandler import LSTModel
from LanguageDataset import LanguageDataset, my_collate_fn
from torch.utils.data import DataLoader

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss as CEL
from torch.optim import Adam

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")

In [2]:
start_token = '<BoS>'
end_token = '<EoS>'

dataset = LanguageDataset(data_path="Data/eng_ita.tsv", start_token = start_token, end_token = end_token)

vocabulary_size = dataset.eng_voc_size
embedding_size = 256

loss_func = CEL()

end_index = dataset.from_ita[end_token]
start_index = dataset.from_ita[start_token]

trainloader = DataLoader(dataset, batch_size = 32, collate_fn = my_collate_fn)

print("English vocabulary size: {}".format(vocabulary_size))

English vocabulary size: 157


In [3]:
encoder = EncoderLSTM(vocabulary_size = vocabulary_size,
                     embedding_size = embedding_size,
                     num_layers = 1, 
                     bidirectional = False)

decoder = DecoderLSTM(vocabulary_size = dataset.ita_voc_size,
                     embedding_size = embedding_size,
                     num_layers = 1, 
                     bidirectional = False)

In [4]:
model = LSTModel(encoder = encoder, 
                 decoder = decoder, 
                 encoder_optimizer = Adam(encoder.parameters()), 
                 decoder_optimizer = Adam(decoder.parameters()),
                 loss_function = loss_func, 
                 eos_token = end_index, 
                 bos_token = start_index)

In [5]:
start_time = time.time()
model.train_model(trainloader, 
                  max_epochs = 1,
                  save_every_epochs = 10,
                  ckp_name = '.')
print("Time required: {}".format(time.time()-start_time))

Completed epoch: 0, loss: 60.21760940551758
Time required: 7.196247339248657


In [7]:
eng_sent, ita_sent = dataset.__getitem__(1)

In [8]:
eng_sent

tensor([ 71,  68, 140,   6,   5,   6, 132, 136])

In [9]:
ita_sent

tensor([115,  11,  82, 151, 152, 185])

In [18]:
dataset.from_ita["io"]

79

In [20]:
model.encoder.eval()
model.decoder.eval();