In [28]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import warnings
import sys
import numpy as np

sys.path.append("../")

warnings.filterwarnings("ignore")

### 0. Parse data
Store data in NER format

In [330]:
import os

train_path = "/datadrive/AGRR-2019/train.csv"
valid_path = "/datadrive/AGRR-2019/dev.csv"

In [17]:
import pandas as pd

In [18]:
train_df = pd.read_csv(train_path, sep="\t")

In [19]:
train_df.head()

Unnamed: 0,text,class,cV,cR1,cR2,V,R1,R2
0,"Будучи в прошлый четверг в Софии, он назвал се...",0,,,,,,
1,Работа с двухбайтовыми наборами символов — про...,1,92:99,83:91,103:109,127:127,119:124,127:134
2,"Заместитель Генерального секретаря подчеркнул,...",0,,,,,,
3,Продажа недвижимости из собственных портфелей ...,0,,,,,,
4,"Новым является то, что повышенное давление кон...",0,,,,,,


In [21]:
sup_labels = ['cV', 'cR1', 'cR2', 'V', 'R1', 'R2']

In [22]:
def to_iob( words, tag):
    tags = []
    if not len(words):
        raise ValueError("Words should have len > 0.")
    tags.append("{}_{}".format("B", tag))
    tags.extend(["{}_{}".format("I", tag)] * (len(words) - 1))
    return tags

In [23]:
def parse_row(row, origin_idx, sup_labels=['cV', 'cR1', 'cR2', 'V', 'R1', 'R2'], word_tokenizer=lambda x: x.split()):
    text = row.text
    len_doc = len(text)
    pos = 0
    splitted = []
    tags = []
    cur_pos = 0
    cls = 0
    if isinstance(cls, (float, np.float64, np.float32, np.float, np.float128)):
        cls = int(row["class"])
    tokens_pos = dict()
    if cls:
        for name in sup_labels:
            if type(row[name]) is not float:
                for b in row[name].split():
                    b = b.split(":")
                    start_pos, end_pos = map(int, b)
                    tokens_pos[start_pos] = {"name": name, "text": text[start_pos:end_pos]}
    while pos < len_doc:
        token = tokens_pos.get(pos)
        if token:
            splitted.append(text[cur_pos:pos])
            tags.append(("O", pos))
            pos += len(token["text"])
            cur_pos = pos
            splitted.append(token["text"])
            tags.append((token["name"], pos))
        pos += 1
    if len(text[cur_pos:pos]):
        # pos = min(len_doc, self.max_pos + 20)
        splitted.append(text[cur_pos:pos])
        tags.append(("O", cur_pos))
    res_tokens = []
    assert len(splitted) == len(tags)
    cur_pos = 0
    len_word = 0
    res_labels = []
    res_tokens = []
    res_start_pos = []
    res_len = []
    for s_text, (tag, pos) in zip(splitted, tags):
        words = word_tokenizer(s_text)
        # words = self.__clean_word(text).split()
        if not len(words):
            continue
        if tag == "O":
            token_tags = ["O"] * len(words)
        else:
            token_tags = to_iob(words, tag)
        assert len(words) == len(token_tags)
        for word, tag_ in zip(words, token_tags):
            cur_pos = text.find(word, cur_pos)
            len_word = len(word)
            if len(word):
                res_labels.append(tag_)
                res_tokens.append(word)
                res_start_pos.append(str(cur_pos))
                res_len.append(str(len_word))
            cur_pos += len_word
    return pd.DataFrame({
        0: [" ".join(res_labels)],
        1: [" ".join(res_tokens)],
        2: [cls],
        "start_pos": [" ".join(res_start_pos)],
        "len": [" ".join(res_len)],
        "origin_idx": origin_idx
    }, columns=[0, 1, 2, "start_pos", "len", "origin_idx"])

In [24]:
from tqdm import tqdm

In [25]:
def prc_df(path, sup_labels):
    res = pd.DataFrame(columns=[0, 1, 2, "start_pos", "len", "origin_idx"])
    df = pd.read_csv(path, sep="\t")
    for origin_idx, row in tqdm(df.iterrows(), total=len(df), leave=False):
        res = res.append(parse_row(row, origin_idx, sup_labels))
    return res

In [26]:
train_df = prc_df(train_path, sup_labels)

                                                      

In [27]:
train_df.head()

Unnamed: 0,0,1,2,start_pos,len,origin_idx
0,O O O O O O O O O O O O O O O O O O O O O O O O O,"Будучи в прошлый четверг в Софии, он назвал се...",0,0 7 9 17 25 27 34 37 44 49 61 82 84 87 97 102 ...,6 1 7 7 1 6 2 6 4 11 20 1 2 9 4 4 5 7 1 4 7 8 ...,0
0,O O O O O O O O O O O O B_cR1 I_cR1 B_cV O B_c...,Работа с двухбайтовыми наборами символов — про...,1,0 7 9 23 32 41 43 50 57 61 75 79 83 89 92 100 ...,6 1 13 8 8 1 6 6 3 13 3 3 5 2 7 2 6 6 1 5 1 2 4 1,1
0,O O O O O O O O O O O O O O O O O O O O O O O ...,"Заместитель Генерального секретаря подчеркнул,...",0,0 12 25 35 47 51 60 71 74 81 97 108 110 118 12...,11 12 9 11 3 8 10 2 6 15 10 1 7 2 1 12 5 6 4 1...,2
0,O O O O O O O O O O O O O O O O O O O O O O O O,Продажа недвижимости из собственных портфелей ...,0,0 8 21 24 36 46 56 72 76 78 87 94 97 107 109 1...,7 12 2 11 9 9 15 3 1 8 6 2 9 1 6 1 13 9 8 3 9 ...,3
0,O O O O O O O O O O O O O O O O O O O O O O O ...,"Новым является то, что повышенное давление кон...",0,0 6 15 19 23 34 43 59 67 73 86 98 102 111 113 ...,5 8 3 3 10 8 15 7 5 12 11 3 8 1 7 5 11 9 1 15 ...,4


