In [35]:
import torch
from torch import optim
from torch import nn
from torch.utils.data import DataLoader, Dataset
import time

class DatasetSeq(Dataset):
    def __init__(self, train_lang='en'):
        with open(train_lang + '.train', 'r', encoding='utf-8') as f:
            train = f.read().split('\n\n')
        train = [x for x in train if not '_ ' in x]
        self.target_vocab = {}
        self.word_vocab = {}
        self.encoded_sequences = []
        self.encoded_targets = []
        n_word = 0
        n_target = 0
        
        for line in train:
            sequence = []
            target = []
            for item in line.split('\n'):
                if item != '':
                    word, label = item.split(' ')
                    if self.word_vocab.get(word) is None:
                        self.word_vocab[word] = n_word
                        n_word += 1
                    if self.target_vocab.get(label) is None:
                        self.target_vocab[label] = n_target
                        n_target += 1
                    sequence.append(self.word_vocab[word])
                    target.append(self.target_vocab[label])
                
            self.encoded_sequences.append(sequence)
            self.encoded_targets.append(target)

    def __len__(self):
        return len(self.encoded_sequences)
    
    def __getitem__(self, index):
        return {
            'data': torch.tensor(self.encoded_sequences[index]),
            'target': torch.tensor(self.encoded_targets[index]),
        }
    
dataset = DatasetSeq()

In [36]:
class RNN(nn.Module):
    def __init__(self, 
                 word_vocab_len: int, 
                 n_classes: int, 
                 rnn: object, 
                 bidirect: bool,
                 emb_size: int = 128, 
                 hidden_size: int = 128):
        super().__init__()
        self.word_emb = nn.Embedding(word_vocab_len, emb_size)
        self.rnn = rnn(input_size=emb_size, hidden_size=hidden_size, batch_first=True, bidirectional=bidirect)
        self.classifier = nn.Linear(hidden_size*(2 if bidirect else 1), n_classes)

    def forward(self, x):
        embedded = self.word_emb(x)
        out, _ = self.rnn(embedded)
        return self.classifier(out)

In [37]:
lr = 0.001
batch_size = 32
n_epochs = 3
vocab_len = len(dataset.word_vocab)
n_classes = len(dataset.target_vocab)

def collate_fn(data):
    return data[0]

dataloader = DataLoader(
                dataset=dataset, 
                collate_fn=collate_fn, 
                batch_size=batch_size, 
                shuffle=True, 
                drop_last=True)

In [38]:
models = []
RNN_list = [nn.RNN, nn.GRU, nn.LSTM]
bidirect_params = [True, False]
info = {}
train_time_key = 'train_time'
accuracy_key = 'accuracy'
inference_time_key = 'inference_time'

for rnn in RNN_list:
    info[rnn.__name__] = {}
    
    for bidirect in bidirect_params:
        info_bidirect = 'bidirect' if bidirect else 'no_bidirect'
        info[rnn.__name__][info_bidirect] = {}
             
        model = RNN(vocab_len, n_classes, rnn, bidirect)       
        criterion = nn.CrossEntropyLoss()
        optimiser = optim.Adam(model.parameters(), lr=lr)
        
        start_train_time = time.time()
        
        for epoch in range(n_epochs):
            
            for step, batch in enumerate(dataloader):
                data = batch['data'].unsqueeze(0)
                target = batch['target']
                predict = model(data)
                predict = predict.view(-1, n_classes)
                loss = criterion(predict, target)
                optimiser.zero_grad()
                loss.backward()
                optimiser.step()
        
        info[rnn.__name__][info_bidirect][train_time_key] = time.time()-start_train_time
        
        model_dic = {'name':rnn.__name__, 'bidirect':bidirect, 'model':model}
        models.append(model_dic)

In [39]:
for model_dic in models:
    info_bidirect = 'bidirect' if model_dic['bidirect'] else 'no_bidirect'
    model_name = model_dic["name"]
    
    model = model_dic['model']
    inferences_times = []
    
    def check_accuracy(loader, model):
        num_correct = 0
        num_samples = 0
        model.eval()
        with torch.no_grad():
            for step, batch in enumerate(loader):
                x = batch['data'].unsqueeze(0)
                y = batch['target']
                start_inference_time = time.time()
                scores = model(x).view(-1, n_classes)
                inference_time = time.time() - start_inference_time
                inferences_times.append(inference_time)
                _, predictions = scores.max(1)
                num_correct += (predictions == y).sum()
                num_samples += predictions.size(0)
        model.train()
        return num_correct / num_samples
    
    accuracy = check_accuracy(dataloader, model)
    mean_inference_time = sum(inferences_times)/len(inferences_times)
    
    info[model_name][info_bidirect][accuracy_key] = accuracy
    info[model_name][info_bidirect][inference_time_key] = mean_inference_time

In [40]:
for model_name in info:
    for bidirect_or_not in info[model_name]:
        print(model_name, bidirect_or_not)
        curr_info = info[model_name][bidirect_or_not]
        print('train time = ', f"{curr_info[train_time_key]:.2f} sec")
        print('accuracy = ', f"{curr_info[accuracy_key]*100:.2f} %")
        print('inference time = ', f"{curr_info[inference_time_key]*1000:.2f} ms")
        print('')

RNN bidirect
train time =  76.89 sec
accuracy =  74.88 %
inference time =  0.79 ms

RNN no_bidirect
train time =  73.19 sec
accuracy =  71.68 %
inference time =  0.47 ms

GRU bidirect
train time =  90.62 sec
accuracy =  77.50 %
inference time =  1.67 ms

GRU no_bidirect
train time =  79.80 sec
accuracy =  74.55 %
inference time =  0.93 ms

LSTM bidirect
train time =  90.49 sec
accuracy =  77.62 %
inference time =  1.72 ms

LSTM no_bidirect
train time =  79.56 sec
accuracy =  74.41 %
inference time =  0.90 ms

