In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.nn import functional as F
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
import numpy as np
import fasttext.util

from nerus import load_nerus
from torch.nn.utils.rnn import pad_sequence

device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"

In [2]:
# fasttext.util.download_model('ru', if_exists='ignore')

In [3]:
class NERTokenizer:
    def __init__(self):
        self.idx2NER = {0: '<PAD>', 1: '<BOS>', 2: '<EOS>', 3: '<UNK>'}
        self.pad_token_id = 0
        self.bos_token_id = 1
        self.eos_token_id = 2
        self.unk_token_id = 3
        self.NER2idx = {'<PAD>': 0, '<BOS>': 1, '<EOS>': 2, '<UNK>': 3}
    def __len__(self):
        return len(self.idx2NER)
    
    def fit(self, sentences_ners):
        idx = len(self.idx2NER)
        for s in sentences_ners:
            for ner in s:
                if self.NER2idx.get(ner) is None:
                    self.NER2idx[ner] = idx
                    self.idx2NER[idx] = ner
                    idx += 1
        return self

    def encode(self, NERs):
        ids = []
        for ner in NERs:
            if self.NER2idx.get(ner) is not None:
                ids.append(self.NER2idx[ner])
            else:
                ids.append(self.NER2idx["<UNK>"])
        return ids

    def __call__(self, NERs):
        return self.encode(NERs)

    def decode(self, ids):
        NERs = []
        for idx in ids:
            if self.idx2NER.get(idx) is not None:
                tokens.append(self.idx2NER[idx])
            else:
                tokens.append(self.unk_token_id)
        return NERs


class NerusDataset(Dataset):
    def __init__(self, path="nerus_lenta.conllu.gz", 
                       min_occurrences=2, 
                       ners_tokenizer=None,
                       first_k=None, n_after_first_k=None):
        documents_generator = load_nerus(path)
        self.sentences_texts = []
        self.sentences_ners = []

        if first_k is not None:
            k = 0
            if n_after_first_k is not None:
                n = 0
        for doc in tqdm(documents_generator, leave=False):
            if n_after_first_k is None:
                for s in doc.sents:
                    self.sentences_texts.append([token.text for token in s.tokens])
                    self.sentences_ners.append([token.tag for token in s.tokens])
                if first_k is not None:
                    k += 1
                    if k >= first_k:
                        break
            else:
                if k < first_k:
                    k += 1
                    continue
                else:
                    for s in doc.sents:
                        self.sentences_texts.append([token.text for token in s.tokens])
                        self.sentences_ners.append([token.tag for token in s.tokens])
                    n += 1
                    if n >= n_after_first_k:
                        break
        if ners_tokenizer is None:
            self.ners_tokenizer = NERTokenizer().fit(self.sentences_ners)
        else:
            self.ners_tokenizer = ners_tokenizer

    def __len__(self):
        return len(self.sentences_texts)

    def __getitem__(self, idx):
        tokens_texts = self.sentences_texts[idx]
        tokens_ners = self.sentences_ners[idx]
        
        target = self.ners_tokenizer(tokens_ners) + [self.ners_tokenizer.eos_token_id]

        return tokens_texts, torch.tensor(target)


class Collator:
    def __init__(self, pad_token, target_pad_id):
        self.pad_token = pad_token
        self.target_pad_id = target_pad_id
                        
        self.words_tokenizer = fasttext.load_model('cc.ru.300.bin')
        fasttext.util.reduce_model(self.words_tokenizer, 100)
            
    def __call__(self, raw_batch):
        inputs_batch = [elem[0] for elem in raw_batch]
        target_batch = [elem[1] for elem in raw_batch]

        max_len = max([len(inp) for inp in inputs_batch])
        for i in range(len(inputs_batch)):
            inputs_batch[i] = [self.words_tokenizer.get_word_vector("[BOS]")] + \
                              [self.words_tokenizer.get_word_vector(token) for token \
                                 in inputs_batch[i] + [self.pad_token] * (max_len - len(inputs_batch[i]))]

        inputs_batch = torch.tensor(np.stack(inputs_batch))

        target_batch = pad_sequence(target_batch, 
                                    batch_first=True, 
                                    padding_value=self.target_pad_id)
        
        return inputs_batch, target_batch

