In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.nn import functional as F
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
import numpy as np

from nerus import load_nerus
from torch.nn.utils.rnn import pad_sequence

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
class WordsTokenizer:
    def __init__(self, min_occurrences):
        self.idx2word = {0: '<PAD>', 1: '<BOS>', 2: '<EOS>', 3: '<UNK>'}
        self.pad_token_id = 0
        self.bos_token_id = 1
        self.eos_token_id = 2
        self.unk_token_id = 3
        self.word2idx = {'<PAD>': 0, '<BOS>': 1, '<EOS>': 2, '<UNK>': 3}
        self.min_occurrences = min_occurrences
    def __len__(self):
        return len(self.idx2word)

    def fit(self, sentences):
        word2occurrences = {}
        idx = len(self.idx2word)
        for s in sentences:
            for word in s:
                if word2occurrences.get(word) is None:
                    word2occurrences[word] = 1
                else: 
                    word2occurrences[word] += 1
        for word, n_occ in word2occurrences.items():
            if n_occ >= self.min_occurrences:
                self.word2idx[word] = idx
                self.idx2word[idx] = word
                idx += 1
        return self
                
    def encode(self, tokens):
        ids = []
        for token in tokens:
            if self.word2idx.get(token) is not None:
                ids.append(self.word2idx[token])
            else:
                ids.append(self.word2idx["<UNK>"])
        return ids

    def __call__(self, tokens):
        return self.encode(tokens)

    def decode(self, ids):
        tokens = []
        for idx in ids:
            if self.idx2word.get(idx) is not None:
                tokens.append(self.idx2word[idx])
            else:
                tokens.append(self.unk_token_id)
        return tokens


class POSTokenizer:
    def __init__(self):
        self.idx2POS = {0: '<PAD>', 1: '<BOS>', 2: '<EOS>', 3: '<UNK>'}
        self.pad_token_id = 0
        self.bos_token_id = 1
        self.eos_token_id = 2
        self.unk_token_id = 3
        self.POS2idx = {'<PAD>': 0, '<BOS>': 1, '<EOS>': 2, '<UNK>': 3}
    def __len__(self):
        return len(self.idx2POS)
    
    def fit(self, sentences_poses):
        idx = len(self.idx2POS)
        for s in sentences_poses:
            for pos in s:
                if self.POS2idx.get(pos) is None:
                    self.POS2idx[pos] = idx
                    self.idx2POS[idx] = pos
                    idx += 1
        return self

    def encode(self, POSes):
        ids = []
        for pos in POSes:
            if self.POS2idx.get(pos) is not None:
                ids.append(self.POS2idx[pos])
            else:
                ids.append(self.POS2idx["<UNK>"])
        return ids

    def __call__(self, POSes):
        return self.encode(POSes)

    def decode(self, ids):
        POSes = []
        for idx in ids:
            if self.idx2POS.get(idx) is not None:
                tokens.append(self.idx2POS[idx])
            else:
                tokens.append(self.unk_token_id)
        return POSes


class NerusDataset(Dataset):
    def __init__(self, path="nerus_lenta.conllu.gz", 
                       min_occurrences=2, 
                       words_tokenizer = None, poses_tokenizer=None,
                       first_k=None, n_after_first_k=None):
        documents_generator = load_nerus(path)
        self.sentences_texts = []
        self.sentences_poses = []

        if first_k is not None:
            k = 0
            if n_after_first_k is not None:
                n = 0
        for doc in tqdm(documents_generator, leave=False):
            if n_after_first_k is None:
                for s in doc.sents:
                    self.sentences_texts.append([token.text for token in s.tokens])
                    self.sentences_poses.append([token.pos for token in s.tokens])
                if first_k is not None:
                    k += 1
                    if k >= first_k:
                        break
            else:
                if k < first_k:
                    k += 1
                    continue
                else:
                    for s in doc.sents:
                        self.sentences_texts.append([token.text for token in s.tokens])
                        self.sentences_poses.append([token.pos for token in s.tokens])
                    n += 1
                    if n >= n_after_first_k:
                        break
                
        if words_tokenizer is None:
            self.words_tokenizer = WordsTokenizer(min_occurrences)
        else:
            self.words_tokenizer = words_tokenizer.fit(self.sentences_texts)
        if poses_tokenizer is None:
            self.poses_tokenizer = POSTokenizer().fit(self.sentences_poses)
        else:
            self.poses_tokenizer = poses_tokenizer.fit(self.sentences_poses)

    def __len__(self):
        return len(self.sentences_texts)

    def __getitem__(self, idx):
        tokens_texts = self.sentences_texts[idx]
        tokens_poses = self.sentences_poses[idx]

        input_ids = [self.words_tokenizer.bos_token_id] + self.words_tokenizer(tokens_texts)
        target = self.poses_tokenizer(tokens_poses) + [self.poses_tokenizer.eos_token_id]

        input_ids_tensor = torch.tensor(input_ids)
        target_tensor = torch.tensor(target)

        return input_ids_tensor, target_tensor