In [28]:
dev_df = prc_df(valid_path, sup_labels)

                                                    

In [29]:
dev_df.head()

Unnamed: 0,0,1,2,start_pos,len,origin_idx
0,O O O O O O O O O O O O O O O O O O O O O O O ...,"Центральная часть лунной тени, где наблюдается...",0,0 12 18 25 31 35 47 54 59 68 77 80 86 88 99 10...,11 5 6 5 3 11 6 4 8 8 2 5 1 10 9 5 5 12 10 1 7...,0
0,O B_cV B_cR1 I_cR1 B_cR2 I_cR2 O O B_R1 I_R1 I...,"Я превращу твое сердце в траву , а все, что ты...",1,0 2 11 16 23 25 30 32 34 39 43 46 52 54 56 60,1 8 4 6 1 5 1 1 4 3 2 6 1 1 4 1,1
0,O O O B_cR1 B_cV B_cR2 I_cR2 I_cR2 O O B_R1 B_...,В данном примере строки сгруппированы по назва...,1,0 2 9 17 24 38 41 50 56 58 60 68 71 81 88,1 6 7 6 13 2 8 6 1 1 7 2 9 7 1,2
0,O O O O O O O O O O O O,Ассоциация намерена занять на мировом рынке до...,0,0 11 20 27 30 38 44 50 57 63 71 79,10 8 6 2 7 5 5 6 5 7 7 14,3
0,O O O O O O O,Ты сама все портишь-твои слова хороший мой,0,0 3 8 12 25 31 39,2 4 3 12 5 7 3,4


In [30]:
train_df.to_csv("/datadrive/AGRR-2019/train_parsed.csv", index=False)

In [31]:
dev_df.to_csv("/datadrive/AGRR-2019/dev_parsed.csv", index=False)

### 1. Create dataloaders

In [32]:
from modules import BertNerData as NerData

In [196]:
import os


train_path = "/datadrive/AGRR-2019/train_parsed.csv"
valid_path = "/datadrive/AGRR-2019/dev_parsed.csv"
model_dir = " /datadrive/models/multi_cased_L-12_H-768_A-12/"
init_checkpoint_pt = os.path.join("/datadrive/models/multi_cased_L-12_H-768_A-12/", "pytorch_model.bin")
bert_config_file = os.path.join("/datadrive/bert/multi_cased_L-12_H-768_A-12/", "bert_config.json")
vocab_file = os.path.join("/datadrive/bert/multi_cased_L-12_H-768_A-12/", "vocab.txt")

In [None]:
data = NerData.create(train_path, valid_path, vocab_file, is_cls=True)

In [198]:
print(data.label2idx)

{'<pad>': 0, '[CLS]': 1, 'B_O': 2, 'X': 3, 'B_cR1': 4, 'I_cR1': 5, 'B_cV': 6, 'B_cR2': 7, 'B_R1': 8, 'B_R2': 9, 'I_R2': 10, 'I_cR2': 11, 'I_R1': 12, 'I_cV': 13}


In [199]:
len(data.label2idx)

14

### 2. Create model
For creating pytorch model we need to create `NerModel` object.

In [200]:
from modules.models.bert_models import BertBiLSTMAttnNCRFJoint, BertBiLSTMAttnNCRF, BertBiLSTMNCRF

In [201]:
# for 0.9
model = BertBiLSTMAttnNCRFJoint.create(
    len(data.label2idx), len(data.cls2idx), bert_config_file, init_checkpoint_pt,
    enc_hidden_dim=1024, rnn_layers=1, num_heads=3, input_dropout=0.5, nbest=12)

build CRF...


In [202]:
model.decoder

AttnNCRFJointDecoder(
  (attn): MultiHeadAttention(
    (attention): _MultiHeadAttention(
      (attention): ScaledDotProductAttention(
        (softmax): Softmax()
        (dropout): Dropout(p=0.5)
      )
    )
    (proj): Linear(in_features=192, out_features=1024, bias=True)
    (dropout): Dropout(p=0.5)
    (layer_norm): LayerNormalization()
  )
  (linear): Linears(
    (linears): ModuleList(
      (0): Linear(in_features=1024, out_features=512, bias=True)
    )
    (output_linear): Linear(in_features=512, out_features=16, bias=True)
  )
  (crf): NCRF()
  (intent_out): PoolingLinearClassifier(
    (dropout): Dropout(p=0.5)
    (linear): Linears(
      (linears): ModuleList(
        (0): Linear(in_features=3072, out_features=512, bias=True)
      )
      (output_linear): Linear(in_features=512, out_features=2, bias=True)
    )
  )
  (intent_loss): CrossEntropyLoss()
)

