# Богданов Александр Иванович, Б05-003

## Анализ модели LSTM

In [300]:
from tqdm.notebook import tqdm
import numpy as np
import warnings
warnings.filterwarnings("ignore")

import torch
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter

from prettytable import PrettyTable

from nerus import load_nerus

In [301]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

### Вспомогательные функции

In [302]:
def get_sent_tags(docs, size=10000):
    list_of_sent = []
    list_of_tags = []
    for doc in tqdm(docs):
        for sent in doc.sents:
            list_of_sent_toks = []
            list_of_sent_tags = []
            for tok in sent.tokens:
                list_of_sent_toks.append(tok.text)
                list_of_sent_tags.append(tok.pos)
        list_of_sent.append(list_of_sent_toks)
        list_of_tags.append(list_of_sent_tags)
        if len(list_of_sent) > size:
            break
    return list_of_sent, list_of_tags

In [303]:
def pos_dict(list_of_tags, test_size=100):
    pos2idx = {'<PAD>' : 0}
    idx2pos = ['<PAD>']

    for tags in list_of_tags[:-test_size]:
        for word in tags:
            if word not in pos2idx:
                pos2idx[word] = len(idx2pos)
                idx2pos.append(word)
    return pos2idx, idx2pos

In [304]:
def word_dict(list_of_tags, test_size=100):
    word2idx = {'<PAD>': 0, '<UNK>': 1}
    idx2word = ['<PAD>', '<UNK>']
    
    for sent in list_of_sent[:-test_size]:
        for word in sent:
            if word not in word2idx:
                word2idx[word] = len(idx2word)
                idx2word.append(word)
                
    return word2idx, idx2word

In [305]:
class NerusDataset(Dataset):
    def __init__(self, list_of_sent, list_of_tags, word2idx, pos2idx, train=True, test_size=100):
        self.X = []
        self.y = []
        
        if train:
            for sent in list_of_sent[:-test_size]:
                data = []
                for word in sent:
                    data.append(word2idx.get(word, 0))
                self.X.append(data)
            
            for tags in list_of_tags[:-test_size]:
                data = []
                for word in tags:
                    data.append(pos2idx.get(word, 0))
                self.y.append(data)
        else:
            for sent in list_of_sent[-test_size:]:
                data = []
                for word in sent:
                    data.append(word2idx.get(word, 0))
                self.X.append(data)
                
            for tags in list_of_tags[-test_size:]:
                data = []
                for word in tags:
                    data.append(pos2idx.get(word, 0))
                self.y.append(data)
            
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return torch.Tensor(self.X[idx]), torch.Tensor(self.y[idx])

In [306]:
def collate_fn(data):
    X, Y = [], []
    for x, y in data:
        X.append(x)
        Y.append(y)
    x_batch = torch.zeros((len(X), max(list(map(len, X)))), dtype=torch.long)
    y_batch = torch.zeros((len(Y), max(list(map(len, Y)))), dtype=torch.long)
    
    for i, sent in enumerate(X):
        x_batch[i, :len(sent)] = sent

    for i, sent in enumerate(Y):
        y_batch[i, :len(sent)] = sent
        
    return x_batch, y_batch

In [327]:
def check(batch_size, dataset, model, loss_function, idx2word, idx2pos):
    
    model.eval()
    
    batch_generator = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, collate_fn=collate_fn)
            
    test_acc = 0
    test_loss = 0
    count = 0
    for it, (x_batch, y_batch) in enumerate(batch_generator):
        x_batch = x_batch.to(model.device)
        y_batch = y_batch.to(model.device)
                
        mask = (y_batch != 0)
        count += mask.sum()
                
        output = model(x_batch)

        test_loss += loss_function(output.transpose(1,2), y_batch).cpu().item()*len(x_batch)
        test_acc += (torch.argmax(output, dim=-1).cpu() == y_batch)[mask].sum().item()
            
    test_loss /= len(dataset)
    test_acc /= count

    print(f'loss: {test_loss}, acc: {test_acc}')
    
    dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
    x, y = next(iter(dataloader))
    x = x.to(device)
    y = y.to(device)

    outputs = model(x)
    
    one_x = x[0].cpu().numpy()
    one_y = y[0].cpu().numpy()
    one_output = outputs[0].argmax(dim=-1).cpu().numpy()

    words = [idx2word[idx] for idx in one_x]
    true_tags = [idx2pos[idx] for idx in one_y]
    pred_tags = [idx2pos[idx] for idx in one_output]

    table = PrettyTable(["Word", "True tag", "Predicted tag"])
    table.align["Word"], table.align["True tag"], table.align["Predicted tag"] = "l", "l", "l"

    for word, true_tag, pred_tag in zip(words, true_tags, pred_tags):
        if word != idx2word[word2idx['<PAD>']]:
            table.add_row([word, true_tag, pred_tag])

    print(table)

    return test_loss, test_acc

In [308]:
def train_on_batch(model, x_batch, y_batch, optimizer, loss_function):
    model.train()
    model.zero_grad()
    
    output = model(x_batch.to(device))
    
    loss = loss_function(output.transpose(1, 2), y_batch.to(device))
    loss.backward()

    optimizer.step()
    return loss.cpu().item()

In [309]:
def train_epoch(train_generator, model, loss_function, optimizer, callback = None):
    epoch_loss = 0
    total = 0
    for it, (batch_of_x, batch_of_y) in enumerate(train_generator):
        batch_loss = train_on_batch(model, batch_of_x.to(device), batch_of_y.to(device), optimizer, loss_function)
        train_generator.set_postfix({'train batch loss': batch_loss})
        
        if callback is not None:
            callback(model, batch_loss)
            
        epoch_loss += batch_loss*len(batch_of_x)
        total += len(batch_of_x)
    
    return epoch_loss/total

In [310]:
def trainer(count_of_epoch, 
            batch_size, 
            dataset,
            model, 
            loss_function,
            optimizer,
            lr = 0.001,
            callback = None):

    optima = optimizer(model.parameters(), lr=lr)
    
    iterations = tqdm(range(count_of_epoch), desc='epoch')
    iterations.set_postfix({'train epoch loss': np.nan})
    for it in iterations:
        batch_generator = tqdm(
            torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn), 
            leave=False, total=len(dataset)//batch_size+(len(dataset)%batch_size> 0))
        
        epoch_loss = train_epoch(
                    train_generator=batch_generator, 
                    model=model, 
                    loss_function=loss_function, 
                    optimizer=optima, 
                    callback=callback)
        
        iterations.set_postfix({'train epoch loss': epoch_loss})

In [311]:
class callback():
    def __init__(self, writer, dataset, loss_function, delimeter = 300, batch_size=64):
        self.step = 0
        self.writer = writer
        self.delimeter = delimeter
        self.loss_function = loss_function
        self.batch_size = batch_size

        self.dataset = dataset

    def forward(self, model, loss):
        model.eval()
        self.step += 1
        self.writer.add_scalar('LOSS/train', loss, self.step)
        
        if self.step % self.delimeter == 0:
            
            batch_generator = torch.utils.data.DataLoader(dataset=self.dataset, 
                                                          batch_size=self.batch_size,
                                                          collate_fn=collate_fn)
            
            test_acc = 0
            test_loss = 0
            count = 0
            for it, (x_batch, y_batch) in enumerate(batch_generator):
                x_batch = x_batch.to(model.device)
                y_batch = y_batch.to(model.device)
                
                mask = (y_batch != 0)
                count += mask.sum()
                
                output = model(x_batch)

                test_loss += self.loss_function(output.transpose(1,2), y_batch).cpu().item()*len(x_batch)

                test_acc += (torch.argmax(output, dim=-1).cpu() == y_batch)[mask].sum().item()
            
            test_loss /= len(self.dataset)
            test_acc /= count

            print(f"\t step={self.step}, train_loss={loss}, val_loss={test_loss}, val_acc={test_acc}")
            
            self.writer.add_scalar('LOSS/test', test_loss, self.step)
            self.writer.add_scalar('ACC/test', test_acc, self.step)
          
    def __call__(self, model, loss):
        return self.forward(model, loss)

