## 读取训练语料和NER结果

In [3]:
import re
import pandas as pd
from tqdm.auto import tqdm
import json


data_path = '/home/cs/yangyuchen/guoyiqiu/kg_llm/data/medmcqa_dev.json'
ner_path = '/home/cs/yangyuchen/guoyiqiu/kg_llm/data/medmcqa_dev_bios_v2.2_release_eng.txt'

data_dir = "/".join(data_path.split('/')[:-1])
print('data_dir: ', data_dir)
data_name = '.'.join(data_path.split('/')[-1].split('.')[:-1]) 
print('data_name: ', data_name)

print(f"读取数据集和NER结果文件...")
with open(ner_path, encoding='utf-8') as f:
    ner_str = f.read()
    f.close()

ner_str = ner_str.replace('｜','|')
with open(ner_path,'w') as f:
    f.write(ner_str)
    f.close()

ner_df = pd.read_csv(ner_path, sep='|')
ner_df['LINE, BEGIN, END'] = ner_df['LINE, BEGIN, END'].apply(lambda x: eval(str(x)))
ner_df['LINE'] = ner_df['LINE, BEGIN, END'].apply(lambda x: x[0])
ner_df['BEGIN, END'] = ner_df['LINE, BEGIN, END'].apply(lambda x: (x[1],x[2]))
ner_df.drop('LINE, BEGIN, END', axis=1, inplace=True)

data_df = pd.read_json(data_path)
data_strs = open(data_path, encoding='utf-8').readlines()
data_df['input_entities'] = data_df.apply(lambda x: [], axis=1)
data_df['input_entities_cid'] = data_df.apply(lambda x: [], axis=1)
data_df['output_entities'] = data_df.apply(lambda x: [], axis=1)
data_df['output_entities_cid'] = data_df.apply(lambda x: [], axis=1)

num_999 = 0
num_not_word = 0
print(f"初步清洗NER结果...")
for row in tqdm(ner_df.itertuples(), total=ner_df.shape[0]):
    j = row.LINE//4
    real_str = data_strs[row.LINE][row[-1][0]:row[-1][1]]
    neighbor_str = data_strs[row.LINE]
    
    if row.STY.startswith('999'):
        num_999+=1
        continue
    if len(re.findall(f"[^\w]{re.escape(real_str)}[^\w]", neighbor_str))==0:
        print('neighbor_str: ', neighbor_str)
        num_not_word+=1
        continue
    if row.LINE%4 == 2:
        data_df.loc[j]['input_entities'].append(real_str)
        data_df.loc[j]['input_entities_cid'].append(row.CID)
    else:
        data_df.loc[j]['output_entities'].append(real_str)
        data_df.loc[j]['output_entities_cid'].append(row.CID)

json.dump(data_df.to_dict(orient='records'), open(f'{data_dir}/ner_results_{data_name}.json','w'),indent=4)
print(f"被清洗的未知类别(999)实体数量: {num_999}，占比: {num_999/ner_df.shape[0]}")
print(f"被清洗的非单词实体数量: {num_not_word}，占比: {num_not_word/ner_df.shape[0]}")
print('初步清洗后的数据集: ')
data_df

data_dir:  /home/cs/yangyuchen/guoyiqiu/kg_llm/data
data_name:  medmcqa_dev
读取数据集和NER结果文件...
初步清洗NER结果...


  0%|          | 0/26826 [00:00<?, ?it/s]

被清洗的未知类别(999)实体数量: 1544，占比: 0.057556102288824275
被清洗的非单词实体数量: 0，占比: 0.0
初步清洗后的数据集: 