In [203]:
model.get_n_trainable_params()

8148255

### 3. Create learner

For training our pytorch model we need to create `NerLearner` object.

In [204]:
from modules import NerLearner

In [205]:
num_epochs = 100
learner = NerLearner(model, data,
                     best_model_path="/datadrive/AGRR-2019/big.cpt",
                     lr=0.001, clip=5.0, sup_labels=data.id2label[1:],
                     t_total=num_epochs * len(data.train_dl))

### 4. Learn your NER model
Call `learner.fit`

In [44]:
from sklearn_crfsuite.metrics import flat_classification_report

In [None]:
learner.fit(num_epochs, target_metric='f1')

2019-02-25 05:28:05,483 INFO: Resuming train... Current epoch 1.
2019-02-25 05:36:29,881 INFO:                                                      
epoch 2, average train epoch loss=12.336

2019-02-25 05:37:26,868 INFO: on epoch 1 by max_f1: 0.949
2019-02-25 05:37:26,869 INFO: on epoch {} classification report:
2019-02-25 05:37:26,870 INFO: Saving new best model...


             precision    recall  f1-score   support

      [CLS]      1.000     1.000     1.000      4142
        B_O      0.921     0.958     0.940     74268
          X      0.989     0.983     0.986    109877
      B_cR1      0.726     0.571     0.639      1382
      I_cR1      0.717     0.494     0.585      2022
       B_cV      0.809     0.760     0.784      1382
      B_cR2      0.761     0.681     0.719      1355
       B_R1      0.868     0.869     0.869      1500
       B_R2      0.849     0.857     0.853      1473
       I_R2      0.811     0.626     0.707      3255
      I_cR2      0.708     0.673     0.690      2395
       I_R1      0.864     0.732     0.793      1769
       I_cV      0.000     0.000     0.000         1

avg / total      0.949     0.950     0.949    204821

             precision    recall  f1-score   support

          0      0.789     0.995     0.880      2760
          1      0.979     0.469     0.634      1382

avg / total      0.852     0.819     0.79

2019-02-25 05:45:58,676 INFO:                                                      
epoch 3, average train epoch loss=6.9561

2019-02-25 05:46:54,367 INFO: on epoch 2 by max_f1: 0.961
2019-02-25 05:46:54,368 INFO: on epoch {} classification report:
2019-02-25 05:46:54,369 INFO: Saving new best model...


             precision    recall  f1-score   support

      [CLS]      1.000     1.000     1.000      4142
        B_O      0.948     0.959     0.953     74268
          X      0.985     0.994     0.989    109877
      B_cR1      0.852     0.599     0.703      1382
      I_cR1      0.837     0.522     0.643      2022
       B_cV      0.902     0.790     0.843      1382
      B_cR2      0.826     0.799     0.812      1355
       B_R1      0.893     0.903     0.898      1500
       B_R2      0.872     0.908     0.889      1473
       I_R2      0.849     0.816     0.832      3255
      I_cR2      0.822     0.746     0.782      2395
       I_R1      0.917     0.814     0.862      1769
       I_cV      0.000     0.000     0.000         1

avg / total      0.961     0.962     0.961    204821

             precision    recall  f1-score   support

          0      0.775     0.999     0.873      2760
          1      0.995     0.422     0.592      1382

avg / total      0.849     0.806     0.77

2019-02-25 05:55:34,985 INFO:                                                      
epoch 4, average train epoch loss=5.1567

2019-02-25 05:56:31,253 INFO: on epoch 3 by max_f1: 0.97
2019-02-25 05:56:31,254 INFO: on epoch {} classification report:
2019-02-25 05:56:31,255 INFO: Saving new best model...


             precision    recall  f1-score   support

      [CLS]      1.000     1.000     1.000      4142
        B_O      0.965     0.962     0.963     74268
          X      0.990     0.993     0.992    109877
      B_cR1      0.808     0.753     0.779      1382
      I_cR1      0.785     0.688     0.734      2022
       B_cV      0.862     0.922     0.891      1382
      B_cR2      0.837     0.859     0.848      1355
       B_R1      0.898     0.936     0.916      1500
       B_R2      0.893     0.935     0.913      1473
       I_R2      0.879     0.879     0.879      3255
      I_cR2      0.807     0.846     0.826      2395
       I_R1      0.924     0.874     0.898      1769
       I_cV      0.000     0.000     0.000         1

avg / total      0.970     0.971     0.970    204821

             precision    recall  f1-score   support

          0      0.850     0.998     0.918      2760
          1      0.993     0.648     0.784      1382

avg / total      0.898     0.881     0.87

2019-02-25 06:05:04,251 INFO:                                                     
epoch 5, average train epoch loss=4.1059

2019-02-25 06:05:59,580 INFO: on epoch 3 by max_f1: 0.97
2019-02-25 06:05:59,581 INFO: on epoch {} classification report:
  0%|          | 0/1026 [00:00<?, ?it/s]

             precision    recall  f1-score   support

      [CLS]      1.000     1.000     1.000      4142
        B_O      0.974     0.951     0.962     74268
          X      0.993     0.990     0.992    109877
      B_cR1      0.777     0.815     0.795      1382
      I_cR1      0.676     0.846     0.751      2022
       B_cV      0.876     0.949     0.911      1382
      B_cR2      0.839     0.869     0.854      1355
       B_R1      0.912     0.928     0.920      1500
       B_R2      0.905     0.942     0.923      1473
       I_R2      0.815     0.936     0.872      3255
      I_cR2      0.739     0.937     0.827      2395
       I_R1      0.852     0.928     0.888      1769
       I_cV      0.000     0.000     0.000         1

