# Богданов Александр Иванович, Б05-003

## Распознавания именованных сущностей на основе fasttext

In [121]:
import numpy as np
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings("ignore")

import torch
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter

import fasttext
import fasttext.util

from prettytable import PrettyTable

from nerus import load_nerus

In [122]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

### Вспомогательные функции

In [123]:
def get_sent_tags(docs, size=10000):
    list_of_sent = []
    list_of_tags = []
    for doc in tqdm(docs):
        for sent in doc.sents:
            list_of_sent_toks = []
            list_of_sent_tags = []
            for tok in sent.tokens:
                list_of_sent_toks.append(tok.text)
                list_of_sent_tags.append(tok.pos)
        list_of_sent.append(list_of_sent_toks)
        list_of_tags.append(list_of_sent_tags)
        if len(list_of_sent) > size:
            break
    return list_of_sent, list_of_tags

In [124]:
def pos_dict(list_of_tags, test_size=100):
    pos2idx = {'<PAD>' : 0}
    idx2pos = ['<PAD>']

    for tags in list_of_tags[:-test_size]:
        for word in tags:
            if word not in pos2idx:
                pos2idx[word] = len(idx2pos)
                idx2pos.append(word)
    return pos2idx, idx2pos

In [125]:
def word_dict(list_of_tags, test_size=100):
    word2idx = {'<PAD>': 0, '<UNK>': 1}
    idx2word = ['<PAD>', '<UNK>']
    
    for sent in list_of_sent[:-test_size]:
        for word in sent:
            if word not in word2idx:
                word2idx[word] = len(idx2word)
                idx2word.append(word)
                
    return word2idx, idx2word

In [126]:
def get_matrix_fasttext(tf):

    matrix_fasttext = []
    
    for i, w in enumerate(tqdm(ft.get_words(on_unicode_error='replace'))):
        matrix_fasttext.append(ft.get_word_vector(w))
        
    for w in ['<PAD>', '<UNK>', '<CLS>', '<SEP>']:
        matrix_fasttext.append(np.zeros_like(matrix_fasttext[-1]))
                
    return matrix_fasttext

In [127]:
class NerusDataset(Dataset):
    def __init__(self, list_of_sent, list_of_tags, word2idx, pos2idx, train=True, test_size=100):
        self.X = []
        self.y = []
        
        if train:
            for sent in list_of_sent[:-test_size]:
                data = []
                for word in sent:
                    data.append(word2idx.get(word, 0))
                self.X.append(data)
            
            for tags in list_of_tags[:-test_size]:
                data = []
                for word in tags:
                    data.append(pos2idx.get(word, 0))
                self.y.append(data)
        else:
            for sent in list_of_sent[-test_size:]:
                data = []
                for word in sent:
                    data.append(word2idx.get(word, 0))
                self.X.append(data)
                
            for tags in list_of_tags[-test_size:]:
                data = []
                for word in tags:
                    data.append(pos2idx.get(word, 0))
                self.y.append(data)
            
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return torch.Tensor(self.X[idx]), torch.Tensor(self.y[idx])

In [128]:
def collate_fn(data):
    X, Y = [], []
    for x, y in data:
        X.append(x)
        Y.append(y)
    x_batch = torch.zeros((len(X), max(list(map(len, X)))), dtype=torch.long)
    y_batch = torch.zeros((len(Y), max(list(map(len, Y)))), dtype=torch.long)
    
    for i, sent in enumerate(X):
        x_batch[i, :len(sent)] = sent

    for i, sent in enumerate(Y):
        y_batch[i, :len(sent)] = sent
        
    return x_batch, y_batch

In [137]:
def check(batch_size, dataset, model, loss_function, idx2word, idx2pos):
    
    model.eval()
    
    batch_generator = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, collate_fn=collate_fn)
            
    test_acc = 0
    test_loss = 0
    count = 0
    for i, (x_batch, y_batch) in enumerate(batch_generator):
        x_batch = x_batch.to(model.device)
        y_batch = y_batch.to(model.device)
                
        mask = (y_batch != 0)
        count += mask.sum()
                
        output = model(x_batch)

        test_loss += loss_function(output.transpose(1,2), y_batch).cpu().item()*len(x_batch)
        test_acc += (torch.argmax(output, dim=-1).cpu() == y_batch)[mask].sum().item()
            
    test_loss /= len(dataset)
    test_acc /= count

    print(f'loss: {test_loss}, acc: {test_acc}')
    
    dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
    x, y = next(iter(dataloader))
    x = x.to(device)
    y = y.to(device)

    outputs = model(x)
    
    one_x = x[0].cpu().numpy()
    one_y = y[0].cpu().numpy()
    one_output = outputs[0].argmax(dim=-1).cpu().numpy()

    words = [idx2word[idx] for idx in one_x]
    true_tags = [idx2pos[idx] for idx in one_y]
    pred_tags = [idx2pos[idx] for idx in one_output]

    table = PrettyTable(["Word", "True tag", "Predicted tag"])
    table.align["Word"], table.align["True tag"], table.align["Predicted tag"] = "l", "l", "l"

    for word, true_tag, pred_tag in zip(words, true_tags, pred_tags):
        if word != idx2word[word2idx['<PAD>']]:
            table.add_row([word, true_tag, pred_tag])

    print(table)

    return test_loss, test_acc

