# Named Entity Recognition с нуля

## Зависимости

In [1]:
%%writefile requirements.txt
pytorch_lightning
numpy
pandas
scikit-learn
transformers
razdel == 0.4.0
ipymarkup == 0.9.0
slovnet == 0.5.0
natasha == 1.4.0
fasttext
deeppavlov

Writing requirements.txt


In [2]:
!pip install --upgrade -r requirements.txt

Collecting pytorch_lightning
[?25l  Downloading https://files.pythonhosted.org/packages/81/d0/84a2f072cd407f93a1e50dff059656bce305f084e63a45cbbceb2fdb67b4/pytorch_lightning-1.1.0-py3-none-any.whl (665kB)
[K     |████████████████████████████████| 675kB 13.0MB/s 
[?25hCollecting numpy
[?25l  Downloading https://files.pythonhosted.org/packages/87/86/753182c9085ba4936c0076269a571613387cdb77ae2bf537448bfd63472c/numpy-1.19.4-cp36-cp36m-manylinux2010_x86_64.whl (14.5MB)
[K     |████████████████████████████████| 14.5MB 199kB/s 
[?25hRequirement already up-to-date: pandas in /usr/local/lib/python3.6/dist-packages (from -r requirements.txt (line 3)) (1.1.5)
Collecting scikit-learn
[?25l  Downloading https://files.pythonhosted.org/packages/5c/a1/273def87037a7fb010512bbc5901c31cfddfca8080bc63b42b26e3cc55b3/scikit_learn-0.23.2-cp36-cp36m-manylinux1_x86_64.whl (6.8MB)
[K     |████████████████████████████████| 6.8MB 52.9MB/s 
[?25hCollecting transformers
[?25l  Downloading https://files.pyt

## Датасет
Новый день - новый датасет!

In [3]:
!rm -f Persons-1000.zip
!wget -q http://ai-center.botik.ru/Airec/ai-resources/Persons-1000.zip
!unzip -qq Persons-1000.zip

In [4]:
!cat Persons-1000/collection/001/anno.markup.xml  

﻿<markup>
<entry>
<id>1</id>
<offset>308</offset>
<length>16</length>
<class>AAA_Estimate_Person</class>
<attribute>
<name>Canonical</name>
<value>ГРИГОРИЙ КАРАСИН</value>
</attribute>
</entry>
<entry>
<id>2</id>
<offset>387</offset>
<length>15</length>
<class>AAA_Estimate_Person</class>
<attribute>
<name>Canonical</name>
<value>ДЭНИЭЛ ФРИД</value>
</attribute>
</entry>
</markup>


Named Entity Recognition - распознавание именных сущностей. Выделяем в тексте спаны PER, LOC, ORG.

В случае с Persons-1000 только PER. Считаем датасет из XML.

In [5]:
import os
import xml.etree.ElementTree as ET
from ipymarkup import show_box_markup
from ipymarkup.palette import palette, BLUE, RED, GREEN

directory = "Persons-1000/collection/"

def read_text_with_markup(directory):
    MARKUP_FILE_NAME = "anno.markup.xml"
    TEXT_FILE_NAME = "text.txt"
    ENCODING = "windows-1251"
    ENTRY_TAG = "entry"
    START_POS_TAG = "offset"
    LENGTH_TAG = "length"

    markup_file_name = os.path.join(directory, MARKUP_FILE_NAME)
    text_file_name = os.path.join(directory, TEXT_FILE_NAME)
    with open(text_file_name, "r", encoding=ENCODING) as r:
        text = r.read()
    text = text.replace("\n", "\r\n")
    
    root = ET.parse(markup_file_name).getroot()
    spans = []
    for entry in root.findall(ENTRY_TAG):
        start_pos = int(entry.find(START_POS_TAG).text)
        end_pos = start_pos + int(entry.find(LENGTH_TAG).text)
        spans.append((start_pos, end_pos, "PER"))
    return text, spans

data = []
for sample_name in os.listdir(directory):
    sample_path = os.path.join(directory, sample_name)
    data.append(read_text_with_markup(sample_path))

ipymarkup - модуль для вывода NER разметки в ipynb

In [6]:
show_box_markup(data[0][0], data[0][1], palette=palette(PER=BLUE, ORG=RED, LOC=GREEN))

## BIO

BIO разметка: B - begin, I - inner, O - outer. Преобразуем задачу разметки спанов в задачу классификации каждого слова.

0 = O = outer, 1 = B = begin, 2 = I = inner

In [7]:
from enum import IntEnum
from collections import namedtuple

from razdel import tokenize

class Label(IntEnum):
    OUTER = 0
    BEGIN = 1
    INNER = 2

Sample = namedtuple("Sample", "text,tokens,spans,labels")

def text_span_to_sample(text, spans):
    labels = []
    tokens = list(tokenize(text))
    for token in tokens:
        label = Label.OUTER
        for span in spans:
            # Начало или часть какого-то спана
            # Не совсем корректно из-за возможных различий разметки и токенизации
            span_begin, span_end, tag = span
            if token.start == span_begin:
                label = Label.BEGIN
            elif token.start > span_begin and token.stop <= span_end:
                label = Label.INNER
        labels.append(label)
    # Проверяем инвариант отсутствия последовательных пар (0, 2)
    assert len([1 for i in range(len(labels) - 1) if (labels[i], labels[i+1]) == (0, 2)]) == 0
    return Sample(text, tokens, spans, labels)

samples = []
for text, spans in data:
    samples.append(text_span_to_sample(text, spans))

show_box_markup(samples[0].text, samples[0].spans, palette=palette(PER=BLUE, ORG=RED, LOC=GREEN))
print(samples[0].labels)
print(len(samples))

[<Label.OUTER: 0>, <Label.OUTER: 0>, <Label.BEGIN: 1>, <Label.INNER: 2>, <Label.INNER: 2>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.BEGIN: 1>, <Label.INNER: 2>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.BEGIN: 1>, <Label.INNER: 2>, <Label.INNER: 2>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OUTER: 0>, <Label.OU

Бьём на выборки

In [8]:
import random
random.shuffle(samples)

train = samples[:700]
val = samples[700:850]
test = samples[850:]

In [9]:
char_set = ["<pad>", "<unk>"] + list({ch for sample in train for token in sample.tokens for ch in token.text})
print(char_set)

['<pad>', '<unk>', 'Б', ';', '_', 'F', 'c', 'т', 'Ж', '>', '%', 'V', '"', 'v', 'и', 'Д', 'О', 'O', 'Й', '#', 'А', 'h', 'я', 'И', 'б', 'K', 'ь', '.', '*', '4', 'C', '\xad', 'T', 'a', ',', 'х', 'Р', 'X', '©', 'M', '+', '=', ']', 'r', 'i', 'л', 'f', '-', 'B', '!', 'N', '«', 'u', 'y', 'П', '7', 'З', 'k', '–', 'ж', 'ё', 'Ц', 'U', 'x', 'Е', 'S', 'Y', 'm', 'о', 'W', 'L', ')', 'У', '8', 'J', 'z', 'ъ', 'к', 'с', 'Н', '&', '—', '6', 'Э', 'I', 'е', ':', 'з', 'ч', 'Ь', '|', 'д', '5', 'Т', 'ф', 'Ъ', 'ю', 'Г', '“', 'Ы', 'ы', 'н', 'p', '3', '[', 'Ф', 'A', 'щ', 'Ш', 'Я', 'в', 'і', '?', 'd', 'j', 'D', 'q', 'e', 'Ч', 'o', 'а', 'H', '”', 'E', 'Z', 'К', 'Щ', 'Л', 'у', 'й', 't', 'g', 'P', 'Х', '»', 'ц', 'l', 'М', 'г', 'р', 's', 'п', 'R', 'Ю', '2', '№', 'С', 'м', 'b', '/', '1', '€', 'Ё', 'э', 'G', '$', "'", 'В', 'w', '0', 'n', '(', 'ш', '9', '…', '<']


Для каждого слова сохраняем его символьный состав, а в остальном старый добрый пайплайн

In [10]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, RandomSampler

class NerCharsDataset(Dataset):
    def __init__(self, samples, char_set, max_seq_len=500, max_char_seq_len=50):
        assert len(samples) != 0
        self.samples = []
        self.tokens = []
        self.texts = []
        for sample in samples:
            inputs = torch.zeros((max_seq_len, max_char_seq_len), dtype=torch.long)
            for token_num, token in enumerate(sample.tokens[:max_seq_len]):
                for ch_num, ch in enumerate(token.text[:max_char_seq_len]):
                    char_index = char_set.index(ch) if ch in char_set else char_set.index("<unk>")
                    inputs[token_num][ch_num] = char_index
            labels = torch.zeros((max_seq_len,), dtype=torch.long)
            input_labels = [int(i) for i in sample.labels[:max_seq_len]]
            labels[:len(input_labels)] = torch.LongTensor(input_labels)
            self.samples.append((torch.LongTensor(inputs), torch.LongTensor(labels)))
            self.tokens.append(sample.tokens[:max_seq_len])
            self.texts.append(sample.text)
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
        return self.samples[index]

    def get_tokens(self, index):
        return self.tokens[index]
    
    def get_text(self, index):
        return self.texts[index]


BATCH_SIZE = 32

train_data = NerCharsDataset(train, char_set)
train_sampler = RandomSampler(train_data)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, sampler=train_sampler)

val_data = NerCharsDataset(val, char_set)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE)

test_data = NerCharsDataset(test, char_set)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE)

In [11]:
test_data[0]

(tensor([[ 79, 120,  35,  ...,   0,   0,   0],
         [110,   0,   0,  ...,   0,   0,   0],
         [139,  68,  87,  ...,   0,   0,   0],
         ...,
         [ 20, 139, 147,  ...,   0,   0,   0],
         [ 27,   0,   0,  ...,   0,   0,   0],
         [  0,   0,   0,  ...,   0,   0,   0]]),
 tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0,
         0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [12]:
for sample in train_loader:
    inputs, labels = sample
    print(inputs.size())
    print(labels.size())
    break

# inputs: batch_size x num_words x num_chars
# labels: batch_size x num_words

torch.Size([32, 500, 50])
torch.Size([32, 500])


In [13]:
import torch
import torch.nn as nn
from pytorch_lightning import LightningModule
from pytorch_lightning.metrics import Accuracy
 

class TemplateModel(LightningModule):
    def __init__(self):
        super().__init__()
        
        self.loss = nn.CrossEntropyLoss()
        self.valid_accuracy = Accuracy()
        self.test_accuracy = Accuracy()
    
    def forward(self, inputs, labels):
        raise NotImplementedError("forward not implemented")
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return [optimizer]
    
    def training_step(self, batch, _):
        inputs, labels = batch
        loss, logits = self(inputs, labels)
        return loss
    
    def validation_step(self, batch, _):
        inputs, labels = batch
        val_loss, logits = self(inputs, labels)
        self.valid_accuracy.update(logits, labels)
        self.log("val_loss", val_loss, prog_bar=True)
        self.log("val_acc", self.valid_accuracy)

    def validation_epoch_end(self, outs):
        self.log("val_acc_epoch", self.valid_accuracy.compute(), prog_bar=True)

    def test_step(self, batch, _):
        inputs, labels = batch
        test_loss, logits = self(inputs, labels)
        self.test_accuracy.update(logits, labels)
        self.log("test_loss", test_loss, prog_bar=True)
        self.log("test_acc", self.test_accuracy)

    def test_epoch_end(self, outs):
        self.log("test_acc_epoch", self.test_accuracy.compute(), prog_bar=True)

## Бесконтекстная модель

In [15]:
import torch
from torch import nn
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

class SuperSimpleModel(TemplateModel):
    def __init__(self, char_set_size, char_embedding_dim=32, classes_count=3, char_max_seq_len=50):
        super().__init__()
        
        self.embeddings_layer = nn.Embedding(char_set_size, char_embedding_dim) # а -> [0.1, ... 0.3]
        self.out_layer = nn.Linear(char_max_seq_len * char_embedding_dim, classes_count)

    def forward(self, inputs, labels):
        projections = self.embeddings_layer.forward(inputs)
        projections = projections.reshape(projections.size(0), projections.size(1), -1) # <-
        logits = self.out_layer.forward(projections)
        logits = logits.transpose(1, 2)
        loss = self.loss(logits, labels)
        return loss, logits


super_simple_model = SuperSimpleModel(len(char_set))
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.0,
    patience=5,
    verbose=True,
    mode="min" 
)
trainer = Trainer(
    gpus=1,
    checkpoint_callback=False,
    accumulate_grad_batches=1,
    max_epochs=150,
    progress_bar_refresh_rate=10,
    callbacks=[early_stop_callback])
trainer.fit(super_simple_model, train_loader, val_loader)
trainer.test(super_simple_model, test_loader)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


KeyboardInterrupt: ignored

## Метрики

Можно использовать как классические мультиклассификационнные метрики, так и метрики специально для NER.

Например, число точных и частичных совпадений спанов, пропущенных и лишних спанов.

In [14]:
def get_spans(labels, tokens):
    spans = []
    for i, (label, token) in enumerate(zip(labels, tokens)):
        if label == 1:
            spans.append((token.start, token.stop, "PER"))
        elif label == 2:
            assert len(spans) != 0, "Incorrect label sequence: {}".format(labels)
            old_begin, _, old_tag = spans[-1]
            spans[-1] = (old_begin, token.stop, old_tag)
    return spans


def compare_span_sets(left_spans, right_spans):
    exact, partial, missing = 0, 0, 0
    for left_span in left_spans:
        is_missing = True
        for right_span in right_spans:
            if left_span == right_span:
                exact += 1
                is_missing = False
                break
            ls, le, _ = left_span
            rs, re, _ = right_span
            # [ls le] [rs re]
            # [rs re] [ls le]
            if not (ls <= le <= rs <= re or rs <= re <= ls <= le):
                is_missing = False
                partial += 1
                break            
        if is_missing:
            missing += 1
    return exact, partial, missing


def calc_metrics(true_labels, predicted_labels, tokens):
    # Метрики классификации
    one_tp, one_fp, one_fn = 0, 0, 0
    for true, predicted in zip(true_labels, predicted_labels):
        for l1, l2 in zip(true, predicted):
            if l1 == 1 and l2 == 1:
                one_tp += 1
            elif l1 != 1 and l2 == 1:
                one_fp += 1
            elif l1 == 1 and l2 !=1:
                one_fn += 1
    if one_tp + one_fp == 0:
        print("No positives!")
    else:
        print("1 Precision: {}, 1 Recall: {}".format(float(one_tp)/(one_tp + one_fp), float(one_tp)/(one_tp + one_fn)))

    # Специализированные метрики
    e, p, m, s = 0, 0, 0, 0
    for (true, predicted), sample_tokens in zip(zip(true_labels, predicted_labels), tokens):
        true_spans = get_spans(true, sample_tokens)
        predicted_spans = get_spans(predicted, sample_tokens)
        exact, partial, missing = compare_span_sets(true_spans, predicted_spans)
        _, _, spurius = compare_span_sets(predicted_spans, true_spans)
        e += exact
        p += partial
        m += missing
        s += spurius
    print("Exact: {}, partial: {}, missing: {}, spurius: {}".format(e, p, m, s))
            


def predict(model, test_loader, show_sample_index=51):
    model.eval()
    all_true_labels, all_predicted_labels, all_tokens, all_texts = [], [], [], []
    for batch_index, batch in enumerate(test_loader):
        inputs, true_labels = batch
        batch_size = inputs.size(0)
        _, logits = model(inputs.to("cuda"), true_labels.to("cuda"))
        predicted_labels = logits.max(dim=1)[1].detach().cpu()

        # Убираем неконсистентность BIO
        for sample_num, sample in enumerate(predicted_labels):
            for token_num, label in enumerate(sample):
                if token_num == 0 and label == 2:
                    predicted_labels[sample_num][0] = 1
                    continue
                prev_label = sample[token_num - 1]
                if label == 2 and prev_label == 0:
                    predicted_labels[sample_num][token_num] = 1

        all_true_labels.extend(true_labels)
        all_predicted_labels.extend(predicted_labels)
        for i in range(batch_size):
            all_tokens.append(test_data.get_tokens(batch_index * batch_size + i))
            all_texts.append(test_data.get_text(batch_index * batch_size + i))

    calc_metrics(all_true_labels, all_predicted_labels, all_tokens)
    print("PREDICTED:")
    show_box_markup(all_texts[show_sample_index],
                    get_spans(all_predicted_labels[show_sample_index], all_tokens[show_sample_index]),
                    palette=palette(PER=BLUE, ORG=RED, LOC=GREEN))

In [None]:
predict(super_simple_model, test_loader)

1 Precision: 0.6146964856230032, 1 Recall: 0.6291693917593199
Exact: 494, partial: 650, missing: 360, spurius: 203
PREDICTED:


## Контекстная модель: LSTM над конкатенацией

In [None]:
import torch
from torch import nn
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

class LstmModel(TemplateModel):
    def __init__(self, char_set_size, char_embedding_dim=4, classes_count=3,
                 lstm_embedding_dim=8, char_max_seq_len=50):
        super().__init__()
        
        self.embeddings_layer = nn.Embedding(char_set_size, char_embedding_dim)
        self.dropout = nn.Dropout(0.4)
        self.lstm_layer = nn.LSTM(char_embedding_dim * char_max_seq_len, lstm_embedding_dim // 2,
                                  batch_first=True, bidirectional=True)
        self.out_layer = nn.Linear(lstm_embedding_dim, classes_count)

    def forward(self, inputs, labels):
        batch_size = inputs.size(0)
        seq_len = inputs.size(1)
        projections = self.embeddings_layer.forward(inputs)
        projections = projections.reshape(projections.size(0), projections.size(1), -1)
        output, _= self.lstm_layer(projections)
        output = self.dropout(output)
        logits = self.out_layer.forward(output)
        logits = logits.transpose(1, 2)
        loss = self.loss(logits, labels)
        return loss, logits


lstm_model = LstmModel(len(char_set))
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.0,
    patience=5,
    verbose=True,
    mode="min" 
)
trainer = Trainer(
    gpus=1,
    checkpoint_callback=False,
    accumulate_grad_batches=1,
    max_epochs=200,
    progress_bar_refresh_rate=10,
    callbacks=[early_stop_callback])
trainer.fit(lstm_model, train_loader, val_loader)
trainer.test(lstm_model, test_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type             | Params
------------------------------------------------------
0 | loss             | CrossEntropyLoss | 0     
1 | valid_accuracy   | Accuracy         | 0     
2 | test_accuracy    | Accuracy         | 0     
3 | embeddings_layer | Embedding        | 672   
4 | dropout          | Dropout          | 0     
5 | lstm_layer       | LSTM             | 6.6 K 
6 | out_layer        | Linear           | 27    


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.9880, device='cuda:0'),
 'test_acc_epoch': tensor(0.9880, device='cuda:0'),
 'test_loss': tensor(0.0384, device='cuda:0'),
 'val_acc': tensor(0.9872, device='cuda:0'),
 'val_acc_epoch': tensor(0.9872, device='cuda:0'),
 'val_loss': tensor(0.0418, device='cuda:0')}
--------------------------------------------------------------------------------



[{'test_acc': 0.98798668384552,
  'test_acc_epoch': 0.98798668384552,
  'test_loss': 0.03836929425597191,
  'val_acc': 0.9872400164604187,
  'val_acc_epoch': 0.9872400164604187,
  'val_loss': 0.04180573299527168}]

In [None]:
predict(lstm_model, test_loader)

1 Precision: 0.6135569531795947, 1 Recall: 0.5679172056921087
Exact: 520, partial: 575, missing: 377, spurius: 166
PREDICTED:


## Контекстная модель: LSTM над CharFF

In [None]:
import torch
from torch import nn
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

class CharFFLstmModel(TemplateModel):
    def __init__(self, char_set_size, char_embedding_dim=4, classes_count=3,
                 word_embedding_dim=16, lstm_embedding_dim=16, char_max_seq_len=50):
        super().__init__()
        
        self.embeddings_layer = nn.Embedding(char_set_size, char_embedding_dim)
        self.dropout = nn.Dropout(0.3)
        self.linear = nn.Linear(char_embedding_dim * char_max_seq_len, word_embedding_dim)
        self.relu = nn.ReLU()
        self.lstm_layer = nn.LSTM(word_embedding_dim, lstm_embedding_dim // 2, batch_first=True, bidirectional=True)
        self.out_layer = nn.Linear(lstm_embedding_dim, classes_count)

    def forward(self, inputs, labels):
        projections = self.embeddings_layer.forward(inputs)
        projections = projections.reshape(projections.size(0), projections.size(1), -1)
        projections = self.relu(self.linear(projections))
        projections = self.dropout(projections)
        output, _= self.lstm_layer(projections)
        output = self.dropout(output)
        logits = self.out_layer.forward(output)
        logits = logits.transpose(1, 2)
        loss = self.loss(logits, labels)
        return loss, logits


char_ff_lstm_model = CharFFLstmModel(len(char_set))
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.0,
    patience=8,
    verbose=True,
    mode="min" 
)
trainer = Trainer(
    gpus=1,
    checkpoint_callback=False,
    accumulate_grad_batches=1,
    max_epochs=150,
    progress_bar_refresh_rate=10,
    callbacks=[early_stop_callback])
trainer.fit(char_ff_lstm_model, train_loader, val_loader)
trainer.test(char_ff_lstm_model, test_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type             | Params
------------------------------------------------------
0 | loss             | CrossEntropyLoss | 0     
1 | valid_accuracy   | Accuracy         | 0     
2 | test_accuracy    | Accuracy         | 0     
3 | embeddings_layer | Embedding        | 672   
4 | dropout          | Dropout          | 0     
5 | linear           | Linear           | 3.2 K 
6 | relu             | ReLU             | 0     
7 | lstm_layer       | LSTM             | 1.7 K 
8 | out_layer        | Linear           | 51    


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.9899, device='cuda:0'),
 'test_acc_epoch': tensor(0.9899, device='cuda:0'),
 'test_loss': tensor(0.0281, device='cuda:0'),
 'val_acc': tensor(0.9901, device='cuda:0'),
 'val_acc_epoch': tensor(0.9901, device='cuda:0'),
 'val_loss': tensor(0.0296, device='cuda:0')}
--------------------------------------------------------------------------------



[{'test_acc': 0.9898666739463806,
  'test_acc_epoch': 0.9898666739463806,
  'test_loss': 0.028131583705544472,
  'val_acc': 0.9900799989700317,
  'val_acc_epoch': 0.9900799989700317,
  'val_loss': 0.02958780527114868}]

In [None]:
predict(char_ff_lstm_model, test_loader)

1 Precision: 0.8388969521044993, 1 Recall: 0.7560497056899934
Exact: 1085, partial: 124, missing: 295, spurius: 144
PREDICTED:


## Задание 1.1
Сделайте то же самое, но с bidirectional LSTM на уровне символов

In [16]:
import torch
from torch import nn
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

class CharLstmModel(TemplateModel):
    def __init__(self, char_set_size, char_embedding_dim=4, classes_count=3,
                 lstm_embedding_dim=16, char_max_seq_len=50):
        super().__init__()
        
        self.embeddings_layer = nn.Embedding(char_set_size, char_embedding_dim)
        self.dropout = nn.Dropout(0.3)
        self.lstm_layer = nn.LSTM(char_embedding_dim, lstm_embedding_dim // 2, batch_first=True, bidirectional=True)
        self.out_layer = nn.Linear(lstm_embedding_dim * char_max_seq_len, classes_count)

    def forward(self, inputs, labels):
        projections = self.embeddings_layer.forward(inputs)
        char_max_seq_len = projections.size(2)
        projections = projections.reshape(projections.size(0), -1, projections.size(3))
        output, _= self.lstm_layer(projections)
        output = self.dropout(output)
        output = output.reshape(output.size(0), -1, output.size(2)*char_max_seq_len)
        logits = self.out_layer.forward(output)
        logits = logits.transpose(1, 2)
        loss = self.loss(logits, labels)
        return loss, logits


char_lstm_model = CharLstmModel(len(char_set))
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.0,
    patience=8,
    verbose=True,
    mode="min" 
)
trainer = Trainer(
    gpus=1,
    checkpoint_callback=False,
    accumulate_grad_batches=1,
    max_epochs=150,
    progress_bar_refresh_rate=10,
    callbacks=[early_stop_callback])
trainer.fit(char_lstm_model, train_loader, val_loader)
trainer.test(char_lstm_model, test_loader)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type             | Params
------------------------------------------------------
0 | loss             | CrossEntropyLoss | 0     
1 | valid_accuracy   | Accuracy         | 0     
2 | test_accuracy    | Accuracy         | 0     
3 | embeddings_layer | Embedding        | 664   
4 | dropout          | Dropout          | 0     
5 | lstm_layer       | LSTM             | 896   
6 | out_layer        | Linear           | 2.4 K 
------------------------------------------------------
4.0 K     Trainable params
0         Non-trainable params
4.0 K     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.9789, device='cuda:0'),
 'test_acc_epoch': tensor(0.9789, device='cuda:0'),
 'test_loss': tensor(0.0561, device='cuda:0'),
 'val_acc': tensor(0.9801, device='cuda:0'),
 'val_acc_epoch': tensor(0.9801, device='cuda:0'),
 'val_loss': tensor(0.0522, device='cuda:0')}
--------------------------------------------------------------------------------


[{'test_acc': 0.9788533449172974,
  'test_acc_epoch': 0.9788533449172974,
  'test_loss': 0.05610419437289238,
  'val_acc': 0.9801466464996338,
  'val_acc_epoch': 0.9801466464996338,
  'val_loss': 0.05219695344567299}]

In [17]:
predict(char_lstm_model, test_loader)

1 Precision: 0.6564195298372514, 1 Recall: 0.6989730423620025
Exact: 804, partial: 408, missing: 282, spurius: 293
PREDICTED:


## Задание 1.2
Сделайте то же самое, но со свёртками на уровне символов

In [18]:
import torch
import torch.nn as nn
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping  

class CharCNNModel(TemplateModel):
    def __init__(self, char_set_size, char_embedding_dim=4, classes_count=3,
                 kernel_size=5, channels=500, char_max_seq_len=50):
        super().__init__()
        
        self.embedding = nn.Embedding(char_set_size, char_embedding_dim)
        self.cnn = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=char_embedding_dim*kernel_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.out_layer = nn.Linear(char_embedding_dim*char_max_seq_len - (char_embedding_dim*kernel_size - 1), classes_count)
    
    def forward(self, inputs, labels):
        projections = self.embedding.forward(inputs)
        projections = projections.reshape(projections.size(0), projections.size(1), -1)
        convolved = self.dropout(self.relu(self.cnn(projections)))
        logits = self.out_layer.forward(convolved)
        
        logits = logits.transpose(1, 2)
        loss = self.loss(logits, labels)
        return loss, logits

char_cnn_model = CharCNNModel(len(char_set))
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.0,
    patience=8,
    verbose=True,
    mode="min" 
)
trainer = Trainer(
    gpus=1,
    checkpoint_callback=False,
    accumulate_grad_batches=1,
    max_epochs=150,
    progress_bar_refresh_rate=10,
    callbacks=[early_stop_callback])
trainer.fit(char_cnn_model, train_loader, val_loader)
trainer.test(char_cnn_model, test_loader)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type             | Params
----------------------------------------------------
0 | loss           | CrossEntropyLoss | 0     
1 | valid_accuracy | Accuracy         | 0     
2 | test_accuracy  | Accuracy         | 0     
3 | embedding      | Embedding        | 664   
4 | cnn            | Conv1d           | 5.0 M 
5 | relu           | ReLU             | 0     
6 | dropout        | Dropout          | 0     
7 | out_layer      | Linear           | 546   
----------------------------------------------------
5.0 M     Trainable params
0         Non-trainable params
5.0 M     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.9531, device='cuda:0'),
 'test_acc_epoch': tensor(0.9531, device='cuda:0'),
 'test_loss': tensor(0.3005, device='cuda:0'),
 'val_acc': tensor(0.9583, device='cuda:0'),
 'val_acc_epoch': tensor(0.9583, device='cuda:0'),
 'val_loss': tensor(0.2573, device='cuda:0')}