avg / total      0.972     0.970     0.970    204821

             precision    recall  f1-score   support

          0      0.959     0.980     0.969      2760
          1      0.958     0.917     0.937      1382

avg / total      0.959     0.959     0.95

2019-02-25 06:14:26,323 INFO:                                                      
epoch 6, average train epoch loss=3.3608

2019-02-25 06:15:17,202 INFO: on epoch 5 by max_f1: 0.974
2019-02-25 06:15:17,203 INFO: on epoch {} classification report:
2019-02-25 06:15:17,203 INFO: Saving new best model...


             precision    recall  f1-score   support

      [CLS]      1.000     1.000     1.000      4142
        B_O      0.964     0.972     0.968     74268
          X      0.995     0.990     0.992    109877
      B_cR1      0.817     0.783     0.800      1382
      I_cR1      0.809     0.772     0.790      2022
       B_cV      0.876     0.957     0.915      1382
      B_cR2      0.862     0.861     0.862      1355
       B_R1      0.922     0.923     0.922      1500
       B_R2      0.911     0.943     0.927      1473
       I_R2      0.891     0.914     0.903      3255
      I_cR2      0.853     0.838     0.845      2395
       I_R1      0.944     0.841     0.889      1769
       I_cV      0.000     0.000     0.000         1

avg / total      0.974     0.974     0.974    204821

             precision    recall  f1-score   support

          0      0.960     0.985     0.972      2760
          1      0.968     0.918     0.942      1382

avg / total      0.962     0.962     0.96

2019-02-25 06:23:59,292 INFO:                                                      
epoch 7, average train epoch loss=2.9361

2019-02-25 06:24:50,833 INFO: on epoch 6 by max_f1: 0.976
2019-02-25 06:24:50,834 INFO: on epoch {} classification report:
2019-02-25 06:24:50,834 INFO: Saving new best model...


             precision    recall  f1-score   support

      [CLS]      1.000     1.000     1.000      4142
        B_O      0.964     0.976     0.970     74268
          X      0.994     0.991     0.993    109877
      B_cR1      0.821     0.842     0.831      1382
      I_cR1      0.775     0.831     0.802      2022
       B_cV      0.943     0.905     0.924      1382
      B_cR2      0.898     0.838     0.867      1355
       B_R1      0.940     0.926     0.933      1500
       B_R2      0.907     0.939     0.923      1473
       I_R2      0.933     0.861     0.896      3255
      I_cR2      0.880     0.824     0.851      2395
       I_R1      0.962     0.859     0.907      1769
       I_cV      0.000     0.000     0.000         1

avg / total      0.976     0.976     0.976    204821

             precision    recall  f1-score   support

          0      0.961     0.991     0.976      2760
          1      0.981     0.920     0.950      1382

avg / total      0.968     0.967     0.96

2019-02-25 06:33:35,156 INFO:                                                      
epoch 8, average train epoch loss=2.7444

2019-02-25 06:34:26,910 INFO: on epoch 6 by max_f1: 0.976
2019-02-25 06:34:26,911 INFO: on epoch {} classification report:
  0%|          | 0/1026 [00:00<?, ?it/s]

             precision    recall  f1-score   support

      [CLS]      1.000     1.000     1.000      4142
        B_O      0.971     0.971     0.971     74268
          X      0.992     0.993     0.993    109877
      B_cR1      0.862     0.812     0.836      1382
      I_cR1      0.751     0.867     0.805      2022
       B_cV      0.963     0.894     0.927      1382
      B_cR2      0.892     0.856     0.874      1355
       B_R1      0.944     0.917     0.930      1500
       B_R2      0.937     0.924     0.930      1473
       I_R2      0.915     0.906     0.910      3255
      I_cR2      0.874     0.863     0.869      2395
       I_R1      0.912     0.918     0.915      1769
       I_cV      0.000     0.000     0.000         1

avg / total      0.977     0.976     0.976    204821

             precision    recall  f1-score   support

          0      0.961     0.989     0.975      2760
          1      0.976     0.920     0.947      1382

avg / total      0.966     0.966     0.96

train loss: 2.3024029034573035:  32%|███▏      | 333/1026 [02:49<05:52,  1.97it/s]

In [63]:
len(learner.history)

16

### 5. Evaluate
Create new data loader from existing path.

In [206]:
from modules.data.bert_data import get_bert_data_loader_for_predict

In [None]:
dl = get_bert_data_loader_for_predict(valid_path, learner)

In [209]:
learner.load_model()

In [None]:
preds_res, preds_cls = learner.predict(dl)
# preds_res = learner.predict(dl)

IOB precision

In [None]:
from modules.train.train import validate_step
p_rep, p_cls = validate_step(
    learner.data.valid_dl, learner.model, learner.data.id2label, learner.data.id2label[1:], learner.data.cls2idx)

