### Conll 2003 evaluation

Data downloaded from [here](https://github.com/kyzhouhzau/BERT-NER/tree/master/NERdata).

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import pandas as pd
import warnings
import os
import sys

sys.path.append("../")

warnings.filterwarnings("ignore")

In [65]:
data_path = "/datadrive/conll-2003/"

train_path = data_path + "train.txt"
dev_path = data_path + "dev.txt"
test_path = data_path + "test.txt"

### 0. Prc data for csv format

In [66]:
import codecs


def read_data(input_file):
    """Reads a BIO data."""
    with codecs.open(input_file, "r", encoding="utf-8") as f:
        lines = []
        words = []
        labels = []
        for line in f:
            contends = line.strip()
            word = line.strip().split(' ')[0]
            label = line.strip().split(' ')[-1]
            if contends.startswith("-DOCSTART-"):
                words.append('')
                continue
            
            if len(contends) == 0 and not len(words):
                words.append("")
            
            if len(contends) == 0 and words[-1] == '.':
                l = ' '.join([label for label in labels if len(label) > 0])
                w = ' '.join([word for word in words if len(word) > 0])
                lines.append([l, w])
                words = []
                labels = []
                continue
            words.append(word)
            labels.append(label.replace("-", "_"))
        return lines


In [67]:
train_f = read_data(train_path)
dev_f = read_data(dev_path)
test_f = read_data(test_path)

In [None]:
[l for l in train_f]

In [68]:
len(train_f), len(dev_f), len(test_f)

(6973, 1739, 1559)

In [69]:
train_f[0]

['B_ORG O B_MISC O O O B_MISC O O',
 'EU rejects German call to boycott British lamb .']

In [70]:
import pandas as pd

In [71]:
train_df = pd.DataFrame(train_f, columns=["0", "1"])
train_df.to_csv(data_path + "train.csv", index=False)

In [72]:
valid_df = pd.DataFrame(dev_f, columns=["0", "1"])
valid_df.to_csv(data_path + "valid.csv", index=False)

In [73]:
test_df = pd.DataFrame(test_f, columns=["0", "1"])
test_df.to_csv(data_path + "test.csv", index=False)

### 1. Create data loaders

In [2]:
import os

data_path = "/datadrive/conll-2003/"
train_path = data_path + "train.csv"
valid_path = data_path + "valid.csv"
test_path = data_path + "test.csv"

model_dir = " /datadrive/models/multi_cased_L-12_H-768_A-12/"
init_checkpoint_pt = os.path.join("/datadrive/models/multi_cased_L-12_H-768_A-12/", "pytorch_model.bin")
bert_config_file = os.path.join("/datadrive/bert/multi_cased_L-12_H-768_A-12/", "bert_config.json")
vocab_file = os.path.join("/datadrive/bert/multi_cased_L-12_H-768_A-12/", "vocab.txt")

In [3]:
import torch
torch.cuda.set_device(0)
torch.cuda.is_available(), torch.cuda.current_device()

(True, 0)

In [4]:
from modules import BertNerData as NerData

INFO:summarizer.preprocessing.cleaner:'pattern' package not found; tag filters are not available for English


In [5]:
data = NerData.create(train_path, valid_path, vocab_file)

In [6]:
len(data.train_dl.dataset), len(data.valid_dl.dataset)

(6973, 1739)

In [7]:
print(data.id2label)

['<pad>', '[CLS]', '[SEP]', 'B_ORG', 'B_O', 'I_O', 'B_MISC', 'B_PER', 'I_PER', 'B_LOC', 'I_LOC', 'I_ORG', 'I_MISC']


In [8]:
sup_labels = ['B_ORG', 'B_MISC', 'B_PER', 'I_PER', 'B_LOC', 'I_LOC', 'I_ORG', 'I_MISC']

In [9]:
max([len(f.labels_ids) for f in data.train_dl.dataset])

424

### 2. Create model

In [10]:
from modules.models.bert_models import BertBiLSTMAttnCRF

In [18]:
model = BertBiLSTMAttnCRF.create(len(data.label2idx), bert_config_file, init_checkpoint_pt, enc_hidden_dim=256)

In [19]:
model.get_n_trainable_params()

1151739

#### TODO: fix bug with len

### 3. Create Learner

In [20]:
from modules import NerLearner

In [21]:
learner = NerLearner(model, data,
                     best_model_path="/datadrive/models/conll-2003/bilstm_attn_cased.cpt",
                     base_lr=0.0001, lr_max=0.005, clip=5.0, use_lr_scheduler=True, sup_labels=sup_labels)

INFO:root:Use lr OneCycleScheduler...


### 4. Start learning

In [None]:
learner.fit(25, target_metric='prec')

### 5. Evaluate dev set

In [23]:
from modules.data.bert_data import get_bert_data_loader_for_predict
dl = get_bert_data_loader_for_predict(data_path + "valid.csv", learner)

In [24]:
learner.load_model()

In [25]:
preds = learner.predict(dl)

HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

IOB precision

In [38]:
from modules.train.train import validate_step
print(validate_step(learner.data.valid_dl, learner.model, learner.data.id2label, learner.sup_labels))

HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

              precision    recall  f1-score   support

       B_ORG      0.917     0.924     0.920      1282
      B_MISC      0.925     0.864     0.894       905
       B_PER      0.971     0.976     0.973      1686
       I_PER      0.981     0.970     0.976      3488
       B_LOC      0.966     0.950     0.958      1669
       I_LOC      0.957     0.909     0.932      1913
       I_ORG      0.922     0.888     0.905      2129
      I_MISC      0.920     0.675     0.779      1061

   micro avg      0.953     0.915     0.933     14133
   macro avg      0.945     0.894     0.917     14133
weighted avg      0.952     0.915     0.932     14133



Span precision

In [44]:
from modules.utils.plot_metrics import get_bert_span_report
clf_report = get_bert_span_report(dl, preds, [])
print(clf_report)

              precision    recall  f1-score   support

         LOC      0.892     0.877     0.885      1669
         ORG      0.828     0.832     0.830      1282
           O      0.988     0.990     0.989     41846
        MISC      0.899     0.840     0.869       905
         PER      0.934     0.938     0.936      1686

   micro avg      0.977     0.977     0.977     47388
   macro avg      0.908     0.895     0.902     47388
weighted avg      0.977     0.977     0.977     47388



### 6. Evaluate test set

In [45]:
from modules.data.bert_data import get_bert_data_loader_for_predict
dl = get_bert_data_loader_for_predict(data_path + "test.csv", learner)

In [46]:
preds = learner.predict(dl)

HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

IOB precision

In [47]:
from modules.train.train import validate_step
print(validate_step(dl, learner.model, learner.data.id2label, learner.sup_labels))

HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

              precision    recall  f1-score   support

       B_ORG      0.917     0.924     0.920      1282
      B_MISC      0.925     0.864     0.894       905
       B_PER      0.971     0.976     0.973      1686
       I_PER      0.981     0.970     0.976      3488
       B_LOC      0.966     0.950     0.958      1669
       I_LOC      0.957     0.909     0.932      1913
       I_ORG      0.922     0.888     0.905      2129
      I_MISC      0.920     0.675     0.779      1061

   micro avg      0.953     0.915     0.933     14133
   macro avg      0.945     0.894     0.917     14133
weighted avg      0.952     0.915     0.932     14133



Span precision

In [59]:
from modules.utils.plot_metrics import get_bert_span_report
clf_report = get_bert_span_report(dl, preds, [])
print(clf_report)

              precision    recall  f1-score   support

         LOC      0.864     0.851     0.858      1570
         ORG      0.714     0.721     0.717      1533
           O      0.981     0.983     0.982     37683
        MISC      0.820     0.753     0.785       688
         PER      0.911     0.901     0.906      1566

   micro avg      0.963     0.963     0.963     43040
   macro avg      0.858     0.842     0.850     43040
weighted avg      0.962     0.963     0.962     43040



### 7. Get mean and stdv on 10 runs

In [None]:
from modules.utils.plot_metrics import *
from modules import NerLearner


num_runs = 10
best_reports = []
for i in range(num_runs):
    model = BertBiLSTMAttnCRF.create(len(data.label2idx), bert_config_file, init_checkpoint_pt, enc_hidden_dim=256)
    best_model_path = "/datadrive/models/conll-2003/exp_{}_attn_cased.cpt".format(i)
    learner = NerLearner(model, data,
                         best_model_path=best_model_path, verbose=False,
                         base_lr=0.0001, lr_max=0.001, clip=5.0, use_lr_scheduler=True, sup_labels=data.id2label[5:])
    learner.fit(100, target_metric='prec')
    idx, res = get_mean_max_metric(learner.history, "f1", True)
    best_reports.append(learner.history[idx])

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009988594480391467
INFO:root:
epoch 1, average train epoch loss=8.9381



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009954377770788668
INFO:root:
epoch 2, average train epoch loss=2.6673



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009897506632721709
INFO:root:
epoch 3, average train epoch loss=1.5642



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009818136929810007
INFO:root:
epoch 4, average train epoch loss=0.95785



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009829645918699543
INFO:root:
epoch 5, average train epoch loss=0.67193



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009863965743218892
INFO:root:
epoch 6, average train epoch loss=0.45879



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009920732813267874
INFO:root:
epoch 7, average train epoch loss=0.34029



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.00099997931621759
INFO:root:
epoch 8, average train epoch loss=0.29414



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009899006236929822
INFO:root:
epoch 9, average train epoch loss=0.2418



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009775817476504662
INFO:root:
epoch 10, average train epoch loss=0.19245



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009630791719937361
INFO:root:
epoch 11, average train epoch loss=0.19405



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009464079206871408
INFO:root:
epoch 12, average train epoch loss=0.13536



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009564965809716738
INFO:root:
epoch 13, average train epoch loss=0.13186



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009688253964396686
INFO:root:
epoch 14, average train epoch loss=0.12742



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009833179407998218
INFO:root:
epoch 15, average train epoch loss=0.11116



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009999593729379682
INFO:root:
epoch 16, average train epoch loss=0.11768



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009812650576713183
INFO:root:
epoch 17, average train epoch loss=0.11684



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009603700114833495
INFO:root:
epoch 18, average train epoch loss=0.11418



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009373700593990951
INFO:root:
epoch 19, average train epoch loss=0.10928



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009122796832781196
INFO:root:
epoch 20, average train epoch loss=0.070534



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009309836066412781
INFO:root:
epoch 21, average train epoch loss=0.087153



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009518882336136918
INFO:root:
epoch 22, average train epoch loss=0.075947



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009748785163573578
INFO:root:
epoch 23, average train epoch loss=0.10177



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009999401492651726
INFO:root:
epoch 24, average train epoch loss=0.10823



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009729410859499124
INFO:root:
epoch 25, average train epoch loss=0.080489



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009437793207548021
INFO:root:
epoch 26, average train epoch loss=0.063801



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.000912588600293419
INFO:root:
epoch 27, average train epoch loss=0.058319



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008793828838811322
INFO:root:
epoch 28, average train epoch loss=0.059022



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009063912086071017
INFO:root:
epoch 29, average train epoch loss=0.045475



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.00093556220888635
INFO:root:
epoch 30, average train epoch loss=0.059438



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009667436089027941
INFO:root:
epoch 31, average train epoch loss=0.059907



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009999216192338662
INFO:root:
epoch 32, average train epoch loss=0.051479



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009649174653736884
INFO:root:
epoch 33, average train epoch loss=0.056105



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009277872664856822
INFO:root:
epoch 34, average train epoch loss=0.053368



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008887013224597101
INFO:root:
epoch 35, average train epoch loss=0.064154



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008476730889217369
INFO:root:
epoch 36, average train epoch loss=0.040207



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008826861700161747
INFO:root:
epoch 37, average train epoch loss=0.050739



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009198252707618027
INFO:root:
epoch 38, average train epoch loss=0.028476



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009589022306493978
INFO:root:
epoch 39, average train epoch loss=0.079055



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009999037578156096
INFO:root:
epoch 40, average train epoch loss=0.054015



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009571833584706517
INFO:root:
epoch 41, average train epoch loss=0.040161



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009123722482727018
INFO:root:
epoch 42, average train epoch loss=0.048201



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008656759614480127
INFO:root:
epoch 43, average train epoch loss=0.041866



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008171074681076045
INFO:root:
epoch 44, average train epoch loss=0.050747



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008598364725683528
INFO:root:
epoch 45, average train epoch loss=0.025911



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009046561634211117
INFO:root:
epoch 46, average train epoch loss=0.028397



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009513437902791463
INFO:root:
epoch 47, average train epoch loss=0.031361



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.000999886540885058
INFO:root:
epoch 48, average train epoch loss=0.034734



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009497283188137605
INFO:root:
epoch 49, average train epoch loss=0.02559



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008975134451128505
INFO:root:
epoch 50, average train epoch loss=0.029419



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008434814169959513
INFO:root:
epoch 51, average train epoch loss=0.035056



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0007876447365778154
INFO:root:
epoch 52, average train epoch loss=0.025335



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008378112532693288
INFO:root:
epoch 53, average train epoch loss=0.038271



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008900343980120909
INFO:root:
epoch 54, average train epoch loss=0.026468



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009440580786370757
INFO:root:
epoch 55, average train epoch loss=0.019331



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009998699451873722
INFO:root:
epoch 56, average train epoch loss=0.027795



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.000942542276910978
INFO:root:
epoch 57, average train epoch loss=0.025859



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008831907872805493
INFO:root:
epoch 58, average train epoch loss=0.020375



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008220877110217311
INFO:root:
epoch 59, average train epoch loss=0.017998



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0007592450991395642
INFO:root:
epoch 60, average train epoch loss=0.05226



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008165807627441099
INFO:root:
epoch 61, average train epoch loss=0.028018



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008759402249750965
INFO:root:
epoch 62, average train epoch loss=0.016244



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009370352549418169
INFO:root:
epoch 63, average train epoch loss=0.010752



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.00099985394830681
INFO:root:
epoch 64, average train epoch loss=0.023369



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009356155266044499
INFO:root:
epoch 65, average train epoch loss=0.18327



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008693849292195535
INFO:root:
epoch 66, average train epoch loss=0.15579



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008014659471328631
INFO:root:
epoch 67, average train epoch loss=0.02635



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0007318701965169524
INFO:root:
epoch 68, average train epoch loss=0.014648



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0007961163250546575
INFO:root:
epoch 69, average train epoch loss=0.013891



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008623546073673774
INFO:root:
epoch 70, average train epoch loss=0.01578



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009302658334936922
INFO:root:
epoch 71, average train epoch loss=0.015194



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009998385286364502
INFO:root:
epoch 72, average train epoch loss=0.02194



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009289387119604383
INFO:root:
epoch 73, average train epoch loss=0.018623



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008560772234129887
INFO:root:
epoch 74, average train epoch loss=0.026317



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0007815882715959252
INFO:root:
epoch 75, average train epoch loss=0.019105



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0007054830535392704
INFO:root:
epoch 76, average train epoch loss=0.009021



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0007763902989674113
INFO:root:
epoch 77, average train epoch loss=0.013462



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008492591951499509
INFO:root:
epoch 78, average train epoch loss=0.014465



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009237406708624169
INFO:root:
epoch 79, average train epoch loss=0.015538



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.000999823665349005
INFO:root:
epoch 80, average train epoch loss=0.024188



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009225028146322963
INFO:root:
epoch 81, average train epoch loss=0.02295



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008432496951962235
INFO:root:
epoch 82, average train epoch loss=0.033862



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0007624278357146333
INFO:root:
epoch 83, average train epoch loss=0.016722



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0006800480291987818
INFO:root:
epoch 84, average train epoch loss=0.14866



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0007573760406183898
INFO:root:
epoch 85, average train epoch loss=0.0059774



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008366363004022818
INFO:root:
epoch 86, average train epoch loss=0.008101



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009174509535371078
INFO:root:
epoch 87, average train epoch loss=0.017693



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009998093383686922
INFO:root:
epoch 88, average train epoch loss=0.018855



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0009162991416794254
INFO:root:
epoch 89, average train epoch loss=0.019623



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0008308850184785628
INFO:root:
epoch 90, average train epoch loss=0.022019



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

INFO:root:
lr after epoch: 0.0007439587595654187
INFO:root:
epoch 91, average train epoch loss=0.017575



HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

HBox(children=(IntProgress(value=0, max=436), HTML(value='')))

In [17]:
import numpy as np

#### f1

Mean and std

In [18]:
np.mean([get_mean_max_metric([r]) for r in best_reports]), np.round(np.std([get_mean_max_metric([r]) for r in best_reports]), 3)

(0.949, 0.002)

Best

In [19]:
get_mean_max_metric(best_reports)

0.951

#### precision

Mean and std

In [20]:
np.mean([get_mean_max_metric([r], "prec") for r in best_reports]), np.round(np.std([get_mean_max_metric([r], "prec") for r in best_reports]), 3)

(0.9558333333333332, 0.002)

Best

In [21]:
get_mean_max_metric(best_reports, "prec")

0.959

#### Test set

In [24]:
idx = np.array([get_mean_max_metric([r]) for r in best_reports]).argmax()

In [25]:
learner.load_model("/datadrive/models/conll-2003/exp_{}_attn_cased.cpt".format(idx))

In [27]:
from modules.data.bert_data import get_bert_data_loader_for_predict
dl = get_bert_data_loader_for_predict(data_path + "test.csv", learner)

In [37]:
from modules.train.train import validate_step
print(validate_step(dl, learner.model, learner.data.id2label, learner.sup_labels))

HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

              precision    recall  f1-score   support

         I_O      0.968     0.972     0.970     10257
      B_MISC      0.873     0.817     0.844       688
       B_PER      0.961     0.963     0.962      1566
       I_PER      0.970     0.971     0.970      3347
       B_LOC      0.932     0.926     0.929      1570
       I_LOC      0.904     0.870     0.887      1444
       I_ORG      0.888     0.922     0.905      2546
      I_MISC      0.746     0.611     0.672       839

   micro avg      0.942     0.937     0.940     22257
   macro avg      0.905     0.882     0.892     22257
weighted avg      0.941     0.937     0.939     22257

