In [2]:
from typing import List
import random
import pickle
import json
import copy
import re
from entity_recognize import kd_ana

In [4]:
# 训练集、测试集、验证集数据读取

feature2chinese = {"Disease": "疾病", "Symptom": "症状", "Attribute": "属性", "Test": "检查", "Medicine": "药物"}

with open("../data/train/train.pk", "rb") as f:
    train_data = pickle.load(f)
    
with open("../data/evalution/dev.pk", "rb") as f:
    dev_data = pickle.load(f)
    
with open("../data/test/test.pk", "rb") as f:
    test_data = pickle.load(f)

In [5]:
# 将测试集和验证集的数据合并处理。主要用于预训练
dev_test_data = dev_data + test_data

# 文本生成任务训练数据生成

In [6]:
# 将测试集和验证集中的历史对话数据用于文本生成训练任务
dev_test_data_end_with_doctor = copy.deepcopy(dev_test_data)
dev_test_data_for_generation = []
for dtd in dev_test_data_end_with_doctor:
    try:
        while dtd['history'][-1][:2] == '患者':
            dtd['history'].pop()
        dev_test_data_for_generation.append(dtd)
    except:
        continue
del dev_test_data_end_with_doctor

In [7]:
# 用于将同一句话中的不同特征进行整合
def cat_feature_info(dt):
    together_list = []
    for feature in feature2chinese.keys():
        together_list.extend(dt[feature])
    return together_list

def dev_type2train_type(input_list):
    output_list = []
    for text in input_list:
        role, sent = text[:2], text[3:]
        entity = list(kd_ana.convert_sen_to_entity_set(sent))
        role = "Patient" if role == "患者" else "Doctor"
        output_list.append({"Attribute": entity, 'Disease': [], 'Medicine': [], 'Sentence': sent, 'Symptom': [], 'Test': [], 'id': role})
    return output_list

In [8]:
# 将用于文本生成的测试集数据和验证集数据转换成训练集的格式，从而统一处理
dev_test_data_for_generation2train_type = []
for dt in dev_test_data_for_generation:
    dev_test_data_for_generation2train_type.append(dev_type2train_type(dt["history"]))

In [11]:
# 将验证集和测试集中的数据合并到训练集中，共同用于训练文本生成任务。
train_data += dev_test_data_for_generation2train_type

In [12]:
# 将文本长度大于1020的对话，进行截断
def cut_dialog_without_entity(line_info, max_lenght=1020):
    inp_list = line_info["input"]
    cut_id = 0
    cut_flag = False
    while True:
        cnt = 0
        for inp in inp_list:
            for _inp in inp:
                cnt += len(_inp) + 1
        if cnt <= max_lenght:
            line_info["input"] = inp_list
            return cut_flag
        ned_cut = cnt - max_lenght
        if len(inp_list) >= 10:
            if len(inp_list[0])==1:
                inp_list = inp_list[1:]
                line_info["begin_role"] = 1 if line_info["begin_role"]==0 else 0
                line_info["output"].pop(0)
            else:
                inp_list[0] = inp_list[0][1:]
                line_info["output"][0].pop(0)
        else:
            for i, sent in enumerate(inp_list[cut_id]):
                to_cut = min(ned_cut, len(sent))
                inp_list[cut_id][i] = sent[to_cut:]
                ned_cut -= to_cut
                if ned_cut <= 0:
                    break
            cut_id += 1
        cut_flag = True

In [13]:
# 生成 “文本生成任务训练数据”
# 策略：每一个dialog 有多轮对话。训练时，有两种策略：1.在同一个样本中对所有医生的话术进行teacher force learning; 
# 2.将同一个样本拆封成多个子样本，每个子样本只对一轮对话中的医生话术进行teacher force learning
# 我们采用第二种策略
train_num = len(train_data)   
_train_data = []
with open("../data/generation_train_final.json", "w", encoding="utf8") as f:
    for t in range(10):  # 生成多个epoch的训练数据，并且将每个epoch的数据打散
        random.shuffle(train_data)   # 在训练数据中手动加入训练样本shuffle
        for r,dt in enumerate(train_data):
            _train_data.append(dt)       # 每1000个样本，构成一个轮回。主要是防止训练时，是先把单轮对话数据训练完后再训练两轮，依次类推；
                                        # 如果这样，容易使模型陷入局部解
            if len(_train_data) == 1000 or r == train_num - 1:
                for i in range(1, 30):   # 最多支持30轮对话
                    break_forword = False
                    for _dt in _train_data:
                        line_info = {"begin_role": 0, "input": [], "output": []}
                        if _dt[0]["id"] == "Doctor":
                            line_info["begin_role"] = 1
                        patient_in_flag = False
                        last_id = ""
                        cnt = 0
                        for seg in _dt:
                            feature_info = cat_feature_info(seg)
                            dialo_sentence = seg['Sentence']