In [106]:
print(p_rep)

             precision    recall  f1-score   support

      [CLS]      1.000     1.000     1.000      4142
        B_O      0.976     0.975     0.975     74268
          X      0.993     0.993     0.993    109877
      B_cR1      0.860     0.851     0.855      1382
      I_cR1      0.855     0.857     0.856      2022
       B_cV      0.956     0.937     0.946      1382
      B_cR2      0.893     0.895     0.894      1355
       B_R1      0.946     0.923     0.935      1500
       B_R2      0.933     0.948     0.940      1473
       I_R2      0.932     0.911     0.922      3255
      I_cR2      0.858     0.921     0.888      2395
       I_R1      0.915     0.922     0.919      1769
       I_cV      0.000     0.000     0.000         1

avg / total      0.980     0.980     0.980    204821



In [70]:
print(p_cls)

             precision    recall  f1-score   support

          0      0.887     0.999     0.940      2760
          1      0.998     0.745     0.853      1382

avg / total      0.924     0.915     0.911      4142



Tokens report

In [211]:
from sklearn_crfsuite.metrics import flat_classification_report

In [212]:
from modules.utils.utils import bert_labels2tokens, first_choicer

In [240]:
pred_tokens, pred_labels = bert_labels2tokens(dl, preds_res, first_choicer)
true_tokens, true_labels = bert_labels2tokens(dl, [x.labels for x in dl.dataset], first_choicer)

In [214]:
assert pred_tokens == true_tokens
tokens_report = flat_classification_report(true_labels, pred_labels)

#### By tokens reports

In [215]:
from sklearn.metrics import accuracy_score, f1_score

pl = []
for line in pred_labels:
    for p in line:
        pl.append(p)
        
tl = []
for line in true_labels:
    for p in line:
        tl.append(p)
accuracy_score(tl, pl), f1_score(tl, pl, average="macro")

(0.9704301667364155, 0.8328229828166069)

In [216]:
print(tokens_report)

             precision    recall  f1-score   support

        B_O       0.98      0.98      0.98     74268
       B_R1       0.95      0.92      0.94      1500
       B_R2       0.94      0.95      0.94      1473
      B_cR1       0.86      0.85      0.86      1382
      B_cR2       0.89      0.90      0.89      1355
       B_cV       0.96      0.94      0.95      1382
       I_R1       0.93      0.92      0.93      1769
       I_R2       0.94      0.91      0.93      3255
      I_cR1       0.86      0.86      0.86      2022
      I_cR2       0.87      0.92      0.89      2395
       I_cV       0.00      0.00      0.00         1

avg / total       0.97      0.97      0.97     90802



### 6. To needle format

In [217]:
import pandas as pd

In [218]:
def tokens2spans_(tokens_, labels_, start_pos):
    pos_idx = 0
    res = []
    idx_ = 0
    while idx_ < len(labels_):
        label = labels_[idx_]
        if label in ["I_O", "B_O", "O"]:
            res.append((tokens_[idx_], "O", None))
            idx_ += 1
        elif label == "[SEP]" or label == "<eos>":
            break
        elif label == "[CLS]" or label == "<bos>":
            res.append((tokens_[idx_], label, start_pos[idx_]))
            idx_ += 1
        else:
            span = [tokens_[idx_]]
            
            try:
                pos = start_pos[idx_]
            except:
                print(" ".join(tokens_))
                print(" ".join(labels_))
                print(tokens_[idx_], labels_[idx_])
                print(start_pos[idx_])
            try:
                span_label = labels_[idx_].split("_")[1]
            except IndexError:
                print(label, labels_[idx_].split("_"))
                span_label = None
            idx_ += 1
            while idx_ < len(labels_):
                if labels_[idx_] not in ["I_O", "B_O", "O"] and labels_[idx_].split("_")[0] == "I" and  span_label == labels_[idx_].split("_")[1]:
                    span.append(tokens_[idx_])
                    idx_ += 1
                # Skip one O
                elif idx_ + 1 < len(labels_) and labels_[idx_ + 1] not in ["I_O", "B_O", "O"] and labels_[idx_ + 1].split("_")[0] == "I" and  span_label == labels_[idx_ + 1].split("_")[1]:
                    span.append(tokens_[idx_])
                    idx_ += 1
                    span.append(tokens_[idx_])
                    idx_ += 1
                else:
                    break
            res.append((" ".join(span), span_label, pos))
    return res

In [274]:
text = ""

In [275]:
def post_prc_row(text, tokens, labels, cls, start_pos):
    # SHIT IN DATASET, fix align
    text = text.replace('\x97', "U")
    text = text.replace('\uf076', "U")
    text = text.replace("\ue405", "U")
    text = text.replace("\ue105", "U")
    text = text.replace("\ue415", "U")
    text = text.replace('\x07', "U")
    start_pos = []
    pos = 0
    for tok in tokens:
        if tok.find("unk") > -1:
            tok = tok.replace("unk", "U")
        pos = text.find(tok, max(pos-2, 0))
        # SHIT IN DATASET, fix align
        if pos == -1:
            text = text.replace("unk", "UUU")
            tok = "UUU"
            print(tok)
        start_pos.append(pos)
    spans = tokens2spans_(tokens, labels, start_pos)
    res = {}
    cls = 0
    for token, label, pos in spans:
        if label != "O":
            cls = 1
            pos = int(pos)
            if label in res:
                if tok.find("unk") > -1:
                    tok = tok.replace("unk", "U")
                pp = pos
                pos = text.find(token, max(pos-1, 0))
                if pos == -1:
                    print(pp, text[pp:], token, text[pp:].find(token))
                    pos = pp
                res[label] += " " + ":".join(map(str, [pos, pos + len(token)]))
                if label == "R2":
                    res["V"] += " " + ":".join(map(str, [pos, pos]))
            else:
                res[label] = ":".join(map(str, [pos, pos + len(token)]))
                if label == "R2":
                    res["V"] = ":".join(map(str, [pos, pos]))
    if res.get("R1") and res.get("R2", None) is None:
        try:
            res["V"] = " ".join(["{}:{}".format(f.split(":")[0], f.split(":")[0]) for f in res.get("R1").split()])
        except:
            print(res.get("R1"), res.get("R1").split(), )
            raise
    res["text"] = text
    res["class"] = cls
    for key in res:
        res[key] = [res[key]]
    return pd.DataFrame(res, columns=['text', 'class', 'cV', 'cR1', 'cR2', 'V', 'R1', 'R2'])