## Модель

In [312]:
class LSTM(torch.nn.Module):
    @property
    def device(self):
        return next(self.parameters()).device
        
    def __init__(self,
                 vocab_dim,
                 output_dim = 18,
                 emb_dim = 10, 
                 hidden_dim = 10,
                 num_layers = 3,
                 dropout = 0,
                 batch_norm = False,
                 bidirectional = False):
        super(LSTM, self).__init__()
        
        self.embedding = torch.nn.Embedding(vocab_dim, emb_dim, padding_idx=0)
        self.encoder = torch.nn.LSTM(emb_dim, hidden_dim, num_layers, bidirectional = bidirectional, batch_first=True)
        if batchnorm:
            self.batch_norm = torch.nn.BatchNorm1d(hidden_dim)
        else:
            self.batch_norm = None
        self.dropout = torch.nn.Dropout(p=dropout)
        self.linear = torch.nn.Linear(hidden_dim, output_dim)
        
    def forward(self, input):
        out = self.embedding(input)
        out, _ = self.encoder(out)
        if self.batch_norm is not None:
            out = self.batch_norm(out.transpose(1, 2)).transpose(1, 2)
        out = self.dropout(out)
        out = self.linear(out)
        return out

## Подключим tensorboard

In [313]:
%load_ext tensorboard
%tensorboard --logdir tensorboard_2/

## Обработка данных

In [314]:
docs = load_nerus("nerus_lenta.conllu.gz")

In [315]:
list_of_sent, list_of_tags = get_sent_tags(docs, 10000)

0it [00:00, ?it/s]

In [316]:
pos2idx, idx2pos = pos_dict(list_of_tags)

In [317]:
word2idx, idx2word = word_dict(list_of_sent)

In [318]:
train_data = NerusDataset(list_of_sent, list_of_tags, word2idx, pos2idx, train=True)
test_data = NerusDataset(list_of_sent, list_of_tags, word2idx, pos2idx, train=False)

## Обучение

In [319]:
loss_function = torch.nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam

In [320]:
dim_list = [10, 20, 30]
num_layers_list = [3, 5, 7, 9]
dropout_list = [0, 0.3, 0.5]
batch_norm_list = [False, True]
len_dict_list = [50000, 100000, 150000]

In [321]:
for dim in dim_list:
    print(f'dim = {dim}')
    
    model = LSTM(vocab_dim=len(word2idx), emb_dim=dim, hidden_dim=dim, output_dim=len(pos2idx))
    model.to(device)
    
    writer = SummaryWriter(log_dir=f'tensorboard_2/dim_{dim}')
    call = callback(writer, test_data, loss_function)
    
    check(64, test_data, model, loss_function, idx2word, idx2pos)
    trainer(count_of_epoch=10, 
            batch_size=64, 
            dataset=train_data,
            model=model, 
            loss_function=loss_function,
            optimizer = optimizer,
            callback=call)
    check(64, test_data, model, loss_function, idx2word, idx2pos)

dim = 10
loss: 2.943786954879761, acc: 0.022016221657395363
+------------+----------+---------------+
| Word       | True tag | Predicted tag |
+------------+----------+---------------+
| В          | ADP      | ADV           |
| нем        | PRON     | ADV           |
| они        | PRON     | ADV           |
| опровергли | VERB     | ADV           |
| информацию | NOUN     | ADV           |
| о          | ADP      | ADV           |
| связи      | NOUN     | ADV           |
| с          | ADP      | ADV           |
| российской | ADJ      | ADV           |
| разведкой  | NOUN     | ADV           |
| и          | CCONJ    | ADV           |
| утверждали | VERB     | ADV           |
| ,          | PUNCT    | ADV           |
| что        | SCONJ    | ADV           |
| с          | ADP      | ADV           |
| визитом    | NOUN     | ADV           |
| и          | CCONJ    | ADV           |
| собирались | VERB     | ADV           |
| посмотреть | VERB     | ADV           |
| на         | A

epoch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=300, train_loss=2.168388605117798, val_loss=2.2419653129577637, val_acc=0.2665121555328369


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=600, train_loss=2.192312240600586, val_loss=2.2246517753601074, val_acc=0.2885283827781677


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=900, train_loss=2.150078058242798, val_loss=2.1914392280578614, val_acc=0.2891077697277069


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1200, train_loss=1.6692229509353638, val_loss=1.6657492542266845, val_acc=0.4548088014125824


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1500, train_loss=1.3005439043045044, val_loss=1.3656794786453248, val_acc=0.5330243110656738
loss: 1.3201122999191284, acc: 0.5643105506896973
+------------+----------+---------------+
| Word       | True tag | Predicted tag |
+------------+----------+---------------+
| По         | ADP      | ADP           |
| их         | DET      | PRON          |
| заявлениям | NOUN     | VERB          |
| ,          | PUNCT    | PUNCT         |
| твит       | NOUN     | VERB          |
| Маска      | PROPN    | VERB          |
| принес     | VERB     | NOUN          |
| им         | PRON     | VERB          |
| убыток     | NOUN     | VERB          |
| .          | PUNCT    | PUNCT         |
+------------+----------+---------------+
dim = 20
loss: 2.967445316314697, acc: 0.010428736917674541
+-----------------+----------+---------------+
| Word            | True tag | Predicted tag |
+-----------------+----------+---------------+
| По              | ADP      | PART          |
| предваритель

epoch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=300, train_loss=2.1371102333068848, val_loss=2.1378755283355715, val_acc=0.37833139300346375


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=600, train_loss=1.642945647239685, val_loss=1.6544002532958983, val_acc=0.4258400797843933


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=900, train_loss=1.2439254522323608, val_loss=1.297070050239563, val_acc=0.5202780961990356


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1200, train_loss=0.9636591672897339, val_loss=0.9738593149185181, val_acc=0.692931592464447


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1500, train_loss=0.8070133328437805, val_loss=0.8439490795135498, val_acc=0.7219003438949585
loss: 0.8280075001716614, acc: 0.725376546382904
+---------+----------+---------------+
| Word    | True tag | Predicted tag |
+---------+----------+---------------+
| В       | ADP      | ADP           |
| ответ   | NOUN     | NOUN          |
| он      | PRON     | PRON          |
| :       | PUNCT    | PUNCT         |
| «       | PUNCT    | PUNCT         |
| быть    | AUX      | CCONJ         |
| не      | PART     | CCONJ         |
| стыдно  | ADV      | VERB          |
| ,       | PUNCT    | PUNCT         |
| стыдно  | ADV      | VERB          |
| быть    | AUX      | CCONJ         |
| дешевым | ADJ      | VERB          |
| »       | PUNCT    | PUNCT         |
| .       | PUNCT    | PUNCT         |
+---------+----------+---------------+
dim = 30
loss: 2.880441274642944, acc: 0.010428736917674541
+---------------+----------+---------------+
| Word          | True tag | Predicted tag |