Unnamed: 0,input,output,input_entities,input_entities_cid,output_entities,output_entities_cid
0,Which of the following is not true for myelina...,A,"[myelinated nerve fibers, Impulse, myelinated ...","[CN33499808, CN00044749, CN02095548, CN0736878...",[],[]
1,Which of the following is not true about glome...,A,"[glomerular capillaries, oncotic pressure, cap...","[CN00453049, CN00007663, CN00015232, CN0045865...",[],[]
2,A 29 yrs old woman with a pregnancy of 17 week...,C,"[pregnancy, down syndrome, down syndrome, advi...","[CN00454003, CN00016551, CN00016551, CN0001799...",[],[]
3,"Axonal transport is: Options:A: Antegrade, B: ...",C,[Axonal transport],[CN00473776],[],[]
4,Low insulin to glucagon ratio is seen in all o...,A,"[insulin, glucagon, Glycogen synthesis, Glycog...","[CN32853719, CN00117640, CN00056666, CN0000965...",[],[]
...,...,...,...,...,...,...
4178,A study is to be conducted with regards to the...,A,"[fat, expressed breast milk, coho]","[CN00506745, CN00150819, CN28173440]",[],[]
4179,"APGAR acronym stands for? Options:A: Activity,...",D,"[APGAR, pulse pressure, grimace, respiration, ...","[CN00033786, CN00029530, CN00086647, CN3292265...",[],[]
4180,Most commonly implicated drug for acute liver ...,A,"[acute liver failure, Paracetamol, Valproate, ...","[CN00452947, CN32931003, CN00236867, CN0005756...",[],[]
4181,A 9 year old boy has steroid dependent nephrot...,B,"[steroid dependent nephrotic syndrome, cushing...","[CN00461839, CN13977028, CN33056928, CN0019686...",[],[]


## 执行TF-IDF清洗

In [2]:
import re
from multiprocessing import Pool, cpu_count
import math
import pickle
from tqdm.auto import tqdm
from collections import defaultdict
import os
import json

os.environ['TOKENIZERS_PARALLELISM'] = 'true'
tqdm.pandas()

corpus = [i.input + " " + i.output for i in data_df.itertuples()]
all_e_list = [i.input_entities + i.output_entities for i in data_df.itertuples()]
all_e =  list(set([e for es in all_e_list for e in es]))
all_text = ' '.join(corpus)


def wc(e):
    return {e:len(re.findall(f"[^\w]{re.escape(e)}[^\w]", all_text))}

if os.path.exists(f'{data_dir}/wc_{data_name}.pkl'):
    print(f"读取缓存的单词计数{data_dir}/wc_{data_name}.pkl...")
    count = pickle.load(open(f'{data_dir}/wc_{data_name}.pkl', 'rb'))
else:
    print("正在进行单词计数...")
    pool = Pool(cpu_count())
    count = {}
    for o in tqdm(pool.imap_unordered(wc, all_e), total=len(all_e)):
        count.update(o)
    pool.close()
    pickle.dump(count, open(f'{data_dir}/wc_{data_name}.pkl', 'wb'))

print("正在计算TF_IDF...")
idf = {e: math.log(len(corpus)/c) for e,c in count.items()}
all_tf = [{e: len(re.findall(f"{re.escape(e)}", corpus[i])) for e in e_list} for i,e_list in enumerate(all_e_list)]
all_tf_idf = [{e: tf[e] * idf[e] for e in tf} for tf in all_tf]

def mean_e_tf_idf(all_tf_idf):
    mean_e_tf_idf = {}
    for tf_idf in all_tf_idf:
        for e in tf_idf:
            if e not in mean_e_tf_idf:
                mean_e_tf_idf[e] = [tf_idf[e]]
            else:
                mean_e_tf_idf[e].append(tf_idf[e])
    for e in mean_e_tf_idf:
        mean_e_tf_idf[e] = sum(mean_e_tf_idf[e]) / len(mean_e_tf_idf[e])

    df = pd.DataFrame(list(mean_e_tf_idf.keys()), columns=['e'])
    df['tf_idf'] = df['e'].apply(lambda x: mean_e_tf_idf[x])
    df.sort_values('tf_idf', ascending=False, inplace=True)
    return df


def wash(row):
    new_input_entities = []
    new_input_entities_cid = []
    new_output_entities = []
    new_output_entities_cid = []
    new_tf_idf = {}
    for e,cid in zip(row['input_entities'],row['input_entities_cid']):
        if row['tf_idf'][e] > TF_IDF_THRESHOLD:
            new_input_entities.append(e)
            new_input_entities_cid.append(cid)
            new_tf_idf[e] = row['tf_idf'][e]
    for e,cid in zip(row['output_entities'],row['output_entities_cid']):
        if row['tf_idf'][e] > TF_IDF_THRESHOLD:
            new_output_entities.append(e)
            new_output_entities_cid.append(cid)
            new_tf_idf[e] = row['tf_idf'][e]
    row['input_entities'] = new_input_entities
    row['input_entities_cid'] = new_input_entities_cid
    row['output_entities'] = new_output_entities
    row['output_entities_cid'] = new_output_entities_cid
    row['tf_idf'] = new_tf_idf
    return row
    