In [243]:
def post_prc(origin_path, parsed_path, pred_tokens, pred_labels, pred_cls):
    origin_df = pd.read_csv(origin_path, sep="\t")
    parsed_df = pd.read_csv(parsed_path)
    res = pd.DataFrame(columns=['text', 'class', 'cV', 'cR1', 'cR2', 'V', 'R1', 'R2'])
    idx = 0
    for (_, o_row), (_, p_row), tokens, labels, cls in zip(origin_df.iterrows(), parsed_df.iterrows(), pred_tokens, pred_labels, pred_cls):
        try:
            res = res.append(post_prc_row(o_row.text, tokens, labels, cls, p_row.start_pos.split()))
        except:
            print(o_row.text)
            print(" ".join(tokens))
            print(len(p_row.start_pos.split()), len(tokens))
            print(idx)
            raise
        idx += 1
    return res

In [None]:
dev_pred = post_prc(
    "/datadrive/AGRR-2019/dev.csv",
    "/datadrive/AGRR-2019/dev_parsed.csv",
    pred_tokens, pred_labels, preds_cls)

In [222]:
dev_pred.head()

Unnamed: 0,text,class,cV,cR1,cR2,V,R1,R2
0,"Центральная часть лунной тени, где наблюдается...",0,,,,,,
0,"Я превращу твое сердце в траву, а все, что ты ...",1,2:10,11:22,23:30,54:54,34:52,54:60
0,В данном примере строки сгруппированы по назва...,1,24:37,17:23,38:56,68:68,60:67,68:88
0,Ассоциация намерена занять на мировом рынке до...,0,,,,,,
0,Ты сама все портишь-твои слова хороший мой,0,,,,,,


In [223]:
dev_pred.to_csv("/datadrive/AGRR-2019/dev_pred.csv", sep="\t", index=False)

In [224]:
!python3 /datadrive/AGRR-2019/agrr_metrics.py -r /datadrive/AGRR-2019/dev.csv /datadrive/AGRR-2019/dev_pred.csv

Binary classification quality (f1-score): 0.9583778014941302
Gapping resolution quality (symbol-wise f-measure): 0.9576077060524616


### 7. Make prediction

In [225]:
test_path = "/datadrive/AGRR-2019/test.csv"

In [None]:
test_df = prc_df(test_path, sup_labels)

In [227]:
test_df.to_csv("/datadrive/AGRR-2019/test_parsed.csv", index=False)

In [228]:
from modules.data.bert_data import get_bert_data_loader_for_predict

In [None]:
dl = get_bert_data_loader_for_predict("/datadrive/AGRR-2019/test_parsed.csv", learner)

In [337]:
learner.load_model()

In [None]:
preds_res, preds_cls = learner.predict(dl)
# preds_res = learner.predict(dl)

In [231]:
from modules.utils.utils import bert_labels2tokens, first_choicer

In [232]:
pred_tokens, pred_labels = bert_labels2tokens(dl, preds_res, first_choicer)
true_tokens, true_labels = bert_labels2tokens(dl, [x.labels for x in dl.dataset], first_choicer)

##### See wtf encoding or spaces :(

In [276]:
test_pred = post_prc(
    "/datadrive/AGRR-2019/test.csv",
    "/datadrive/AGRR-2019/test_parsed.csv",
    pred_tokens, pred_labels, preds_cls)

167 1 890 989 (один миллион восемьсот девяносто тысяч девятьсот восемьдесят девять) рублей 09 копеек, в том числе НДС (18 %) – 288 455 (двести восемьдесят восемь тысяч четыреста пятьдесят пять) рублей 96 копеек. 1 890 989 (один миллион -1
UUU
87 О  log  n   — «довольно дешево». О log n -1
277 1997 г. – 41). 1997 г. -1
227 относительно соответствующих месяцев 1997 г. достигло 30%, а в декабре — 40%. относительно соответствующих месяцев 1997 г. -1


In [277]:
test_pred.head()

Unnamed: 0,text,class,cV,cR1,cR2,V,R1,R2
0,"В РП и торцевой стенке, расположенной со сторо...",0,,,,,,
0,Диаграммы каротажа сопротивлений помогают в вы...,1,33:41,0:32,42:91,142:142,94:139,142:167
0,При оценке структуры в тенговом эквиваленте 99...,1,89:99,44:88,100:107,123:123 149:149,108:122 133:146,123:130 149:156
0,Ты самый лучший во всём Мироздании и я люблю т...,0,,,,,,
0,Способ позволяет ликвидировать обострение забо...,0,,,,,,