epoch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=300, train_loss=1.7502983808517456, val_loss=1.8108046531677247, val_acc=0.4889918863773346


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=600, train_loss=1.2603060007095337, val_loss=1.2261375284194946, val_acc=0.6066048741340637


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=900, train_loss=0.9872215986251831, val_loss=0.981963210105896, val_acc=0.6865584850311279


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1200, train_loss=0.6944470405578613, val_loss=0.7773849034309387, val_acc=0.740440309047699


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1500, train_loss=0.544668436050415, val_loss=0.638272910118103, val_acc=0.7972189784049988
loss: 0.6303366827964783, acc: 0.8070683479309082
+-------------+----------+---------------+
| Word        | True tag | Predicted tag |
+-------------+----------+---------------+
| Группа      | NOUN     | NOUN          |
| вошла       | VERB     | VERB          |
| в           | ADP      | ADP           |
| Зал         | PROPN    | ADJ           |
| славы       | NOUN     | NOUN          |
| рок-н-ролла | NOUN     | NOUN          |
| в           | ADP      | ADP           |
| 1996        | ADJ      | ADJ           |
| году        | NOUN     | NOUN          |
| .           | PUNCT    | PUNCT         |
+-------------+----------+---------------+


Как мы видим - лучше всего использовать размер слоя 30.

In [322]:
for num_layers in num_layers_list:
    print(f'num_layers = {num_layers}')
    
    model = LSTM(vocab_dim=len(word2idx), num_layers=num_layers)
    model.to(device)
    
    writer = SummaryWriter(log_dir=f'tensorboard_2/num_layers_{num_layers}')
    call = callback(writer, test_data, loss_function)
    
    check(64, test_data, model, loss_function, idx2word, idx2pos)
    trainer(count_of_epoch=10, 
            batch_size=64, 
            dataset=train_data,
            model=model, 
            loss_function=loss_function,
            optimizer = optimizer,
            callback=call)
    check(64, test_data, model, loss_function, idx2word, idx2pos)

num_layers = 3
loss: 2.7688625144958494, acc: 0.26998841762542725
+--------------------+----------+---------------+
| Word               | True tag | Predicted tag |
+--------------------+----------+---------------+
| Он                 | PRON     | NOUN          |
| призвал            | VERB     | NOUN          |
| бороться           | VERB     | NOUN          |
| с                  | ADP      | NOUN          |
| мздоимством        | NOUN     | NOUN          |
| ,                  | PUNCT    | NOUN          |
| нанося             | VERB     | NOUN          |
| удары              | NOUN     | NOUN          |
| «                  | PUNCT    | NOUN          |
| и                  | CCONJ    | NOUN          |
| по                 | ADP      | NOUN          |
| тигру              | NOUN     | NOUN          |
| ,                  | PUNCT    | NOUN          |
| и                  | CCONJ    | NOUN          |
| по                 | ADP      | NOUN          |
| мухе               | NOUN     | 

epoch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=300, train_loss=2.179889440536499, val_loss=2.2437978076934812, val_acc=0.26998841762542725


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=600, train_loss=2.1469616889953613, val_loss=2.222812490463257, val_acc=0.26998841762542725


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=900, train_loss=2.1240622997283936, val_loss=2.1720971202850343, val_acc=0.272885262966156


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1200, train_loss=1.748166799545288, val_loss=1.7818132877349853, val_acc=0.44090381264686584


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1500, train_loss=1.5219806432724, val_loss=1.539350953102112, val_acc=0.5504055619239807
loss: 1.4960809230804444, acc: 0.5521436929702759
+-------------+----------+---------------+
| Word        | True tag | Predicted tag |
+-------------+----------+---------------+
| В           | ADP      | ADP           |
| подписи     | NOUN     | NOUN          |
| к           | ADP      | ADP           |
| видео       | NOUN     | NOUN          |
| упоминается | VERB     | NOUN          |
| персонаж    | NOUN     | NOUN          |
| —           | PUNCT    | PUNCT         |
| из          | ADP      | ADP           |
| серии       | NOUN     | NOUN          |
| видеоигр    | NOUN     | NOUN          |
| .           | PUNCT    | PUNCT         |
+-------------+----------+---------------+
num_layers = 5
loss: 2.8956561279296875, acc: 0.0
+---------------------------------+----------+---------------+
| Word                            | True tag | Predicted tag |
+--------------------------------

epoch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=300, train_loss=2.2035202980041504, val_loss=2.2489081764221193, val_acc=0.26998841762542725


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=600, train_loss=2.1978530883789062, val_loss=2.22392333984375, val_acc=0.26998841762542725


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=900, train_loss=2.1919264793395996, val_loss=2.2092659664154053, val_acc=0.26998841762542725


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1200, train_loss=2.168609619140625, val_loss=2.1984876537323, val_acc=0.2682502865791321


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1500, train_loss=2.173926830291748, val_loss=2.189113063812256, val_acc=0.2885283827781677
loss: 2.18811842918396, acc: 0.2885283827781677
+----------------+----------+---------------+
| Word           | True tag | Predicted tag |
+----------------+----------+---------------+
| За             | ADP      | ADP           |
| день           | NOUN     | NOUN          |
| до             | ADP      | NOUN          |
| этого          | PRON     | NOUN          |
| игроки         | NOUN     | NOUN          |
| поставили      | VERB     | NOUN          |
| в              | ADP      | NOUN          |
| Instagram      | NOUN     | NOUN          |
| под            | ADP      | NOUN          |
| стихотворением | NOUN     | NOUN          |
| актера         | NOUN     | NOUN          |
| Дмитрия        | PROPN    | NOUN          |
| Назарова       | PROPN    | NOUN          |
| с              | ADP      | NOUN          |
| критикой       | NOUN     | NOUN          |
| наставника     | NOUN   

epoch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=300, train_loss=2.1824333667755127, val_loss=2.22796422958374, val_acc=0.26998841762542725


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=600, train_loss=2.162968635559082, val_loss=2.2077189350128172, val_acc=0.26998841762542725


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=900, train_loss=2.1626052856445312, val_loss=2.1993054294586183, val_acc=0.26998841762542725


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1200, train_loss=2.165659189224243, val_loss=2.1894202041625976, val_acc=0.2885283827781677


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1500, train_loss=2.1599009037017822, val_loss=2.1826791858673094, val_acc=0.2885283827781677
loss: 2.1845289993286134, acc: 0.2885283827781677
+-----------+----------+---------------+
| Word      | True tag | Predicted tag |
+-----------+----------+---------------+
| Они       | PRON     | ADP           |
| готовы    | ADJ      | NOUN          |
| даже      | PART     | NOUN          |
| оплатить  | VERB     | NOUN          |
| штраф     | NOUN     | NOUN          |
| в         | ADP      | NOUN          |
| 150       | NUM      | NOUN          |
| евро      | NOUN     | NOUN          |
| (         | PUNCT    | NOUN          |
| примерно  | ADV      | NOUN          |
| тысячи    | NOUN     | NOUN          |
| рублей    | NOUN     | NOUN          |
| )         | PUNCT    | NOUN          |
| ,         | PUNCT    | NOUN          |
| так       | ADV      | NOUN          |
| как       | SCONJ    | NOUN          |
| это       | PRON     | NOUN          |
| обойдется | VERB     | NOUN 