class Collator:
    def __init__(self, token_pad_id, target_pad_id):
        self.token_pad_id = token_pad_id
        self.target_pad_id = target_pad_id
    def __call__(self, raw_batch):
        input_ids_batch = [elem[0] for elem in raw_batch]
        target_batch = [elem[1] for elem in raw_batch]

        input_ids_batch = pad_sequence(input_ids_batch, batch_first=True, padding_value=self.token_pad_id)
        target_batch = pad_sequence(target_batch, batch_first=True, padding_value=self.target_pad_id)
        return input_ids_batch, target_batch

In [12]:
class Encoder(nn.Module):
    def __init__(self, 
                 input_dict_size,
                 output_dict_size,
                 hidden_dim=128,
                 embedding_dim=64,
                 n_layers=1, 
                 batch_norm=False,
                 dropout=0):      
        super().__init__()

        self.embedding = nn.Embedding(input_dict_size, embedding_dim)
        self.n_layers = n_layers
        self.lstm = nn.LSTM(embedding_dim, 
                            hidden_size=hidden_dim//2, 
                            num_layers=n_layers,
                            batch_first=True,
                            dropout=dropout)
        
        self.linear = nn.Linear(hidden_dim//2, output_dict_size)
        self.batch_norm = nn.BatchNorm1d(output_dict_size) if batch_norm else False
        
    def forward(self, x):
        
        x = self.embedding(x)
 
        max_len = x.shape[1]
        logps = []
        
        i = 0
        h = None
        c = None
        while len(logps) < max_len:
            cur_token_emb = (x[:,i].to(device)).unsqueeze(1)
            i += 1
            output, (h, c) = self.lstm(cur_token_emb, (h, c) if h is not None and c is not None else None)
            output = self.linear(output)
            if self.batch_norm:
                output = torch.permute(self.batch_norm(torch.permute(output, (0, 2, 1))), (0, 2,1))
            next_logp = F.log_softmax(output, dim=-1)
            logps.append(next_logp)

        return torch.cat(logps, dim=1)

In [4]:
def train_epoch(train_loader, model, loss_function, optimizer, callback=None):
    epoch_loss = 0
    total = 0
    for it, (batch_of_x, batch_of_y) in enumerate(tqdm(train_loader, leave=False)):
        # batch_of_x, batch_of_y = model.tokenizer(batch_of_x, padding=True, truncation=True, max_length=max_length,
        #                          return_tensors="pt")["input_ids"], \
        #                          model.tokenizer(batch_of_y, padding=True, truncation=True, max_length=max_length,
        #                          return_tensors="pt")["input_ids"]
                            
        batch_loss = train_on_batch(model, batch_of_x, batch_of_y, optimizer, loss_function)
        
        if callback is not None:
            with torch.no_grad():
                callback(model, batch_loss)
            
        epoch_loss += batch_loss * len(batch_of_x)
        total += len(batch_of_x)
    
    return epoch_loss / total


def train_on_batch(model, x_batch, y_batch, optimizer, loss_function):
    model.train()
    optimizer.zero_grad()
    preds = model(x_batch.to(device))
    loss = loss_function(preds, y_batch.to(device))
    loss.backward()

    optimizer.step()
    return loss.detach().cpu().item()


def trainer(count_of_epoch, 
            batch_size, 
            loader,
            model, 
            loss_function,
            optimizer,
            lr = 0.001,
            callback = None):

    optima = optimizer(model.parameters(), lr=lr)
    
    iterations = tqdm(range(count_of_epoch), desc='epoch')
    iterations.set_postfix({'train epoch loss': np.nan})
    for it in iterations:
        
        
        epoch_loss = train_epoch(train_loader=loader, 
                    model=model, 
                    loss_function=loss_function,
                    optimizer=optima, 
                    callback=callback)
        
        iterations.set_postfix({'train epoch loss': epoch_loss})


class Callback():
    def __init__(self, writer, test_loader, loss_function, delimeter=100, batch_size=64):
        self.step = 0
        self.writer = writer
        self.delimeter = delimeter
        self.loss_function = loss_function
        self.batch_size = batch_size

        self.loader = test_loader

    def forward(self, model, loss):
        self.step += 1
        self.writer.add_scalar('LOSS/train', loss, self.step)
        
        if self.step % self.delimeter == 0:
            
            pred = []
            real = []
            model.eval()
            with torch.no_grad():
                for it, (x_batch, y_batch) in enumerate(tqdm(self.loader, leave=False)):
                    # x_batch, y_batch = model.tokenizer(x_batch, padding=True, truncation=True, max_length=max_length,
                    #                  return_tensors="pt")["input_ids"], \
                    #                    model.tokenizer(y_batch, padding=True, truncation=True, max_length=max_length,
                    #                  return_tensors="pt")["input_ids"]
                    
                    x_batch = x_batch.to(device)
    
                    output = model(x_batch).detach()
    
                    pred.extend(torch.argmax(output, dim=-1).cpu().view(-1).tolist())
                    real.extend(y_batch.view(-1).tolist())
                    
                test_acc = np.mean(np.array(pred) == np.array(real))
                
                self.writer.add_scalar('Acc/test', test_acc, self.step)

          
    def __call__(self, model, loss):
        return self.forward(model, loss)

In [11]:
%load_ext tensorboard
%tensorboard --logdir ./ --port=6005

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [6]:
class LSTM_loss():
    def __init__(self, vocab_size, ignore_index):
        self.vocab_size = vocab_size
        self.ignore_index = ignore_index
        
    def __call__(self, pred, target):
        
        pred_shifted = pred.contiguous().view(-1, self.vocab_size)
        target_shifted = target.contiguous().view(-1)
        
        return F.nll_loss(pred_shifted, target_shifted, 
                        ignore_index=self.ignore_index)

In [7]:
train_dataset = NerusDataset(min_occurrences=2, first_k=100000, n_after_first_k=None)
test_dataset = NerusDataset(min_occurrences=2, 
                            words_tokenizer=train_dataset.words_tokenizer, 
                            poses_tokenizer=train_dataset.poses_tokenizer,
                            first_k=100000, n_after_first_k=5000)
# first_k = k --- takes only first k elements of dataset if n_after_first_k is None
# if n_after_first_k is not None takes n element after first k (x_{k}, x_{k+1} .... x_{k+n-1})
# It seems impossible to use whole dataset in the task, so we would go for a small subset

0it [00:00, ?it/s]

0it [00:00, ?it/s]

In [8]:
loss_function = LSTM_loss(vocab_size=len(train_dataset.poses_tokenizer), 
                          ignore_index=train_dataset.poses_tokenizer.pad_token_id)

optimizer = torch.optim.Adam
lr = 1e-3
hidden_dims = [32,64,128,256,512]
batch_size = 1500
test_step_size = 100
n_epochs=4



train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, 
                          collate_fn=Collator(token_pad_id=train_dataset.words_tokenizer.pad_token_id,
                                              target_pad_id=train_dataset.poses_tokenizer.pad_token_id))
        
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size, 
                          collate_fn=Collator(token_pad_id=test_dataset.words_tokenizer.pad_token_id,
                                              target_pad_id=test_dataset.poses_tokenizer.pad_token_id))