data_df['tf_idf'] = all_tf_idf

TF_IDF_THRESHOLD = 7
print(f"TF_IDF阈值: {TF_IDF_THRESHOLD} 开始执行TF-IDF清洗... ")
new_data_df = data_df.progress_apply(wash, axis=1)
json.dump(new_data_df.to_dict(orient='records'), open(f'{data_dir}/ner_results_{data_name}.json','w'), indent=4)

new_e_tf_idf = mean_e_tf_idf(new_data_df['tf_idf'])
old_e_tf_idf = mean_e_tf_idf(all_tf_idf)
new_e = set(new_e_tf_idf['e'].tolist())
old_e = set(old_e_tf_idf['e'].tolist())
wash_e = old_e - new_e

print(f"清洗了{len(wash_e)}个实体类别")
# for e in wash_e:
#     print(e)
print("TF-IDF清洗后, 实体总数剩余比例: ", (new_data_df['input_entities'].apply(len).sum()+new_data_df['output_entities'].apply(len).sum()) / (data_df['input_entities'].apply(len).sum() + data_df['output_entities'].apply(len).sum()))
print("TF-IDF清洗后, 实体类别剩余比例:", mean_e_tf_idf(new_data_df['tf_idf'].tolist()).shape[0]/mean_e_tf_idf(all_tf_idf).shape[0])

读取缓存的单词计数/home/cs/yangyuchen/guoyiqiu/kg_llm/data/wc_usmle_train.pkl...
正在计算TF_IDF...
TF_IDF阈值: 7 开始执行TF-IDF清洗... 


  0%|          | 0/10178 [00:00<?, ?it/s]

清洗了2836个实体类别
TF-IDF清洗后, 实体总数剩余比例:  0.4173416013721426
TF-IDF清洗后, 实体类别剩余比例: 0.9448817367306085


## 执行Grounding并导出kg_dataset.json

In [2]:
from collections import defaultdict
import pandas as pd

print("正在读取知识图谱...")
kg = pd.read_csv('data/bios_kg_with_def_detailed.csv')
c2i = defaultdict(list)

print("正在构建知识图谱索引...")
for row in tqdm(kg.itertuples(), total=kg.shape[0]):
    c2i[row[4]].append(row[0])

正在读取知识图谱...
正在构建知识图谱索引...


  0%|          | 0/35327128 [00:00<?, ?it/s]

In [4]:
def grounding(row):
    new_input_entities = []
    new_output_entities = []
    new_input_entities_cid = []
    new_output_entities_cid = []
    input_triplets = []
    output_triplets = []
    max_t = 9999999999999
    for e,cid in zip(row['input_entities'],row['input_entities_cid']):
        if cid in c2i and len(c2i[cid]) <= max_t:
            new_input_entities.append(e)
            new_input_entities_cid.append(cid)
            input_triplets.append(c2i[cid])
    for e,cid in zip(row['output_entities'],row['output_entities_cid']):
        if cid in c2i and len(c2i[cid]) <= max_t:
            new_output_entities.append(e)
            new_output_entities_cid.append(cid)
            output_triplets.append(c2i[cid])
    row['input_entities'] = new_input_entities
    row['output_entities'] = new_output_entities
    row['input_entities_cid'] = new_input_entities_cid
    row['output_entities_cid'] = new_output_entities_cid
    row['input_triplets'] = input_triplets
    row['output_triplets'] = output_triplets
    return row

print("正在添加知识三元组...")
new_data_df = new_data_df.progress_apply(grounding, axis=1)
kg_data_df = new_data_df.drop(['tf_idf','input_entities_cid','output_entities_cid'],axis=1,inplace=False)
import json
json.dump(kg_data_df.to_dict(orient='records'), open(f'{data_dir}/kg_{data_name}.json','w'),indent=4)
kg_data_df

正在添加知识三元组...


  0%|          | 0/10178 [00:00<?, ?it/s]