In [278]:
test_pred.to_csv("/datadrive/AGRR-2019/test_pred.csv", sep="\t")

### 8. Merge train and dev

In [2]:
import pandas as pd

In [3]:
train_path = "/datadrive/AGRR-2019/train_parsed.csv"
valid_path = "/datadrive/AGRR-2019/dev_parsed.csv"

In [347]:
train_df = pd.read_csv(train_path)
dev_df = pd.read_csv(valid_path)

In [348]:
ful_df = train_df.append(dev_df)

In [349]:
len(ful_df)

20548

In [350]:
ful_df.to_csv("/datadrive/AGRR-2019/train_dev_parsed.csv")

### 9. Train full

In [279]:
from modules import BertNerData as NerData

In [280]:
import os


train_path = "/datadrive/AGRR-2019/train_dev_parsed.csv"
valid_path = "/datadrive/AGRR-2019/dev_parsed.csv"
model_dir = " /datadrive/models/multi_cased_L-12_H-768_A-12/"
init_checkpoint_pt = os.path.join("/datadrive/models/multi_cased_L-12_H-768_A-12/", "pytorch_model.bin")
bert_config_file = os.path.join("/datadrive/bert/multi_cased_L-12_H-768_A-12/", "bert_config.json")
vocab_file = os.path.join("/datadrive/bert/multi_cased_L-12_H-768_A-12/", "vocab.txt")

In [None]:
data = NerData.create(train_path, valid_path, vocab_file, is_cls=True)

In [283]:
print(data.label2idx)

{'<pad>': 0, '[CLS]': 1, 'B_O': 2, 'X': 3, 'B_cR1': 4, 'I_cR1': 5, 'B_cV': 6, 'B_cR2': 7, 'B_R1': 8, 'B_R2': 9, 'I_R2': 10, 'I_cR2': 11, 'I_R1': 12, 'I_cV': 13}


In [284]:
len(data.label2idx)

14

### 9.2. Create model
For creating pytorch model we need to create `NerModel` object.

In [285]:
from modules.models.bert_models import BertBiLSTMAttnNCRFJoint, BertBiLSTMAttnNCRF, BertBiLSTMNCRF

In [286]:
# for 0.9
model = BertBiLSTMAttnNCRFJoint.create(
    len(data.label2idx), len(data.cls2idx), bert_config_file, init_checkpoint_pt,
    enc_hidden_dim=1024, rnn_layers=1, num_heads=3, input_dropout=0.5, nbest=12)

build CRF...


In [287]:
model.decoder

AttnNCRFJointDecoder(
  (attn): MultiHeadAttention(
    (attention): _MultiHeadAttention(
      (attention): ScaledDotProductAttention(
        (softmax): Softmax()
        (dropout): Dropout(p=0.5)
      )
    )
    (proj): Linear(in_features=192, out_features=1024, bias=True)
    (dropout): Dropout(p=0.5)
    (layer_norm): LayerNormalization()
  )
  (linear): Linears(
    (linears): ModuleList(
      (0): Linear(in_features=1024, out_features=512, bias=True)
    )
    (output_linear): Linear(in_features=512, out_features=16, bias=True)
  )
  (crf): NCRF()
  (intent_out): PoolingLinearClassifier(
    (dropout): Dropout(p=0.5)
    (linear): Linears(
      (linears): ModuleList(
        (0): Linear(in_features=3072, out_features=512, bias=True)
      )
      (output_linear): Linear(in_features=512, out_features=2, bias=True)
    )
  )
  (intent_loss): CrossEntropyLoss()
)

In [288]:
model.get_n_trainable_params()

8148255

### 9.3. Create learner

For training our pytorch model we need to create `NerLearner` object.

In [289]:
from modules import NerLearner

In [290]:
num_epochs = 50
learner = NerLearner(model, data,
                     best_model_path="/datadrive/AGRR-2019/big_full.cpt",
                     lr=0.001, clip=5.0, sup_labels=data.id2label[1:],
                     t_total=num_epochs * len(data.train_dl))

### 9.4. Learn your NER model
Call `learner.fit`

In [None]:
learner.fit(num_epochs, target_metric='f1')

#### How to overfit :)

In [34]:
from modules.data.bert_data import get_bert_data_loader_for_predict

In [None]:
dl = get_bert_data_loader_for_predict(valid_path, learner)

In [None]:
preds_res, preds_cls = learner.predict(dl)
# preds_res = learner.predict(dl)

IOB precision

In [None]:
from modules.train.train import validate_step
p_rep, p_cls = validate_step(
    learner.data.valid_dl, learner.model, learner.data.id2label, learner.data.id2label[1:], learner.data.cls2idx)

In [38]:
print(p_rep)

             precision    recall  f1-score   support

      [CLS]      1.000     1.000     1.000      4142
        B_O      1.000     1.000     1.000     74268
          X      1.000     1.000     1.000    109877
      B_cR1      0.999     0.999     0.999      1382
      I_cR1      1.000     1.000     1.000      2022
       B_cV      1.000     1.000     1.000      1382
      B_cR2      1.000     0.999     1.000      1355
       B_R1      1.000     1.000     1.000      1500
       B_R2      1.000     1.000     1.000      1473
       I_R2      1.000     1.000     1.000      3255
      I_cR2      1.000     1.000     1.000      2395
       I_R1      1.000     1.000     1.000      1769
       I_cV      1.000     1.000     1.000         1

