Сравнить LSTM, RNN и GRU на задаче предсказания части речи (качество предсказания, скорость обучения, время инференса модели)  
 *к первой зачаче добавить bidirectional

In [7]:
import torch
from time import perf_counter
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Adam

In [8]:
class DatasetSeq(Dataset):

    def __init__(self, data_dir='./', train_lang='en'):
        # open file
        with open(data_dir + train_lang + '.train', 'r',
                  encoding='utf-8') as f:
            train = f.read().split('\n\n')

        # delete extra tag markup
        train = [x for x in train if not '_ ' in x]
        # init vocabs of tokens for encoding { token:  id}
        self.target_vocab = {}  # {NOUN: 1, VERB: 2, ADP: 3, NOUN: 1, PUNCT: 4}
        self.word_vocab = {}  # {cat: 1, sat: 2, on: 3, mat: 4, '.': 5}
        self.char_vocab = {}  # {c: 1, a: 2, t: 3, ' ': 4, s: 5}

        # init encoded sequences lists (processed data)
        self.encoded_sequences = []
        self.encoded_targets = []
        self.encoded_char_sequences = []
        # n=1 because first value is padding
        n_word = 1
        n_target = 1
        n_char = 1
        for line in train:
            sequence = []
            target = []
            chars = []
            for item in line.split('\n'):
                if item != '':
                    word, label = item.split(' ')

                    if self.word_vocab.get(word) is None:
                        self.word_vocab[word] = n_word
                        n_word += 1
                    if self.target_vocab.get(label) is None:
                        self.target_vocab[label] = n_target
                        n_target += 1
                    for char in word:
                        if self.char_vocab.get(char) is None:
                            self.char_vocab[char] = n_char
                            n_char += 1
                    sequence.append(self.word_vocab[word])
                    target.append(self.target_vocab[label])
                    chars.append([self.char_vocab[char] for char in word])
            self.encoded_sequences.append(sequence)  # n_seq x words_in_seq
            self.encoded_targets.append(target)  # n_seq x words_in_seq
            # n_seq x words_in_seq x word_len
            self.encoded_char_sequences.append(chars)

    def __len__(self):
        return len(self.encoded_sequences)

    def __getitem__(self, index):
        return {
            'data': self.encoded_sequences[index],  # words_in_seq
            'char': self.encoded_char_sequences[index],  # words_in_seq
            'target': self.encoded_targets[index],  # words_in_seq x word_len
        }

In [9]:
dataset = DatasetSeq()

In [10]:
def collate_fn(input_data):
    data = []
    chars = []
    targets = []
    data_len = len(input_data)
    max_len = 0
    for item in input_data:
        if len(item['data']) > max_len:
            max_len = len(item['data'])
        data.append(torch.as_tensor(item['data']))
        chars.append(item['char'])
        targets.append(torch.as_tensor(item['target']))
    chars_seq = [[torch.as_tensor([0]) for _ in range(data_len)]
                 for _ in range(max_len)]  # max_seq_len x batch_len
    for j in range(data_len):  # batch_len
        i = 0
        while i < len(chars[j]):  # max_seq_len
            # batch_len x seq_len x word_len
            chars_seq[i][j] = torch.as_tensor(chars[j][i])
            i += 1
    for i in range(max_len):
        chars_seq[i] = pad_sequence(chars_seq[i],
                                    batch_first=True,
                                    padding_value=0)
    data = pad_sequence(data, batch_first=True, padding_value=0)
    targets = pad_sequence(targets, batch_first=True, padding_value=0)
    return {'data': data, 'chars': chars_seq, 'target': targets}

In [11]:
class SelectItem(nn.Module):

    def __init__(self, item_index):
        super().__init__()
        self._name = 'selectitem'
        self.item_index = item_index

    def forward(self, inputs):
        return inputs[self.item_index]

In [22]:
class CustomRNN(nn.Module):

    def __init__(self, vocab_size, emb_dim, hid_dim, n_classes, arch=nn.RNN, do_r=0.1):
        super().__init__()
        self.seq = nn.Sequential()
        self.seq.append(nn.Embedding(vocab_size, emb_dim))
        self.seq.append(arch(emb_dim, hid_dim, batch_first=True))
        self.seq.append(SelectItem(0))
        self.seq.append(nn.Dropout(do_r))
        self.seq.append(nn.Linear(hid_dim, n_classes))

    def forward(self, x):
        return self.seq(x)

In [34]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, data_dict in enumerate(dataloader):
        words = data_dict['data'].to(device)
        targets = data_dict['target'].to(device).view(-1)

        # Compute prediction error
        pred = model(words).view(-1, n_classes)
        loss = loss_fn(pred, targets)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(words)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [39]:
