In [1]:
#!/usr/bin/python
# -*- coding: UTF-8 -*-

'''
@version:0.1
@author:Cai Qingpeng
@file: test.py
@time: 2020/3/18 7:30 PM
'''



import os
import numpy as np
from conlleval import evaluate
from flair.data import Corpus
from flair.datasets import ColumnCorpus

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# define columns
columns = {0: 'text', 1: '_', 2: '_', 3: 'ner'}

# this is the folder in which train, test and dev files reside
data_folder = './data'  # /path/to/data/folder

# init a corpus using column format, data folder and the names of the train, dev and test files
corpus: Corpus = ColumnCorpus(data_folder, columns,
                              train_file='train.txt',
                              test_file='test.txt',
                              dev_file='valid.txt')

print(corpus)

# 2. what tag do we want to predict?
tag_type = 'ner'

# 3. make the tag dictionary from the corpus
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
print(tag_dictionary)


2020-03-30 15:32:39,196 Reading data from data
2020-03-30 15:32:39,197 Train: data/train.txt
2020-03-30 15:32:39,198 Dev: data/valid.txt
2020-03-30 15:32:39,198 Test: data/test.txt
Corpus: 14041 train + 3250 dev + 3453 test sentences
Dictionary with 12 tags: <unk>, O, B-ORG, B-MISC, B-PER, I-PER, B-LOC, I-ORG, I-MISC, I-LOC, <START>, <STOP>


In [2]:
real = []

for sentence in corpus.test:
    for token in sentence.tokens:
        real.append(token.get_tag("ner").value)
str(real)

"['O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'B-PER', 'O', 'O', 'O', 'O', 'B-PER', 'I-PER', 'B-LOC', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', '

In [3]:
os.listdir("./log/")

['pool_flair_f_20200330002549',
 'bert',
 'mix_ebx_20200324145239',
 'mix_ebx_20200324144951',
 'mix_ebx_20200324145746',
 'elmo_20200323220158',
 'flair_20200320202343',
 'elmo',
 'xlnet_20200320131431',
 'bert_20200318225447']

In [6]:
from flair.models import SequenceTagger
labels = tag_dictionary.get_items()
print(labels)

bert_model = SequenceTagger.load('./log/bert_20200318225447/best-model.pt')
elmo_model = SequenceTagger.load('./log/elmo/best-model.pt')
xlnet_model = SequenceTagger.load('./log/xlnet_20200320131431/best-model.pt')
flair_model = SequenceTagger.load('./log/flair_20200320202343/best-model.pt')


['<unk>', 'O', 'B-ORG', 'B-MISC', 'B-PER', 'I-PER', 'B-LOC', 'I-ORG', 'I-MISC', 'I-LOC', '<START>', '<STOP>']
2020-03-30 15:39:37,885 loading file ./log/bert_20200318225447/best-model.pt
2020-03-30 15:39:38,340 loading file ./log/elmo/best-model.pt
2020-03-30 15:39:38,462 loading file ./log/xlnet_20200320131431/best-model.pt
2020-03-30 15:39:51,690 loading file ./log/flair_20200320202343/best-model.pt


In [8]:
model_dict = {
    "bert": bert_model,
    "elmo": elmo_model,
    "xlnet": xlnet_model,
    "flair": flair_model,
}
model_dict