epoch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=300, train_loss=2.2419397830963135, val_loss=2.247304735183716, val_acc=0.26998841762542725


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=600, train_loss=2.23158860206604, val_loss=2.2289364910125733, val_acc=0.26998841762542725


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=900, train_loss=2.1356711387634277, val_loss=2.215140895843506, val_acc=0.26998841762542725


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1200, train_loss=2.1517250537872314, val_loss=2.2059787845611574, val_acc=0.26998841762542725


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1500, train_loss=2.1995115280151367, val_loss=2.1966159248352053, val_acc=0.2885283827781677
loss: 2.195254373550415, acc: 0.2885283827781677
+------------+----------+---------------+
| Word       | True tag | Predicted tag |
+------------+----------+---------------+
| По         | ADP      | ADP           |
| их         | DET      | NOUN          |
| заявлениям | NOUN     | NOUN          |
| ,          | PUNCT    | NOUN          |
| твит       | NOUN     | NOUN          |
| Маска      | PROPN    | NOUN          |
| принес     | VERB     | NOUN          |
| им         | PRON     | NOUN          |
| убыток     | NOUN     | NOUN          |
| .          | PUNCT    | NOUN          |
+------------+----------+---------------+


Как мы видим - лучше всего использовать 3 слоя.

In [323]:
for dropout in dropout_list:
    print(f'dropout = {dropout}')
    
    model = LSTM(vocab_dim=len(word2idx), dropout=dropout)
    model.to(device)
    
    writer = SummaryWriter(log_dir=f'tensorboard_2/dropout_{dropout}')
    call = callback(writer, test_data, loss_function)
    
    check(64, test_data, model, loss_function, idx2word, idx2pos)
    trainer(count_of_epoch=10, 
            batch_size=64, 
            dataset=train_data,
            model=model, 
            loss_function=loss_function,
            optimizer = optimizer,
            callback=call)
    check(64, test_data, model, loss_function, idx2word, idx2pos)

dropout = 0
loss: 2.934212760925293, acc: 0.0
+-----------+----------+---------------+
| Word      | True tag | Predicted tag |
+-----------+----------+---------------+
| Одни      | DET      | INTJ          |
| пользу    | NOUN     | INTJ          |
| ,         | PUNCT    | INTJ          |
| другие    | ADJ      | INTJ          |
| указывают | VERB     | INTJ          |
| на        | ADP      | INTJ          |
| вред      | NOUN     | INTJ          |
| .         | PUNCT    | INTJ          |
+-----------+----------+---------------+


epoch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=300, train_loss=2.216475486755371, val_loss=2.240416326522827, val_acc=0.2665121555328369


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=600, train_loss=2.157855987548828, val_loss=2.2154041862487794, val_acc=0.26998841762542725


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=900, train_loss=2.176269292831421, val_loss=2.1756235122680665, val_acc=0.27462339401245117


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1200, train_loss=1.668185830116272, val_loss=1.722159504890442, val_acc=0.5057937502861023


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1500, train_loss=1.314639687538147, val_loss=1.3743450927734375, val_acc=0.5863267779350281
loss: 1.3361519002914428, acc: 0.5990729928016663
+-------------+----------+---------------+
| Word        | True tag | Predicted tag |
+-------------+----------+---------------+
| —           | PUNCT    | PUNCT         |
| (           | PUNCT    | PUNCT         |
| 2012        | ADJ      | NOUN          |
| )           | PUNCT    | PUNCT         |
| и           | CCONJ    | CCONJ         |
| двукратный  | ADJ      | NOUN          |
| призер      | NOUN     | NOUN          |
| чемпионатов | NOUN     | NOUN          |
| Европы      | PROPN    | NOUN          |
| (           | PUNCT    | PUNCT         |
| 2007        | ADJ      | NOUN          |
| и           | CCONJ    | CCONJ         |
| 2011        | ADJ      | NOUN          |
| годы        | NOUN     | NOUN          |
| )           | PUNCT    | PUNCT         |
| ,           | PUNCT    | PUNCT         |
| а           | CCONJ    | CCONJ  

epoch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=300, train_loss=2.2386271953582764, val_loss=2.240007791519165, val_acc=0.26998841762542725


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=600, train_loss=1.8485844135284424, val_loss=1.8651882314682007, val_acc=0.4258400797843933


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=900, train_loss=1.7017019987106323, val_loss=1.705317406654358, val_acc=0.42873695492744446


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1200, train_loss=1.5938231945037842, val_loss=1.568288631439209, val_acc=0.4316338300704956


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1500, train_loss=1.4354887008666992, val_loss=1.414461793899536, val_acc=0.4907299876213074
loss: 1.38475266456604, acc: 0.5
+--------------+----------+---------------+
| Word         | True tag | Predicted tag |
+--------------+----------+---------------+
| В            | ADP      | NOUN          |
| течение      | NOUN     | NOUN          |
| осени        | NOUN     | NOUN          |
| разработчики | NOUN     | NOUN          |
| проведут     | VERB     | NOUN          |
| проекта      | NOUN     | NOUN          |
| ,            | PUNCT    | PUNCT         |
| на           | ADP      | NOUN          |
| которое      | PRON     | PRON          |
| можно        | ADV      | PRON          |
| на           | ADP      | NOUN          |
| сайте        | NOUN     | NOUN          |
| игры         | NOUN     | NOUN          |
| .            | PUNCT    | PUNCT         |
+--------------+----------+---------------+
dropout = 0.5
loss: 2.8985625553131102, acc: 0.03244495764374733
+----------

epoch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=300, train_loss=2.272543430328369, val_loss=2.2536212253570556, val_acc=0.26303592324256897


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=600, train_loss=2.095494031906128, val_loss=1.9804879856109618, val_acc=0.3858632445335388


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=900, train_loss=1.7339156866073608, val_loss=1.65535662651062, val_acc=0.4275782108306885


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1200, train_loss=1.5890618562698364, val_loss=1.4597036361694335, val_acc=0.4600231647491455


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1500, train_loss=1.450932502746582, val_loss=1.3395133876800538, val_acc=0.4600231647491455
loss: 1.3254704618453979, acc: 0.4594437777996063
+------------+----------+---------------+
| Word       | True tag | Predicted tag |
+------------+----------+---------------+
| По         | ADP      | PUNCT         |
| подсчетам  | NOUN     | NOUN          |
| «          | PUNCT    | PUNCT         |
| »          | PUNCT    | PUNCT         |
| ,          | PUNCT    | PUNCT         |
| к          | ADP      | PUNCT         |
| 2018-го    | ADJ      | NOUN          |
| в          | ADP      | PUNCT         |
| республике | NOUN     | NOUN          |
| погибли    | VERB     | NOUN          |
| 87         | NUM      | NOUN          |
| россиян    | NOUN     | NOUN          |
| .          | PUNCT    | PUNCT         |
+------------+----------+---------------+
dropout = 0.8
loss: 2.9008523654937743, acc: 0.03244495764374733
+--------------+----------+---------------+
| Word         | True tag | 

epoch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=300, train_loss=2.3407304286956787, val_loss=2.2257367324829103, val_acc=0.26998841762542725


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=600, train_loss=2.154033899307251, val_loss=2.041973352432251, val_acc=0.2775202691555023


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=900, train_loss=1.9993135929107666, val_loss=1.8684067440032959, val_acc=0.3800695240497589


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1200, train_loss=1.850610613822937, val_loss=1.71686119556427, val_acc=0.3933951258659363


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1500, train_loss=1.8044638633728027, val_loss=1.615329966545105, val_acc=0.39455386996269226
loss: 1.5980723190307617, acc: 0.39455386996269226
+---------------+----------+---------------+
| Word          | True tag | Predicted tag |
+---------------+----------+---------------+
| Во            | ADP      | ADP           |
| ФСИН          | PROPN    | NOUN          |
| пояснили      | VERB     | NOUN          |
| ,             | PUNCT    | NOUN          |
| что           | SCONJ    | ADP           |
| это           | PRON     | ADP           |
| было          | AUX      | ADP           |
| сделано       | VERB     | NOUN          |
| по            | ADP      | ADP           |
| правительства | NOUN     | NOUN          |
| .             | PUNCT    | NOUN          |
+---------------+----------+---------------+