# hyper params
VOCAB_SIZE = len(dataset.word_vocab) + 1
N_CLASSES = len(dataset.target_vocab) + 1
N_CHARS = len(dataset.char_vocab) + 1
EMB_DIM = 300
HID_DIM = 300
N_EPOCHS = 10
BATCH_SIZE = 128
SEED = 123
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [43]:
simple_rnn = CustomRNN(VOCAB_SIZE, EMB_DIM, HID_DIM, N_CLASSES, arch=nn.RNN)
gru_rnn = CustomRNN(VOCAB_SIZE, EMB_DIM, HID_DIM, N_CLASSES, arch=nn.GRU)
lstm_rnn = CustomRNN(VOCAB_SIZE, EMB_DIM, HID_DIM, N_CLASSES, arch=nn.LSTM)

In [44]:
simple_rnn_opt = Adam(simple_rnn.parameters())
gru_rnn_opt = Adam(gru_rnn.parameters())
lstm_rnn_opt = Adam(lstm_rnn.parameters())

In [45]:
loss_fn = nn.CrossEntropyLoss()

In [46]:
# Simple RNN model
torch.manual_seed(SEED)
start = perf_counter()
for e in range(N_EPOCHS):
    print(f"Epoch {e+1}\n-------------------------------")
    dataloader = DataLoader(
        dataset,
        BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn,
        drop_last=True,
    )
    train(dataloader, simple_rnn, loss_fn, simple_rnn_opt)
end = perf_counter()
print(f'Done! Train time: {end - start:.4f}s')

Epoch 1
-------------------------------
loss: 2.840596  [    0/21235]
loss: 0.202978  [12800/21235]
Epoch 2
-------------------------------
loss: 0.200353  [    0/21235]
loss: 0.132252  [12800/21235]
Epoch 3
-------------------------------
loss: 0.176207  [    0/21235]
loss: 0.120550  [12800/21235]
Epoch 4
-------------------------------
loss: 0.139270  [    0/21235]
loss: 0.099176  [12800/21235]
Epoch 5
-------------------------------
loss: 0.075701  [    0/21235]
loss: 0.090879  [12800/21235]
Epoch 6
-------------------------------
loss: 0.067742  [    0/21235]
loss: 0.056539  [12800/21235]
Epoch 7
-------------------------------
loss: 0.046602  [    0/21235]
loss: 0.061081  [12800/21235]
Epoch 8
-------------------------------
loss: 0.040490  [    0/21235]
loss: 0.049335  [12800/21235]
Epoch 9
-------------------------------
loss: 0.032992  [    0/21235]
loss: 0.027958  [12800/21235]
Epoch 10
-------------------------------
loss: 0.038102  [    0/21235]
loss: 0.035124  [12800/21235]

In [47]:
# GRU RNN model
torch.manual_seed(SEED)
start = perf_counter()
for e in range(N_EPOCHS):
    print(f"Epoch {e+1}\n-------------------------------")
    dataloader = DataLoader(
        dataset,
        BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn,
        drop_last=True,
    )
    train(dataloader, gru_rnn, loss_fn, gru_rnn_opt)
end = perf_counter()
print(f'Done! Train time: {end - start:.4f}s')

Epoch 1
-------------------------------
loss: 3.242490  [    0/21235]
loss: 0.207586  [12800/21235]
Epoch 2
-------------------------------
loss: 0.191505  [    0/21235]
loss: 0.121904  [12800/21235]
Epoch 3
-------------------------------
loss: 0.162709  [    0/21235]
loss: 0.107539  [12800/21235]
Epoch 4
-------------------------------
loss: 0.120855  [    0/21235]
loss: 0.092994  [12800/21235]
Epoch 5
-------------------------------
loss: 0.061491  [    0/21235]
loss: 0.079759  [12800/21235]
Epoch 6
-------------------------------
loss: 0.055441  [    0/21235]
loss: 0.049521  [12800/21235]
Epoch 7
-------------------------------
loss: 0.039002  [    0/21235]
loss: 0.052963  [12800/21235]
Epoch 8
-------------------------------
loss: 0.030167  [    0/21235]
loss: 0.038532  [12800/21235]
Epoch 9
-------------------------------
loss: 0.024477  [    0/21235]
loss: 0.022229  [12800/21235]
Epoch 10
-------------------------------
loss: 0.029689  [    0/21235]
loss: 0.025807  [12800/21235]

In [48]:
# LSTM RNN model
torch.manual_seed(SEED)
start = perf_counter()
for e in range(N_EPOCHS):
    print(f"Epoch {e+1}\n-------------------------------")
    dataloader = DataLoader(
        dataset,
        BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn,
        drop_last=True,
    )
    train(dataloader, lstm_rnn, loss_fn, lstm_rnn_opt)