{'bert': SequenceTagger(
   (embeddings): BertEmbeddings(
     (model): BertModel(
       (embeddings): BertEmbeddings(
         (word_embeddings): Embedding(30522, 768, padding_idx=0)
         (position_embeddings): Embedding(512, 768)
         (token_type_embeddings): Embedding(2, 768)
         (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
         (dropout): Dropout(p=0.1, inplace=False)
       )
       (encoder): BertEncoder(
         (layer): ModuleList(
           (0): BertLayer(
             (attention): BertAttention(
               (self): BertSelfAttention(
                 (query): Linear(in_features=768, out_features=768, bias=True)
                 (key): Linear(in_features=768, out_features=768, bias=True)
                 (value): Linear(in_features=768, out_features=768, bias=True)
                 (dropout): Dropout(p=0.1, inplace=False)
               )
               (output): BertSelfOutput(
                 (dense): Linear(in_features=768, out_

In [9]:
len(corpus.test)

3453

In [10]:
def model_prediction(model):
    model_pred = []
    for sentence in corpus.test:
        model.predict(sentence)
        for token in sentence.tokens:
            model_pred.append(token.get_tag("ner").value)
    return model_pred


In [11]:
print("****** bert prediction ******")
bert_pred = model_prediction(bert_model)
print(evaluate(real,bert_pred))
print("****** elmo prediction ******")
elmo_pred = model_prediction(elmo_model)
print(evaluate(real,elmo_pred))
print("****** xlnet prediction ******")
xlnet_pred = model_prediction(xlnet_model)
print(evaluate(real,xlnet_pred))
print("****** flair prediction ******")
flair_pred = model_prediction(flair_model)
print(evaluate(real,flair_pred))

****** bert prediction ******
processed 46435 tokens with 5648 phrases; found: 5678 phrases; correct: 5121.
accuracy:  91.86%; (non-O)
accuracy:  97.94%; precision:  90.19%; recall:  90.67%; FB1:  90.43
              LOC: precision:  91.47%; recall:  91.97%; FB1:  91.72  1677
             MISC: precision:  81.10%; recall:  79.49%; FB1:  80.29  688
              ORG: precision:  86.41%; recall:  88.44%; FB1:  87.41  1700
              PER: precision:  96.71%; recall:  96.47%; FB1:  96.59  1613
(90.19020781965482, 90.66926345609065, 90.4291011831185)
****** elmo prediction ******
processed 46435 tokens with 5648 phrases; found: 5684 phrases; correct: 5116.
accuracy:  91.75%; (non-O)
accuracy:  97.93%; precision:  90.01%; recall:  90.58%; FB1:  90.29
              LOC: precision:  92.21%; recall:  92.27%; FB1:  92.24  1669
             MISC: precision:  77.12%; recall:  80.20%; FB1:  78.63  730
              ORG: precision:  88.33%; recall:  88.44%; FB1:  88.39  1663
              PER: pr

In [17]:
pool_flair_model = SequenceTagger.load('./log/pool_flair_f_20200330002549/best-model.pt')
print("****** pool_flair prediction ******")
pool_flair_pred = model_prediction(pool_flair_model)
print(evaluate(real,pool_flair_pred))

2020-03-30 16:15:19,251 loading file ./log/pool_flair_f_20200330002549/best-model.pt
****** pool_flair prediction ******


RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'other' in call to _th_min

In [18]:
def cal_proba_score(model,sentence):
    model.predict(sentence,all_tag_prob=True)
    score = []
    for t_id, token in enumerate(sentence.tokens):
        # print(token.get_tag("ner").value)
        # print(token.get_tags_proba_dist("ner"))
        for index,item in enumerate(token.get_tags_proba_dist("ner")):
            # print(item.value)
            # print(item.score)
            score.append(item.score)
    return score

def get_mix_preds(models,method="avg"):
    mix_pred = []
    for sentence in corpus.test:
        scores = []
        for model in models:
            model_score = cal_proba_score(model,sentence)
            scores.append(model_score) 
        
        if method == "avg":
            scores = np.mean(scores,axis=0)
        elif method == "confidence":
            scores = np.max(scores,axis=0)
        else:
            raise NotImplementedError(method)
        result = np.reshape(scores,(len(sentence),len(labels)))

        id_result = np.argmax(result,axis=1)
        la_result = [tag_dictionary.get_item_for_index(i) for i in id_result]

        mix_pred.extend(la_result)
    return mix_pred

In [19]:
print("****** avg prediction ******")
avg_pred = get_mix_preds([bert_model,elmo_model],"avg")
print(evaluate(real,avg_pred))
print("****** confidence prediction ******")
confidence_pred = get_mix_preds([bert_model,elmo_model],"confidence")
print(evaluate(real,confidence_pred))

****** avg prediction ******
processed 46435 tokens with 5648 phrases; found: 5683 phrases; correct: 5196.
accuracy:  92.95%; (non-O)
accuracy:  98.38%; precision:  91.43%; recall:  92.00%; FB1:  91.71
              LOC: precision:  92.73%; recall:  92.51%; FB1:  92.62  1664
             MISC: precision:  80.75%; recall:  82.48%; FB1:  81.61  717
              ORG: precision:  89.17%; recall:  90.25%; FB1:  89.71  1681
              PER: precision:  97.16%; recall:  97.40%; FB1:  97.28  1621
(91.43058243885271, 91.9971671388102, 91.71299973523959)
****** confidence prediction ******
processed 46435 tokens with 5648 phrases; found: 5681 phrases; correct: 5194.
accuracy:  92.89%; (non-O)
accuracy:  98.36%; precision:  91.43%; recall:  91.96%; FB1:  91.69
              LOC: precision:  92.62%; recall:  92.57%; FB1:  92.59  1667
             MISC: precision:  81.61%; recall:  82.19%; FB1:  81.90  707
              ORG: precision:  89.01%; recall:  90.25%; FB1:  89.63  1684
              PE

In [20]:
print("****** avg prediction ******")
avg_pred = get_mix_preds([bert_model,xlnet_model],"avg")
print(evaluate(real,avg_pred))
print("****** confidence prediction ******")
confidence_pred = get_mix_preds([bert_model,xlnet_model],"confidence")
print(evaluate(real,confidence_pred))

****** avg prediction ******
processed 46435 tokens with 5648 phrases; found: 5706 phrases; correct: 5141.
accuracy:  92.28%; (non-O)
accuracy:  98.25%; precision:  90.10%; recall:  91.02%; FB1:  90.56
              LOC: precision:  91.43%; recall:  92.75%; FB1:  92.08  1692
             MISC: precision:  78.37%; recall:  81.05%; FB1:  79.69  726
              ORG: precision:  87.95%; recall:  87.42%; FB1:  87.68  1651
              PER: precision:  96.09%; recall:  97.28%; FB1:  96.68  1637
(90.0981423063442, 91.02337110481587, 90.55839351770301)
****** confidence prediction ******
processed 46435 tokens with 5648 phrases; found: 5704 phrases; correct: 5142.
accuracy:  92.25%; (non-O)
accuracy:  98.24%; precision:  90.15%; recall:  91.04%; FB1:  90.59
              LOC: precision:  91.54%; recall:  92.81%; FB1:  92.17  1691
             MISC: precision:  78.97%; recall:  80.77%; FB1:  79.86  718
              ORG: precision:  87.98%; recall:  87.66%; FB1:  87.82  1655
              PE

In [21]:
print("****** avg prediction ******")
avg_pred = get_mix_preds([elmo_model,xlnet_model],"avg")
print(evaluate(real,avg_pred))
print("****** confidence prediction ******")
confidence_pred = get_mix_preds([elmo_model,xlnet_model],"confidence")
print(evaluate(real,confidence_pred))

****** avg prediction ******
processed 46435 tokens with 5648 phrases; found: 5784 phrases; correct: 5136.
accuracy:  92.15%; (non-O)
accuracy:  97.97%; precision:  88.80%; recall:  90.93%; FB1:  89.85
              LOC: precision:  91.34%; recall:  92.27%; FB1:  91.80  1685
             MISC: precision:  75.36%; recall:  81.48%; FB1:  78.30  759
              ORG: precision:  86.56%; recall:  88.38%; FB1:  87.46  1696
              PER: precision:  94.71%; recall:  96.29%; FB1:  95.49  1644
(88.79668049792531, 90.93484419263456, 89.85304408677396)
****** confidence prediction ******
processed 46435 tokens with 5648 phrases; found: 5781 phrases; correct: 5135.
accuracy:  92.21%; (non-O)
accuracy:  97.99%; precision:  88.83%; recall:  90.92%; FB1:  89.86
              LOC: precision:  91.30%; recall:  92.51%; FB1:  91.90  1690
             MISC: precision:  75.23%; recall:  81.34%; FB1:  78.17  759
              ORG: precision:  86.68%; recall:  88.14%; FB1:  87.40  1689
              P

In [22]:
print("****** avg prediction ******")
avg_pred = get_mix_preds([bert_model,elmo_model,xlnet_model],"avg")
print(evaluate(real,avg_pred))
print("****** confidence prediction ******")
confidence_pred = get_mix_preds([bert_model,elmo_model,xlnet_model],"confidence")
print(evaluate(real,confidence_pred))

****** avg prediction ******
processed 46435 tokens with 5648 phrases; found: 5702 phrases; correct: 5175.
accuracy:  92.91%; (non-O)
accuracy:  98.28%; precision:  90.76%; recall:  91.63%; FB1:  91.19
              LOC: precision:  92.52%; recall:  92.69%; FB1:  92.60  1671
             MISC: precision:  79.59%; recall:  82.19%; FB1:  80.87  725
              ORG: precision:  88.31%; recall:  89.10%; FB1:  88.70  1676
              PER: precision:  96.44%; recall:  97.22%; FB1:  96.83  1630
(90.7576289021396, 91.62535410764873, 91.18942731277532)
****** confidence prediction ******
processed 46435 tokens with 5648 phrases; found: 5707 phrases; correct: 5178.
accuracy:  92.75%; (non-O)
accuracy:  98.34%; precision:  90.73%; recall:  91.68%; FB1:  91.20
              LOC: precision:  92.20%; recall:  92.87%; FB1:  92.53  1680
             MISC: precision:  79.36%; recall:  81.62%; FB1:  80.48  722
              ORG: precision:  88.42%; recall:  89.16%; FB1:  88.79  1675
              PE

In [23]:
print("****** avg prediction ******")
avg_pred = get_mix_preds([bert_model,elmo_model,xlnet_model,flair_model],"avg")
print(evaluate(real,avg_pred))
print("****** confidence prediction ******")
confidence_pred = get_mix_preds([bert_model,elmo_model,xlnet_model,flair_model],"confidence")
print(evaluate(real,confidence_pred))

****** avg prediction ******
processed 46435 tokens with 5648 phrases; found: 5704 phrases; correct: 5166.
accuracy:  92.78%; (non-O)
accuracy:  98.26%; precision:  90.57%; recall:  91.47%; FB1:  91.01
              LOC: precision:  92.24%; recall:  92.63%; FB1:  92.43  1675
             MISC: precision:  80.42%; recall:  81.91%; FB1:  81.16  715
              ORG: precision:  88.11%; recall:  88.80%; FB1:  88.46  1674
              PER: precision:  95.79%; recall:  97.16%; FB1:  96.47  1640
(90.56802244039271, 91.46600566572238, 91.01479915433404)
****** confidence prediction ******
processed 46435 tokens with 5648 phrases; found: 5701 phrases; correct: 5166.
accuracy:  92.55%; (non-O)
accuracy:  98.32%; precision:  90.62%; recall:  91.47%; FB1:  91.04
              LOC: precision:  91.87%; recall:  92.87%; FB1:  92.37  1686
             MISC: precision:  80.31%; recall:  81.34%; FB1:  80.82  711
              ORG: precision:  87.93%; recall:  88.56%; FB1:  88.24  1673
              P

In [42]:
mix_ebx_model = SequenceTagger.load('./log/mix_ebx_20200323214754/best-model.pt')
print("****** mix_ebx prediction ******")
mix_ebx_pred = model_prediction(mix_ebx_model)
print(evaluate(real,mix_ebx_pred))

2020-03-24 14:37:28,293 loading file ./log/mix_ebx_20200323214754/best-model.pt


FileNotFoundError: [Errno 2] No such file or directory: './log/mix_ebx_20200323214754/best-model.pt'

In [43]:
print("****** elmo(middle) prediction ******")
elmo_m_model = SequenceTagger.load('./log/elmo_20200323220158/best-model.pt')
elmo_m_pred = model_prediction(elmo_m_model)
print(evaluate(real,elmo_m_pred))

****** elmo(middle) prediction ******
2020-03-24 14:40:16,294 loading file ./log/elmo_20200323220158/best-model.pt
processed 46435 tokens with 5648 phrases; found: 5717 phrases; correct: 5160.
accuracy:  92.50%; (non-O)
accuracy:  98.04%; precision:  90.26%; recall:  91.36%; FB1:  90.81
              LOC: precision:  92.23%; recall:  91.85%; FB1:  92.04  1661
             MISC: precision:  79.01%; recall:  82.05%; FB1:  80.50  729
              ORG: precision:  86.98%; recall:  90.07%; FB1:  88.49  1720
              PER: precision:  96.83%; recall:  96.23%; FB1:  96.53  1607
(90.25712786426448, 91.35977337110481, 90.80510338759349)