In [4]:
class Encoder(nn.Module):
    def __init__(self, 
                 output_dict_size,
                 hidden_dim=128,
                 embedding_dim=64,
                 n_layers=1, 
                 batch_norm=False,
                 dropout=0):      
        super().__init__()

        self.n_layers = n_layers
        self.lstm = nn.LSTM(embedding_dim, 
                            hidden_size=hidden_dim//2, 
                            num_layers=n_layers,
                            batch_first=True,
                            dropout=dropout)
        
        self.linear = nn.Linear(hidden_dim//2, output_dict_size)
        self.batch_norm = nn.BatchNorm1d(output_dict_size) if batch_norm else False
        
    def forward(self, x):
 
        max_len = x.shape[1]
        logps = []
        
        i = 0
        h = None
        c = None
        while len(logps) < max_len:
            cur_token_emb = (x[:,i].to(device)).unsqueeze(1)
            i += 1
            output, (h, c) = self.lstm(cur_token_emb, (h, c) if h is not None and c is not None else None)
            output = self.linear(output)
            if self.batch_norm:
                output = torch.permute(self.batch_norm(torch.permute(output, (0, 2, 1))), (0, 2, 1))
            next_logp = F.log_softmax(output, dim=-1)
            logps.append(next_logp)

        return torch.cat(logps, dim=1)

In [5]:
def train_epoch(train_loader, model, loss_function, optimizer, callback=None):
    epoch_loss = 0
    total = 0
    for it, (batch_of_x, batch_of_y) in enumerate(tqdm(train_loader, leave=False)):
              
        batch_loss = train_on_batch(model, batch_of_x, batch_of_y, optimizer, loss_function)
        
        if callback is not None:
            with torch.no_grad():
                callback(model, batch_loss)
            
        epoch_loss += batch_loss * len(batch_of_x)
        total += len(batch_of_x)
    
    return epoch_loss / total


def train_on_batch(model, x_batch, y_batch, optimizer, loss_function):
    x_batch, y_batch = x_batch.to(device), y_batch.to(device)
    model.train()
    optimizer.zero_grad()
    preds = model(x_batch)
    loss = loss_function(preds, y_batch)
    loss.backward()
    optimizer.step()

    return loss.detach().cpu().item()


def trainer(count_of_epoch, 
            batch_size, 
            loader,
            model, 
            loss_function,
            optimizer,
            lr = 0.001,
            callback = None):

    optima = optimizer(model.parameters(), lr=lr)
    
    iterations = tqdm(range(count_of_epoch), desc='epoch')
    iterations.set_postfix({'train epoch loss': np.nan})
    for it in iterations:
        
        
        epoch_loss = train_epoch(train_loader=loader, 
                    model=model, 
                    loss_function=loss_function,
                    optimizer=optima, 
                    callback=callback)
        
        iterations.set_postfix({'train epoch loss': epoch_loss})


class Callback():
    def __init__(self, writer, test_loader, loss_function, delimeter=100, batch_size=64):
        self.step = 0
        self.writer = writer
        self.delimeter = delimeter
        self.loss_function = loss_function
        self.batch_size = batch_size

        self.loader = test_loader

    def forward(self, model, loss):
        self.step += 1
        self.writer.add_scalar('LOSS/train', loss, self.step)
        
        if self.step % self.delimeter == 0:
            
            pred = []
            real = []
            model.eval()
            with torch.no_grad():
                for it, (x_batch, y_batch) in enumerate(tqdm(self.loader, leave=False)):

                    x_batch = x_batch.to(device)
    
                    output = model(x_batch).detach()
    
                    pred.extend(torch.argmax(output, dim=-1).cpu().view(-1).tolist())
                    real.extend(y_batch.view(-1).tolist())
                    
                test_acc = np.mean(np.array(pred) == np.array(real))
                
                self.writer.add_scalar('Acc/test', test_acc, self.step)

          
    def __call__(self, model, loss):
        return self.forward(model, loss)

In [6]:
class LSTM_loss():
    def __init__(self, vocab_size, ignore_index):
        self.vocab_size = vocab_size
        self.ignore_index = ignore_index
        
    def __call__(self, pred, target):
        
        pred_shifted = pred.contiguous().view(-1, self.vocab_size)
        target_shifted = target.contiguous().view(-1)
        
        return F.nll_loss(pred_shifted, target_shifted, 
                        ignore_index=self.ignore_index)

In [7]:
train_dataset = NerusDataset(min_occurrences=2, first_k=100000, n_after_first_k=None)
test_dataset = NerusDataset(min_occurrences=2, 
                            ners_tokenizer=train_dataset.ners_tokenizer,
                            first_k=100000, n_after_first_k=4000)
# first_k = k --- takes only first k elements of dataset if n_after_first_k is None
# if n_after_first_k is not None takes n element after first k (x_{k}, x_{k+1} .... x_{k+n-1})
# It seems impossible to use whole dataset in the task, so we would go for a small subset

0it [00:00, ?it/s]

0it [00:00, ?it/s]

In [8]:
collator = Collator(pad_token="[PAD]",
                    target_pad_id=train_dataset.ners_tokenizer.pad_token_id)

In [11]:
%load_ext tensorboard
%tensorboard --logdir ./ --port=6001

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [12]:
loss_function = LSTM_loss(vocab_size=len(train_dataset.ners_tokenizer), 
                          ignore_index=train_dataset.ners_tokenizer.pad_token_id)

optimizer = torch.optim.Adam
lr = 3e-4
hidden_dim = 384
batch_size = 64
test_step_size = 1000
n_epochs=4



train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, 
                          collate_fn=collator)
        
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size, 
                          collate_fn=collator)


model = Encoder(output_dict_size=len(train_dataset.ners_tokenizer),
                hidden_dim=hidden_dim,
                embedding_dim=collator.words_tokenizer.get_dimension(),
                n_layers=3, 
                batch_norm=False,
                dropout=0).to(device)

writer = SummaryWriter(log_dir=f'./run0')

callback = Callback(writer, 
                    test_loader, 
                    loss_function, 
                    delimeter=test_step_size)

trainer(count_of_epoch=n_epochs, 
        batch_size=batch_size, 
        loader=train_loader,
        model=model, 
        loss_function=loss_function,
        optimizer=optimizer,
        lr=lr,
        callback=callback)

epoch:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/17453 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/17453 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/17453 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/17453 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

  0%|          | 0/673 [00:00<?, ?it/s]

Медленно, но учится, лосс падает, точность на тесте растет. С моделькой больше код падает с непонятными ошибками cuda, хотя запас по памяти большой, возможно у меня отваливается видеокарта -- нехорошо. В идеале надо было бы взять претрененный rubert и доучить голову на разметку токенов, т.к. лстмку с нуля учить довольно долго и довольно затратно по ресурсам, но в задании именно лстм, поэтому оставлю как есть.