In [1]:
from collections import defaultdict
from operator import itemgetter
import numpy as np
import torch
import torch.nn.functional as F # pytorch 激活函数的类
import pickle as pk
import pandas as pd
from tqdm import tqdm
from model_zoo import *

def load_model(model_name):
    parameter = pk.load(open('parameter.pkl','rb'))
#     parameter['device'] = torch.device('cpu')
    if 'bert' in model_name:
        if 'speed' in model_name:
            model = eval(model_name.split('-')[0]+"(config,parameter).to(parameter['device'])")
        else:
            model = eval(model_name+"(config,parameter).to(parameter['device'])")
    else:
        model = eval(model_name+"(parameter).to(parameter['device'])")
    model.load_state_dict(torch.load(model_name+'.h5'))
    model.eval() 
    return model,parameter

def list2torch(ins):
    return torch.from_numpy(np.array(ins))

def batch_yield(parameter,shuffle = True,isTrain = True,isBert = False):
    data_set = parameter['data_set']['train'] if isTrain else parameter['data_set']['dev']
    Epoch = parameter['epoch'] if isTrain else 1
    for epoch in range(Epoch):
        # 每轮对原始数据进行随机化
        if shuffle:
            random.shuffle(data_set)
        inputs,targets = [],[]
        max_len = 0
        for items in tqdm(data_set):
            if not isBert:
                input = itemgetter(*items[0])(parameter['word2ind'])
                input = input if type(input) == type(()) else (input,0)
            else:
                input = tokenizer.convert_tokens_to_ids(items[0])
            target = itemgetter(*items[1])(parameter['key2ind'])
            target = target if type(target) == type(()) else (target,0)
            if len(input) > max_len:
                max_len = len(input)
            inputs.append(list(input))
            targets.append(list(target))
            if len(inputs) >= parameter['batch_size']:
                inputs = [i+[0]*(max_len-len(i)) for i in inputs]
                targets = [i+[-1]*(max_len-len(i)) for i in targets]
                yield list2torch(inputs),list2torch(targets),None,False
                inputs,targets = [],[]
                max_len = 0
        inputs = [i+[0]*(max_len-len(i)) for i in inputs]
        targets = [i+[-1]*(max_len-len(i)) for i in targets]
        yield list2torch(inputs),list2torch(targets),epoch,False
        inputs,targets = [],[]
        max_len = 0
    yield None,None,None,True


def eval_model(model_name):
    model,parameter = load_model(model_name)
    count_table = {}
    if 'bert' not in model_name:
        test_yield = batch_yield(parameter,shuffle = False,isTrain = False)
    else:
        test_yield = batch_yield(parameter,shuffle = False,isTrain = False,isBert = True)
    while 1:
        inputs,targets,_,keys = next(test_yield)
        if not keys:
            pred = model(inputs.long().to(parameter['device']))
            if 'crf' in model_name:
                pred = model(inputs.long().to(parameter['device']))
                predicted_index = np.array(model.crf.decode(pred))
                targets = targets.numpy()#.long().to(parameter['device'])
            else:
                predicted_prob,predicted_index = torch.max(F.softmax(pred, 1), 1)
                predicted_index = predicted_index.reshape(inputs.shape)
                targets = targets.long().to(parameter['device'])
            right = (targets == predicted_index)
            for i in range(1,parameter['output_size']):
                if i not in count_table:
                    count_table[i] = {
                    'pred':len(predicted_index[(predicted_index == i) & (targets != -1)]),
                    'real':len(targets[targets == i]),
                    'common':len(targets[right & (targets == i)])
                    }
                else:
                    count_table[i]['pred'] += len(predicted_index[predicted_index == i])
                    count_table[i]['real'] += len(targets[targets == i])
                    count_table[i]['common'] += len(targets[right & (targets == i)])
        else:
            break
    count_pandas = {}
    name,count = list(parameter['key2ind'].keys())[1:],list(count_table.values())
    for ind,i in enumerate(name):
        i = i.split('-')[1]
        if i in count_pandas:
            count_pandas[i][0] += count[ind]['pred']
            count_pandas[i][1] += count[ind]['real']
            count_pandas[i][2] += count[ind]['common']
        else:
            count_pandas[i] = [0,0,0]
            count_pandas[i][0] = count[ind]['pred']
            count_pandas[i][1] = count[ind]['real']
            count_pandas[i][2] = count[ind]['common']
    count_pandas['all'] = [sum([count_pandas[i][0] for i in count_pandas]),
                      sum([count_pandas[i][1] for i in count_pandas]),
                      sum([count_pandas[i][2] for i in count_pandas])]
    name = count_pandas.keys()
    count_pandas = pd.DataFrame(count_pandas.values())
    count_pandas.columns = ['pred','real','common']
    count_pandas['p'] = count_pandas['common']/count_pandas['pred']
    count_pandas['r'] = count_pandas['common']/count_pandas['real']
    count_pandas['f1'] = 2*count_pandas['p']*count_pandas['r']/(count_pandas['p']+count_pandas['r'])
    count_pandas.index = list(name)
    return count_pandas

