# 注意：
此文件已在colab上成功运行，经过测评效果效果排序从小到大依次为:bilstm,bilstm_crf,bert,bert_crf,具体效果参考代码执行结果

# 配置colab环境

In [1]:
#colab中运行jupyter文件的步骤：
# 1.挂载云盘
from google.colab import drive
drive.mount('/content/gdrive')

# 2.安装需要的软件
!pip3 install transformers
!pip3 install pytorch-crf

import os
def get_root_dir():
    if os.path.exists('/content/gdrive/MyDrive/第二次进行实体识别-面向课程_toColab/'):
        return '/content/gdrive/MyDrive/第二次进行实体识别-面向课程_toColab/' #在Colab里
    else:
        return './' #在本地

# 3.调用系统命令，切换到对应工程路径，相当于cd，但是直接!cd是不行的
print("path:",get_root_dir())
os.chdir(get_root_dir())

# 4.再次确认路径
!pwd
!ls

Mounted at /content/gdrive
Collecting transformers
  Downloading transformers-4.15.0-py3-none-any.whl (3.4 MB)
[K     |████████████████████████████████| 3.4 MB 8.0 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 53.2 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.46-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 63.8 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 72.6 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.2.1-py3-none-any.whl (61 kB)
[K     |████████████████████████████████| 61 kB 652 kB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingfac

# 开始正式工程

In [2]:
from collections import defaultdict
from operator import itemgetter
import numpy as np
import torch
import torch.nn.functional as F # pytorch 激活函数的类
import pickle as pk
import pandas as pd
from tqdm import tqdm
from model_zoo import *

# 此处是加载对应的模型和配置文件
def load_model(model_name):
    parameter = pk.load(open('parameter.pkl','rb'))
    #     parameter['device'] = torch.device('cpu')
    # 因为bert模型需要加载他对应的config文件，因此此处进行了一定的区分
    if 'bert' in model_name:
        if 'speed' in model_name:
            model = eval(model_name.split('-')[0]+"(config,parameter).to(parameter['device'])")
        else:
            model = eval(model_name+"(config,parameter).to(parameter['device'])")
    else:
        model = eval(model_name+"(parameter).to(parameter['device'])")
    model.load_state_dict(torch.load(model_name+'.h5'))
    model.eval() 
    return model,parameter

# 将数组转成pytorch支持的输入
def list2torch(ins):
    return torch.from_numpy(np.array(ins))

# 此处和之前的数据预处理方式一致，不过这边是考虑bert有自带的字典因此，进行了一定的区分
def batch_yield(parameter,shuffle = True,isTrain = True,isBert = False):
    data_set = parameter['data_set']['train'] if isTrain else parameter['data_set']['dev']
    Epoch = parameter['epoch'] if isTrain else 1
    parameter['batch_size'] = 10
    for epoch in range(Epoch):
        # 每轮对原始数据进行随机化
        if shuffle:
            random.shuffle(data_set)
        inputs,targets = [],[]
        max_len = 0
        for items in tqdm(data_set):
            if not isBert:
                input = itemgetter(*items[0])(parameter['word2ind'])
                input = input if type(input) == type(()) else (input,0)
            else:
                input = tokenizer.convert_tokens_to_ids(items[0])
            target = itemgetter(*items[1])(parameter['key2ind'])
            target = target if type(target) == type(()) else (target,0)
            if len(input) > max_len:
                max_len = len(input)
            inputs.append(list(input))
            targets.append(list(target))
            if len(inputs) >= parameter['batch_size']:
                inputs = [i+[0]*(max_len-len(i)) for i in inputs]
                targets = [i+[-1]*(max_len-len(i)) for i in targets]
                yield list2torch(inputs),list2torch(targets),None,False
                inputs,targets = [],[]
                max_len = 0
        inputs = [i+[0]*(max_len-len(i)) for i in inputs]
        targets = [i+[-1]*(max_len-len(i)) for i in targets]
        yield list2torch(inputs),list2torch(targets),epoch,False
        inputs,targets = [],[]
        max_len = 0
    yield None,None,None,True

# 这边是模型评估
def eval_model(model_name):
    # 加载相应训练好的模型
    model,parameter = load_model(model_name)
    # 准备好待统计的内容
    count_table = {}
    # 根据是否文件名中包含bert字样判断是否为bert模型，决定使用哪个数据迭代器
    if 'bert' not in model_name:
        test_yield = batch_yield(parameter,shuffle = False,isTrain = False)
    else:
        test_yield = batch_yield(parameter,shuffle = False,isTrain = False,isBert = True)
    while 1:
        # 数据迭代
        inputs,targets,_,keys = next(test_yield)
        if not keys:
            # 获取相应模型数据的预测结果
            pred = model(inputs.long().to(parameter['device']))
            # 因为crf模型和直接用softmax模型推理方面有一定的区别，因此根据crf模型或者softmax模型进行区分
            if 'crf' in model_name:
                # crf模型需要对内容进行解码，得到相应的结果
                predicted_index = np.array(model.crf.decode(pred))
                targets = targets.numpy()#.long().to(parameter['device'])
            else:
                # softmax模型直接使用softmax区最大值
                predicted_prob,predicted_index = torch.max(F.softmax(pred, 1), 1)
                predicted_index = predicted_index.reshape(inputs.shape)
                targets = targets.long().to(parameter['device'])
            # 此处注意，回忆一下精确度和召回率的定义；
            # 精确度是，大致可以描述为，判断正确的正例/预测中总共判断正例的数量
            # 召回率是，大致可以描述为，判断正确的正例/实际中总共正例的数量
            # 由此可以得到以下处理的方法：
            # 提前准备好tp，
            right = (targets == predicted_index)
            for i in range(1,parameter['output_size']):
                if i not in count_table:
                    count_table[i] = {
                    'pred':len(predicted_index[(predicted_index == i) & (targets != -1)]), # i标签下的，tp+fp，预测总正例
                    'real':len(targets[targets == i]),# i标签下的，tp+fn，实际总正例
                    'common':len(targets[right & (targets == i)])# i标签下的tp
                    }
                else:
                    count_table[i]['pred'] += len(predicted_index[predicted_index == i])
                    count_table[i]['real'] += len(targets[targets == i])
                    count_table[i]['common'] += len(targets[right & (targets == i)])
        else:
            break
    count_pandas = {}
    # 获取对应标签中文名，和相应统计值，从1开始，为了过滤标签O的统计
    name,count = list(parameter['key2ind'].keys())[1:],list(count_table.values())
    for ind,i in enumerate(name):
        # 'B-*','I-*','E-*','S-*'都可以用'-'分割，合并同样标签的内容
        i = i.split('-')[1]
        # 综合统计
        if i in count_pandas:
            count_pandas[i][0] += count[ind]['pred']
            count_pandas[i][1] += count[ind]['real']
            count_pandas[i][2] += count[ind]['common']
        else:
            count_pandas[i] = [0,0,0]
            count_pandas[i][0] = count[ind]['pred']
            count_pandas[i][1] = count[ind]['real']
            count_pandas[i][2] = count[ind]['common']
    # 计算总数
    count_pandas['all'] = [sum([count_pandas[i][0] for i in count_pandas]),
                      sum([count_pandas[i][1] for i in count_pandas]),
                      sum([count_pandas[i][2] for i in count_pandas])]
    name = count_pandas.keys()
    count_pandas = pd.DataFrame(count_pandas.values())
    count_pandas.columns = ['pred','real','common']
    # 基于tp、tp+fn、tp+fp计算相应的p、r以及计算f1；回忆一下f1计算公式：2pr/(p+r)，fn：(1+b^2)/(b^2)*(pr)/(p+r)，f1好处？
    count_pandas['p'] = count_pandas['common']/count_pandas['pred']
    count_pandas['r'] = count_pandas['common']/count_pandas['real']
    count_pandas['f1'] = 2*count_pandas['p']*count_pandas['r']/(count_pandas['p']+count_pandas['r'])
    count_pandas.index = list(name)
    return count_pandas

In [3]:
eval_model('bilstm')

100%|██████████| 1343/1343 [00:01<00:00, 866.57it/s]


Unnamed: 0,pred,real,common,p,r,f1
name,1522,1486,1136,0.746386,0.764468,0.755319
company,1669,1693,1274,0.763331,0.75251,0.757882
game,1727,1657,1418,0.821077,0.855763,0.838061
organization,1308,1454,959,0.73318,0.65956,0.694424
movie,972,1043,811,0.834362,0.777565,0.804963
address,1684,1702,1067,0.63361,0.62691,0.630242
position,1081,1201,823,0.761332,0.685262,0.721297
government,1407,1315,1071,0.761194,0.814449,0.786921
scene,793,931,499,0.629256,0.535983,0.578886
book,900,1031,713,0.792222,0.691562,0.738477


In [4]:
eval_model('bilstm_crf')

  score = torch.where(mask[i].unsqueeze(1), next_score, score)
100%|██████████| 1343/1343 [00:02<00:00, 634.75it/s]


Unnamed: 0,pred,real,common,p,r,f1
name,1513,1486,1187,0.784534,0.798789,0.791597
company,1624,1693,1271,0.782635,0.750738,0.766355
game,1669,1657,1407,0.84302,0.849125,0.846061
organization,1241,1454,950,0.765512,0.65337,0.705009
movie,959,1043,807,0.841502,0.77373,0.806194
address,1720,1702,1144,0.665116,0.67215,0.668615
position,1124,1201,858,0.763345,0.714405,0.738065
government,1372,1315,1075,0.783528,0.81749,0.800149
scene,718,931,503,0.700557,0.540279,0.610067
book,977,1031,741,0.758444,0.71872,0.738048


In [5]:
eval_model('bert')

100%|██████████| 1343/1343 [00:03<00:00, 375.15it/s]


Unnamed: 0,pred,real,common,p,r,f1
name,1544,1486,1337,0.865933,0.899731,0.882508
company,1841,1693,1452,0.788702,0.857649,0.821732
game,1787,1657,1491,0.834359,0.899819,0.865854
organization,1402,1454,1096,0.78174,0.753783,0.767507
movie,1045,1043,931,0.890909,0.892617,0.891762
address,1682,1702,1304,0.775268,0.766157,0.770686
position,1202,1201,982,0.816972,0.817652,0.817312
government,1482,1315,1208,0.815115,0.918631,0.863783
scene,1015,931,777,0.765517,0.834586,0.798561
book,966,1031,878,0.908903,0.8516,0.879319


In [6]:
eval_model('bert_crf')

100%|██████████| 1343/1343 [00:04<00:00, 327.61it/s]


Unnamed: 0,pred,real,common,p,r,f1
name,1591,1486,1358,0.853551,0.913863,0.882678
company,1772,1693,1442,0.81377,0.851742,0.832323
game,1787,1657,1503,0.841074,0.907061,0.872822
organization,1410,1454,1104,0.782979,0.759285,0.77095
movie,998,1043,907,0.908818,0.869607,0.88878
address,1783,1702,1358,0.761638,0.797885,0.77934
position,1221,1201,1005,0.823096,0.836803,0.829893
government,1524,1315,1213,0.795932,0.922433,0.854526
scene,846,931,706,0.834515,0.758324,0.794598
book,1008,1031,897,0.889881,0.870029,0.879843


In [None]:
def keyword_predict(input):
    input = list(input)
    input_id = tokenizer.convert_tokens_to_ids(input)
    predict = model.crf.decode(model(list2torch([input_id]).long().to(parameter['device'])))[0]
    predict = itemgetter(*predict)(parameter['ind2key'])
    print(predict)
    keys_list = []
    for ind,i in enumerate(predict):
        if i == 'O':
            continue
        if i[0] == 'S':
            if not(len(keys_list) == 0 or keys_list[-1][-1]):
                del keys_list[-1]
            keys_list.append([input[ind],[i],[ind],True])
            continue
        if i[0] == 'B':
            if not(len(keys_list) == 0 or keys_list[-1][-1]):
                del keys_list[-1]
            keys_list.append([input[ind],[i],[ind],False])
            continue
        if i[0] == 'I':
            if len(keys_list) > 0 and not keys_list[-1][-1] and \
            keys_list[-1][1][0].split('-')[1] == i.split('-')[1]:
                keys_list[-1][0] += input[ind]
                keys_list[-1][1] += [i]
                keys_list[-1][2] += [ind]
            else:
                if len(keys_list) > 0:
                    del keys_list[-1]
            continue
        if i[0] == 'E':
            if len(keys_list) > 0 and not keys_list[-1][-1] and \
            keys_list[-1][1][0].split('-')[1] == i.split('-')[1]:
                keys_list[-1][0] += input[ind]
                keys_list[-1][1] += [i]
                keys_list[-1][2] += [ind]
                keys_list[-1][3] = True
            else:
                if len(keys_list) > 0:
                    del keys_list[-1]
            continue
#     print(keys_list)
#     keys_list = [i[0] for i in keys_list]
    return keys_list

model,parameter = load_model('bert_crf')
tokenizer = tokenizer_class.from_pretrained("prev_trained_model")
model = model.to(parameter['device'])

In [None]:
tokenizer.convert_tokens_to_ids(['你','好'])

[872, 1962]

In [None]:
test_text = '浙商银行企业信贷部叶老桂博士则从另一个角度对五道门槛进行了解读。叶老桂认为，对目前国内商业银行而言'
keyword_predict(test_text)

('B-company', 'I-company', 'I-company', 'I-company', 'I-company', 'I-company', 'I-company', 'I-company', 'E-company', 'B-name', 'I-name', 'E-name', 'B-position', 'E-position', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-name', 'I-name', 'E-name', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O')


[['浙商银行企业信贷部',
  ['B-company',
   'I-company',
   'I-company',
   'I-company',
   'I-company',
   'I-company',
   'I-company',
   'I-company',
   'E-company'],
  [0, 1, 2, 3, 4, 5, 6, 7, 8],
  True],
 ['叶老桂', ['B-name', 'I-name', 'E-name'], [9, 10, 11], True],
 ['博士', ['B-position', 'E-position'], [12, 13], True],
 ['叶老桂', ['B-name', 'I-name', 'E-name'], [32, 33, 34], True]]