end = perf_counter()
print(f'Done! Train time: {end - start:.4f}s')

Epoch 1
-------------------------------
loss: 2.796016  [    0/21235]
loss: 0.226620  [12800/21235]
Epoch 2
-------------------------------
loss: 0.201842  [    0/21235]
loss: 0.123999  [12800/21235]
Epoch 3
-------------------------------
loss: 0.170860  [    0/21235]
loss: 0.110687  [12800/21235]
Epoch 4
-------------------------------
loss: 0.122155  [    0/21235]
loss: 0.091878  [12800/21235]
Epoch 5
-------------------------------
loss: 0.065292  [    0/21235]
loss: 0.083195  [12800/21235]
Epoch 6
-------------------------------
loss: 0.058239  [    0/21235]
loss: 0.049416  [12800/21235]
Epoch 7
-------------------------------
loss: 0.039523  [    0/21235]
loss: 0.049778  [12800/21235]
Epoch 8
-------------------------------
loss: 0.031447  [    0/21235]
loss: 0.041792  [12800/21235]
Epoch 9
-------------------------------
loss: 0.025541  [    0/21235]
loss: 0.022387  [12800/21235]
Epoch 10
-------------------------------
loss: 0.024966  [    0/21235]
loss: 0.024645  [12800/21235]

In [49]:
torch.save(simple_rnn.state_dict(), 'simple_rnn.pth')
torch.save(gru_rnn.state_dict(), 'gru_rnn.pth')
torch.save(lstm_rnn.state_dict(), 'lstm_rnn.pth')

In [65]:
phrase = 'Paul looked at her , caught by the odd savagery beneath her casual attitude'
words = phrase.split(' ')
tokens = [dataset.word_vocab[w] for w in words]

In [66]:
start = perf_counter()
with torch.no_grad():
    pred = simple_rnn(torch.tensor(tokens).unsqueeze(0).to(device))
    labels = torch.argmax(pred, dim=-1).squeeze().cpu().detach().tolist()
    end = perf_counter()

print('Simple RNN')
print(f'Inference time: {end - start:.4f}s')
target_labels = list(dataset.target_vocab.keys())
print([target_labels[l - 1] for l in labels])

Simple RNN
Inference time: 0.0036s
['PROPN', 'VERB', 'ADP', 'PRON', 'PUNCT', 'VERB', 'ADP', 'DET', 'ADJ', 'NOUN', 'ADJ', 'PRON', 'ADJ', 'NOUN']


In [67]:
start = perf_counter()
with torch.no_grad():
    pred = gru_rnn(torch.tensor(tokens).unsqueeze(0).to(device))
    labels = torch.argmax(pred, dim=-1).squeeze().cpu().detach().tolist()
    end = perf_counter()

print('GRU RNN')
print(f'Inference time: {end - start:.4f}s')
target_labels = list(dataset.target_vocab.keys())
print([target_labels[l - 1] for l in labels])

GRU RNN
Inference time: 0.0043s
['PROPN', 'VERB', 'ADP', 'PRON', 'PUNCT', 'VERB', 'ADP', 'DET', 'ADJ', 'NOUN', 'ADP', 'PRON', 'ADJ', 'NOUN']


In [68]:
start = perf_counter()
with torch.no_grad():
    pred = lstm_rnn(torch.tensor(tokens).unsqueeze(0).to(device))
    labels = torch.argmax(pred, dim=-1).squeeze().cpu().detach().tolist()
    end = perf_counter()

print('LSTM RNN')
print(f'Inference time: {end - start:.4f}s')
target_labels = list(dataset.target_vocab.keys())
print([target_labels[l - 1] for l in labels])

LSTM RNN
Inference time: 0.0044s
['PROPN', 'VERB', 'ADP', 'PRON', 'PUNCT', 'VERB', 'ADP', 'DET', 'ADJ', 'PRON', 'NOUN', 'PRON', 'ADJ', 'NOUN']


## Выводы:
Как видим все модели неплохо справились, однако со словом beneath справилась только GRU.<br>
GRU и LSTM по времени инференса почти не отличаются, обычная RNN работает чуть быстрее.<br>
С точки зрения качества предсказания на обучающей выборке обычная RNN в 1.5 хуже по функции потерь, однако она обучилась примерно за 8 минут, против 22 минут у GRU и 27 минут у LSTM.<br>
Если нет потребности в сильной точности, то можно ограничится простой RNN. В более сложных задачах наверное лучше использовать GRU и LSTM.