#                             dialo_sentence = whole_word_process(seg['Sentence'], feature_info)
                            if seg["id"] == "Doctor":
                                cnt += 1
                                if last_id == "Doctor":
                                    line_info["input"][-1].append(dialo_sentence)
                                    line_info["output"][-1].append(feature_info)
                                else:
                                    line_info["input"].append([dialo_sentence])
                                    line_info["output"].append([feature_info])
                                if cnt == i and line_info["begin_role"] == 0:
                                    break_forword =True
                                    flag = cut_dialog_without_entity(line_info)
#                                     if flag:
                                    f.write(json.dumps(line_info, ensure_ascii=False) + '\n')
                                    break
                                elif cnt > i and line_info["begin_role"] == 1 and patient_in_flag:
                                    break_forword = True
                                    flag = cut_dialog_without_entity(line_info)
#                                     if flag:
                                    f.write(json.dumps(line_info, ensure_ascii=False) + '\n')
                                    break
                                last_id = "Doctor"
                            elif seg["id"] == "Patient":
                                patient_in_flag = True
                                if last_id == "Patient":
                                    line_info["input"][-1].append(dialo_sentence)
                                    line_info["output"][-1].append(feature_info)
                                else:
                                    line_info["input"].append([dialo_sentence])
                                    line_info["output"].append([feature_info])
                                last_id = "Patient"

                    if not break_forword:
                        print("finished forword:{}".format(i))
                        break
                _train_data = []

finished forword:27
finished forword:27
finished forword:26
finished forword:25
finished forword:23
finished forword:25
finished forword:25
finished forword:25
finished forword:26
finished forword:24
finished forword:28
finished forword:29
finished forword:25
finished forword:24
finished forword:25
finished forword:28
finished forword:28
finished forword:28
finished forword:26
finished forword:27
finished forword:21


# 用于entity train data生成

In [14]:
# 将entity_list 中的entity作为单个char字符，进行训练
def whole_word_process_without_entity(sentence: str, feature_list: List) -> str:
    if not feature_list:
        return sentence
    rst = list(sentence)
    add_n = 0
    for node in re.finditer("|".join(feature_list), sentence):
        beg, end = node.span()
        rst.insert(beg+add_n, "##")  # 在entity名词前后分别加上 ‘##’ 、‘$’。在tokenizer.encode时进行识别并标为单token
        add_n += 1
        rst.insert(end+add_n, "$")
        add_n += 1
    return "".join(rst)

In [15]:
#char级别
train_num = len(train_data)   
print(train_num)
_train_data = []
with open("../data/train_total_dialog_for_entity_char.json", "w", encoding="utf8") as f:
    for t in range(15):
        random.shuffle(train_data)   # 在训练数据中手动加入训练样本shuffle
        for _dt in train_data:
            line_info = {"begin_role": 0, "input": [], "output": []}
            if _dt[0]["id"] == "Doctor":
                line_info["begin_role"] = 1
            last_id = ""
            for seg in _dt:
                feature_info = cat_feature_info(seg)
                dialo_sentence = whole_word_process_without_entity(seg['Sentence'], feature_info)
                if seg["id"] == "Doctor":
                    if last_id == "Doctor":
                        line_info["input"][-1].append(dialo_sentence)
                        line_info["output"][-1].append(feature_info)
                    else:
                        line_info["input"].append([dialo_sentence])
                        line_info["output"].append([feature_info])                    
                    last_id = "Doctor"
                elif seg["id"] == "Patient":
                    patient_in_flag = True
                    if last_id == "Patient":
                        line_info["input"][-1].append(dialo_sentence)
                        line_info["output"][-1].append(feature_info)
                    else:
                        line_info["input"].append([dialo_sentence])
                        line_info["output"].append([feature_info])
                    last_id = "Patient"
            flag = cut_dialog_without_entity(line_info)
           # print(line_info)
            f.write(json.dumps(line_info, ensure_ascii=False) + '\n')

22208


# 评测数据集生成

In [2]:
with open(r"E:\BaiduNetdiskDownload\ccks21_mdg_evaluation\test_sample.pk", "rb") as f:
    dt = pickle.load(f)
    
with open(r"E:\BaiduNetdiskDownload\ccks21_mdg_evaluation\test_sample_reference.pk", "rb") as f:
    dt_response = pickle.load(f)
    