Unnamed: 0,input,output,input_entities,output_entities,input_triplets,output_triplets
0,A 23-year-old pregnant woman at 22 weeks gesta...,E,"[cranberry extract, gravid uterus]",[],"[[18182139, 18182140, 18182141], [18429743, 18...",[]
1,A 3-month-old baby died suddenly at night whil...,A,[death],[],[[285185]],[]
2,A mother brings her 3-week-old infant to the p...,A,"[feeding habits, ventral pancreatic bud, proxi...",[],"[[11800904, 11800905], [16762920, 16762921, 16...",[]
3,A pulmonary autopsy specimen from a 58-year-ol...,A,"[pulmonary autopsy, acute hypoxic respiratory ...",[],"[[22661828], [17896622], [235796, 644860, 6448...",[]
4,A 20-year-old woman presents with menorrhagia ...,E,"[bruising, PT 12, PTT 43, Factor V Leiden, Lup...",[],"[[22110595], [21066621, 21066622, 21066623], [...",[]
...,...,...,...,...,...,...
10173,A 60-year-old man presents to the emergency de...,B,"[preceding symptoms, preserved ejection fracti...",[],"[[17472335], [10934878, 10934879], [1346950, 1...",[]
10174,A 45-year-old male with a 15-year history of d...,B,"[renal impairment, sensitive test, renal impai...",[],"[[35315253, 35315254, 35315255, 35315256, 3531...",[]
10175,After receiving a positive newborn screening r...,B,"[sweat test, DNA sequencing, base pair deletio...",[],"[[362529, 18348081, 18348082, 18348083, 183480...",[]
10176,A 25-year-old man comes to the office because ...,C,"[point tenderness, shoulder, Branched-chain al...",[],"[[26845955, 26845956, 26845957], [46416, 23295...",[]


In [5]:
def get_options(row):
    option_text = row['input'].split("Options:")[1]
    option_text = option_text.replace(" A: ","").replace(" B: ","##").replace(" C: ","##").replace(" D: ","##").replace(" E: ","##")
    row['options'] = option_text.split("##")
    return row
kg_data_df_with_options = kg_data_df.apply(get_options, axis=1)

def options_in_entities(row):
    oine_num = 0
    for option in row['options']:
        for e in row['input_entities']:
            if option == e:
                oine_num+=1
                break
        
    return oine_num
kg_data_df=kg_data_df[kg_data_df_with_options.apply(options_in_entities,axis=1)==5]
json.dump(kg_data_df.to_dict(orient='records'), open(f'{data_dir}/kg_{data_name}_440.json','w'),indent=4)
kg_data_df

Unnamed: 0,input,output,input_entities,output_entities,input_triplets,output_triplets
39,A 31-year-old G2P2 female at 40 weeks gestatio...,D,"[Fetal heart tracing, variable decelerations, ...",[],"[[20692702, 20692703, 20692704, 20692705], [17...",[]
63,A 4-year-old boy is brought to the physician b...,D,"[Enalapril therapy, Furosemide therapy, Anti-s...",[],"[[28316168, 28316169], [12215050, 12215051, 12...",[]
71,A 37-year-old woman comes to the physician bec...,B,"[white spots, D-xylose, Gluten-free diet, Panc...",[],"[[325958, 363212, 15126349, 15126350, 15126351...",[]
111,A 37-year-old patient is being evaluated for i...,D,"[involuntary movements, neuromediators, trinuc...",[],"[[76701, 9007766, 9007767, 9007768], [142281, ...",[]
134,A 3000-g (6.6-lb) female newborn is delivered ...,C,"[auditory screening tests, Congenital parvovir...",[],"[[1857389, 1857390], [16465312, 16465313, 1646...",[]
...,...,...,...,...,...,...
10020,A 29-year-old internal medicine resident prese...,E,"[Schistosoma haematobium, Onchocerca volvulus,...",[],"[[181070], [8453394], [284370, 6615965, 661596...",[]
10026,A 45-year-old woman presents to her primary ca...,C,"[badminton, snuffbox, Nodules, Ulnar deviation...",[],"[[289004, 28518992, 28518993], [35839, 291383,...",[]
10071,A 24-day-old neonate is brought to the emergen...,C,"[sick people, rouse, neonatal meningitis, empi...",[],"[[151528], [22285, 28852735, 28852736], [69755...",[]
10166,A 61-year-old man presents to his primary care...,B,"[cocaine abuse, bilateral lung bases, Palpable...",[],"[[59943, 15143363, 15143364], [23788962, 23788...",[]


