In [1]:
import datetime

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import numpy as np

In [2]:
# from google.colab import drive
# drive.mount('/content/drive')


In [3]:
# Сравнить LSTM, RNN и GRU на задаче предсказания части речи (качество предсказания, скорость обучения, время инференса модели)

### Краткие итоги

In [32]:
# Итого по RNN:
# Training 02:26

# Loss 0.010004318319261074
# inf time:
# CPU times: user 21.7 s, sys: 26.1 s, total: 47.9 s
# Wall time: 27.9 s

In [5]:
# Итого LSTM 
# Training 06:18
# Loss 0.0064738476648926735

# inf time:
# CPU times: user 1min 3s, sys: 1min 8s, total: 2min 11s
# Wall time: 2min 11s

In [6]:
# Итого по GRU
# Training 05:08  mm:ss
# Loss 0.006786808371543884

# inf:
# CPU times: user 56.7 s, sys: 56.1 s, total: 1min 52s
# Wall time: 1min 34s

In [7]:
# Самая быстрая получилась RNN, по лоссу почти в два раза хуже остальных сеток.
# Наименьший лосс у LSTM, но она и  работает дольше RNN.
# В целом время тренировки коррелирует с временем инференса.

# При проверки модели использовался весь обучающий датасет.
# Делить выборку пробовал, но с этим были проблемы из-за словаря. Решил оставить затею и вернуть, как было.

### Далее код

In [8]:
data_dir = './'
train_lang = 'en'

In [9]:
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence

class DatasetSeq(Dataset):
    def __init__(self, data_dir, train_lang='en'):
	#open file
        with open(data_dir + train_lang + '.train', 'r') as f:
            train = f.read().split('\n\n')

        # delete extra tag markup
        train = [x for x in train if not '_ ' in x]
	    #init vocabs of tokens for encoding {<str> token: <int> id}
        self.target_vocab = {'<pad>': 0} # {p: 1, a: 2, r: 3, pu: 4}
        self.word_vocab = {'<pad>': 0} # {cat: 1, sat: 2, on: 3, mat: 4, '.': 5}
        self.char_vocab = {'<pad>': 0} # {c: 1, a: 2, t: 3, ' ': 4, s: 5}
	    
        # Cat sat on mat. -> [1, 2, 3, 4, 5]
        # p    a  r  p pu -> [1, 2, 3, 1, 4]
        # chars  -> [1, 2, 3, 4, 5, 2, 3, 4]

	    #init encoded sequences lists (processed data)
        self.encoded_sequences = []
        self.encoded_targets = []
        self.encoded_char_sequences = []
        # n=1 because first value is padding
        n_word = 1
        n_target = 1
        n_char = 1
        for line in train:
            sequence = []
            target = []
            chars = []
            for item in line.split('\n'):
                if item != '':
                    word, label = item.split(' ')

                    if self.word_vocab.get(word) is None:
                        self.word_vocab[word] = n_word
                        n_word += 1
                    if self.target_vocab.get(label) is None:
                        self.target_vocab[label] = n_target
                        n_target += 1
                    for char in word:
                        if self.char_vocab.get(char) is None:
                            self.char_vocab[char] = n_char
                            n_char += 1
                    sequence.append(self.word_vocab[word])
                    target.append(self.target_vocab[label])
                    chars.append([self.char_vocab[char] for char in word])
            self.encoded_sequences.append(sequence)
            self.encoded_targets.append(target)
            self.encoded_char_sequences.append(chars)

    def __len__(self):
        return len(self.encoded_sequences)

    def __getitem__(self, index):
        return {
            'data': self.encoded_sequences[index], # [1, 2, 3, 4, 6] len=5
            'char': self.encoded_char_sequences[index],# [[1,2,3], [4,5], [1,2], [2,6,5,4], []] len=5
            'target': self.encoded_targets[index], # [1, 2, 3, 4, 6] len=5
        }

In [10]:
dataset = DatasetSeq(data_dir)

In [11]:
def collate_fn(batch):
    data = []
    target = []
    for item in batch:
        data.append(torch.as_tensor(item['data']))
        target.append(torch.as_tensor(item['target']))
    data = pad_sequence(data, batch_first=True, padding_value=0)
    target = pad_sequence(target, batch_first=True, padding_value=0)

    return {'data': data, 'target': target}