for hidden_dim in hidden_dims:

    model = Encoder(input_dict_size=len(train_dataset.words_tokenizer),
                    output_dict_size=len(train_dataset.poses_tokenizer),
                    hidden_dim=hidden_dim,
                    embedding_dim=64,
                    n_layers=1, 
                    batch_norm=False,
                    dropout=0).to(device)
    
    writer = SummaryWriter(log_dir=f'different_hidden_dims/{hidden_dim}')

    callback = Callback(writer, test_loader, loss_function, delimeter=test_step_size)

    trainer(count_of_epoch=n_epochs, 
            batch_size=batch_size, 
            loader=train_loader,
            model=model, 
            loss_function=loss_function,
            optimizer=optimizer,
            lr=lr,
            callback=callback)

epoch:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

epoch:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

epoch:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

epoch:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

epoch:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

Последовательное удвоение размера скрытого состояния дает свои плоды и метрика растет с разом, судя по графикам предел не достигнут, но время и ресурсы ограничены, поэтому проверять дальше не будем. Для остальных экспериментов используем промежуточный вариант hidden_dim = 128

Рассмотрим зависимость от числа слоев в LSTM

In [10]:
loss_function = LSTM_loss(vocab_size=len(train_dataset.poses_tokenizer), 
                          ignore_index=train_dataset.poses_tokenizer.pad_token_id)