In [2]:
eval_model('bilstm')

100%|████████████████████████████████████████████████████████████████████████████| 1343/1343 [00:00<00:00, 1596.53it/s]


Unnamed: 0,pred,real,common,p,r,f1
name,1468,1486,1064,0.724796,0.716016,0.720379
company,1675,1693,1267,0.756418,0.748376,0.752375
game,1621,1657,1368,0.843924,0.825588,0.834655
organization,1351,1454,979,0.724648,0.673315,0.698039
movie,1090,1043,835,0.766055,0.800575,0.782935
address,1665,1702,1040,0.624625,0.611046,0.617761
position,1144,1201,824,0.72028,0.686095,0.702772
government,1334,1315,1049,0.786357,0.797719,0.791997
scene,878,931,564,0.642369,0.6058,0.623549
book,926,1031,714,0.771058,0.692532,0.729688


In [3]:
eval_model('bilstm_crf')

100%|█████████████████████████████████████████████████████████████████████████████| 1343/1343 [00:04<00:00, 309.59it/s]


Unnamed: 0,pred,real,common,p,r,f1
name,1438,1486,1153,0.801808,0.775908,0.788646
company,1664,1693,1289,0.774639,0.76137,0.767948
game,1724,1657,1392,0.807425,0.840072,0.823425
organization,1248,1454,936,0.75,0.643741,0.69282
movie,947,1043,767,0.809926,0.735379,0.770854
address,1679,1702,1120,0.667064,0.658049,0.662526
position,1091,1201,861,0.789184,0.716903,0.751309
government,1418,1315,1062,0.748942,0.807605,0.777168
scene,750,931,537,0.716,0.576799,0.638905
book,954,1031,736,0.771488,0.71387,0.741562


In [4]:
eval_model('bert')

100%|█████████████████████████████████████████████████████████████████████████████| 1343/1343 [00:03<00:00, 388.70it/s]


Unnamed: 0,pred,real,common,p,r,f1
name,1515,1486,1316,0.868647,0.885599,0.877041
company,1795,1693,1412,0.78663,0.834022,0.809633
game,1818,1657,1531,0.842134,0.923959,0.881151
organization,1436,1454,1083,0.754178,0.744842,0.749481
movie,1000,1043,920,0.92,0.882071,0.900636
address,1740,1702,1269,0.72931,0.745593,0.737362
position,1236,1201,984,0.796117,0.819317,0.80755
government,1461,1315,1171,0.801506,0.890494,0.84366
scene,902,931,702,0.778271,0.754028,0.765957
book,988,1031,893,0.903846,0.866149,0.884596


In [5]:
eval_model('bert_crf')

100%|█████████████████████████████████████████████████████████████████████████████| 1343/1343 [00:09<00:00, 139.03it/s]


Unnamed: 0,pred,real,common,p,r,f1
name,1528,1486,1323,0.865838,0.89031,0.877903
company,1845,1693,1459,0.790786,0.861784,0.82476
game,1773,1657,1525,0.860124,0.920338,0.889213
organization,1414,1454,1106,0.782178,0.76066,0.771269
movie,1058,1043,955,0.902647,0.915628,0.909091
address,1742,1702,1304,0.748565,0.766157,0.757259
position,1243,1201,999,0.803701,0.831807,0.817512
government,1454,1315,1186,0.815681,0.901901,0.856627
scene,911,931,727,0.798024,0.780881,0.789359
book,979,1031,888,0.907048,0.8613,0.883582