In [12]:
class RNN_GRU(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, n_classes):
        super().__init__()
        self.word_emb = nn.Embedding(vocab_size, emb_dim)
        #TODO try to use other RNN archicetures, f.e. RNN and LSTM
#         self.word_emb = nn.Embedding(vocab_size, emb_dim)
        self.rnn = nn.GRU(emb_dim, hidden_dim, batch_first=True)
        self.clf = nn.Linear(hidden_dim, n_classes)
        self.do = nn.Dropout(0.1)
    
    def forward(self, x):
        emb = self.word_emb(x) # B x T x Emb_dim
        hidden, _ = self.rnn(emb) # B x T x Hid, B x 1 x Hid
        pred = self.clf(self.do(hidden)) # B x T x N_classes

        return pred

In [13]:
#hyper params
vocab_size = len(dataset.word_vocab) + 1
n_classes = len(dataset.target_vocab) + 1
n_chars = len(dataset.char_vocab) + 1
#TODO try to use other model parameters
emb_dim = 256
hidden = 256
n_epochs = 10
# batch_size = 64
cuda_device = -1
batch_size = 100
device = 'mps' if cuda_device != -1 else 'cpu'

In [14]:
model = RNN_GRU(vocab_size, emb_dim, hidden, n_classes).to(device)
model.train()
optim = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss()

In [15]:
from tqdm import tqdm
for epoch in tqdm(range(n_epochs)):
    dataloader = DataLoader(dataset, 
                            batch_size, 
                            shuffle=True, 
                            collate_fn=collate_fn,
                            drop_last = True,
                            )
    for i, batch in enumerate(dataloader):
        optim.zero_grad()

        predict = model(batch['data'].to(device))
        loss = loss_func(predict.view(-1, n_classes),
                         batch['target'].to(device).view(-1), 
                         )
        loss.backward()
        optim.step()
        if i % 100 == 0:
            print(f'epoch: {epoch}, step: {i}, loss: {loss.item()}')
   
    torch.save(model.state_dict(), f'./rnn_chkpt_{epoch}.pth')

  0%|          | 0/10 [00:00<?, ?it/s]

epoch: 0, step: 0, loss: 3.0209927558898926
epoch: 0, step: 100, loss: 0.23971734941005707
epoch: 0, step: 200, loss: 0.19011838734149933


 10%|█         | 1/10 [00:29<04:24, 29.38s/it]

epoch: 1, step: 0, loss: 0.17262177169322968
epoch: 1, step: 100, loss: 0.18685826659202576
epoch: 1, step: 200, loss: 0.12672248482704163


 20%|██        | 2/10 [01:03<04:15, 31.94s/it]

epoch: 2, step: 0, loss: 0.10885640978813171
epoch: 2, step: 100, loss: 0.11688248068094254
epoch: 2, step: 200, loss: 0.1368888020515442


 30%|███       | 3/10 [01:33<03:38, 31.20s/it]

epoch: 3, step: 0, loss: 0.1278466135263443
epoch: 3, step: 100, loss: 0.09457961469888687
epoch: 3, step: 200, loss: 0.043349333107471466


 40%|████      | 4/10 [02:03<03:04, 30.74s/it]

epoch: 4, step: 0, loss: 0.08545960485935211
epoch: 4, step: 100, loss: 0.0912952646613121
epoch: 4, step: 200, loss: 0.06851591914892197


 50%|█████     | 5/10 [02:35<02:36, 31.21s/it]

epoch: 5, step: 0, loss: 0.05169614031910896
epoch: 5, step: 100, loss: 0.04664351046085358
epoch: 5, step: 200, loss: 0.04021913930773735


 60%|██████    | 6/10 [03:06<02:04, 31.23s/it]

epoch: 6, step: 0, loss: 0.04740612208843231
epoch: 6, step: 100, loss: 0.048272036015987396
epoch: 6, step: 200, loss: 0.04780229181051254


 70%|███████   | 7/10 [03:37<01:32, 30.99s/it]

