### FactRuEval nmt evaluation

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import warnings
import sys

sys.path.append("../")

warnings.filterwarnings("ignore")

In [2]:
import os


data_path = "/home/lis/ner/ulmfit/data/factrueval/"
train_path = os.path.join(data_path, "train_with_pos.csv")
valid_path = os.path.join(data_path, "valid_with_pos.csv")
model_dir = " /datadrive/models/multi_cased_L-12_H-768_A-12/"
init_checkpoint_pt = os.path.join("/datadrive/models/multi_cased_L-12_H-768_A-12/", "pytorch_model.bin")
bert_config_file = os.path.join("/datadrive/bert/multi_cased_L-12_H-768_A-12/", "bert_config.json")
vocab_file = os.path.join("/datadrive/bert/multi_cased_L-12_H-768_A-12/", "vocab.txt")

In [3]:
import torch
torch.cuda.set_device(1)
torch.cuda.is_available(), torch.cuda.current_device()

(True, 1)

### 1. Create dataloaders

In [4]:
from modules import BertNerData as NerData

INFO:summarizer.preprocessing.cleaner:'pattern' package not found; tag filters are not available for English


In [5]:
data = NerData.create(train_path, valid_path, vocab_file)

For factrueval we use the following sample of labels:

In [6]:
print(data.label2idx)

{'<pad>': 0, '[CLS]': 1, '[SEP]': 2, 'B_O': 3, 'I_O': 4, 'B_ORG': 5, 'I_ORG': 6, 'B_LOC': 7, 'I_LOC': 8, 'B_PER': 9, 'I_PER': 10}


### 2. Create model
For creating pytorch model we need to create `NerModel` object.

In [8]:
from modules.models.bert_models import BertBiLSTMAttnNMT

In [10]:
model = BertBiLSTMAttnNMT.create(len(data.label2idx), bert_config_file, init_checkpoint_pt,
                                 enc_hidden_dim=128, dec_hidden_dim=128, dec_embedding_dim=16)

In [11]:
model.decoder

NMTDecoder(
  (embedding): Embedding(11, 16)
  (lstm): LSTM(272, 128, batch_first=True)
  (attn): Linear(in_features=128, out_features=128, bias=True)
  (slot_out): Linear(in_features=256, out_features=11, bias=True)
  (loss): CrossEntropyLoss()
)

In [12]:
model.get_n_trainable_params()

652360

### 3. Create learner

For training our pytorch model we need to create `NerLearner` object.

In [13]:
from modules import NerLearner

In [14]:
num_epochs = 100
learner = NerLearner(model, data,
                     best_model_path="/datadrive/models/factrueval/final_attn_cased_nmt.cpt",
                     lr=0.01, clip=1.0, sup_labels=data.id2label[5:],
                     t_total=num_epochs * len(data.train_dl))

INFO:root:Don't use lr scheduler...


### 4. Learn your NER model
Call `learner.fit`

In [None]:
learner.fit(num_epochs, target_metric='prec')

### 5. Evaluate
Create new data loader from existing path.

In [16]:
from modules.data.bert_data import get_bert_data_loader_for_predict

In [17]:
dl = get_bert_data_loader_for_predict(data_path + "valid_with_pos.csv", learner)

In [18]:
learner.load_model()

In [19]:
preds = learner.predict(dl)

HBox(children=(IntProgress(value=0, max=26), HTML(value='')))

IOB precision

In [20]:
from modules.train.train import validate_step
print(validate_step(learner.data.valid_dl, learner.model, learner.data.id2label, learner.sup_labels))

HBox(children=(IntProgress(value=0, max=26), HTML(value='')))

              precision    recall  f1-score   support

       B_ORG      0.870     0.776     0.820       259
       I_ORG      0.937     0.775     0.848      1000
       B_LOC      0.914     0.880     0.897       192
       I_LOC      0.894     0.835     0.863       303
       B_PER      0.958     0.979     0.968       188
       I_PER      0.974     0.978     0.976       649

   micro avg      0.935     0.856     0.894      2591
   macro avg      0.925     0.871     0.896      2591
weighted avg      0.934     0.856     0.892      2591



Span precision

In [21]:
from modules.utils.plot_metrics import get_bert_span_report
clf_report = get_bert_span_report(dl, preds)
print(clf_report)

              precision    recall  f1-score   support

         LOC      0.840     0.823     0.832       192
         PER      0.870     0.888     0.879       188
         ORG      0.770     0.726     0.748       259

   micro avg      0.822     0.803     0.812       639
   macro avg      0.827     0.812     0.819       639
weighted avg      0.821     0.803     0.811       639