Как мы видим - лучше всего использовать dropout c 0.5.

In [324]:
for batch_norm in batch_norm_list:
    print(f'batchnorm = {batch_norm}')
    
    model = LSTM(vocab_dim=len(word2idx), batch_norm=batch_norm)
    model.to(device)
    
    writer = SummaryWriter(log_dir=f'tensorboard_2/batch_norm_{batch_norm}')
    call = callback(writer, test_data, loss_function)
    
    check(64, test_data, model, loss_function, idx2word, idx2pos)
    trainer(count_of_epoch=10, 
            batch_size=64, 
            dataset=train_data,
            model=model, 
            loss_function=loss_function,
            optimizer = optimizer,
            callback=call)
    check(64, test_data, model, loss_function, idx2word, idx2pos)

batch_norm = False
loss: 2.8620294666290285, acc: 0.16164541244506836
+---------------+----------+---------------+
| Word          | True tag | Predicted tag |
+---------------+----------+---------------+
| В             | ADP      | PUNCT         |
| результате    | NOUN     | PUNCT         |
| землетрясения | NOUN     | PUNCT         |
| были          | AUX      | PUNCT         |
| разрушены     | VERB     | PUNCT         |
| тысячи        | NOUN     | PUNCT         |
| домов         | NOUN     | PUNCT         |
| ,             | PUNCT    | PUNCT         |
| а             | CCONJ    | PUNCT         |
| также         | ADV      | PUNCT         |
| отели         | NOUN     | PUNCT         |
| ,             | PUNCT    | PUNCT         |
| больницы      | NOUN     | PUNCT         |
| и             | CCONJ    | PUNCT         |
| торговые      | ADJ      | PUNCT         |
| центры        | NOUN     | PUNCT         |
| .             | PUNCT    | PUNCT         |
+---------------+----------+--

epoch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=300, train_loss=2.1636157035827637, val_loss=2.2413926887512208, val_acc=0.26998841762542725


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=600, train_loss=2.1398491859436035, val_loss=2.2130838298797606, val_acc=0.26998841762542725


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=900, train_loss=2.1514692306518555, val_loss=2.1872722721099853, val_acc=0.2885283827781677


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1200, train_loss=1.7868026494979858, val_loss=1.819361572265625, val_acc=0.4484356641769409


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1500, train_loss=1.7129496335983276, val_loss=1.7271662616729737, val_acc=0.45133253931999207
loss: 1.709066891670227, acc: 0.45307067036628723
+---------+----------+---------------+
| Word    | True tag | Predicted tag |
+---------+----------+---------------+
| В       | ADP      | ADP           |
| ответ   | NOUN     | NOUN          |
| он      | PRON     | NOUN          |
| :       | PUNCT    | PUNCT         |
| «       | PUNCT    | PUNCT         |
| быть    | AUX      | NOUN          |
| не      | PART     | NOUN          |
| стыдно  | ADV      | NOUN          |
| ,       | PUNCT    | PUNCT         |
| стыдно  | ADV      | NOUN          |
| быть    | AUX      | NOUN          |
| дешевым | ADJ      | NOUN          |
| »       | PUNCT    | PUNCT         |
| .       | PUNCT    | PUNCT         |
+---------+----------+---------------+
batch_norm = True
loss: 2.9311695384979246, acc: 0.0
+-------------+----------+---------------+
| Word        | True tag | Predicted tag |
+-------

epoch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=300, train_loss=1.7125821113586426, val_loss=1.6898076629638672, val_acc=0.49884122610092163


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=600, train_loss=1.2985923290252686, val_loss=1.3564738988876344, val_acc=0.5701043009757996


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=900, train_loss=1.0579735040664673, val_loss=1.086392731666565, val_acc=0.6460022926330566


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1200, train_loss=0.8655945658683777, val_loss=0.883663227558136, val_acc=0.6969872117042542


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1500, train_loss=0.7061347365379333, val_loss=0.7359313678741455, val_acc=0.7497103214263916
loss: 0.7095850706100464, acc: 0.7555040121078491
+-----------------+----------+---------------+
| Word            | True tag | Predicted tag |
+-----------------+----------+---------------+
| В               | ADP      | ADP           |
| Кремле          | PROPN    | NOUN          |
| соответствующие | ADJ      | NOUN          |
| инициативы      | NOUN     | NOUN          |
| комментировать  | VERB     | ADJ           |
| не              | PART     | PART          |
| стали           | VERB     | VERB          |
| .               | PUNCT    | PUNCT         |
+-----------------+----------+---------------+


BatchNorm надо использовать.

In [326]:
for len_dict in len_dict_list:
    print(f'len_dict = {len_dict}')
    
    list_of_sent, list_of_tags = get_sent_tags(docs, len_dict)
    pos2idx, idx2pos = pos_dict(list_of_tags)
    word2idx, idx2word = word_dict(list_of_sent)
    
    train_data = NerusDataset(list_of_sent, list_of_tags, word2idx, pos2idx, train=True)
    test_data = NerusDataset(list_of_sent, list_of_tags, word2idx, pos2idx, train=False)
    
    model = LSTM(vocab_dim=len(word2idx))
    model.to(device)
    
    writer = SummaryWriter(log_dir=f'tensorboard_2/len_dict_{len_dict}')
    call = callback(writer, test_data, loss_function)
    
    check(64, test_data, model, loss_function, idx2word, idx2pos)
    trainer(count_of_epoch=10, 
            batch_size=64, 
            dataset=train_data,
            model=model, 
            loss_function=loss_function,
            optimizer = optimizer,
            callback=call)
    check(64, test_data, model, loss_function, idx2word, idx2pos)

len_dict = 50000


0it [00:00, ?it/s]

loss: 2.751299066543579, acc: 0.11717922240495682
+-------------+----------+---------------+
| Word        | True tag | Predicted tag |
+-------------+----------+---------------+
| Первые      | ADJ      | ADP           |
| пять        | NUM      | ADP           |
| миллиардов  | NOUN     | ADP           |
| были        | AUX      | ADP           |
| перечислены | VERB     | ADP           |
| в           | ADP      | ADP           |
| марте       | NOUN     | ADP           |
| 2015        | ADJ      | ADP           |
| года        | NOUN     | ADP           |
| ,           | PUNCT    | ADP           |
| еще         | ADV      | ADP           |
| 1,7         | NUM      | ADP           |
| миллиарда   | NOUN     | ADP           |
| —           | PUNCT    | ADP           |
| в           | ADP      | ADP           |
| августе     | NOUN     | ADP           |
| того        | DET      | ADP           |
| же          | PART     | ADP           |
| года        | NOUN     | ADP           |
| . 

