In [11]:
from transformers import BertTokenizer
import torch
from torch import nn
from model import DRE_model, new_DREmodel
from torch.utils.data import Dataset
import json
import numpy as np

In [12]:
#加载模型
#model=torch.load('models/baseline/12.pth')
model=torch.load('models/method/21.pth').cuda()
testpath="./data/test.json"
devpath="./data/dev.json"

In [13]:
def GetOutput(path,model):
    tokenizer=BertTokenizer.from_pretrained('bert-base-uncased') 
    InputIds=[]
    AttentionMasks=[]
    max_seq_len=256
    with open(path, encoding='utf8') as fp:
        data = json.load(fp)
        for item in data:
            texts=[]
            text = ""      
            for sentence in item[0]:
                text += sentence
                text += " "
                text = text.replace('"', "'")
                texts.append(text)
            for d in item[1]:
                for text in texts:
                    newtext = "[CLS] " + text + " [SEP] "
                    if d['x'] == "Speaker 1":                       #x为subject，y为object
                        newtext = text.replace("Speaker 1", "X")    #[S1]:X    [S2]:Y
                        newtext += " X"
                    elif d['x'] == "Speaker 2":
                        newtext = text.replace("Speaker 2", "Y")
                        newtext += " Y"
                    else:
                        newtext = newtext + " " + d['x']
                    newtext += " [SEP]"
                    if d['y'] == "Speaker 1":
                        text = text.replace("Speaker 1", "X")
                        newtext += " X"
                    elif d['y'] == "Speaker 2":
                        text = text.replace("Speaker 2", "Y")
                        newtext += " Y"
                    else:
                        newtext = newtext + " " + d['y']
#                     newtext += d['x']
#                     newtext += " [SEP]"
#                     newtext += d['y']                    
#                     newtext += " [SEP]"
                    text_dict = tokenizer.encode_plus(newtext, return_attention_mask=True)  # 分词
                    l=len(text_dict['input_ids'])
                    if(l>max_seq_len):
                        InputIds.append(torch.tensor(text_dict['input_ids'])[-max_seq_len-1:-1])
                        AttentionMasks.append(torch.tensor(text_dict['attention_mask'])[-max_seq_len-1:-1])
                    else:
                        pad=nn.ZeroPad2d(padding=(0,max_seq_len-1,0,0))
                        InputIds.append(pad(torch.tensor(text_dict['input_ids']).unsqueeze(0)).squeeze(0))
                        AttentionMasks.append(pad(torch.tensor(text_dict['attention_mask']).unsqueeze(0)).squeeze(0))
    Outputs=[]
    for i in range(len(InputIds)):
        if(i%1000==0):
            print(i)
        input_ids=np.array(InputIds[i],dtype=float)
        input_ids=input_ids[np.newaxis, :]
        input_ids=torch.LongTensor(input_ids).cuda()
        attention_masks=np.array(AttentionMasks[i],dtype=float)
        attention_masks=attention_masks[np.newaxis, :]
        attention_masks=torch.LongTensor(attention_masks).cuda()
        with torch.no_grad():
            output=model( input_ids,attention_masks )
        Outputs.append(output.cpu().squeeze(0).detach().numpy())
    #torch.cuda.empty_cache()
    return Outputs
   
    

In [14]:
def ReadData(path):
    with open(path, "r", encoding='utf8') as f:
        data = json.load(f)
    for i in range(len(data)):
        for j in range(len(data[i][1])):
            for k in range(len(data[i][1][j]["rid"])):
                data[i][1][j]["rid"][k] -= 1                    #第37是unawnserable
    
    return data

def Vec2Label(PredVec, Threshold=0.5):
    '''
    用两个阈值生成标签
    for i in range(len(result)):
        r = []
        maxl, maxj = -1, -1
        for j in range(len(result[i])):
            if result[i][j] > T1:
                r += [j]
            if result[i][j] > maxl:
                maxl = result[i][j]
                maxj = j
        if len(r) == 0:
            if maxl <= T2:
                r = [36]
            else:
                r += [maxj]
        result[i] = r
    '''
    ##用1个阈值生成标签
    for i in range(len(PredVec)):
        r = []
        for j in range(len(PredVec[i])):
            if PredVec[i][j] > Threshold:
                r += [j]
        if len(r) == 0:
                r = [36]
        PredVec[i] = r
    return PredVec

#PredVec是对应data预测的出的标签向量矩阵。每个元素是一个37维的向量：
# 共 m(一段对话的轮数，data[i][0]的长度)*s(一段对话的实体对个数)*n(数据集对话个数)
def F1_c(PredVec,Data):
    index=0   #索引
    PredLabel=Vec2Label(PredVec, Threshold= 0.5)
    #print(len(PredLabel))
    precisions = []  #每一对实体的P_c
    recalls = []     #每一对实体的R_c
    for i in range(len(Data)):   #遍历数据集
        for j in range(len(Data[i][1])):   #遍历一段对话下的标注数据
            correct_sys, all_sys = 0, 0
            correct_gt = 0          
            x = Data[i][1][j]["x"].lower().strip()   #头实体
            y = Data[i][1][j]["y"].lower().strip()   #尾实体
            t = {}                                   #建立从关系id到触发词的映射
            for k in range(len(Data[i][1][j]["rid"])):
                if(Data[i][1][j]["rid"][k] != 36):
                    t[Data[i][1][j]["rid"][k]] = Data[i][1][j]["t"][k].lower().strip()

            l = set(Data[i][1][j]["rid"]) - set([36])   #这段标注的关系种类的集合

            ex, ey = False, False
            et = {}                  #标注标签与预测标签的差集
            for r in range(36):
                et[r] = r not in l

            for k in range(len(Data[i][0])):  #前k轮对话预测的关系种类集合
                o = set(PredLabel[index]) - set([36])        #O
                e = set()                                    #E
                if x in Data[i][0][k].lower():
                    ex = True
                if y in Data[i][0][k].lower():
                    ey = True
                if k == len(Data[i][0])-1:
                    ex = ey = True
                    for r in range(36):
                        et[r] = True
                for r in range(36):
                    if(r in t):
                        if(t[r] != "" and t[r] in Data[i][0][k].lower()):   #若存在关系r的标记触发词
                            et[r] = True
                    if(ex and ey and et[r]):
                        e.add(r)
                correct_sys += len(o & l & e)
                all_sys += len(o & e)
                correct_gt += len(l & e)
                index += 1
            
            precisions += [correct_sys/all_sys if all_sys != 0 else 1] 
            recalls += [correct_sys/correct_gt if correct_gt != 0 else 0]
    print(index)
    precision = sum(precisions) / len(precisions)
    recall = sum(recalls) / len(recalls)
    f_1 = 2*precision*recall/(precision+recall) if precision+recall != 0 else 0

    return precision, recall, f_1

In [None]:
DevPred=GetOutput(devpath,model)
DevData=ReadData(devpath)
P, R, F1 = F1_c(DevPred,DevData)
print("Precision_c = ", P * 100, "%")
print("Recall_c = ", R * 100, "%")
print("F1_c = ", F1 * 100, "%")

In [None]:
TestPred=GetOutput(testpath,model)
TestData=ReadData(testpath)
P, R, F1 = F1_c(TestPred,TestData)
print("Precision_c = ", P * 100, "%")
print("Recall_c = ", R * 100, "%")
print("F1_c = ", F1 * 100, "%")