In [1]:
import re
import sys, json
import torch
import os
import numpy as np
import opennre
from opennre import encoder, model, framework
import argparse
import pandas as pd
import itertools

In [2]:
# 导入关系抽取模型

parser = argparse.ArgumentParser()
parser.add_argument('--mask_entity', action='store_true', help='Mask entity mentions')
args = parser.parse_known_args()[0]

root_path = '.'
sys.path.append(root_path)
if not os.path.exists('ckpt'):
    os.mkdir('ckpt')
ckpt = 'ckpt/people_chinese_bert_softmax.pth.tar'

rel2id = json.load(open(os.path.join(root_path, 'benchmark/people-relation/people-relation_rel2id.json'), encoding='utf-8'))

sentence_encoder = opennre.encoder.BERTEncoder(
    max_length=80, 
    pretrain_path=os.path.join(root_path, 'pretrain/chinese_wwm_pytorch'),
    mask_entity=args.mask_entity
)

model = opennre.model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id)
model.load_state_dict(torch.load(ckpt)['state_dict'])

<All keys matched successfully>

In [4]:
with open("Harry_Potter.txt", "r", encoding="utf-8") as f:
    total_lines = [line.strip() for line in f.readlines()]

total_lines = [line for line in total_lines if line != '']

In [5]:
# 分句
cutLineFlag = ["？", "！", "。", "!"]
sentenceList = []
for words in total_lines:
    oneSentence = ""
    for word in words:
        if word not in cutLineFlag:
            oneSentence = oneSentence + word
        else:
            oneSentence = oneSentence + word
            if oneSentence.__len__() > 4:
                sentenceList.append(oneSentence.strip())
            oneSentence = ""

In [6]:
sentenceList[1997]

'“哦，顺便说一句，你鼻子上有块脏东西，你知道吗?”她出去时，罗恩瞪了她一眼。'

In [7]:
# 获取所有的实体

csv = pd.read_csv("data/harryid.csv", header=0)
origin_id = list(csv.iloc[:, 1])

total_id = []
for id in origin_id:
    total_id.append(id)
    for sub_id in id.split('.'):
        total_id.append(sub_id)

total_id = list(set(total_id))
total_id.remove('')
print("共有实体：", len(total_id))

共有实体： 1079


In [8]:
new_data = []
for sentence in sentenceList:
    id_loc = []
    id_list = []
    for id in total_id:
        if id in sentence:
            loc = [(item.start(), item.end()-1) for item in re.finditer(id, sentence)]
            id_list.append(id)
            id_loc.append(loc[0])
            # print(id_loc)
        
    if len(id_loc) >= 2:
        permute = list(itertools.combinations(range(len(id_list)), 2))
        # print(len(permute))
        for idx in permute:
            if id_list[idx[0]] not in id_list[idx[1]] and id_list[idx[1]] not in id_list[idx[0]]:
                new_data.append({'text':sentence, 'h': {'pos': id_loc[idx[0]]}, 't': {'pos': id_loc[idx[1]]}})

In [9]:
print("共构造数据集:", len(new_data))

共构造数据集: 35962


In [10]:
new_data[0]

{'text': '在他十一岁生日那天，一切都发生了变化，信使猫头鹰带来了一封神秘的信：邀请哈利去一个他——以及所有读到哈利故事的人——会觉得永远难忘的、不可思议的地方——霍格沃茨魔法学校。',
 'h': {'pos': (76, 79)},
 't': {'pos': (36, 37)}}

In [11]:
from tqdm import tqdm
relation_list = []
for data in tqdm(new_data):
    text = data['text']
    t_pos = data['t']['pos']
    h_pos = data['h']['pos']
    rela = model.infer(data)
    relation_list.append([text[t_pos[0]:t_pos[1]+1], text[h_pos[0]:h_pos[1]+1], rela])

100%|██████████| 35962/35962 [51:38<00:00, 11.61it/s] 


In [45]:
relation_df = pd.DataFrame(relation_list)
relation_df.to_csv("data/relation_raw.csv", header=False, index=False)

In [89]:
# 获取已有的关系
existed_relation = pd.read_csv("data/harryRel.csv", header=0)
sub_obj_list = list(zip(existed_relation.iloc[:, 1], existed_relation.iloc[:, 2]))

from collections import Counter
id2count = dict(Counter(list(existed_relation.iloc[:, 1]) + list(existed_relation.iloc[:, 2])))

In [90]:
obj_list = {} # [{(sub_obj, obj): {'rela':(rela1, rela2, ...), 'prob':(..), 'sub':(...), 'obj':(...)}}]

for item in relation_list:
    cand_sub = []
    cand_obj = []
    if item[2][0] != 'unknown' and item[2][1] > 0.95:
        if (item[0], item[1]) not in obj_list.keys():
            obj_list[(item[0], item[1])] = {'relation': [item[2][0]], 'prob': [item[2][1]]}
        else:
            if item[2][0] not in obj_list[(item[0], item[1])]['relation']: # 合并重复的实体对，以及其所有的候选关系
                obj_list[(item[0], item[1])]['relation'].append(item[2][0])
                obj_list[(item[0], item[1])]['prob'].append(item[2][1])
        
        for id in origin_id:
            if item[0] in id:
                cand_sub.append(id)
            if item[1] in id:
                cand_obj.append(id)
            obj_list[(item[0], item[1])]['sub'] = cand_sub # 所有的候选实体
            obj_list[(item[0], item[1])]['obj'] = cand_obj

In [91]:
new_relation_list = []
for key, value in obj_list.items():
    if len(value['sub']) > 1: # 如果候选实体中只有一个实体，则直接替代
        sub_count = [id2count[i] for i in value['sub']]
        subobj = value['sub'][np.argmax(sub_count)]
    else:
        subobj = value['sub'][0]
    
    if len(value['obj']) > 1: # 若候选实体中有多个实体，则选出现频率最高的那个替代
        obj_count = [id2count[i] for i in value['obj']]
        obj = value['obj'][np.argmax(obj_count)]
    else:
        obj = value['obj'][0]
    
    if (subobj, obj) not in sub_obj_list and subobj != obj:
        rela = value['relation'][np.argmax(value['prob'])]
        new_relation_list.append([subobj, obj, rela, value['sub'], value['obj']])

In [18]:
relation_df = pd.DataFrame(new_relation_list)
relation_df.to_csv("data/relation_raw.csv", header=False, index=False)

In [19]:
# 已经清洗过的数据
# 推理一些空白关系
clean_relation_list = pd.read_csv("data/relation_clean.csv", header=0)
sub_list = list(clean_relation_list.iloc[:, 0])
obj_list = list(clean_relation_list.iloc[:, 1])
rela_list = list(clean_relation_list.iloc[:, 2])
clean_relation_list = list(zip(clean_relation_list.iloc[:, 0], clean_relation_list.iloc[:, 1], clean_relation_list.iloc[:, 2]))

In [21]:
infer_relation = []
for i in range(len(rela_list)):
    if rela_list[i] == '师生':
        infer_relation.append([obj_list[i], sub_list[i], '学生'])

for rela in infer_relation:
    if (rela[0], rela[1]) not in list(zip(sub_list, obj_list)):
        clean_relation_list.append(rela)

In [None]:
relation_df = pd.DataFrame(clean_relation_list)
relation_df.to_csv("data/relation_infered.csv", header=False, index=False)