epoch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/780 [00:00<?, ?it/s]

	 step=300, train_loss=2.17415714263916, val_loss=2.190752458572388, val_acc=0.30063629150390625
	 step=600, train_loss=1.9619183540344238, val_loss=1.9039065217971802, val_acc=0.4432661831378937


  0%|          | 0/780 [00:00<?, ?it/s]

	 step=900, train_loss=1.588650107383728, val_loss=1.659011206626892, val_acc=0.4872746765613556
	 step=1200, train_loss=1.4322758913040161, val_loss=1.494058198928833, val_acc=0.5588547587394714
	 step=1500, train_loss=1.3380062580108643, val_loss=1.4257608461380005, val_acc=0.5657476186752319


  0%|          | 0/780 [00:00<?, ?it/s]

	 step=1800, train_loss=1.2752432823181152, val_loss=1.3551241207122802, val_acc=0.5683987736701965
	 step=2100, train_loss=1.2933509349822998, val_loss=1.2783186435699463, val_acc=0.5874868035316467


  0%|          | 0/780 [00:00<?, ?it/s]

	 step=2400, train_loss=1.1281224489212036, val_loss=1.1814544916152954, val_acc=0.6139979362487793
	 step=2700, train_loss=0.9686529636383057, val_loss=1.023742823600769, val_acc=0.6797455549240112
	 step=3000, train_loss=0.90737384557724, val_loss=0.9168824815750122, val_acc=0.7104984521865845


  0%|          | 0/780 [00:00<?, ?it/s]

	 step=3300, train_loss=0.761817455291748, val_loss=0.8421183323860169, val_acc=0.7200424671173096
	 step=3600, train_loss=0.719936192035675, val_loss=0.7794430875778198, val_acc=0.7279958128929138
	 step=3900, train_loss=0.6839991807937622, val_loss=0.7244663834571838, val_acc=0.7444326877593994


  0%|          | 0/780 [00:00<?, ?it/s]

	 step=4200, train_loss=0.6112229824066162, val_loss=0.6685533118247986, val_acc=0.7656416296958923
	 step=4500, train_loss=0.5870016813278198, val_loss=0.6246959161758423, val_acc=0.7841994166374207


  0%|          | 0/780 [00:00<?, ?it/s]

	 step=4800, train_loss=0.5877983570098877, val_loss=0.5837391066551209, val_acc=0.80169677734375
	 step=5100, train_loss=0.526932954788208, val_loss=0.5497924852371215, val_acc=0.8335101008415222
	 step=5400, train_loss=0.40582510828971863, val_loss=0.5008601999282837, val_acc=0.8547190427780151


  0%|          | 0/780 [00:00<?, ?it/s]

	 step=5700, train_loss=0.41881194710731506, val_loss=0.4740096700191498, val_acc=0.8653234839439392
	 step=6000, train_loss=0.37554383277893066, val_loss=0.44540046572685243, val_acc=0.8727465867996216


  0%|          | 0/780 [00:00<?, ?it/s]

	 step=6300, train_loss=0.29447251558303833, val_loss=0.4266130793094635, val_acc=0.8796395063400269
	 step=6600, train_loss=0.28996631503105164, val_loss=0.41054497718811034, val_acc=0.8865323662757874
	 step=6900, train_loss=0.28703829646110535, val_loss=0.39320804715156554, val_acc=0.8955461978912354


  0%|          | 0/780 [00:00<?, ?it/s]

	 step=7200, train_loss=0.27149537205696106, val_loss=0.37940789461135865, val_acc=0.8960763812065125
	 step=7500, train_loss=0.2777436375617981, val_loss=0.36441257715225217, val_acc=0.9003182053565979
	 step=7800, train_loss=0.2550980746746063, val_loss=0.353090363740921, val_acc=0.9061506390571594
loss: 0.353090363740921, acc: 0.9061506390571594
+-------------+----------+---------------+
| Word        | True tag | Predicted tag |
+-------------+----------+---------------+
| В           | ADP      | ADP           |
| сентябре    | NOUN     | NOUN          |
| 2016        | ADJ      | ADJ           |
| года        | NOUN     | NOUN          |
| апартаменты | NOUN     | NOUN          |
| в           | ADP      | ADP           |
| комплексе   | NOUN     | NOUN          |
| «           | PUNCT    | PUNCT         |
| Федерация   | PROPN    | PROPN         |
| »           | PUNCT    | PUNCT         |
| были        | AUX      | AUX           |
| сданы       | VERB     | NOUN          |
| в 

0it [00:00, ?it/s]

loss: 2.9356502628326417, acc: 0.030534351244568825
+--------------+----------+---------------+
| Word         | True tag | Predicted tag |
+--------------+----------+---------------+
| Это          | PRON     | ADV           |
| позволило    | VERB     | ADV           |
| установить   | VERB     | ADV           |
| личности     | NOUN     | ADV           |
| ряда         | NOUN     | ADV           |
| членов       | NOUN     | ADV           |
| банды        | NOUN     | ADV           |
| и            | CCONJ    | ADV           |
| впоследствии | ADV      | ADV           |
| задержать    | VERB     | ADV           |
| их           | PRON     | ADV           |
| .            | PUNCT    | ADV           |
+--------------+----------+---------------+