epoch: 7, step: 0, loss: 0.04671379551291466
epoch: 7, step: 100, loss: 0.054097529500722885
epoch: 7, step: 200, loss: 0.03877849504351616


 80%|████████  | 8/10 [04:07<01:01, 30.78s/it]

epoch: 8, step: 0, loss: 0.03637346997857094
epoch: 8, step: 100, loss: 0.02860109694302082
epoch: 8, step: 200, loss: 0.021310096606612206


 90%|█████████ | 9/10 [04:38<00:30, 30.68s/it]

epoch: 9, step: 0, loss: 0.03170089051127434
epoch: 9, step: 100, loss: 0.02074115164577961
epoch: 9, step: 200, loss: 0.038910266011953354


100%|██████████| 10/10 [05:08<00:00, 30.83s/it]


In [16]:
# Итого на обучение GRU потрачено 05:08

In [17]:
        predict = model(batch['data'].to(device))
        loss = loss_func(predict.view(-1, n_classes),
                         batch['target'].to(device).view(-1), 
                         )

In [18]:
%%time
dataloader = DataLoader(dataset, 
                            batch_size = len(dataset), 
                            shuffle=True, 
                            collate_fn=collate_fn,
                            drop_last = True,
                            )
model.eval()
for batch in tqdm(dataloader):
    with torch.no_grad():
        predict = model(batch['data'].to(device))
        loss = loss_func(predict.view(-1, n_classes),
                             batch['target'].to(device).view(-1), 
                             )
print(loss.item())

100%|██████████| 1/1 [01:34<00:00, 94.80s/it]

0.006786808371543884
CPU times: user 56.7 s, sys: 56.1 s, total: 1min 52s
Wall time: 1min 34s





In [19]:
# Итого по GRU
# Training 05:08  mm:ss
# Loss 0.006786808371543884

# inf:
# CPU times: user 56.7 s, sys: 56.1 s, total: 1min 52s
# Wall time: 1min 34s

# LSTM

In [20]:
class LSTM_RNN(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, n_classes):
        super().__init__()
        self.word_emb = nn.Embedding(vocab_size, emb_dim)
        #TODO try to use other RNN archicetures, f.e. RNN and LSTM
#         self.word_emb = nn.Embedding(vocab_size, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hidden_dim, batch_first=True)
        self.clf = nn.Linear(hidden_dim, n_classes)
        self.do = nn.Dropout(0.1)
    
    def forward(self, x):
        emb = self.word_emb(x) # B x T x Emb_dim
        hidden, _ = self.rnn(emb) # B x T x Hid, B x 1 x Hid
        pred = self.clf(self.do(hidden)) # B x T x N_classes

        return pred

In [21]:
model = LSTM_RNN(vocab_size, emb_dim, hidden, n_classes).to(device)
model.train()
optim = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss()

In [22]:
from tqdm import tqdm
for epoch in tqdm(range(n_epochs)):
    dataloader = DataLoader(dataset, 
                            batch_size, 
                            shuffle=True, 
                            collate_fn=collate_fn,
                            drop_last = True,
                            )
    
    for i, batch in enumerate(dataloader):
        optim.zero_grad()
        predict = model(batch['data'].to(device))
        loss = loss_func(predict.view(-1, n_classes),
                         batch['target'].to(device).view(-1), 
                         )
        loss.backward()
        optim.step()
        if i % 100 == 0:
            print(f'epoch: {epoch}, step: {i}, loss: {loss.item()}')
   
    torch.save(model.state_dict(), f'./rnn_chkpt_{epoch}.pth')

  0%|          | 0/10 [00:00<?, ?it/s]

epoch: 0, step: 0, loss: 2.896737813949585
epoch: 0, step: 100, loss: 0.34550487995147705
epoch: 0, step: 200, loss: 0.22371971607208252


 10%|█         | 1/10 [00:37<05:41, 37.92s/it]

epoch: 1, step: 0, loss: 0.06949827075004578
epoch: 1, step: 100, loss: 0.14931996166706085
epoch: 1, step: 200, loss: 0.0991474837064743


 20%|██        | 2/10 [01:14<04:58, 37.34s/it]