## 可视化数据集统计信息

In [None]:
import re
from transformers import AutoTokenizer
import pandas as pd
import matplotlib.pyplot as plt

tok = AutoTokenizer.from_pretrained("/home/cs/yangyuchen/yushengliao/Medical_LLM/llama-7b-hf")

LEN_DEF = kg[kg['edge'] == "has definition of "].shape[0]

def da(row):
    row_da = pd.Series()
    all_text = (row['input'] + row['output']).lower()
    row_da['sent_token_num'] = len(tok(row['input']+row['output'])['input_ids'])
    et = [(e,ts) for e,ts in zip(row['input_entities'] + row['output_entities'], row['input_triplets'] + row['output_triplets'])]
    et = sorted(et, key=lambda x: -len(tok(x[0])['input_ids']))
    row_da['e_num'] = len(et)
    uni_ts = set([t for (e,ts) in et for t in ts])
    row_da['uni_t_num'] = len(uni_ts)
    row_da['def_t_num'] = len([t for t in uni_ts if t < LEN_DEF])
    row_da['kg_t_num'] = row_da['uni_t_num'] - row_da['def_t_num']
    kg_token_num = 0
    e_counts = []
    ts_token_num = []
    for e,ts in et:
        e_count = len(re.findall(f"[^\w]{re.escape(e.lower())}[^\w]", all_text))
        e_counts.append(e_count)
        all_text = re.sub(f"[^\w]{re.escape(e.lower())}[^\w]", '', all_text)
        all_t_token_num = [len(tok(kg.iloc[t]['edge'] + kg.iloc[t]['target'])['input_ids']) + 1 for t in ts]
        ts_token_num.append(sum(all_t_token_num))
        kg_token_num += sum(all_t_token_num) * e_count
    row_da['e_counts'] = e_counts
    e_ts_num = [len(ts) for e,ts in et]
    row_da['e_ts_num'] = e_ts_num
    row_da['ts_token_num'] = ts_token_num
    row_da['kg_token_num'] = kg_token_num       
    return row_da

def plot_series(series, title, xlabel, ylabel):
    series.hist(bins=series.max(), range=(0, series.max()))
    print(f"average {xlabel} :{series.mean()}")
    plt.axvline(x=series.mean(), color='r', linestyle='--', )
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.show()

data_da_df = kg_data_df.progress_apply(da, axis=1)

plot_series(data_da_df['sent_token_num'], 'corpus length distribution', 'corpus length', 'frequency')
plot_series(data_da_df['kg_token_num'], 'kg token length distribution', 'kg token length', 'frequency')
plot_series(data_da_df['kg_token_num'] + data_da_df['sent_token_num'], 'corpus with all kg length distribution', 'corpus with all kg length', 'frequency')
plot_series(data_da_df['e_num'], 'entity num each corpus distribution', 'entity num each corpus', 'frequency')
plot_series(data_da_df['uni_t_num'], 'unique triplets num each corpus distribution', 'unique triplets num each corpus', 'frequency')
plot_series(data_da_df['def_t_num'], 'definition triplets num each corpus distribution', 'definition triplets num each corpus', 'frequency')
plot_series(data_da_df['kg_t_num'], 'kg triplets num each corpus distribution', 'kg triplets num each corpus', 'frequency')
plot_series(data_da_df['e_counts'].apply(lambda x: max(x) if x else 0), 'max e_counts num each corpus distribution', 'max e_counts num each corpus', 'frequency')

## 生成COT推理数据集-训练

In [117]:
import json
import pandas as pd
import random

COT_INPUT_PROMPT = "{INPUT} The medical entities in the text include: {ENTITIES}."
COT_OUTPUT_PROMPT = "The related knowledge of the medical entities include: {KNOWLEDGE}\n"
COT2_OUTPUT_PROMPT = "The answer is {ANSWER}"

def cot_step1(row):
    new_row = row.copy()
    ENTITIES = ', '.join(row['input_entities'])
    KNOWLEDGE = ''
    for e,ts in zip(row['input_entities'],row['input_triplets']):
        for tid in ts:
            know = kg.iloc[tid]
            KNOWLEDGE += f"{e} {know['edge']} {know['target']}; "
    ANSWER = row['output'].strip()
    new_row['input'] = COT_INPUT_PROMPT.format(INPUT=row['input'], ENTITIES=ENTITIES)
    new_row['output'] = COT_OUTPUT_PROMPT.format(KNOWLEDGE=KNOWLEDGE)
    
    return new_row

