In [1]:
import tensorflow as tf
import tensorflow_addons as tfa

from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras import backend as K

from tensorflow.keras.preprocessing import sequence
import numpy as np
import json


import pandas as pd
import jieba
from joblib import dump, load
import warnings
warnings.filterwarnings("ignore")
import time

# 命名实体识别

In [11]:
class question_ner:
    # 参数
    def __init__(self):
        # 导入处理好的标签数据集
        with open('data/idx2vocab.json',encoding="utf-8") as file_obj:
            self.idx2vocab=json.load(file_obj) 
        with open('data/vocab2idx.json',encoding="utf-8") as file_obj:
            self.vocab2idx=json.load(file_obj)
        with open('data/idx2label.json',encoding="utf-8") as file_obj:
            self.idx2label=json.load(file_obj)
        with open('data/label2idx.json',encoding="utf-8") as file_obj:
            self.label2idx=json.load(file_obj)

        # 导入训练好的模型
        self.model = models.load_model("model/output/bilstm_crf_ner", compile=False)

        #提取转移矩阵参数
        self.trans_params = self.model.get_layer('crf').get_weights()[0]
        # 获得BiLSTM的输出logits
        self.sub_model = models.Model(inputs=self.model.get_layer('input_ids').input,
                        outputs=self.model.get_layer('dense').output)

    def predict(self,inputs, input_lens):
        logits = self.sub_model.predict(inputs)
        # 获取CRF层的转移矩阵
        # crf_decode：viterbi解码获得结果
        pred_seq, viterbi_score = tfa.text.crf_decode(logits, self.trans_params, input_lens)
        return pred_seq

    def pre_data(self,sentence):
        maxlen=100
        sent_chars = list(sentence)
        sent2id = [self.vocab2idx[word] if word in self.vocab2idx else self.vocab2idx['UNK'] for word in sent_chars]
        sent2id_new = np.array([[0] * (maxlen-len(sent2id)) + sent2id[:maxlen]])
        test_lens = np.array([100])

        pred_seq = self.predict(sent2id_new, test_lens)
        y_label = pred_seq.numpy().reshape(1, -1)[0]
        y_ner = [self.idx2label[str(i)] for i in y_label][-len(sent_chars):]
        return y_ner

# 贝叶斯分类

In [24]:
class question_classify():
    #参数
    def __init__(self):
        self.model=load('model/NBmodel.joblib')
        self.tf=load('model/tf-idf.joblib')
        
        
        
    def classify_question(self,test_word):
        test_temp=jieba.cut(test_word)
        tt=''
        for word in test_temp:
            tt+=word+' '
        test_features = self.tf.transform([tt])
        label=self.model.predict(test_features)[0]
        #['L1类_回答方法.txt','L2类_回答症状.txt','L3类_回答释名.txt','L4类_回答气味.txt','L5类_回答子部.txt','L6类_回答部门.txt']
        if label==1:
            return('L1')
        elif label==2:
            return('L2')
        elif label==3:
            return('L3')
        elif label==4:
            return('L4')
        elif label==5:
            return('L5')
        elif label==6:
            return('L6')
        else:
            return(-1)

# 融合部分

In [25]:
def get_valid_nertag(input_data, result_tags):
    entity=[]
    pos=[]
    ques=list(input_data)
    flag=0
    text=''
    label=''
    index=0
    for i in range(0,len(result_tags)):
        if 'che' in result_tags[i]:
            continue
        else:
            if 'B' in result_tags[i] and flag==0:
                flag=1
                label=result_tags[i][-3:]
                index=i
            elif 'B' in result_tags[i] and flag==1:
                entity.append([input_data[index:i],label])
                pos.append([index,i])
                index=i
                flag=1
                label=result_tags[i][-3:]
            elif 'E' in result_tags[i] and flag==1:
                entity.append([input_data[index:i+1],label])
                pos.append([index,i+1])
                index=0
                flag=0
                label=''
                now=i
            elif 'S' in result_tags[i]:
                entity.append([input_data[i],'dru'])
                pos.append([i])
                now=i
    for item in pos:
        if len(item)==2:
            ques[item[0]:item[1]]=['*']*(item[1]-item[0])
        else:
            ques[item[0]]='*'
            
    ques=''.join(ques)
    ques=ques.replace('*','')
    return(ques,entity)

In [26]:
def question_init():
    ner=question_ner()
    classify=question_classify()
    return(ner,classify)

In [67]:
def question_(question,ner,classify):
    question=question.strip()
    #去除非法字符
    text=''
    for n in range(0, len(question)-1):
        if '\u4e00' <= question[n] <= '\u9fff' or question[n] in '：，,:0123456789.%':
            text += question[n]
            
    result = ner.pre_data(text)
    ques,entity=get_valid_nertag(text, result)
    
    #检查实体数量是否合法
    if entity==[]:
        return -1
    all_entity=[]
    for i in entity:
        all_entity.append(i[1])
    if len(set(all_entity))>1:
        return -1
    
    if ques=='':
        return -1
    #问句分类
    ques=ques.strip()
    if ques[0]=='，' or ques[0]=='。':
        ques=ques[1:]
    if ques[-1]=='，' or ques[-1]=='。':
        ques=ques[:-1]
    label=classify.classify_question(ques)
    
    #检查实体和问题是否对应
    all_entity=list(set(all_entity))
    if all_entity[0]=='sym' or all_entity[0]=='dis':
        if label!='L1':
            return -1
    elif all_entity[0]=='dru':
        if label!='L2' and label!='L3' and label!='L4' and label!='L5':
            return -1
    elif all_entity[0]=='dep':
        if label!='L6':
            return -1
        
    #返回
    return_dic={}
    return_dic['实体']=[]
    return_dic['实体类型']=all_entity[0]
    return_dic['问句类型']=label
    for i in entity:
        return_dic['实体'].append(i[0])
    
    return return_dic

In [51]:
if __name__=="__main__":
    ner,classify=question_init()
    dic=question_('紫檀对治什么病比较有用',ner,classify)

# 生成测试问题的答案

In [80]:
ques=[]
for line in open('data/test_question.txt',encoding='UTF-8'):
    line=line.rstrip().split('\t')
    line=line[0].strip().split()
    ques.extend(line)

In [81]:
result=[]
for q in ques:
    result.append(question_(q,ner,classify))

In [84]:
filename='data/test_result.json'
with open(filename,'w',encoding='utf-8') as file_obj:
    json.dump((result),file_obj,ensure_ascii=False,indent = 4)