avg / total      1.000     1.000     1.000    204821



In [39]:
print(p_cls)

             precision    recall  f1-score   support

          0      1.000     1.000     1.000      2760
          1      1.000     1.000     1.000      1382

avg / total      1.000     1.000     1.000      4142



Tokens report

In [40]:
from sklearn_crfsuite.metrics import flat_classification_report

In [41]:
from modules.utils.utils import bert_labels2tokens, first_choicer

In [42]:
pred_tokens, pred_labels = bert_labels2tokens(dl, preds_res, first_choicer)
true_tokens, true_labels = bert_labels2tokens(dl, [x.labels for x in dl.dataset], first_choicer)

In [43]:
assert pred_tokens == true_tokens
tokens_report = flat_classification_report(true_labels, pred_labels)

#### By tokens reports

In [44]:
from sklearn.metrics import accuracy_score, f1_score

pl = []
for line in pred_labels:
    for p in line:
        pl.append(p)
        
tl = []
for line in true_labels:
    for p in line:
        tl.append(p)
accuracy_score(tl, pl), f1_score(tl, pl, average="macro")

(0.9999669610801524, 0.9998804619684648)

In [45]:
print(tokens_report)

             precision    recall  f1-score   support

        B_O       1.00      1.00      1.00     74268
       B_R1       1.00      1.00      1.00      1500
       B_R2       1.00      1.00      1.00      1473
      B_cR1       1.00      1.00      1.00      1382
      B_cR2       1.00      1.00      1.00      1355
       B_cV       1.00      1.00      1.00      1382
       I_R1       1.00      1.00      1.00      1769
       I_R2       1.00      1.00      1.00      3255
      I_cR1       1.00      1.00      1.00      2022
      I_cR2       1.00      1.00      1.00      2395
       I_cV       1.00      1.00      1.00         1

avg / total       1.00      1.00      1.00     90802



In [188]:
dev_pred = post_prc(
    "/datadrive/AGRR-2019/dev.csv",
    "/datadrive/AGRR-2019/dev_parsed.csv",
    pred_tokens, pred_labels, preds_cls)

In [189]:
dev_pred.head()

Unnamed: 0,text,class,cV,cR1,cR2,V,R1,R2
0,"Центральная часть лунной тени, где наблюдается...",0,,,,,,
0,"Я превращу твое сердце в траву, а все, что ты ...",1,2:10,11:22,23:30,54:54,34:52,54:60
0,В данном примере строки сгруппированы по назва...,1,24:37,17:23,38:56,68:68,60:67,68:88
0,Ассоциация намерена занять на мировом рынке до...,0,,,,,,
0,Ты сама все портишь-твои слова хороший мой,0,,,,,,


In [190]:
dev_pred.to_csv("/datadrive/AGRR-2019/dev_pred_overfit.csv", sep="\t", index=False)

### WTF -> errors with spaces or so on?)

In [195]:
!python3 /datadrive/AGRR-2019/agrr_metrics.py -r /datadrive/AGRR-2019/dev.csv /datadrive/AGRR-2019/dev_pred_overfit.csv

Binary classification quality (f1-score): 1.0
Gapping resolution quality (symbol-wise f-measure): 0.9982856495331368


### 9.5 Make prediction with model trained on full data

In [291]:
from modules.data.bert_data import get_bert_data_loader_for_predict

In [None]:
dl = get_bert_data_loader_for_predict("/datadrive/AGRR-2019/test_parsed.csv", learner)

In [294]:
learner.load_model()

In [None]:
preds_res, preds_cls = learner.predict(dl)
# preds_res = learner.predict(dl)

In [297]:
from modules.utils.utils import bert_labels2tokens, first_choicer

In [298]:
pred_tokens, pred_labels = bert_labels2tokens(dl, preds_res, first_choicer)
true_tokens, true_labels = bert_labels2tokens(dl, [x.labels for x in dl.dataset], first_choicer)

##### See wtf encoding or spaces :(

In [299]:
test_pred = post_prc(
    "/datadrive/AGRR-2019/test.csv",
    "/datadrive/AGRR-2019/test_parsed.csv",
    pred_tokens, pred_labels, preds_cls)

163 т в  год. т в год. -1
UUU
87 О  log  n   — «довольно дешево». О log n -1
277 1997 г. – 41). 1997 г. -1


In [300]:
test_pred.head()

Unnamed: 0,text,class,cV,cR1,cR2,V,R1,R2
0,"В РП и торцевой стенке, расположенной со сторо...",0,,,,,,
0,Диаграммы каротажа сопротивлений помогают в вы...,1,33:41,0:32,42:91,142:142,94:139,142:167
0,При оценке структуры в тенговом эквиваленте 99...,1,89:99,44:88,100:107,123:123 149:149,108:122 133:146,123:130 149:156
0,Ты самый лучший во всём Мироздании и я люблю т...,0,,,,,,
0,Способ позволяет ликвидировать обострение забо...,0,,,,,,


In [301]:
test_pred.to_csv("/datadrive/AGRR-2019/test_pred_full.csv", sep="\t")