def cot_step2(row):
    new_row = row.copy()
    ENTITIES = ', '.join(row['input_entities'])
    KNOWLEDGE = ''
    for e,ts in zip(row['input_entities'],row['input_triplets']):
        for tid in ts:
            know = kg.iloc[tid]
            KNOWLEDGE += f"{e} {know['edge']} {know['target']}; "
    ANSWER = row['output'].strip()
    new_row['input'] = COT_INPUT_PROMPT.format(INPUT=row['input'], ENTITIES=ENTITIES) + COT_OUTPUT_PROMPT.format(KNOWLEDGE=KNOWLEDGE)
    new_row['output'] = COT2_OUTPUT_PROMPT.format(ANSWER=ANSWER)
    
    return new_row

    
cot_kg_out_data_df = pd.concat([kg_data_df.apply(cot_step1, axis=1), kg_data_df.apply(cot_step2, axis=1)])
json.dump(cot_kg_out_data_df.to_dict('records'), open(f'{data_dir}/cot_kg_{data_name}.json','w'), indent=4)
cot_kg_out_data_df

Unnamed: 0,input,output,input_entities,output_entities,input_triplets,output_triplets
0,A 23-year-old pregnant woman at 22 weeks gesta...,The related knowledge of the medical entities ...,"[cranberry extract, gravid uterus]",[],"[[18182139, 18182140, 18182141], [18429743, 18...",[]
1,A 3-month-old baby died suddenly at night whil...,The related knowledge of the medical entities ...,[death],[],[[285185]],[]
2,A mother brings her 3-week-old infant to the p...,The related knowledge of the medical entities ...,"[feeding habits, ventral pancreatic bud, proxi...",[],"[[11800904, 11800905], [16762920, 16762921, 16...",[]
3,A pulmonary autopsy specimen from a 58-year-ol...,The related knowledge of the medical entities ...,"[pulmonary autopsy, acute hypoxic respiratory ...",[],"[[22661828], [17896622], [235796, 644860, 6448...",[]
4,A 20-year-old woman presents with menorrhagia ...,The related knowledge of the medical entities ...,"[bruising, PT 12, PTT 43, Factor V Leiden, Lup...",[],"[[22110595], [21066621, 21066622, 21066623], [...",[]
...,...,...,...,...,...,...
10173,A 60-year-old man presents to the emergency de...,The answer is B,"[preceding symptoms, preserved ejection fracti...",[],"[[17472335], [10934878, 10934879], [1346950, 1...",[]
10174,A 45-year-old male with a 15-year history of d...,The answer is B,"[renal impairment, sensitive test, renal impai...",[],"[[35315253, 35315254, 35315255, 35315256, 3531...",[]
10175,After receiving a positive newborn screening r...,The answer is B,"[sweat test, DNA sequencing, base pair deletio...",[],"[[362529, 18348081, 18348082, 18348083, 183480...",[]
10176,A 25-year-old man comes to the office because ...,The answer is C,"[point tenderness, shoulder, Branched-chain al...",[],"[[26845955, 26845956, 26845957], [46416, 23295...",[]


## 生成COT推理数据集-测试

In [65]:
import json
import pandas as pd
import random

COT_INPUT_PROMPT = "{INPUT} The medical entities in the text include: {ENTITIES}."
COT_OUTPUT_PROMPT = "The related knowledge of the medical entities include: {KNOWLEDGE}\n"
COT2_OUTPUT_PROMPT = "The answer is {ANSWER}"

def cot_step12(row):
    new_row = pd.Series()
    ENTITIES = ', '.join(row['input_entities'])
    KNOWLEDGE = ''
    for e,ts in zip(row['input_entities'],row['input_triplets']):
        for tid in ts:
            know = kg.iloc[tid]
            KNOWLEDGE += f"{e} {know['edge']} {know['target']}; "
    ANSWER = row['output'].strip()
    new_row['input'] = COT_INPUT_PROMPT.format(INPUT=row['input'], ENTITIES=ENTITIES)
    new_row['output'] = COT_OUTPUT_PROMPT.format(KNOWLEDGE=KNOWLEDGE) + COT2_OUTPUT_PROMPT.format(ANSWER=ANSWER)
    return new_row