optimizer = torch.optim.Adam
lr = 1e-3
n_layers = [1,2,3]
batch_size = 1500
test_step_size = 100
n_epochs=6



train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, 
                          collate_fn=Collator(token_pad_id=train_dataset.words_tokenizer.pad_token_id,
                                              target_pad_id=train_dataset.poses_tokenizer.pad_token_id))
        
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size, 
                          collate_fn=Collator(token_pad_id=test_dataset.words_tokenizer.pad_token_id,
                                              target_pad_id=test_dataset.poses_tokenizer.pad_token_id))

for n_layer in n_layers:

    model = Encoder(input_dict_size=len(train_dataset.words_tokenizer),
                    output_dict_size=len(train_dataset.poses_tokenizer),
                    hidden_dim=128,
                    embedding_dim=64,
                    n_layers=n_layer, 
                    batch_norm=False,
                    dropout=0).to(device)
    
    writer = SummaryWriter(log_dir=f'different_number_of_layers/{n_layer}')

    callback = Callback(writer, test_loader, loss_function, delimeter=test_step_size)

    trainer(count_of_epoch=n_epochs, 
            batch_size=batch_size, 
            loader=train_loader,
            model=model, 
            loss_function=loss_function,
            optimizer=optimizer,
            lr=lr,
            callback=callback)

epoch:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

epoch:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

epoch:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

Даже на небольшом hidden_dim и увеличенном числе эпох до 6, варианты с больше чем 1 слоем выступают не лучше бейзлайна, либо для задачи достаточно возможностей lstm с одним слоем, либо выбранная часть датасета слишком маленькая, чтобы почувствовать разницу, либо даже 6 эпох -- слишком мало чтобы увеличенная сеть обучилась лучше (проверять еще больше не будем, поскольку итак считается довольно долго)

Проверим зависимость от длины словаря. Единственный способ корректно ее менять -- изменение трешхолда для минимальной частоты встречаемости в датасете (токены с меньшей помечаем как unknown)

In [7]:
optimizer = torch.optim.Adam
lr = 1e-3
min_occurrences_options = [1,2,5,10,20]
batch_size = 3000
test_step_size = 50
n_epochs=6