--------------------------------------------------------------------------------


[{'test_acc': 0.9531199932098389,
  'test_acc_epoch': 0.9531199932098389,
  'test_loss': 0.3004660904407501,
  'val_acc': 0.9582933187484741,
  'val_acc_epoch': 0.9582933187484741,
  'val_loss': 0.25733622908592224}]

In [19]:
predict(char_cnn_model, test_loader)

1 Precision: 0.12626262626262627, 1 Recall: 0.03209242618741977
Exact: 28, partial: 49, missing: 1417, spurius: 315
PREDICTED:


## Задание 1.3
Сделайте то же самое, но с CRF над головой

## Задание 2*
Сделайте то же самое, но с XLMRobertaForTokenClassification. Не забудьте про word -> subwords!

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
for sample in train:
    for token, label in zip(sample.tokens, sample.labels):
        input_ids = tokenizer(token.text, add_special_tokens=False)["input_ids"]
        print(token.text, tokenizer.tokenize(token.text), input_ids)
    break

Компания ['▁Компания'] [95425]
Luxoft ['▁Lux', 'o', 'ft'] [85561, 31, 2480]
проведет ['▁проведе', 'т'] [49958, 222]
IPO ['▁', 'IPO'] [6, 102402]
в ['▁в'] [49]
Нью-Йорке ['▁Нью', '-', 'Йорк', 'е'] [63326, 9, 71244, 103]
в ['▁в'] [49]
ближайшее ['▁ближайше', 'е'] [152280, 103]
время ['▁время'] [3699]
Российская ['▁', 'Российская'] [6, 241167]
компания ['▁компания'] [21313]
Luxoft ['▁Lux', 'o', 'ft'] [85561, 31, 2480]
, ['▁', ','] [6, 4]
входящая ['▁в', 'ходящ', 'ая'] [49, 124387, 3997]
в ['▁в'] [49]
IBS ['▁I', 'BS'] [87, 20429]
Group ['▁Group'] [10760]
, ['▁', ','] [6, 4]
в ['▁в'] [49]
ходе ['▁ходе'] [68642]
IPO ['▁', 'IPO'] [6, 102402]
предложит ['▁предложи', 'т'] [83293, 222]
4 ['▁4'] [201]
миллиона ['▁миллион', 'а'] [38608, 59]
или ['▁или'] [637]
12,5 ['▁12,5'] [155068]
% ['▁%'] [1745]
всех ['▁всех'] [8808]
акций ['▁акций'] [156804]
компании ['▁компании'] [13117]
. ['▁', '.'] [6, 5]
Компания ['▁Компания'] [95425]
Luxoft ['▁Lux', 'o', 'ft'] [85561, 31, 2480]
проведет ['▁проведе', 'т'] 