In [130]:
def train_on_batch(model, x_batch, y_batch, optimizer, loss_function):
    model.train()
    model.zero_grad()
    
    output = model(x_batch.to(device))
    
    loss = loss_function(output.transpose(1, 2), y_batch.to(device))
    loss.backward()

    optimizer.step()
    return loss.cpu().item()

In [131]:
def train_epoch(train_generator, model, loss_function, optimizer, callback = None):
    epoch_loss = 0
    total = 0
    for it, (batch_of_x, batch_of_y) in enumerate(train_generator):
        batch_loss = train_on_batch(model, batch_of_x.to(device), batch_of_y.to(device), optimizer, loss_function)
        train_generator.set_postfix({'train batch loss': batch_loss})
        
        if callback is not None:
            callback(model, batch_loss)
            
        epoch_loss += batch_loss*len(batch_of_x)
        total += len(batch_of_x)
    
    return epoch_loss/total

In [132]:
def trainer(count_of_epoch, 
            batch_size, 
            dataset,
            model, 
            loss_function,
            optimizer,
            lr = 0.001,
            callback = None):

    optima = optimizer(model.parameters(), lr=lr)
    
    iterations = tqdm(range(count_of_epoch), desc='epoch')
    iterations.set_postfix({'train epoch loss': np.nan})
    for it in iterations:
        batch_generator = tqdm(
            torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn), 
            leave=False, total=len(dataset)//batch_size+(len(dataset)%batch_size> 0))
        
        epoch_loss = train_epoch(
                    train_generator=batch_generator, 
                    model=model, 
                    loss_function=loss_function, 
                    optimizer=optima, 
                    callback=callback)
        
        iterations.set_postfix({'train epoch loss': epoch_loss})

In [133]:
class callback():
    def __init__(self, writer, dataset, loss_function, delimeter = 300, batch_size=64):
        self.step = 0
        self.writer = writer
        self.delimeter = delimeter
        self.loss_function = loss_function
        self.batch_size = batch_size

        self.dataset = dataset

    def forward(self, model, loss):
        model.eval()
        self.step += 1
        self.writer.add_scalar('LOSS/train', loss, self.step)
        
        if self.step % self.delimeter == 0:
            
            batch_generator = torch.utils.data.DataLoader(dataset=self.dataset, 
                                                          batch_size=self.batch_size,
                                                          collate_fn=collate_fn)
            
            test_acc = 0
            test_loss = 0
            count = 0
            for it, (x_batch, y_batch) in enumerate(batch_generator):
                x_batch = x_batch.to(model.device)
                y_batch = y_batch.to(model.device)
                
                mask = (y_batch != 0)
                count += mask.sum()
                
                output = model(x_batch)

                test_loss += self.loss_function(output.transpose(1,2), y_batch).cpu().item()*len(x_batch)

                test_acc += (torch.argmax(output, dim=-1).cpu() == y_batch)[mask].sum().item()
            
            test_loss /= len(self.dataset)
            test_acc /= count

            print(f"\t step={self.step}, train_loss={loss}, val_loss={test_loss}, val_acc={test_acc}")
            
            self.writer.add_scalar('LOSS/test', test_loss, self.step)
            self.writer.add_scalar('ACC/test', test_acc, self.step)
          
    def __call__(self, model, loss):
        return self.forward(model, loss)

## Модель

In [110]:
class RNN(torch.nn.Module):
    @property
    def device(self):
        return next(self.parameters()).device
        
    def __init__(self,
                 vocab_dim,
                 output_dim = 12,
                 emb_dim = 300, 
                 hidden_dim = 10,
                 num_layers = 3,
                 dropout = 0.7,
                 batch_norm = True):
        super(RNN, self).__init__()
        
        self.embedding = torch.nn.Embedding(vocab_dim, emb_dim, padding_idx=0)
        self.encoder = torch.nn.LSTM(emb_dim, hidden_dim, num_layers, dropout=dropout, batch_first=True)
        if batch_norm:
            self.batch_norm = torch.nn.BatchNorm1d(hidden_dim)
        else:
            self.batch_norm = None
        self.linear = torch.nn.Linear(hidden_dim, output_dim)
        
    def forward(self, input):
        out = self.embedding(input)
        out, _ = self.encoder(out)
        if self.batch_norm is not None:
            out = self.batch_norm(out.transpose(1, 2)).transpose(1, 2)
        out = self.linear(out)
        return out

## Обработка данных

In [101]:
docs = load_nerus("nerus_lenta.conllu.gz")

In [140]:
ft = fasttext.load_model('cc.ru.300.bin')



In [141]:
list_of_sent, list_of_tags = get_sent_tags(docs, 10000)

0it [00:00, ?it/s]

In [142]:
pos2idx, idx2pos = pos_dict(list_of_tags)

