### Conll 2003 evaluation

Data downloaded from [here](https://github.com/kyzhouhzau/BERT-NER/tree/master/NERdata).

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import pandas as pd
import warnings
import os

warnings.filterwarnings("ignore")

In [65]:
data_path = "/datadrive/conll-2003/"

train_path = data_path + "train.txt"
dev_path = data_path + "dev.txt"
test_path = data_path + "test.txt"

### 0. Prc data for csv format

In [66]:
import codecs


def read_data(input_file):
    """Reads a BIO data."""
    with codecs.open(input_file, "r", encoding="utf-8") as f:
        lines = []
        words = []
        labels = []
        for line in f:
            contends = line.strip()
            word = line.strip().split(' ')[0]
            label = line.strip().split(' ')[-1]
            if contends.startswith("-DOCSTART-"):
                words.append('')
                continue
            
            if len(contends) == 0 and not len(words):
                words.append("")
            
            if len(contends) == 0 and words[-1] == '.':
                l = ' '.join([label for label in labels if len(label) > 0])
                w = ' '.join([word for word in words if len(word) > 0])
                lines.append([l, w])
                words = []
                labels = []
                continue
            words.append(word)
            labels.append(label.replace("-", "_"))
        return lines


In [67]:
train_f = read_data(train_path)
dev_f = read_data(dev_path)
test_f = read_data(test_path)

In [None]:
[l for l in train_f]

In [68]:
len(train_f), len(dev_f), len(test_f)

(6973, 1739, 1559)

In [69]:
train_f[0]

['B_ORG O B_MISC O O O B_MISC O O',
 'EU rejects German call to boycott British lamb .']

In [70]:
import pandas as pd

In [71]:
train_df = pd.DataFrame(train_f, columns=["0", "1"])
train_df.to_csv(data_path + "train.csv", index=False)

In [72]:
valid_df = pd.DataFrame(dev_f, columns=["0", "1"])
valid_df.to_csv(data_path + "valid.csv", index=False)

In [73]:
test_df = pd.DataFrame(test_f, columns=["0", "1"])
test_df.to_csv(data_path + "test.csv", index=False)

### 1. Create data loaders

In [2]:
import os

data_path = "/datadrive/conll-2003/"
train_path = data_path + "train.csv"
valid_path = data_path + "valid.csv"
test_path = data_path + "test.csv"

model_dir = " /datadrive/models/multi_cased_L-12_H-768_A-12/"
init_checkpoint_pt = os.path.join("/datadrive/models/multi_cased_L-12_H-768_A-12/", "pytorch_model.bin")
bert_config_file = os.path.join("/datadrive/bert/multi_cased_L-12_H-768_A-12/", "bert_config.json")
vocab_file = os.path.join("/datadrive/bert/multi_cased_L-12_H-768_A-12/", "vocab.txt")

In [3]:
import torch
torch.cuda.set_device(1)
torch.cuda.is_available(), torch.cuda.current_device()

(True, 1)

In [4]:
from modules import NerData

INFO:summarizer.preprocessing.cleaner:'pattern' package not found; tag filters are not available for English


In [12]:
data = NerData.create(train_path, valid_path, vocab_file)

In [13]:
len(data.train_dl.dataset), len(data.valid_dl.dataset)

(6973, 1739)

In [14]:
print(data.id2label)

['<pad>', '[CLS]', '[SEP]', 'B_ORG', 'B_O', 'I_O', 'B_MISC', 'B_PER', 'I_PER', 'B_LOC', 'I_LOC', 'I_ORG', 'I_MISC']


In [15]:
sup_labels = ['B_ORG', 'B_MISC', 'B_PER', 'I_PER', 'B_LOC', 'I_LOC', 'I_ORG', 'I_MISC']

In [16]:
max([len(f.labels_ids) for f in data.train_dl.dataset])

424

### 2. Create model

In [17]:
from modules.models.models import BertBiLSTMAttnCRF

In [18]:
model = BertBiLSTMAttnCRF.create(len(data.label2idx), bert_config_file, init_checkpoint_pt, enc_hidden_dim=256)

In [19]:
model.get_n_trainable_params()

1151739

#### TODO: fix bug with len

### 3. Create Learner

In [20]:
from modules import NerLearner

In [21]:
learner = NerLearner(model, data,
                     best_model_path="/datadrive/models/conll-2003/bilstm_attn_cased.cpt",
                     base_lr=0.0001, lr_max=0.005, clip=5.0, use_lr_scheduler=True, sup_labels=sup_labels)

INFO:root:Use lr OneCycleScheduler...


### 4. Start learning

In [None]:
learner.fit(25, target_metric='prec')

### 5. Evaluate dev set

In [23]:
from modules.data.data import get_bert_data_loader_for_predict
dl = get_bert_data_loader_for_predict(data_path + "valid.csv", learner)

In [24]:
learner.load_model()

In [25]:
preds = learner.predict(dl)

HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

IOB precision

In [38]:
from modules.train.train import validate_step
print(validate_step(learner.data.valid_dl, learner.model, learner.data.id2label, learner.sup_labels))

HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

              precision    recall  f1-score   support

       B_ORG      0.917     0.924     0.920      1282
      B_MISC      0.925     0.864     0.894       905
       B_PER      0.971     0.976     0.973      1686
       I_PER      0.981     0.970     0.976      3488
       B_LOC      0.966     0.950     0.958      1669
       I_LOC      0.957     0.909     0.932      1913
       I_ORG      0.922     0.888     0.905      2129
      I_MISC      0.920     0.675     0.779      1061

   micro avg      0.953     0.915     0.933     14133
   macro avg      0.945     0.894     0.917     14133
weighted avg      0.952     0.915     0.932     14133



Span precision

In [44]:
from modules.utils.plot_metrics import get_bert_span_report
clf_report = get_bert_span_report(dl, preds, [])
print(clf_report)

              precision    recall  f1-score   support

         LOC      0.892     0.877     0.885      1669
         ORG      0.828     0.832     0.830      1282
           O      0.988     0.990     0.989     41846
        MISC      0.899     0.840     0.869       905
         PER      0.934     0.938     0.936      1686

   micro avg      0.977     0.977     0.977     47388
   macro avg      0.908     0.895     0.902     47388
weighted avg      0.977     0.977     0.977     47388



### 6. Evaluate test set

In [45]:
from modules.data.data import get_bert_data_loader_for_predict
dl = get_bert_data_loader_for_predict(data_path + "test.csv", learner)

In [46]:
preds = learner.predict(dl)

HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

IOB precision

In [47]:
from modules.train.train import validate_step
print(validate_step(learner.data.valid_dl, learner.model, learner.data.id2label, learner.sup_labels))

HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

              precision    recall  f1-score   support

       B_ORG      0.917     0.924     0.920      1282
      B_MISC      0.925     0.864     0.894       905
       B_PER      0.971     0.976     0.973      1686
       I_PER      0.981     0.970     0.976      3488
       B_LOC      0.966     0.950     0.958      1669
       I_LOC      0.957     0.909     0.932      1913
       I_ORG      0.922     0.888     0.905      2129
      I_MISC      0.920     0.675     0.779      1061

   micro avg      0.953     0.915     0.933     14133
   macro avg      0.945     0.894     0.917     14133
weighted avg      0.952     0.915     0.932     14133



Span precision

In [59]:
from modules.utils.plot_metrics import get_bert_span_report
clf_report = get_bert_span_report(dl, preds, [])
print(clf_report)

              precision    recall  f1-score   support

         LOC      0.864     0.851     0.858      1570
         ORG      0.714     0.721     0.717      1533
           O      0.981     0.983     0.982     37683
        MISC      0.820     0.753     0.785       688
         PER      0.911     0.901     0.906      1566

   micro avg      0.963     0.963     0.963     43040
   macro avg      0.858     0.842     0.850     43040
weighted avg      0.962     0.963     0.962     43040