cot_kg_out_data_df = kg_data_df.apply(cot_step12, axis=1)
json.dump(cot_kg_out_data_df.to_dict('records'), open(f'{data_dir}/cot_kg_{data_name}.json','w'), indent=4)
cot_kg_out_data_df

Unnamed: 0,input,output
0,Which of the following is not true for myelina...,The related knowledge of the medical entities ...
1,Which of the following is not true about glome...,The related knowledge of the medical entities ...
2,A 29 yrs old woman with a pregnancy of 17 week...,The related knowledge of the medical entities ...
3,"Axonal transport is: Options:A: Antegrade, B: ...",The related knowledge of the medical entities ...
4,Low insulin to glucagon ratio is seen in all o...,The related knowledge of the medical entities ...
...,...,...
4178,A study is to be conducted with regards to the...,The related knowledge of the medical entities ...
4179,"APGAR acronym stands for? Options:A: Activity,...",The related knowledge of the medical entities ...
4180,Most commonly implicated drug for acute liver ...,The related knowledge of the medical entities ...
4181,A 9 year old boy has steroid dependent nephrot...,The related knowledge of the medical entities ...


## 生成混合triplets数据集

In [6]:
all_ets = {}
for row in kg_data_df.itertuples():
    es = row.input_entities + row.output_entities
    tss = row.input_triplets + row.output_triplets
    for (e,ts) in zip(es,tss):
        all_ets[e] = ts
all_triplets = []
for e,ts in all_ets.items():
    for t in ts:
        all_triplets.append(dict(input=f"{e.strip()} {kg.iloc[t]['edge'].strip()}", output=f"{kg.iloc[t]['target'].strip()}"))

triplets_df = pd.DataFrame(all_triplets)
data_df_with_tri = pd.concat([kg_data_df.drop(['input_entities','output_entities','input_triplets','output_triplets'], axis=1, inplace=False), triplets_df], axis=0)
json.dump(data_df_with_tri.to_dict(orient='records'),open(f"{data_dir}/tri_{data_name}.json",'w'), indent=4)

In [38]:
json.dump(triplets_df.to_dict(orient='records'),open(f"{data_dir}/tri_440.json",'w'), indent=4)

In [11]:
import random
all_targets = set()
for e,ts in all_ets.items():
    for t in ts:
        all_targets.add(kg.iloc[t]['target'].strip())
sr2t = {}
for triplet in all_triplets:
    sr2t[triplet['input']] = [triplet['output']] + sr2t.get(triplet['input']) if sr2t.get(triplet['input'],False) else [triplet['output']]

all_choices = []
for sr,t in sr2t.items():
    correct_ops = random.sample(t,min(3,len(t)))
    wrong_ops = random.sample(list(all_targets - set(t)),5-len(correct_ops))
    all_ops = correct_ops + wrong_ops
    random.shuffle(all_ops)
    op_str = ""
    output = ""
    for idx,op in zip(['A','B','C','D','E'],all_ops):
        op_str += f"{idx}: {op};"
        if op in correct_ops:
            output += f"{idx}: {op};"
    all_choices.append(dict(input=f"Question: {sr}? Options: {op_str} The correct answer is", output=output))
train_ratio = 0.4
train_choices = all_choices[:int(len(all_choices)*train_ratio)]
test_choices = all_choices[int(len(all_choices)*train_ratio):]
json.dump(train_choices,open(f"{data_dir}/tri_multi_choice_440_train.json",'w'), indent=4)
json.dump(test_choices,open(f"{data_dir}/tri_multi_choice_440_test.json",'w'), indent=4)

In [16]:
triplets_df

Unnamed: 0,input,output
0,phyllodes-like tumor has definition of,medical condition
1,estradiol substitution may treat,rec anorexia nervosa
2,estradiol substitution may treat,ovarian function insufficiency
3,estradiol substitution may treat,pituitary insufficiencies
4,estradiol substitution may treat,silent anovulation
...,...,...
11454,stomach adenocarcinoma tumor is a,cancers
11455,thyroid's left lobe is a,endocrine and exocrine glands
11456,thyroid's left lobe is a,thyroidium
11457,urinary stress urinary incontinence is a,kidney and-urine related disorders
