# ДЗ_13: Нейронные сети. Рекурентные сети

## Сравнить LSTM, RNN и GRU на задаче предсказания части речи (качество предсказания, скорость обучения, время инференса модели)

In [1]:
import datetime

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

In [8]:
data_dir = ''
train_lang = 'en'

In [9]:
class DatasetSeq(Dataset):
    def __init__(self, data_dir, train_lang='en'):
	#open file
        with open(data_dir + train_lang + '.train', 'r') as f:
            train = f.read().split('\n\n') #разбиваем на предложения
        # delete extra tag markup
        train = [x for x in train if not '_ ' in x]
	    #init vocabs of tokens for encoding {<str> token: <int> id}
        # формируем словари токенов
        self.target_vocab = {'<pad>': 0} # {p: 1, a: 2, r: 3, pu: 4}
        self.word_vocab = {'<pad>': 0} # {cat: 1, sat: 2, on: 3, mat: 4, '.': 5}
        self.char_vocab = {'<pad>': 0} # {c: 1, a: 2, t: 3, ' ': 4, s: 5}
	    
        # Cat sat on mat. -> [1, 2, 3, 4, 5]
        # p    a  r  p pu -> [1, 2, 3, 1, 4]
        # chars  -> [1, 2, 3, 4, 5, 2, 3, 4]

	    #init encoded sequences lists (processed data)
        self.encoded_sequences = []
        self.encoded_targets = []
        self.encoded_char_sequences = []
        # n=1 because first value is padding
        n_word = 1
        n_target = 1
        n_char = 1
        for line in train:
            sequence = []
            target = []
            chars = []
            for item in line.split('\n'):
                if item != '':
                    word, label = item.split(' ')

                    if self.word_vocab.get(word) is None:
                        self.word_vocab[word] = n_word
                        n_word += 1
                    if self.target_vocab.get(label) is None:
                        self.target_vocab[label] = n_target
                        n_target += 1
                    for char in word:
                        if self.char_vocab.get(char) is None:
                            self.char_vocab[char] = n_char
                            n_char += 1
                    sequence.append(self.word_vocab[word])
                    target.append(self.target_vocab[label])
                    chars.append([self.char_vocab[char] for char in word])
            self.encoded_sequences.append(sequence)
            self.encoded_targets.append(target)
            self.encoded_char_sequences.append(chars)

    def __len__(self):
        return len(self.encoded_sequences)

    def __getitem__(self, index):
        return {
            'data': self.encoded_sequences[index], # [1, 2, 3, 4, 6] len=5
            'char': self.encoded_char_sequences[index],# [[1,2,3], [4,5], [1,2], [2,6,5,4], []] len=5
            'target': self.encoded_targets[index], # [1, 2, 3, 4, 6] len=5
        }

In [10]:
dataset = DatasetSeq(data_dir)

In [11]:
def collate_fn(batch):
    data = []
    target = []
    for item in batch:
        data.append(torch.as_tensor(item['data']))
        target.append(torch.as_tensor(item['target']))
    data = pad_sequence(data, batch_first=True, padding_value=0)
    target = pad_sequence(target, batch_first=True, padding_value=0)

    return {'data': data, 'target': target}

In [12]:
class RNNPredictorV2(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, n_classes, RNN):
        super().__init__()
        self.word_emb = nn.Embedding(vocab_size, emb_dim)
        #TODO try to use other RNN archicetures, f.e. RNN and LSTM
        self.word_emb = nn.Embedding(vocab_size, emb_dim)
        self.rnn = RNN(emb_dim, hidden_dim, batch_first=True)
        self.clf = nn.Linear(hidden_dim, n_classes)
        self.do = nn.Dropout(0.1)
    
    def forward(self, x):
        emb = self.word_emb(x) # B x T x Emb_dim
        hidden, _ = self.rnn(emb) # B x T x Hid, B x 1 x Hid # hidden всей последовательности и hidden последнего токена
        pred = self.clf(self.do(hidden)) # B x T x N_classes

        return pred