for min_occurrences in min_occurrences_options:

    train_dataset = NerusDataset(min_occurrences=min_occurrences, first_k=100000, n_after_first_k=None)
    test_dataset = NerusDataset(min_occurrences=min_occurrences, 
                                words_tokenizer=train_dataset.words_tokenizer, 
                                poses_tokenizer=train_dataset.poses_tokenizer,
                                first_k=100000, n_after_first_k=5000)

    loss_function = LSTM_loss(vocab_size=len(train_dataset.poses_tokenizer), 
                          ignore_index=train_dataset.poses_tokenizer.pad_token_id)

    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, 
                          collate_fn=Collator(token_pad_id=train_dataset.words_tokenizer.pad_token_id,
                                              target_pad_id=train_dataset.poses_tokenizer.pad_token_id))
        
    test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size, 
                              collate_fn=Collator(token_pad_id=test_dataset.words_tokenizer.pad_token_id,
                                                  target_pad_id=test_dataset.poses_tokenizer.pad_token_id))

    
    model = Encoder(input_dict_size=len(train_dataset.words_tokenizer),
                    output_dict_size=len(train_dataset.poses_tokenizer),
                    hidden_dim=128,
                    embedding_dim=64,
                    n_layers=1, 
                    batch_norm=False,
                    dropout=0).to(device)

    dict_size = len(train_dataset.words_tokenizer)
    print(dict_size)
    
    writer = SummaryWriter(log_dir=f'different_dict_sizes/{dict_size}')

    callback = Callback(writer, test_loader, loss_function, delimeter=test_step_size)

    trainer(count_of_epoch=n_epochs, 
            batch_size=batch_size, 
            loader=train_loader,
            model=model, 
            loss_function=loss_function,
            optimizer=optimizer,
            lr=lr,
            callback=callback)

0it [00:00, ?it/s]

0it [00:00, ?it/s]

87798


epoch:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

45192


epoch:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

19562


epoch:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

10563


epoch:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

5531


epoch:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/373 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Лосс получается тем меньше, чем больше размер словаря, т.е. большой словарь, вероятно, не мешает нашей сети учиться, метрика при этом ведет себя соответствующе, за исключением самого маленького словаря -- для него наблюдается непонятный всплеск в середине обучения.

Наконец проверим помогает ли добавление batch norm-а

In [8]:
train_dataset = NerusDataset(min_occurrences=2, first_k=100000, n_after_first_k=None)
test_dataset = NerusDataset(min_occurrences=2, 
                            words_tokenizer=train_dataset.words_tokenizer, 
                            poses_tokenizer=train_dataset.poses_tokenizer,
                            first_k=100000, n_after_first_k=5000)
# first_k = k --- takes only first k elements of dataset if n_after_first_k is None
# if n_after_first_k is not None takes n element after first k (x_{k}, x_{k+1} .... x_{k+n-1})
# It seems impossible to use whole dataset in the task, so we would go for a small subset

0it [00:00, ?it/s]

0it [00:00, ?it/s]

In [13]:
loss_function = LSTM_loss(vocab_size=len(train_dataset.poses_tokenizer), 
                          ignore_index=train_dataset.poses_tokenizer.pad_token_id)

optimizer = torch.optim.Adam
lr = 1e-3
batch_norms = [True, False]
batch_size = 1500
test_step_size = 100
n_epochs=4


train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, 
                          collate_fn=Collator(token_pad_id=train_dataset.words_tokenizer.pad_token_id,
                                              target_pad_id=train_dataset.poses_tokenizer.pad_token_id))
        
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size, 
                          collate_fn=Collator(token_pad_id=test_dataset.words_tokenizer.pad_token_id,
                                              target_pad_id=test_dataset.poses_tokenizer.pad_token_id))

for batch_norm in batch_norms:

    model = Encoder(input_dict_size=len(train_dataset.words_tokenizer),
                    output_dict_size=len(train_dataset.poses_tokenizer),
                    hidden_dim=128,
                    embedding_dim=64,
                    n_layers=1, 
                    batch_norm=batch_norm,
                    dropout=0).to(device)
    
    writer = SummaryWriter(log_dir=f'batch_norms/{batch_norm}')

    callback = Callback(writer, test_loader, loss_function, delimeter=test_step_size)

    trainer(count_of_epoch=n_epochs, 
            batch_size=batch_size, 
            loader=train_loader,
            model=model, 
            loss_function=loss_function,
            optimizer=optimizer,
            lr=lr,
            callback=callback)

epoch:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

epoch:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/745 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

При включенном batch_norm лосс оказывается выше, однако метрика снова ведет себя странно. 