epoch: 2, step: 0, loss: 0.1524089276790619
epoch: 2, step: 100, loss: 0.109224334359169
epoch: 2, step: 200, loss: 0.07955622673034668


 30%|███       | 3/10 [01:52<04:21, 37.29s/it]

epoch: 3, step: 0, loss: 0.03469686582684517
epoch: 3, step: 100, loss: 0.09225381165742874
epoch: 3, step: 200, loss: 0.09221889823675156


 40%|████      | 4/10 [02:29<03:44, 37.40s/it]

epoch: 4, step: 0, loss: 0.05210357531905174
epoch: 4, step: 100, loss: 0.06773122400045395
epoch: 4, step: 200, loss: 0.06038236245512962


 50%|█████     | 5/10 [03:07<03:07, 37.57s/it]

epoch: 5, step: 0, loss: 0.04614487662911415
epoch: 5, step: 100, loss: 0.03866646811366081
epoch: 5, step: 200, loss: 0.048889756202697754


 60%|██████    | 6/10 [03:44<02:29, 37.48s/it]

epoch: 6, step: 0, loss: 0.08125331997871399
epoch: 6, step: 100, loss: 0.050544243305921555
epoch: 6, step: 200, loss: 0.04475888982415199


 70%|███████   | 7/10 [04:21<01:52, 37.37s/it]

epoch: 7, step: 0, loss: 0.049557849764823914
epoch: 7, step: 100, loss: 0.03464404121041298
epoch: 7, step: 200, loss: 0.03848026320338249


 80%|████████  | 8/10 [04:59<01:14, 37.42s/it]

epoch: 8, step: 0, loss: 0.04805496707558632
epoch: 8, step: 100, loss: 0.026667052879929543
epoch: 8, step: 200, loss: 0.03146441653370857


 90%|█████████ | 9/10 [05:40<00:38, 38.50s/it]

epoch: 9, step: 0, loss: 0.020581014454364777
epoch: 9, step: 100, loss: 0.027861973270773888
epoch: 9, step: 200, loss: 0.03414291515946388


100%|██████████| 10/10 [06:18<00:00, 37.86s/it]


In [23]:
%%time
dataloader = DataLoader(dataset, 
                            batch_size = len(dataset), 
                            shuffle=True, 
                            collate_fn=collate_fn,
                            drop_last = True,
                            )
model.eval()
for batch in tqdm(dataloader):
    with torch.no_grad():
        predict = model(batch['data'].to(device))
        loss = loss_func(predict.view(-1, n_classes),
                             batch['target'].to(device).view(-1), 
                             )
print(loss.item())

100%|██████████| 1/1 [02:11<00:00, 131.15s/it]

0.0064738476648926735
CPU times: user 1min 3s, sys: 1min 8s, total: 2min 11s
Wall time: 2min 11s





In [24]:
# Итого LSTM 
# Training 06:18
# Loss 0.0064738476648926735

# inf time:
# CPU times: user 1min 3s, sys: 1min 8s, total: 2min 11s
# Wall time: 2min 11s

# RNN

In [25]:
class RNN(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, n_classes):
        super().__init__()
        self.word_emb = nn.Embedding(vocab_size, emb_dim)
        #TODO try to use other RNN archicetures, f.e. RNN and LSTM
#         self.word_emb = nn.Embedding(vocab_size, emb_dim)
        self.rnn = nn.RNN(emb_dim, hidden_dim, batch_first=True)
        self.clf = nn.Linear(hidden_dim, n_classes)
        self.do = nn.Dropout(0.1)
    
    def forward(self, x):
        emb = self.word_emb(x) # B x T x Emb_dim
        hidden, _ = self.rnn(emb) # B x T x Hid, B x 1 x Hid
        pred = self.clf(self.do(hidden)) # B x T x N_classes

        return pred

In [26]:
model = RNN(vocab_size, emb_dim, hidden, n_classes).to(device)
model.train()
optim = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss()