In [17]:
#hyper params
vocab_size = len(dataset.word_vocab) + 1
n_classes = len(dataset.target_vocab) + 1
n_chars = len(dataset.char_vocab) + 1
#TODO try to use other model parameters
emb_dim = 512
hidden = 256
n_epochs = 3
batch_size = 128
cuda_device = -1
device = f'cuda:{cuda_device}' if cuda_device != -1 else 'cpu'

In [18]:
def NNtrain (model):
    start = datetime.datetime.now()
    model.train()
    optim = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_func = nn.CrossEntropyLoss()
    for epoch in range(n_epochs):
        dataloader = DataLoader(dataset, 
                                batch_size, 
                                shuffle=True, 
                                collate_fn=collate_fn,
                                drop_last = True,
                                )
        for i, batch in enumerate(dataloader):
            optim.zero_grad()

            predict = model(batch['data'].to(device))
            loss = loss_func(predict.view(-1, n_classes),
                            batch['target'].to(device).view(-1), 
                            )
            loss.backward()
            optim.step()
            #if i % 100 == 0:
            #   print(f'epoch: {epoch}, step: {i}, loss: {loss.item()}')
        
        torch.save(model.state_dict(), f'./rnn_chkpt_{epoch}.pth')
    print(f'train time: {datetime.datetime.now()-start}')
    print(f'epoch: {epoch}, step: {i}, loss: {loss.item()}')

In [19]:
def NNinference(model):
    start = datetime.datetime.now()
    #example
    phrase = 'He ran quickly after the red bus and caught it'
    words = phrase.split(' ')
    tokens = [dataset.word_vocab[w] for w in words]

    start = datetime.datetime.now()
    with torch.no_grad():
        model.eval()
        predict = model(torch.tensor(tokens).unsqueeze(0).to(device)) # 1 x T x N_classes
        labels = torch.argmax(predict, dim=-1).squeeze().cpu().detach().tolist()
        end = datetime.datetime.now() - start

    target_labels = list(dataset.target_vocab.keys())
    print([target_labels[l] for l in labels])
    print(f'inferense time: {datetime.datetime.now()-start}')


'He ran quickly after the red bus and caught it'

In [20]:
for NN in [nn.GRU, nn.LSTM, nn.RNN]:
    print(NN)
    model = RNNPredictorV2(vocab_size, emb_dim, hidden, n_classes, NN).to(device)
    NNtrain(model)   
    NNinference(model)


<class 'torch.nn.modules.rnn.GRU'>
train time: 0:06:31.560606
epoch: 2, step: 164, loss: 0.07197506725788116
['PRON', 'VERB', 'ADV', 'ADP', 'DET', 'ADJ', 'NOUN', 'CCONJ', 'VERB', 'PRON']
inferense time: 0:00:00.015866
<class 'torch.nn.modules.rnn.LSTM'>
train time: 0:07:51.866777
epoch: 2, step: 164, loss: 0.11098306626081467
['PRON', 'VERB', 'ADV', 'SCONJ', 'DET', 'ADJ', 'NOUN', 'CCONJ', 'VERB', 'PRON']
inferense time: 0:00:00.004960
<class 'torch.nn.modules.rnn.RNN'>
train time: 0:04:09.242716
epoch: 2, step: 164, loss: 0.09951762109994888
['PRON', 'VERB', 'ADV', 'ADP', 'DET', 'ADJ', 'NOUN', 'CCONJ', 'VERB', 'PRON']
inferense time: 0:00:00.003800


Обучал модели на CPU поэтому сделал всего 3 эпохи.
По качеству:
1. GRU
2. RNN
3. LSTM

По скорости обучения:
1. RNN
2. GRU
3. LSTM

По скорости инференса:
1. RNN
2. LSTM
3. GRU