In [143]:
matrix_fasttext = get_matrix_fasttext(ft)

  0%|          | 0/2000000 [00:00<?, ?it/s]

In [144]:
word2idx, idx2word = word_dict(ft)

In [145]:
train_data = NerusDataset(list_of_sent, list_of_tags, word2idx, pos2idx, train=True)
test_data = NerusDataset(list_of_sent, list_of_tags, word2idx, pos2idx, train=False)

## Подключим tensorboard

In [187]:
%load_ext tensorboard
%tensorboard --logdir tensorboard_1/

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6007 (pid 60721), started 0:00:03 ago. (Use '!kill 60721' to kill it.)

## Обучение

In [188]:
loss_function = torch.nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam

In [189]:
model = RNN(vocab_dim=len(matrix_fasttext), output_dim=len(pos2idx))
model.embedding.weight.data.copy_(torch.tensor(matrix_fasttext))
for param in model.embedding.parameters():
    param.requires_grad = False
model.to(device)

RNN(
  (embedding): Embedding(2000004, 300, padding_idx=0)
  (encoder): LSTM(300, 10, num_layers=3, batch_first=True, dropout=0.7)
  (batch_norm): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear): Linear(in_features=10, out_features=18, bias=True)
)

In [190]:
writer = SummaryWriter(log_dir=f'tensorboard_1/experiment')
call = callback(writer, test_data, loss_function)
    
check(64, test_data, model, loss_function, idx2word, idx2pos)
trainer(count_of_epoch=20, 
        batch_size=64, 
        dataset=train_data,
        model=model, 
        loss_function=loss_function,
        optimizer = optimizer,
        callback=call)
check(64, test_data, model, loss_function, idx2word, idx2pos)

loss: 2.8875919914245607, acc: 0.0
+--------------------+----------+---------------+
| Word               | True tag | Predicted tag |
+--------------------+----------+---------------+
| Он                 | PRON     | <PAD>         |
| призвал            | VERB     | <PAD>         |
| бороться           | VERB     | <PAD>         |
| с                  | ADP      | <PAD>         |
| мздоимством        | NOUN     | <PAD>         |
| ,                  | PUNCT    | <PAD>         |
| нанося             | VERB     | <PAD>         |
| удары              | NOUN     | <PAD>         |
| «                  | PUNCT    | <PAD>         |
| и                  | CCONJ    | <PAD>         |
| по                 | ADP      | <PAD>         |
| тигру              | NOUN     | <PAD>         |
| ,                  | PUNCT    | <PAD>         |
| и                  | CCONJ    | <PAD>         |
| по                 | ADP      | <PAD>         |
| мухе               | NOUN     | <PAD>         |
| »            

epoch:   0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=300, train_loss=1.9143242835998535, val_loss=1.9298494005203246, val_acc=0.3870220184326172


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=600, train_loss=1.6940072774887085, val_loss=1.8164882040023804, val_acc=0.4154113531112671


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=900, train_loss=1.6826220750808716, val_loss=1.6774484634399414, val_acc=0.4779837727546692


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1200, train_loss=1.561474323272705, val_loss=1.6148735475540161, val_acc=0.5196987390518188


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1500, train_loss=1.5243782997131348, val_loss=1.5608360052108765, val_acc=0.5260718464851379


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=1800, train_loss=1.4919233322143555, val_loss=1.57381441116333, val_acc=0.5376592874526978


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=2100, train_loss=1.4440556764602661, val_loss=1.575224094390869, val_acc=0.5469292998313904


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=2400, train_loss=1.4258780479431152, val_loss=1.6800995445251465, val_acc=0.5521436929702759


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=2700, train_loss=1.3678654432296753, val_loss=1.692449049949646, val_acc=0.5463499426841736


  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

	 step=3000, train_loss=1.3479853868484497, val_loss=1.7745773029327392, val_acc=0.5469292998313904
loss: 1.8341864824295044, acc: 0.5498261451721191
+------------+----------+---------------+
| Word       | True tag | Predicted tag |
+------------+----------+---------------+
| Он         | PRON     | NOUN          |
| подлетел   | VERB     | NOUN          |
| к          | ADP      | ADP           |
| планете    | NOUN     | NOUN          |
| на         | ADP      | ADP           |
| расстоянии | NOUN     | NOUN          |
| около      | ADP      | NOUN          |
| 203        | NUM      | NOUN          |
| тысяч      | NOUN     | NOUN          |
| километров | NOUN     | NOUN          |
| (          | PUNCT    | PUNCT         |
| примерно   | ADV      | VERB          |
| половина   | NOUN     | NOUN          |
| расстояния | NOUN     | NOUN          |
| до         | ADP      | ADP           |
| Луны       | PROPN    | NOUN          |
| )          | PUNCT    | CCONJ         |
| .       

(1.8341864824295044, tensor(0.5498))

Качество классфикации достаточно хорошее - мы используем не очень большой датасет, не очень сложную архитектуру и небольшое количество эпох, мы видим, что FastText играет существенную роль: FastText создает эмбендинги слов, которые хорошо учитывают отношения между словами.