epoch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1561 [00:00<?, ?it/s]

	 step=300, train_loss=2.1823244094848633, val_loss=2.2516268062591553, val_acc=0.2707379162311554
	 step=600, train_loss=2.1538970470428467, val_loss=2.2316928482055665, val_acc=0.2707379162311554
	 step=900, train_loss=2.108912706375122, val_loss=2.197469720840454, val_acc=0.2941475808620453
	 step=1200, train_loss=1.7808514833450317, val_loss=1.858955512046814, val_acc=0.4396946430206299
	 step=1500, train_loss=1.6384191513061523, val_loss=1.6851587104797363, val_acc=0.46768447756767273


  0%|          | 0/1561 [00:00<?, ?it/s]

	 step=1800, train_loss=1.418698787689209, val_loss=1.4732735013961793, val_acc=0.5496183037757874
	 step=2100, train_loss=1.2784459590911865, val_loss=1.3512261056900023, val_acc=0.5511450171470642
	 step=2400, train_loss=1.2146025896072388, val_loss=1.2543799495697021, val_acc=0.5730279684066772
	 step=2700, train_loss=1.0866175889968872, val_loss=1.1594622087478639, val_acc=0.6213740706443787
	 step=3000, train_loss=0.9724507927894592, val_loss=1.0763964867591858, val_acc=0.6503816843032837


  0%|          | 0/1561 [00:00<?, ?it/s]

	 step=3300, train_loss=0.8940892815589905, val_loss=1.01496013879776, val_acc=0.6671755909919739
	 step=3600, train_loss=0.9253404140472412, val_loss=0.968252592086792, val_acc=0.6854962110519409
	 step=3900, train_loss=0.7495632767677307, val_loss=0.9061030411720276, val_acc=0.7160305380821228
	 step=4200, train_loss=0.7205329537391663, val_loss=0.8580114269256591, val_acc=0.7307888269424438
	 step=4500, train_loss=0.7301415205001831, val_loss=0.796689441204071, val_acc=0.7394402027130127


  0%|          | 0/1561 [00:00<?, ?it/s]

	 step=4800, train_loss=0.7610759139060974, val_loss=0.749108967781067, val_acc=0.7516539692878723
	 step=5100, train_loss=0.5817445516586304, val_loss=0.7095578980445861, val_acc=0.7659032940864563
	 step=5400, train_loss=0.5627533197402954, val_loss=0.6792437887191772, val_acc=0.7781170606613159
	 step=5700, train_loss=0.503745436668396, val_loss=0.6506291913986206, val_acc=0.7898218631744385
	 step=6000, train_loss=0.560690701007843, val_loss=0.6188961100578309, val_acc=0.8091602921485901


  0%|          | 0/1561 [00:00<?, ?it/s]

	 step=6300, train_loss=0.6004912853240967, val_loss=0.6005349373817443, val_acc=0.8142493963241577
	 step=6600, train_loss=0.45451492071151733, val_loss=0.5773879337310791, val_acc=0.8234096765518188
	 step=6900, train_loss=0.4223775267601013, val_loss=0.5646361064910889, val_acc=0.8284987211227417
	 step=7200, train_loss=0.4593382775783539, val_loss=0.540513322353363, val_acc=0.8356233835220337
	 step=7500, train_loss=0.4541116952896118, val_loss=0.5271411275863648, val_acc=0.8381679654121399
	 step=7800, train_loss=0.38744598627090454, val_loss=0.5146540141105652, val_acc=0.8432570099830627


  0%|          | 0/1561 [00:00<?, ?it/s]

	 step=8100, train_loss=0.34807971119880676, val_loss=0.5069481921195984, val_acc=0.8463104367256165
	 step=8400, train_loss=0.3578006327152252, val_loss=0.4974575209617615, val_acc=0.854452908039093
	 step=8700, train_loss=0.3329041302204132, val_loss=0.48436217069625853, val_acc=0.8580152988433838
	 step=9000, train_loss=0.3280828595161438, val_loss=0.4688332140445709, val_acc=0.8636132478713989
	 step=9300, train_loss=0.2837768793106079, val_loss=0.46587601661682126, val_acc=0.8636132478713989


  0%|          | 0/1561 [00:00<?, ?it/s]

	 step=9600, train_loss=0.33569803833961487, val_loss=0.46052619218826296, val_acc=0.8671755790710449
	 step=9900, train_loss=0.2967661917209625, val_loss=0.4550191831588745, val_acc=0.8697201013565063
	 step=10200, train_loss=0.26131144165992737, val_loss=0.4493402302265167, val_acc=0.8748091459274292
	 step=10500, train_loss=0.34802499413490295, val_loss=0.4400095772743225, val_acc=0.8732824325561523
	 step=10800, train_loss=0.26617127656936646, val_loss=0.43323188066482543, val_acc=0.876335859298706


  0%|          | 0/1561 [00:00<?, ?it/s]

	 step=11100, train_loss=0.2656514346599579, val_loss=0.4358349144458771, val_acc=0.8798982501029968
	 step=11400, train_loss=0.2528670132160187, val_loss=0.42638030409812927, val_acc=0.8793892860412598
	 step=11700, train_loss=0.2781057059764862, val_loss=0.4259167671203613, val_acc=0.8788803815841675
	 step=12000, train_loss=0.2788260281085968, val_loss=0.42039448380470273, val_acc=0.8844783902168274
	 step=12300, train_loss=0.23661145567893982, val_loss=0.41798457980155945, val_acc=0.8885496258735657


  0%|          | 0/1561 [00:00<?, ?it/s]

	 step=12600, train_loss=0.23424425721168518, val_loss=0.41880964279174804, val_acc=0.889058530330658
	 step=12900, train_loss=0.3135918378829956, val_loss=0.4183088374137878, val_acc=0.8951653838157654
	 step=13200, train_loss=0.24026606976985931, val_loss=0.42003442406654357, val_acc=0.9007633328437805
	 step=13500, train_loss=0.2658415138721466, val_loss=0.4128185188770294, val_acc=0.9007633328437805
	 step=13800, train_loss=0.21317988634109497, val_loss=0.41215975284576417, val_acc=0.9043257236480713


  0%|          | 0/1561 [00:00<?, ?it/s]

	 step=14100, train_loss=0.22402851283550262, val_loss=0.40328992128372193, val_acc=0.9134860038757324
	 step=14400, train_loss=0.3067173659801483, val_loss=0.4127724552154541, val_acc=0.9119592905044556
	 step=14700, train_loss=0.2108595222234726, val_loss=0.4097630798816681, val_acc=0.910941481590271
	 step=15000, train_loss=0.2996213734149933, val_loss=0.39885209798812865, val_acc=0.9155216217041016
	 step=15300, train_loss=0.15275312960147858, val_loss=0.3988737893104553, val_acc=0.9175572395324707
	 step=15600, train_loss=0.2547042965888977, val_loss=0.4040169024467468, val_acc=0.9175572395324707
loss: 0.39784971356391907, acc: 0.9175572395324707
+------------+----------+---------------+
| Word       | True tag | Predicted tag |
+------------+----------+---------------+
| Также      | ADV      | ADV           |
| не         | PART     | PART          |
| исключили  | VERB     | VERB          |
| переход    | NOUN     | NOUN          |
| на         | ADP      | ADP           |
| от

0it [00:00, ?it/s]

loss: 2.962384204864502, acc: 0.02566683292388916
+-----------+----------+---------------+
| Word      | True tag | Predicted tag |
+-----------+----------+---------------+
| 20        | ADJ      | CCONJ         |
| июня      | NOUN     | CCONJ         |
| клуб      | NOUN     | CCONJ         |
| продлил   | VERB     | CCONJ         |
| контракты | NOUN     | CCONJ         |
| с         | ADP      | CCONJ         |
| и         | CCONJ    | CCONJ         |
| Андреем   | PROPN    | CCONJ         |
| .         | PUNCT    | CCONJ         |
+-----------+----------+---------------+