feature_list = []
for line in open("../data/entity_list.txt", "r", encoding="utf8"):
    feature_list.append(line.strip())

In [3]:
def whole_word_process_without_entity(sentence: str, feature_list: List) -> str:
    if not feature_list:
        return sentence
    rst = list(sentence)
    add_n = 0
    for node in re.finditer("|".join(feature_list), sentence):
        beg, end = node.span()
        rst.insert(beg+add_n, "##")
        add_n += 1
        rst.insert(end+add_n, "$")
        add_n += 1
    return "".join(rst)

In [4]:
with open("../data/evalution_data.txt", "w", encoding="utf8") as f:
    for dialog, response in zip(dt, dt_response):
        normalization_data = {"begin_role": 0, "input":[], "output":[]}
        response = whole_word_process_without_entity(response, feature_list)
        response_entity = list(kd_ana.convert_sen_to_entity_set(response))
        last_role = ""
        dialog = dialog["history"]
        if dialog[0].startswith("医生"):
            normalization_data["begin_role"] = 1 
#             last_role = "doctor"
        for _dialog in dialog:
            role, text = _dialog[:2], _dialog[3:]
            text = whole_word_process_without_entity(text, feature_list)
            entity = list(kd_ana.convert_sen_to_entity_set(text))
            if role != last_role:
                normalization_data["input"].append([text])
                normalization_data["output"].append([entity])
            else:
                normalization_data["input"][-1].append(text)
                normalization_data["output"][-1].append(entity)
            last_role = role
        if last_role == "患者":
            normalization_data["input"].append([response])
            normalization_data["output"].append([response_entity])
        else:
            normalization_data["input"][-1].append(response)
            normalization_data["output"][-1].append(response_entity)
        f.write(json.dumps(normalization_data, ensure_ascii=False) + '\n')

# generation过程中，采用word level的roformer模型，训练集中新增的部分word没有在vocab当中，需要筛选出并手动加上

In [5]:
import jieba
jieba.load_userdict("../data/entity_list.txt")

Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\LASTFI~1\AppData\Local\Temp\jieba.cache
Loading model cost 0.631 seconds.
Prefix dict has been built successfully.


In [7]:
entity_list = [line.strip() for line in open("../data/entity_list.txt", "r", encoding="utf8")]

In [9]:
feature2chinese = {"Disease": "疾病", "Symptom": "症状", "Attribute": "属性", "Test": "检查", "Medicine": "药物"}

with open(r"E:\BaiduNetdiskDownload\ccks21_mdg_dataset\train.pk", "rb") as f:
    train_data = pickle.load(f)

In [10]:
word_cnt_dict = {}
for td in train_data:
    for _td in td:
        for part in jieba.cut(_td["Sentence"]):
            if len(part) > 1:
                word_cnt_dict[part] = word_cnt_dict.get(part, 0) + 1
                
word_cont_dict_filter = {key:val for key, val in word_cnt_dict.items() if val >=15}  # 如果加入预训练，可以适当减小过滤阈值

In [11]:
tokenier_already_haven_word = [line.strip() for line in 
                               open(r"E:\BaiduNetdiskDownload\chinese_roformer_L-12_H-768_A-12\vocab.txt", "r", encoding="utf8")]

# roformer_add_whole_word_except_entity.txt 属于当前数据集的新增词，且不包含entity_list中的词
with open("../data/roformer_add_whole_word_except_entity.txt", "w", encoding="utf8") as f:
    for key in word_cont_dict_filter.keys():
        if key not in entity_list:
            if key in tokenier_already_haven_word:
                f.write(key + "\n")

In [15]:
jieba_add_word = []

tokenier_already_haven_word = [line.strip() for line in open(r"E:\BaiduNetdiskDownload\chinese_roformer_L-12_H-768_A-12\vocab.txt", "r", encoding="utf8")]
print("wobert vocab number:", len(tokenier_already_haven_word))

for line in open("../data/entity_list.txt","r", encoding="utf8"):
    word = line.strip()
    if word not in tokenier_already_haven_word:
        jieba_add_word.append(word)
    
for key in word_cont_dict_filter:
    if key not in tokenier_already_haven_word:  # tokenier中只是加入了20000个高频的词
        jieba_add_word.append(key)
print("jieba add_word number:", len(set(jieba_add_word)))
with open("../data/jieba_add.txt", "w", encoding="utf8") as f:
    for word in set(jieba_add_word):
        f.write(word + "\n")

wobert vocab number: 50000
jieba add_word number: 2362