# NER из коробки

https://github.com/natasha/natasha

https://github.com/deepmipt/DeepPavlov

https://pypi.org/project/polyglot/

http://www.pullenti.ru/

## Natasha

In [None]:
from natasha import (
    Segmenter,
    NewsEmbedding,
    NewsNERTagger,
    Doc
)
	

text = '''
Иисус ходил по воде. Простите, еще несколько цитат из приговора. «…Отрицал существование
Иисуса и пророка Мухаммеда», «наделял Иисуса Христа качествами
ожившего мертвеца — зомби» [и] «качествами покемонов —
представителей бестиария японской мифологии, тем самым совершил
преступление, предусмотренное статьей 148 УК РФ
'''

emb = NewsEmbedding()
segmenter = Segmenter()
ner_tagger = NewsNERTagger(emb)

doc = Doc(text)
doc.segment(segmenter)
doc.tag_ner(ner_tagger)
doc.ner.print()
for span in doc.spans:
    print(span)

Иисус ходил по воде. Простите, еще несколько цитат из приговора. 
PER──                                                            
«…Отрицал существование
Иисуса и пророка Мухаммеда», «наделял Иисуса Христа качествами
                                      PER──────────           
ожившего мертвеца — зомби» [и] «качествами покемонов —
представителей бестиария японской мифологии, тем самым совершил
преступление, предусмотренное статьей 148 УК РФ
                                             LO