epoch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/2343 [00:00<?, ?it/s]

	 step=300, train_loss=2.257784843444824, val_loss=2.2403133964538573, val_acc=0.2747860848903656
	 step=600, train_loss=2.2116317749023438, val_loss=2.206220645904541, val_acc=0.2833417057991028
	 step=900, train_loss=2.1563665866851807, val_loss=2.1916823768615723, val_acc=0.2833417057991028
	 step=1200, train_loss=1.7541184425354004, val_loss=1.7222108221054078, val_acc=0.49370908737182617
	 step=1500, train_loss=1.3611618280410767, val_loss=1.386070213317871, val_acc=0.5455459952354431
	 step=1800, train_loss=1.152496337890625, val_loss=1.1596536684036254, val_acc=0.6114745736122131
	 step=2100, train_loss=1.0111852884292603, val_loss=1.017269582748413, val_acc=0.6597886085510254


  0%|          | 0/2343 [00:00<?, ?it/s]

	 step=2400, train_loss=0.963224470615387, val_loss=0.9259310626983642, val_acc=0.6995469927787781
	 step=2700, train_loss=0.851319432258606, val_loss=0.8624996209144592, val_acc=0.7141419053077698
	 step=3000, train_loss=0.7809246182441711, val_loss=0.8134115862846375, val_acc=0.7221941947937012
	 step=3300, train_loss=0.7074428796768188, val_loss=0.7739262676239014, val_acc=0.7297433018684387
	 step=3600, train_loss=0.7531043291091919, val_loss=0.7364250469207764, val_acc=0.7377956509590149
	 step=3900, train_loss=0.7079928517341614, val_loss=0.6941439723968506, val_acc=0.7629591822624207
	 step=4200, train_loss=0.6135386228561401, val_loss=0.6542793869972229, val_acc=0.7775540351867676
	 step=4500, train_loss=0.5951389670372009, val_loss=0.6179312348365784, val_acc=0.7936587333679199


  0%|          | 0/2343 [00:00<?, ?it/s]

	 step=4800, train_loss=0.5168314576148987, val_loss=0.5855531334877014, val_acc=0.8178157806396484
	 step=5100, train_loss=0.5363302826881409, val_loss=0.5582571637630462, val_acc=0.8268746137619019
	 step=5400, train_loss=0.49608683586120605, val_loss=0.534521062374115, val_acc=0.8359335064888
	 step=5700, train_loss=0.494312584400177, val_loss=0.5143003714084625, val_acc=0.8384498953819275
	 step=6000, train_loss=0.49757805466651917, val_loss=0.49438767790794375, val_acc=0.8485153317451477
	 step=6300, train_loss=0.44846123456954956, val_loss=0.4755078339576721, val_acc=0.8555610775947571
	 step=6600, train_loss=0.47129157185554504, val_loss=0.4630428671836853, val_acc=0.8570709228515625
	 step=6900, train_loss=0.4384761154651642, val_loss=0.4485400760173798, val_acc=0.8636134266853333


  0%|          | 0/2343 [00:00<?, ?it/s]

	 step=7200, train_loss=0.43151095509529114, val_loss=0.43683913230896, val_acc=0.8721690773963928
	 step=7500, train_loss=0.4183478355407715, val_loss=0.4243521988391876, val_acc=0.8772017359733582
	 step=7800, train_loss=0.420198917388916, val_loss=0.40940109848976136, val_acc=0.8857573866844177
	 step=8100, train_loss=0.40688467025756836, val_loss=0.39675777792930605, val_acc=0.8907901048660278
	 step=8400, train_loss=0.34447163343429565, val_loss=0.37642072558403017, val_acc=0.8902868032455444
	 step=8700, train_loss=0.4198586940765381, val_loss=0.36538257598876955, val_acc=0.8902868032455444
	 step=9000, train_loss=0.4005824625492096, val_loss=0.35147714614868164, val_acc=0.8922998905181885
	 step=9300, train_loss=0.35067030787467957, val_loss=0.34271239876747134, val_acc=0.8933064341545105


  0%|          | 0/2343 [00:00<?, ?it/s]

	 step=9600, train_loss=0.30766770243644714, val_loss=0.3348023498058319, val_acc=0.8958228230476379
	 step=9900, train_loss=0.27623578906059265, val_loss=0.3297548878192902, val_acc=0.8973326086997986
	 step=10200, train_loss=0.30880409479141235, val_loss=0.32071205854415896, val_acc=0.9003522396087646
	 step=10500, train_loss=0.2700585722923279, val_loss=0.3150857162475586, val_acc=0.9013587832450867
	 step=10800, train_loss=0.27922695875167847, val_loss=0.3085281229019165, val_acc=0.9023653268814087
	 step=11100, train_loss=0.26930177211761475, val_loss=0.2992644846439362, val_acc=0.9028685688972473
	 step=11400, train_loss=0.24136239290237427, val_loss=0.2933707535266876, val_acc=0.9068947434425354
	 step=11700, train_loss=0.28259706497192383, val_loss=0.2875398689508438, val_acc=0.9073980450630188


  0%|          | 0/2343 [00:00<?, ?it/s]

	 step=12000, train_loss=0.25612619519233704, val_loss=0.28611698746681213, val_acc=0.9089078307151794
	 step=12300, train_loss=0.2472347468137741, val_loss=0.2800625103712082, val_acc=0.9109209179878235
	 step=12600, train_loss=0.21175505220890045, val_loss=0.274473232626915, val_acc=0.9134373068809509
	 step=12900, train_loss=0.20728902518749237, val_loss=0.26839071691036226, val_acc=0.9139405488967896
	 step=13200, train_loss=0.26228174567222595, val_loss=0.2660785973072052, val_acc=0.916456937789917
	 step=13500, train_loss=0.22402453422546387, val_loss=0.2543609893321991, val_acc=0.9209863543510437
	 step=13800, train_loss=0.2963131368160248, val_loss=0.254135200381279, val_acc=0.9245092868804932


  0%|          | 0/2343 [00:00<?, ?it/s]

	 step=14100, train_loss=0.26085394620895386, val_loss=0.2528988665342331, val_acc=0.9265223741531372
	 step=14400, train_loss=0.1902899444103241, val_loss=0.24614856421947479, val_acc=0.9285354018211365
	 step=14700, train_loss=0.19649729132652283, val_loss=0.24339707016944886, val_acc=0.9300452470779419
	 step=15000, train_loss=0.18944920599460602, val_loss=0.23391113698482513, val_acc=0.933064877986908
	 step=15300, train_loss=0.13658350706100464, val_loss=0.2288960462808609, val_acc=0.9335681200027466
	 step=15600, train_loss=0.22064122557640076, val_loss=0.2270075821876526, val_acc=0.9335681200027466
	 step=15900, train_loss=0.20279881358146667, val_loss=0.22087429583072662, val_acc=0.9380975961685181
	 step=16200, train_loss=0.17937059700489044, val_loss=0.21448033630847932, val_acc=0.9401106834411621


  0%|          | 0/2343 [00:00<?, ?it/s]

	 step=16500, train_loss=0.18736568093299866, val_loss=0.2144860190153122, val_acc=0.9416204690933228
	 step=16800, train_loss=0.1712111085653305, val_loss=0.21607169270515442, val_acc=0.9391041398048401
	 step=17100, train_loss=0.18433357775211334, val_loss=0.20838802874088289, val_acc=0.9401106834411621
	 step=17400, train_loss=0.13598328828811646, val_loss=0.20756876826286316, val_acc=0.9426270127296448
	 step=17700, train_loss=0.1456078439950943, val_loss=0.20535694122314452, val_acc=0.9431303143501282
	 step=18000, train_loss=0.18662507832050323, val_loss=0.2000361567735672, val_acc=0.9436335563659668
	 step=18300, train_loss=0.17808927595615387, val_loss=0.19702111780643464, val_acc=0.9451434016227722
	 step=18600, train_loss=0.17649084329605103, val_loss=0.19480936706066132, val_acc=0.9426270127296448


  0%|          | 0/2343 [00:00<?, ?it/s]

	 step=18900, train_loss=0.16719451546669006, val_loss=0.19709946751594543, val_acc=0.9451434016227722
	 step=19200, train_loss=0.12367787957191467, val_loss=0.1954007351398468, val_acc=0.9451434016227722
	 step=19500, train_loss=0.13029004633426666, val_loss=0.1919304597377777, val_acc=0.9451434016227722
	 step=19800, train_loss=0.15529634058475494, val_loss=0.1858449923992157, val_acc=0.9471564888954163
	 step=20100, train_loss=0.10111507773399353, val_loss=0.18420552790164949, val_acc=0.9456466436386108
	 step=20400, train_loss=0.1354769915342331, val_loss=0.18298677265644073, val_acc=0.9456466436386108
	 step=20700, train_loss=0.11335772275924683, val_loss=0.18566490650177003, val_acc=0.9436335563659668
	 step=21000, train_loss=0.14709541201591492, val_loss=0.1803484845161438, val_acc=0.9441368579864502


  0%|          | 0/2343 [00:00<?, ?it/s]

	 step=21300, train_loss=0.15767942368984222, val_loss=0.19300721287727357, val_acc=0.9461499452590942
	 step=21600, train_loss=0.15034255385398865, val_loss=0.18673058211803437, val_acc=0.9431303143501282
	 step=21900, train_loss=0.1252322494983673, val_loss=0.18280743360519408, val_acc=0.9471564888954163
	 step=22200, train_loss=0.10544056445360184, val_loss=0.17784545242786406, val_acc=0.9466531872749329
	 step=22500, train_loss=0.11438729614019394, val_loss=0.17644255697727204, val_acc=0.9471564888954163
	 step=22800, train_loss=0.10694093257188797, val_loss=0.17822689235210418, val_acc=0.9461499452590942
	 step=23100, train_loss=0.10398882627487183, val_loss=0.18195416688919067, val_acc=0.9446401000022888
	 step=23400, train_loss=0.1125495582818985, val_loss=0.17482637882232666, val_acc=0.9486662745475769
loss: 0.16918574810028075, acc: 0.9496728181838989
+------------------------+----------+---------------+
| Word                   | True tag | Predicted tag |
+------------------

Чем больше словарь, тем лучше, но и на маленьком словаре сеть хорошо обучилась.

Были проведены эксперименты над моделью LSTM.

Как мы видим из графиков: размер слоев надо поставить побольше, главное не уйти в переобучение; количество слоев надо поставить поменьше, так как очень долго учится, dropout надо использовать с параметром 0.5; batchnorm себя хорошо показал; размер словаря надо брать как можно больше, так как увеличиваем размер выборки.