In [27]:
from tqdm import tqdm
for epoch in tqdm(range(n_epochs)):
    dataloader = DataLoader(dataset, 
                            batch_size, 
                            shuffle=True, 
                            collate_fn=collate_fn,
                            drop_last = True,
                            )
    
    for i, batch in enumerate(dataloader):
        optim.zero_grad()
        predict = model(batch['data'].to(device))
        loss = loss_func(predict.view(-1, n_classes),
                         batch['target'].to(device).view(-1), 
                         )
        loss.backward()
        optim.step()
        if i % 100 == 0:
            print(f'epoch: {epoch}, step: {i}, loss: {loss.item()}')
   
    torch.save(model.state_dict(), f'./rnn_chkpt_{epoch}.pth')

  0%|          | 0/10 [00:00<?, ?it/s]

epoch: 0, step: 0, loss: 2.690293550491333
epoch: 0, step: 100, loss: 0.3794049620628357
epoch: 0, step: 200, loss: 0.2011912316083908


 10%|█         | 1/10 [00:14<02:11, 14.61s/it]

epoch: 1, step: 0, loss: 0.2103102207183838
epoch: 1, step: 100, loss: 0.21691042184829712
epoch: 1, step: 200, loss: 0.12616492807865143


 20%|██        | 2/10 [00:29<01:56, 14.53s/it]

epoch: 2, step: 0, loss: 0.14406153559684753
epoch: 2, step: 100, loss: 0.10701227188110352
epoch: 2, step: 200, loss: 0.13120998442173004


 30%|███       | 3/10 [00:43<01:42, 14.58s/it]

epoch: 3, step: 0, loss: 0.09255042672157288
epoch: 3, step: 100, loss: 0.09091448038816452
epoch: 3, step: 200, loss: 0.08240093290805817


 40%|████      | 4/10 [00:58<01:27, 14.58s/it]

epoch: 4, step: 0, loss: 0.08517985790967941
epoch: 4, step: 100, loss: 0.08245503902435303
epoch: 4, step: 200, loss: 0.10096476972103119


 50%|█████     | 5/10 [01:12<01:13, 14.60s/it]

epoch: 5, step: 0, loss: 0.06579200178384781
epoch: 5, step: 100, loss: 0.06601131707429886
epoch: 5, step: 200, loss: 0.053149737417697906


 60%|██████    | 6/10 [01:27<00:58, 14.58s/it]

epoch: 6, step: 0, loss: 0.04098718985915184
epoch: 6, step: 100, loss: 0.05772813409566879
epoch: 6, step: 200, loss: 0.05526456981897354


 70%|███████   | 7/10 [01:42<00:43, 14.61s/it]

epoch: 7, step: 0, loss: 0.0627102181315422
epoch: 7, step: 100, loss: 0.07079427689313889
epoch: 7, step: 200, loss: 0.052282970398664474


 80%|████████  | 8/10 [01:56<00:29, 14.57s/it]

epoch: 8, step: 0, loss: 0.04723954200744629
epoch: 8, step: 100, loss: 0.044466450810432434
epoch: 8, step: 200, loss: 0.03113018162548542


 90%|█████████ | 9/10 [02:11<00:14, 14.58s/it]

epoch: 9, step: 0, loss: 0.035367827862501144
epoch: 9, step: 100, loss: 0.03763251751661301
epoch: 9, step: 200, loss: 0.039136942476034164


100%|██████████| 10/10 [02:26<00:00, 14.60s/it]


In [28]:
%%time
dataloader = DataLoader(dataset, 
                            batch_size = len(dataset), 
                            shuffle=True, 
                            collate_fn=collate_fn,
                            drop_last = True,
                            )
model.eval()
for batch in tqdm(dataloader):
    with torch.no_grad():
        predict = model(batch['data'].to(device))
        loss = loss_func(predict.view(-1, n_classes),
                             batch['target'].to(device).view(-1), 
                             )
print(loss.item())

100%|██████████| 1/1 [00:27<00:00, 27.91s/it]

0.010004318319261074
CPU times: user 21.7 s, sys: 26.1 s, total: 47.9 s
Wall time: 27.9 s





In [30]:
# Итого по RNN:
# Training 02:33

# Loss 0.010004318319261074
# inf time:
# CPU times: user 21.7 s, sys: 26.1 s, total: 47.9 s
# Wall time: 27.9 s