# 完成实体的关系分类模型
[文章地址](http://blackedu.vip/2021/04/关系抽取模型初步/)

关系分类采用Bert，训练数据使用[百度公开比赛数据](https://aistudio.baidu.com/aistudio/competition/detail/46).

In [1]:
import pandas as pd
import numpy as np

In [2]:
train_json = 'data/DuIE_2_0/train.json'
dev_json = 'data/DuIE_2_0/dev.json'
test_json = 'data/DuIE_2_0/test.json'
schema_json = 'data/DuIE_2_0/schema.json'

In [3]:
import json
from tqdm import tqdm

def read_json_file(json_name):
    data = []
    with open(json_name) as reader:
        lines = reader.read().split('\n')
        for line in tqdm(lines[:-1], desc='read ...'): # 去除最后的空格
            data.append(json.loads(line))
    return data

In [4]:
# {'object_type': {'@value': '学校'}, 'predicate': '毕业院校', 'subject_type': '人物'},
#  {'object_type': {'@value': '人物'}, 'predicate': '嘉宾', 'subject_type': '电视综艺'},
#  {'object_type': {'inWork': '影视作品', '@value': '人物'},
#   'predicate': '配音', 'subject_type': '娱乐人物'}

schema = read_json_file(schema_json)

read ...: 100%|██████████| 48/48 [00:00<00:00, 207126.12it/s]


In [5]:
relations = [item['predicate'] for item in schema]

In [6]:
relations[:5]

['毕业院校', '嘉宾', '配音', '主题曲', '代言人']

In [7]:
rel2id = {val:idx for idx, val in enumerate(relations, 1)}
id2rel = {v:k for k, v in rel2id.items()}

In [8]:
# bert config
import os

model_path = 'bert_models/chinese_L-12_H-768_A-12/'

bert_config = os.path.join(model_path, 'bert_config.json')
check_point = os.path.join(model_path, 'bert_model.ckpt')
vocab = os.path.join(model_path, 'vocab.txt')
model_type = 'bert'

export_model_name = 'weights/rel-bert-base-best.weights'

In [9]:
from bert4keras.tokenizers import Tokenizer, load_vocab
from bert4keras.models import build_transformer_model

token_dict = load_vocab(vocab)
tokenizer = Tokenizer(token_dict=token_dict, do_lower_case=True)

Using TensorFlow backend.


In [10]:
train = read_json_file(train_json)
dev = read_json_file(dev_json)

read ...: 100%|██████████| 171293/171293 [00:02<00:00, 84004.11it/s]
read ...: 100%|██████████| 20674/20674 [00:00<00:00, 142845.44it/s]


In [11]:
train[:2]

[{'text': '《邪少兵王》是冰火未央写的网络小说连载于旗峰天下',
  'spo_list': [{'predicate': '作者',
    'object_type': {'@value': '人物'},
    'subject_type': '图书作品',
    'object': {'@value': '冰火未央'},
    'subject': '邪少兵王'}]},
 {'text': 'GV-971由中国海洋大学、中国科学院上海药物研究所（下称“上海药物所”）和上海绿谷制药有限公司（下称“绿谷制药”）联合研发，不同于传统靶向抗体药物，GV-971是从海藻中提取的海洋寡糖类分子',
  'spo_list': [{'predicate': '简称',
    'object_type': {'@value': 'Text'},
    'subject_type': '机构',
    'object': {'@value': '上海药物所'},
    'subject': '中国科学院上海药物研究所'}]}]

In [12]:
tokenizer.token_to_id('[SEP]')

102

In [13]:
from keras.utils import Sequence, to_categorical
from bert4keras.snippets import sequence_padding, to_array, DataGenerator

MAXLEN = 200

class DataLoader(DataGenerator):
    def __init__(self, data, batch_size=32):
        self.data = data
        self.batch_size = batch_size
    
    def read_data(self):
        for item in self.data:
            text = item['text']
            for spo in item['spo_list']:
                yield text, spo['predicate'], spo['subject'], spo['object']['@value']
    
    def __iter__(self, random=False):
        batch_token, batch_segment, batch_label = [], [], []
        for text, p, s, o in self.read_data():
            token_id, segment_id = tokenizer.encode(first_text=text, second_text=s, maxlen=MAXLEN)
            obj_token_id = tokenizer.tokens_to_ids(o)
            token_id = token_id + obj_token_id + [102]
            segment_id = segment_id + [1]*(len(obj_token_id)+1)
            
            batch_token.append(token_id)
            batch_segment.append(segment_id)
            batch_label.append([rel2id.get(p, 0)])
            
            if len(batch_label) == self.batch_size:
                batch_token = sequence_padding(batch_token)
                batch_segment = sequence_padding(batch_segment)
                yield [batch_token, batch_segment], to_array(batch_label)
                batch_token, batch_segment, batch_label = [], [], []

In [14]:
for i, batch in enumerate(DataLoader(dev)):
    (batch_token, batch_segment), batch_label = batch
    print(batch_token.shape)
    print(batch_segment.shape)
    print(batch_label.shape)
    
    if i == 5:
        break

(32, 210)
(32, 210)
(32, 1)
(32, 191)
(32, 191)
(32, 1)
(32, 172)
(32, 172)
(32, 1)
(32, 189)
(32, 189)
(32, 1)
(32, 175)
(32, 175)
(32, 1)
(32, 206)
(32, 206)
(32, 1)


In [15]:
bert = build_transformer_model(bert_config, checkpoint_path=check_point, model=model_type)

In [16]:
from keras.layers import Lambda, Dense
from keras.models import Model

# output_layer = 'Transformer-11-FeedForward-Norm'
# output = bert.get_layer(output_layer)
output = Lambda(lambda x: x[:, 0], name='CLS')(bert.output)
output = Dense(len(rel2id)+1, activation='softmax')(output)
model = Model(bert.input, output)

In [17]:
model.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
Input-Token (InputLayer)        (None, None)         0                                            
__________________________________________________________________________________________________
Input-Segment (InputLayer)      (None, None)         0                                            
__________________________________________________________________________________________________
Embedding-Token (Embedding)     (None, None, 768)    16226304    Input-Token[0][0]                
__________________________________________________________________________________________________
Embedding-Segment (Embedding)   (None, None, 768)    1536        Input-Segment[0][0]              
____________________________________________________________________________________________

In [18]:
from keras.optimizers import Adam
from keras.losses import SparseCategoricalCrossentropy
from keras.metrics import accuracy, Recall

model.compile(loss=SparseCategoricalCrossentropy(), optimizer=Adam(1e-5), metrics=['accuracy'])

In [19]:
from keras.callbacks import Callback


def evalute_relation(text, subject, obj):
    token_id, segment_id = tokenizer.encode(first_text=text, second_text=subject)
    obj_id = tokenizer.tokens_to_ids(obj)
    token_id = to_array([token_id + obj_id + [102]])
    segment_id = to_array([segment_id + [1] * (len(obj_id) + 1)])
    y_pred = model.predict([token_id, segment_id])
    return id2rel.get(y_pred.argmax(), '未知')


class EvalCallback(Callback):
    def __init__(self):
        self.loss_lower = 1e5
        
    def on_epoch_end(self, epoch, logs=None):
        if logs['val_loss'] < self.loss_lower:
            self.loss_lower = logs['val_loss']
            
            model.save_weights(export_model_name)
        
        self.just_show()
    
    @staticmethod
    def just_show():
        data = [{'text': '根据启信宝的数据显示，泡泡玛特主体公司北京泡泡玛特文化创意 \
                有限公司，法定代表人为王宁，该公司曾于2017年2月登陆新三板，在2019年4月终止挂牌',
               'subject': '北京泡泡玛特文化创意',
               'obj': '王宁'},
                {'text': "《明早起飞》是由明太鱼作词，满江作曲，戴娆演唱的一首歌曲",
                'obj': '满江',
                'subject': '明早起飞'},
                {'text': "《明早起飞》是由明太鱼作词，满江作曲，戴娆演唱的一首歌曲",
                'subject': '明早起飞',
                'obj': '明太鱼'},
                {'text': "【#真正男子汉#新兵欧豪报到】男子汉，我一直在渴望 我不怕什么挑战，就要证明给你们看 穿上这身军装，我就是军人",
                'subject': '真正男子汉',
                'obj': '欧豪'}
               ]
        for each in data:
            print(evalute_relation(**each))


In [None]:
# for batch in DataLoader(dev, batch_size=16):
#     (batch_token, batch_segment), batch_label = batch
#     model.fit(x=[batch_token, batch_segment], y=batch_label, epochs=1)
train_iter = DataLoader(train, batch_size=16)
dev_iter = DataLoader(dev, batch_size=16)

if os.path.exists(export_model_name):
    model.load_weights(export_model_name)
    print('==== load model weight!!! =====')
    
model.fit(train_iter.forfit(), 
          validation_data=dev_iter.forfit(), 
          epochs=30, 
          steps_per_epoch=1000, 
          validation_steps=1000, callbacks=[EvalCallback()])



Epoch 1/30
董事长
作曲
作词
嘉宾
Epoch 2/30
董事长
作曲
作词
嘉宾
Epoch 3/30
董事长
作曲
作词
嘉宾
Epoch 4/30
董事长
作曲
作词
嘉宾
Epoch 5/30
董事长
作曲
作词
嘉宾
Epoch 6/30
董事长
作曲
作词
嘉宾
Epoch 7/30