DocSpan(start=1, stop=6, type='PER', text='Иисус', tokens=[...])
DocSpan(start=128, stop=141, type='PER', text='Иисуса Христа', tokens=[...])
DocSpan(start=317, stop=319, type='LOC', text='РФ', tokens=[...])


In [None]:
def calc_metrics_spans_only(true_spans, predicted_spans):
    exact, partial, missing = compare_span_sets(true_spans, predicted_spans)
    _, _, spurius = compare_span_sets(predicted_spans, true_spans)
    return exact, partial, missing, spurius

In [None]:
aexact = 0
apartial = 0
amissing = 0
aspurius = 0
for sample in test:
    true_spans = sample.spans
    doc = Doc(sample.text)
    doc.segment(segmenter)
    doc.tag_ner(ner_tagger)
    predicted_spans = [(span.start, span.stop, span.type) for span in doc.spans if span.type == "PER"]
    exact, partial, missing, spurius = calc_metrics_spans_only(true_spans, predicted_spans)
    aexact += exact
    apartial += partial
    amissing += missing
    aspurius += spurius
print("Exact: {}, partial: {}, missing: {}, spurius: {}".format(aexact, apartial, amissing, aspurius))

Exact: 1593, partial: 20, missing: 41, spurius: 20


## Задание 3
Оцените готовые модели из deeppavlov-ner на нашем тестовом датасете

In [None]:
from deeppavlov import configs, build_model

ner_model = build_model(configs.ner.ner_rus, download=True)

In [None]:
predict(ner_model, test_loader)