Test

In [None]:
# 测试faiss
import faiss
import numpy as np
import os

index_path = os.path.join(os.environ['my_datasets_dir'], "bios_v2.2_release/CoreData/TermDiseaseZHEmbedding_HNSW64.index")
embedding_path = os.path.join(os.environ['my_datasets_dir'], "bios_v2.2_release/CoreData/TermDiseaseZHEmbedding.npy")

if os.path.exists(index_path):
    index = faiss.read_index(index_path)
else:
    term_disease_zh_embedding = np.load(embedding_path)

    dim, measure = 768, faiss.METRIC_L2  
    param =  'HNSW64'
    index = faiss.index_factory(dim, param, measure)
    if not index.is_trained:
        index.train(term_disease_zh_embedding)
    index.add(term_disease_zh_embedding)
    faiss.write_index(index, index_path)

import json
import os
import numpy as np


term_disease_zh = json.load(open((os.path.join(os.environ['my_datasets_dir'], "bios_v2.2_release/CoreData/TermDiseaseZH.json"))))
terms = list(term_disease_zh.keys())
print('terms: ', len(terms))
cid_term = {term_disease_zh[term]:term for term in terms}
ninth_diseases = json.load(open(os.path.join(os.environ['my_datasets_dir'], "bios_v2.2_release/CoreData/NinthDisease.json")))
print('ninth_diseases: ', len(ninth_diseases))
ninth_disease_zh_embedding = np.load(os.path.join(os.environ['my_datasets_dir'], "bios_v2.2_release/CoreData/NinthDiseaseEmbedding.npy"))
print('ninth_disease_zh_embedding: ', len(ninth_disease_zh_embedding))

k = 50
D, I = index.search(ninth_disease_zh_embedding, k)
print(D.shape, I.shape)
for i in range(I.shape[0]):
    print('query:', ninth_diseases[i])
    cid_distances = {}
    for j in range(I.shape[1]):
        term = terms[I[i][j]]
        cid = term_disease_zh[term]
        cid_distances[cid] = cid_distances.get(cid, []) + [D[i][j]]
    cid_distances = sorted(list(cid_distances.items()), key=lambda x: np.min(x[1]))
    for j, (cid, distance) in enumerate(cid_distances[:10]):
        print(f' top {j}: {cid}-{cid_term[cid]}: {distance}')

In [None]:
# dllm local unit test 
from model.dict_llm import *
import os
import json
import os
import time
from tqdm import tqdm


dllm = DictLLM(
    mt_path=os.path.join(os.environ['my_models_dir'],'gpt2'),
    encoder_hidden_size=768,
    num_table_token=32,
    num_encoder_head=8,
    num_encoder_layers=12,
    special_tokens_path=None,
    mask_strategy="table",
    position_strategy="group",
    encoder_type="bert",
    mapper_type="otk",
    deep_fusion=True
)

data = json.load(open(os.path.join(os.environ['my_datasets_dir'],'ninth/v3.2/checkout_data_eval.json')))

nan_ratio = lambda x: int(x.isnan().sum()/x.numel())
clear_hooks = lambda model: [(module._forward_hooks.clear(), module._forward_pre_hooks.clear()) for module in model.modules()]
clear_hooks(dllm)


batch_input_text = [d['input'][:100] for d in data[:2]]
print('batch_input_text: ', batch_input_text)
batch_dicts = [d['data'] for d in data[:2]]
batch_label_text = [d['output'][:100] for d in data[:2]]
labels = dllm.llm.tok(batch_label_text, padding=True, return_tensors='pt', add_special_tokens=False)['input_ids']
print(labels.shape[-1])
dllm_output = dllm(batch_input_text=batch_input_text, batch_dicts=batch_dicts, batch_label_text=batch_label_text)
print(dllm_output['loss'], dllm_output['logits'].shape)
dllm_gen_output = dllm.generate(batch_input_text=[batch_input_text[0]], batch_dicts=[batch_dicts[0]], batch_label_text=[batch_label_text[0]], max_new_tokens=10, cut_input=True)
print('dllm_gen_output: ', dllm_gen_output)
print(dllm.llm.tok.batch_decode(dllm_gen_output))

ToolBox

In [None]:
# llm panel
from model import *

model_list = {name : os.path.join(os.environ['my_models_dir'], name) for name in os.listdir(os.environ['my_models_dir'])}
panel = LLMPanel(model_list, chat_template=ChatTemplate.INTERNLM_TEMPLATE)
panel

In [None]:
# launch clash
import subprocess
import os

result = subprocess.run("pidof clash", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if not result.stdout:
    subprocess.Popen("~/tools/clash/clash", shell=True)
    result = subprocess.run("pidof clash", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
print(f"Clash is running, pid: {result.stdout}")
os.environ["http_proxy"] = "http://localhost:7890"
os.environ["https_proxy"] = "http://localhost:7890"

In [None]:
# test clash
!wget www.google.com
!rm -f index.html*

In [None]:
# close clash
!killall clash
!unset http_proxy
!unset https_proxy

In [None]:
# download model
from transformers import AutoTokenizer, AutoModel
import os

model_name = "FreedomIntelligence/HuatuoGPT2-7B"
local_dir = os.path.join(os.environ['my_models_dir'], model_name.split("/")[-1])
print('local_dir: ', local_dir)

from huggingface_hub import snapshot_download as hf_snapshot_download
hf_snapshot_download(
    model_name, 
    cache_dir=local_dir, 
    local_dir=local_dir, 
    local_dir_use_symlinks=False, 
    ignore_patterns=["*.h5","*safetensors","*msgpack"],
    etag_timeout=60, 
    force_download=False, 
    resume_download=False
)

# from modelscope import snapshot_download as modelscope_snapshot_download
# modelscope_snapshot_download('Shanghai_AI_Laboratory/internlm-7b', revision='v1.0.2',)

In [None]:
# query chatgpt
import openai
from model import multithread_query_chatgpt
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')

openai.api_key = "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"
output = multithread_query_chatgpt([dict(query_input="hello")])
print('output: ', output)

In [None]:
# multigpu inference
from model.llm_utils import get_free_gpus
import subprocess
import os

free_gpus = get_free_gpus()
command = f'CUDA_VISIBLE_DEVICES={",".join([str(i) for i in free_gpus])} \
            torchrun --nproc_per_node {len(free_gpus)} --master_port=12570 /home/cs/yangyuchen/guoyiqiu/gpt_re/mgpu_infer.py \
            --func infer \
            --model_path /home/cs/yangyuchen/guoyiqiu/my_models/Llama-2-13b-chat-ms \
            --dst_path /home/cs/yangyuchen/guoyiqiu/gpt_re/data/mgpu_infer/medqa_test_noquestion_inference.json \
            --save_path /home/cs/yangyuchen/guoyiqiu/gpt_re/data/mgpu_infer_output/llama-2-13b-chat_medqa_test_noquestion_inference.json \
            --mnt 8 '

process = subprocess.Popen(command, shell=True)
print(f"shell process id: {process.pid}")

九院数据预处理

In [None]:
# 生成ninth_data.json

import os
import json
from tqdm.auto import tqdm
from bs4 import BeautifulSoup
import pandas as pd
import fitz
from multiprocessing import Pool, cpu_count
import re

def get_data(i):
    broken_files = []
    advice_path = os.path.join(base_dir,"多科室医嘱",i+".json")
    with open(advice_path, "r", encoding="gbk") as f:
        try:
            advice = json.load(f)['Data']
        except Exception as e:
            broken_files.append(advice_path)
            # print(f"{i}: broken advice {advice_path}.")
            advice=None
    
    all_checks = []
    check_path = os.path.join(base_dir,"多科室检验报告",i)
    for fname in os.listdir(check_path):
        if not fname.endswith(".html"):
            continue
        try:
            with open(os.path.join(check_path, fname), "r") as f:
                soup = BeautifulSoup(f, features="lxml")
        except Exception as e:
            broken_files.append(os.path.join(check_path, fname))
            # print(f"{i}: broken check: {fname}.")
            continue
            
        if not soup.table:
            continue
        check_title = soup.h4.get_text(strip=True).replace("共 1 项","").strip()
        # print('check_title: ', check_title)
        info = dict(标题=check_title)
        info_text = soup.div.get_text(strip=True).replace("\n","").replace("\t","")
        # print('info_text: ', info_text)
        info_keys = ["工作组:","送检医生:","报告时间:","检验医生:","送检时间:","标本名称:","性别:","年龄:","诊断:","生理周期:","申请号:"]
        for k1,k2 in zip(info_keys[:-1],info_keys[1:]):
            value = re.search(f"{k1}(.*?){k2}", info_text)
            if value:
                info[k1[:-1]] = value.group(1).strip()
        # print('info: ', info)
        
        table_head = []
        for row in soup.table.find_all('th'):
            table_head.append(row.get_text(strip=True))
        
        table_data = []
        for row in soup.table.find_all('tr'):
            row_data = [i.get_text(strip=True) for i in row.find_all('td')]
            if row_data:
                table_data.append(row_data)
        check_df = pd.DataFrame(table_data, columns=table_head)
        check_dict = check_df[[c for c in check_df.columns if c]].to_dict(orient="records")
        all_checks.append(dict(filename=fname, info=info, values=check_dict))
    
    all_emrs = []
    emr_path = os.path.join(base_dir,"多科室EMR",i)
    for fname in os.listdir(emr_path):
        if not fname.endswith(".pdf"):
            continue
        try:
            texts = []
            doc = fitz.open(os.path.join(emr_path,fname))
            # print(os.path.join(emr_path,fname))
            for page in doc:
                texts.append(page.get_text())
        except Exception as e:
            broken_files.append(os.path.join(emr_path,fname))
            # print(f"{i}: broken emr: {fname}.")
            continue
        finally:
            all_emrs.append(dict(filename=fname, texts=texts))
    
    return dict(zid=i,advice=advice, checks=all_checks, emrs=all_emrs), broken_files

base_dir = os.path.join(os.environ['my_datasets_dir'], "ninth")
check_ids = os.listdir(os.path.join(base_dir,"多科室检验报告"))
advice_ids = [i.replace(".json","") for i in os.listdir(os.path.join(base_dir,"多科室医嘱"))]
emr_ids = os.listdir(os.path.join(base_dir,"多科室EMR"))
overlap_ids = sorted(list(set(check_ids) & set(advice_ids) & set(emr_ids)))

all_data = []
all_broken_files = []
process_num = cpu_count()
print(f"Total {len(overlap_ids)} files. process_num={process_num}")

with Pool(process_num) as p:
    for data, broken_files in tqdm(p.imap(get_data, overlap_ids), total=len(overlap_ids)):
        all_data.append(data)
        all_broken_files.extend(broken_files)

all_data = sorted(all_data, key=lambda x: x["zid"])
all_data = [d for d in all_data if d["advice"] and d["checks"] and d["emrs"]]
json.dump(all_broken_files, open(os.path.join(base_dir, "ninth_broken_files.json"), "w", encoding="utf8"), ensure_ascii=False, indent=4)
json.dump(all_data[-10:], open(os.path.join(base_dir, "ninth_data_sample.json"), "w", encoding="utf8"), ensure_ascii=False, indent=4)
json.dump(all_data, open(os.path.join(base_dir, "ninth_data.json"), "w", encoding="utf8"), ensure_ascii=False)

In [None]:
# 生成ninth_data_im.json

import os
import json
from tqdm.auto import tqdm
from bs4 import BeautifulSoup
import pandas as pd
import fitz
from multiprocessing import Pool, cpu_count
import re

def get_data(i):
    broken_files = []
    
    all_checks = []
    check_path = os.path.join(os.environ['my_datasets_dir'],"ninth/多科室检验报告2",i)
    for fname in os.listdir(check_path):
        if not fname.endswith(".html"):
            continue
        try:
            with open(os.path.join(check_path, fname), "r") as f:
                soup = BeautifulSoup(f, features="lxml")
        except Exception as e:
            broken_files.append(os.path.join(check_path, fname))
            # print(f"{i}: broken check: {fname}.")
            continue
            
        if not soup.table:
            continue
        check_title = soup.h4.get_text(strip=True).replace("共 1 项","").strip()
        # print('check_title: ', check_title)
        info = dict(标题=check_title)
        info_text = soup.div.get_text(strip=True).replace("\n","").replace("\t","")
        # print('info_text: ', info_text)
        info_keys = ["工作组:","送检医生:","报告时间:","检验医生:","送检时间:","标本名称:","性别:","年龄:","诊断:","生理周期:","申请号:"]
        for k1,k2 in zip(info_keys[:-1],info_keys[1:]):
            value = re.search(f"{k1}(.*?){k2}", info_text)
            if value:
                info[k1[:-1]] = value.group(1).strip()
        # print('info: ', info)
        
        table_head = []
        for row in soup.table.find_all('th'):
            table_head.append(row.get_text(strip=True))
        
        table_data = []
        for row in soup.table.find_all('tr'):
            row_data = [i.get_text(strip=True) for i in row.find_all('td')]
            if row_data:
                table_data.append(row_data)
        check_df = pd.DataFrame(table_data, columns=table_head)
        check_dict = check_df[[c for c in check_df.columns if c]].to_dict(orient="records")
        all_checks.append(dict(filename=fname, info=info, values=check_dict))
    
    all_emrs = []
    emr_path = os.path.join(os.environ['my_datasets_dir'], "ninth/全科室出院小结",i)
    for fname in os.listdir(emr_path):
        if not fname.endswith(".pdf"):
            continue
        try:
            texts = []
            doc = fitz.open(os.path.join(emr_path,fname))
            # print(os.path.join(emr_path,fname))
            for page in doc:
                texts.append(page.get_text())
        except Exception as e:
            broken_files.append(os.path.join(emr_path,fname))
            # print(f"{i}: broken emr: {fname}.")
            continue
        finally:
            all_emrs.append(dict(filename=fname, texts=texts))
    
    return dict(zid=i, checks=all_checks, emrs=all_emrs), broken_files

check_ids = os.listdir(os.path.join(os.environ['my_datasets_dir'], "ninth/多科室检验报告2"))
emr_ids = os.listdir(os.path.join(os.environ['my_datasets_dir'], "ninth/全科室出院小结"))
overlap_ids = sorted(list(set(check_ids) & set(emr_ids)))

all_data = []
all_broken_files = []
process_num = cpu_count()
print(f"Total {len(overlap_ids)} files. process_num={process_num}")

with Pool(process_num) as p:
    for data, broken_files in tqdm(p.imap(get_data, overlap_ids), total=len(overlap_ids)):
        all_data.append(data)
        all_broken_files.extend(broken_files)

all_data = sorted(all_data, key=lambda x: x["zid"])
all_data = [d for d in all_data if d["checks"] and d["emrs"]]
json.dump(all_broken_files, open(os.path.join(os.environ['my_datasets_dir'], "ninth/ninth_im_broken_files.json"), "w", encoding="utf8"), ensure_ascii=False, indent=4)
json.dump(all_data[-10:], open(os.path.join(os.environ['my_datasets_dir'], "ninth/ninth_data_im_sample.json"), "w", encoding="utf8"), ensure_ascii=False, indent=4)
json.dump(all_data, open(os.path.join(os.environ['my_datasets_dir'], "ninth/ninth_data_im.json"), "w", encoding="utf8"), ensure_ascii=False)

In [None]:
# 读取ninth_data
import json
import os

data = json.load(open(os.path.join(os.environ['my_datasets_dir'], "ninth/ninth_data_im.json")))

In [None]:
# 统计ninth_data中的化验单模板分布
from collections import defaultdict
from tqdm.auto import tqdm
import traceback
import re


check_tmps = defaultdict(dict)
info_counts = {"标题":defaultdict(dict),"工作组":defaultdict(dict),"标本名称":defaultdict(dict)}
for d in tqdm(data):
    try:
        for check in d['checks']:
            for i in ['标题','工作组','标本名称']:
                info = check['info']
                if info.get(i):
                    info_counts[i][info[i]] = info_counts[i].get(info[i], 0) + 1
            if not check['values']:
                continue
            check_tmp = tuple(check['values'][0].keys())
            check_tmps[check_tmp]['tmp_count'] = check_tmps[check_tmp].get('tmp_count', 0) + 1
            if 'check_key_names_counts' not in check_tmps[check_tmp]:
                check_tmps[check_tmp]['check_key_names_counts'] = defaultdict(dict)
            for value in check['values']:
                if check_tmp == ('药品名称', '结果', '结果标识', '折点', '菌落数', '细菌名称'):
                    check_tmps[check_tmp]['check_key_names_counts']['药品名称+细菌名称'][str(value['药品名称'])+str(value['细菌名称'])] = check_tmps[check_tmp]['check_key_names_counts']['药品名称+细菌名称'].get(str(value['药品名称'])+str(value['细菌名称']), 0) + 1
                
                for k in value:
                    if check_tmp == ('检验项', '结果', '参考范围') and k == '结果':
                        value_map = [
                            ("^敏感[\s\S]*","敏感"),
                            ("^耐药[\s\S]*","耐药"),
                            ("^中介[\s\S]*","中介"),
                            ("阴.*性.*","阴性"),
                            ("阳.*性.*","阳性"),
                        ]
                        valuesk = value[k]
                        for v in value_map:
                            valuesk = re.sub(v[0], v[1], valuesk)
                        check_tmps[check_tmp]['check_key_names_counts'][k][valuesk] = check_tmps[check_tmp]['check_key_names_counts'][k].get(valuesk, 0) + 1
                        continue
                    check_tmps[check_tmp]['check_key_names_counts'][k][value[k]] = check_tmps[check_tmp]['check_key_names_counts'][k].get(value[k], 0) + 1
            
    except Exception as e:
        tb_str = traceback.format_tb(e.__traceback__)
        print(f"{tb_str} {e}")
        break

topn = 50

for i in info_counts:
    counts = sorted(list(info_counts[i].items()), key=lambda x : -x[1])
    print(f"{i} {len(counts)}")
    for i,c in counts[:topn]:
        print(f"  - {i} {c}")
        
for tmp in check_tmps:
    print(f"化验单模板: {tmp} 模板出现次数：{check_tmps[tmp]['tmp_count']}")
    for k in check_tmps[tmp]['check_key_names_counts']:
        print(f"  - 列名：{k} 项目类别数：{len(check_tmps[tmp]['check_key_names_counts'][k])} 项目类别数求和:{sum(list(check_tmps[tmp]['check_key_names_counts'][k].values()))}")
        for v,c in sorted(list(check_tmps[tmp]['check_key_names_counts'][k].items()), key=lambda x:-x[1])[:topn]:
            print(f"    - 项目名：{v} 出现频次：{c}")

In [None]:
# 根据ninth_data生成checkout_data_with_checks.json
import json
import re
import os
import pylcs
from datetime import datetime
from tqdm.auto import tqdm
import traceback
from collections import defaultdict
import bert_score

def is_float(str):
    try:
        float(str)
        return True
    except:
        return False


def preprocess_emr_texts(emr_texts):
    texts = [re.sub(r'第(\s)*\d+(\s)*页','',t) for t in emr_texts]
    text = ' '.join([t.replace(os.path.commonprefix(texts),'') for t in texts]) if len(texts) > 1 else texts[0]
    return text


def get_key_variants(key):
    special_variants_mapper = dict(
        入院时间=['入院日期',],
        出院时间=['出院日期','死亡时间'],
        门诊诊断=['门诊诊',],
        入院诊断=['入院',],
        出院诊断=['死亡诊断'],
        入院时主要症状及体征=['入院情况'],
        主要化验结果=['主要结果','检验结果'],
        特殊检查及重要会诊=['特殊检查结果','特殊检验及重要会诊','动态心电图结果','术中冰冻结果','术中冰冻','病理结果'],
        病程与治疗结果=['诊疗经过','经过','病程及治疗结果'],
        出院时情况=['院时情况',"死亡时情况"],
        出院后建议=['出院后用药及建议','出院后用药建议','其他'],
        治疗结果=['治果'],
    )
    basic_variants = [key] + special_variants_mapper.get(key, [])
    split_variants = []
    for k in (basic_variants):
        for i in range(1,len(k)):
            split_variants.append((k[:i] + "\n" + k[i:]).strip())
    all_variants = basic_variants + split_variants
    return all_variants


def is_emr_checkout(emr_text):
    checkout_default_tmp = ('性别', '年龄', '入院时间', '出院时间', '门诊诊断', '入院诊断', '出院诊断', '入院时主要症状及体征', '主要化验结果', '特殊检查及重要会诊', '病程与治疗结果', '合并症', '出院时情况', '出院后建议', '治疗结果', "主治医师")
    return all(any(kv in emr_text for kv in get_key_variants(key)) for key in checkout_default_tmp)


def get_checkout_tmp(emr_text):
    checkout_default_tmp = ('性别', '年龄', '入院时间', '出院时间', '门诊诊断', '入院诊断', '出院诊断', '入院时主要症状及体征', '主要化验结果', '特殊检查及重要会诊', '病程与治疗结果', '合并症', '出院时情况', '出院后建议', '治疗结果', "主治医师")
    checkout_tmp_orders = [(0,       1,        2,          3,        4,          5,          6,          7,                  8,              9,              10,             11,         12,          13,         14,         15),
                           (0,       1,        2,          3,        4,          5,          6,          7,                  10,              8,              9,             11,         12,          13,         14,         15)]
    tmp = []
    
    for key in checkout_default_tmp:
        for kv in get_key_variants(key):
            if kv in emr_text:
                tmp.append((kv,key))
                break
    
    for order in checkout_tmp_orders:
        new_tmp = [tmp[i] for i in order]
        new_tmp_str = ''.join([t[0] for t in new_tmp])
        if pylcs.lcs_sequence_length(new_tmp_str,emr_text) == len(new_tmp_str):
            return new_tmp
    
    raise ValueError("No tmp found for this emr")


def parse_checkout_text(emr_text, checkout_tmp):
    checkout_dict = {}
    for (kv,k),(nkv,nk) in zip(checkout_tmp[:-1],checkout_tmp[1:]):
        try:
            match = re.search(f"{kv}(.*?){nkv}", emr_text, re.DOTALL)
            span = match.span()
            emr_text = emr_text[span[0]:]
            checkout_dict[k] = match.group(1).replace("：","").strip()
        except Exception as e:
            raise ValueError(f"Wrong tmp found for this emr")
    
    advice = checkout_dict['出院后建议']
    advice = re.sub("预约[\s\S]*","", advice)
    advice = re.sub("下次来院时间[\s\S]*","", advice)
    advice = re.sub("健康宣教[\s\S]*","", advice)
    advice = re.sub("请输入出院后建议[\s\S]*","", advice)
    checkout_dict['出院后建议'] = advice.strip()
    
    time_span = []
    default_start_time_stamp = ['2023','12','31','23','59']
    default_end_time_stamp = ['2000','1','1','1','1']
    try:
        for string, default_time_stamp in zip([checkout_dict['入院时间'],checkout_dict['出院时间']], [default_start_time_stamp,default_end_time_stamp]):
            time_stamp = re.findall("[\d]+",string)
            time_stamp += default_time_stamp[len(time_stamp):]
            time_stamp = ' '.join(time_stamp)
            time_span.append(datetime.strptime(time_stamp, "%Y %m %d %H %M"))
    except:
        raise ValueError(f"error when finding time span")
    time_span = [s.strftime("%Y %m %d %H %M") for s in time_span]
    checkout_dict['time_span'] = time_span
    return checkout_dict


def parse_check(check):
    check_dict = {
        "header" : [check['info']['标题'], check['info']['工作组'], check['info']['标本名称']],
        "report_time" : check['info']['报告时间'],
        "data" : {},
        "raw_data" : {}
    }
    blankize = lambda x : "空" if x == "" else x
    def handler1(value):
        key = value['检验项']
        raw_value = str(value['检验结果值']) + str(value['单位'])
        label_value = value['异常标识']
        if key is None or raw_value is None or label_value is None:
            return None
        key = blankize(key)
        raw_value = blankize(raw_value)
        label_value = blankize(label_value)
        return {key : raw_value}, {key : label_value} 
    
    def handler2(value):
        key = value['检验项']
        raw_value = value['结果']
        label_value = value['结果']
        if key is None or raw_value is None or label_value is None:
            return None
        key = blankize(key)
        raw_value = blankize(raw_value)
        label_value = blankize(label_value)
        value_map = [
                    ("^敏感[\s\S]*","敏感"),
                    ("^耐药[\s\S]*","耐药"),
                    ("^中介[\s\S]*","中介"),
                    ("阴.*性.*","阴性"),
                    ("阳.*性.*","阳性"),
                ]
        for v in value_map:
            label_value = re.sub(v[0], v[1], label_value)
        if is_float(label_value):
            return None
        return {key : raw_value}, {key : label_value} 
    
    def handler3(value):
        if value['药品名称'] is None or value['细菌名称'] is None:
            return None
        key = value['药品名称'] + value['细菌名称']
        raw_value = value['结果']
        label_value = value['结果标识']
        if key is None or raw_value is None or label_value is None:
            return None
        key = blankize(key)
        raw_value = blankize(raw_value)
        label_value = blankize(label_value)
        return {key : raw_value}, {key : label_value} 
    
    check_tmp_handler = {
        ('检验项', '检验结果值', '单位', '参考范围', '异常标识'): handler1,
        ('检验项', '结果', '参考范围'): handler2,
        ('药品名称', '结果', '结果标识', '折点', '菌落数', '细菌名称'): handler3,
    }
    if not check['values']:
        return check_dict
    check_tmp = tuple(check['values'][0].keys())
    if check_tmp not in check_tmp_handler:
        return check_dict
    handler = check_tmp_handler[check_tmp]
    for value in check['values']:
        res = handler(value)
        if res is None:
            continue
        raw_data, label_data = res
        check_dict['data'].update(label_data)
        check_dict['raw_data'].update(raw_data)
    return check_dict


def get_related_checks(checkout_dict, checks):
    related_checks = []
    for check in checks:
        check_time = datetime.strptime(check['info']['报告时间'], "%Y/%m/%d %H:%M")
        time_span = [datetime.strptime(t, "%Y %m %d %H %M") for t in checkout_dict['time_span']]
        if not (time_span[0] < check_time < time_span[1]):
            continue
        check_dict = parse_check(check)
        if not check_dict['data']:
            continue
        related_checks.append(check_dict)
    return related_checks


def get_checkout_data(emr_checks):
    emr = emr_checks[0]
    checks = emr_checks[1]
    checkout_dict = {}
    error_dict = defaultdict(list)
    try:
        emr_id = emr['emr_id']
        
        emr_text = preprocess_emr_texts(emr['texts'])
        
        if not is_emr_checkout(emr_text):
            if ("出院小结" in emr_id and "小时入出院记录" not in emr_text):
                raise ValueError(f"strange not checkout")
            raise ValueError(f"normal not checkout")
        
        checkout_tmp = get_checkout_tmp(emr_text)
        
        checkout_dict = parse_checkout_text(emr_text, checkout_tmp)
        
        checkout_dict['id'] = emr_id
        
        related_checks = get_related_checks(checkout_dict, checks)
        
        if not related_checks:
            raise ValueError("has no related checks")
        
        checkout_dict['完整化验结果'] = related_checks
    except Exception as e:
        lineno = traceback.extract_tb(e.__traceback__)[-1].lineno
        error_dict[f"{str(e)} line{lineno}"].append(emr_id)
        checkout_dict = {}
        return checkout_dict, error_dict
    return checkout_dict, error_dict


all_emr_checks = [(emr.update({'emr_id':d['zid']+"/"+emr['filename']}) or emr, d['checks']) for d in data for emr in d['emrs'] if emr['texts']]
checkout_data = []
error_dict = defaultdict(list)


for emr_checks in tqdm(all_emr_checks, total=len(all_emr_checks)):
    checkout_dict, local_error_dict = get_checkout_data(emr_checks)
    if checkout_dict:
        checkout_data.append(checkout_dict)
    for k,v in local_error_dict.items():
        error_dict[k].extend(v)



print(f"success num :{len(checkout_data)}\nnum error:{sum([len(e) for e in list(error_dict.values())])}")
for e,l in error_dict.items():
    print(f"{e} {len(l)}")

# 后处理统计
checkout_data['入院年月'] = checkout_data['time_span'].apply(lambda x : '-'.join(x[0].split(" ")[:2]))
checkout_data['住院天数'] = checkout_data['time_span'].apply(lambda x : (datetime.strptime(x[1], "%Y %m %d %H %M") - datetime.strptime(x[0], "%Y %m %d %H %M")).days)
checkout_data['化验单数量'] = checkout_data['完整化验结果'].apply(lambda x : len(x))
checkout_data['化验单总项目数量'] = checkout_data['完整化验结果'].apply(lambda checks : sum([len(check['data']) for check in checks]))
part_checkout_data = checkout_data[(checkout_data['门诊诊断'].apply(len) > 0)&(checkout_data['出院诊断'].apply(len) > 0)]
p,d,f1 = bert_score.score(part_checkout_data['门诊诊断'].to_list(), part_checkout_data['出院诊断'].to_list(), lang="zh", verbose=True)
bert_scores = {'p':p, 'r':d, 'f1':f1}
checkout_data['门诊出院bert_score_p'] = pd.DataFrame(bert_scores['p']).set_index(part_checkout_data.index)
checkout_data['门诊出院bert_score_r'] = pd.DataFrame(bert_scores['r']).set_index(part_checkout_data.index)
checkout_data['门诊出院bert_score_f1'] = pd.DataFrame(bert_scores['f1']).set_index(part_checkout_data.index)

output_sample_path = os.path.join(os.environ['my_datasets_dir'], "ninth/checkout_data_im_with_checks_sample.json")
json.dump(checkout_data[:20], open(output_sample_path,'w'), ensure_ascii=False)

output_path = os.path.join(os.environ['my_datasets_dir'], "ninth/checkout_data_im_with_checks.json")
json.dump(checkout_data, open(output_path,'w'), ensure_ascii=False)

In [None]:
# checkout_data_with_checks.json数据分析
import os
import pandas as pd

checkout_data = pd.read_json(os.path.join(os.environ['my_datasets_dir'], "ninth/checkout_data_with_checks.json"))
print('checkout_data num: ', checkout_data.shape[0])
checkout_data.sort_values(by='入院年月')['入院年月'].hist(backend='plotly',title="住院时间分布").show()
print(f"平均住院天数：{checkout_data['住院天数'].mean()}")
checkout_data['住院天数'].hist(backend='plotly',log_y=True, title="住院天数分布").show()
print(f"平均化验单数量：{checkout_data['化验单数量'].mean()}")
checkout_data['化验单数量'].hist(backend='plotly',log_y=True, title="化验单数量分布").show()
checkout_data['化验单总项目数量'] = checkout_data['完整化验结果'].apply(lambda checks : sum([len(check['data']) for check in checks]))
print(f"平均总化验项目数量：{checkout_data['化验单总项目数量'].mean()}")
checkout_data['化验单总项目数量'].hist(backend='plotly',log_y=True, title="化验单总项目数量分布").show()
checkout_data['门诊出院bert_score_f1'].hist(backend='plotly',title="门诊出院bert_score_f1分布").show()
print(f"出院诊断平均长度：{checkout_data['出院诊断'].apply(len).mean()}")
checkout_data['出院诊断'].apply(len).hist(backend='plotly',log_y=True, title="出院诊断长度分布").show()

In [None]:
# 生成checkout_data_special_tokens.json
import json
import os
from collections import defaultdict
from tqdm.auto import tqdm
import pandas as pd

checkout_data = pd.read_json(os.path.join(os.environ['my_datasets_dir'], "ninth/checkout_data_with_checks.json"))
special_token_ratio = 0.9

key_counts = defaultdict(int)
label_counts = defaultdict(int)

for index, row in checkout_data.iterrows():
    for check in row['完整化验结果']:
        for k in check['data']:
            key_counts[k] += 1
            label_counts[check['data'][k]] += 1

key_counts = pd.DataFrame(list(key_counts.items()), columns=['word','count']).sort_values(by='count',ascending=False)
key_counts['cumsum'] = key_counts['count'].cumsum()
label_counts = pd.DataFrame(list(label_counts.items()), columns=['word','count']).sort_values(by='count',ascending=False)
label_counts['cumsum'] = label_counts['count'].cumsum()
key_counts = key_counts[key_counts['cumsum'] <= key_counts['cumsum'].max()*special_token_ratio]
print('key num: ', key_counts.shape[0])
print('label num: ', label_counts.shape[0])
special_tokens = list(key_counts['word']) + list(label_counts['word'])
special_tokens = list(set(special_tokens))
print('special_token nums: ', len(special_tokens))

json.dump(special_tokens, open(os.path.join(os.environ['my_datasets_dir'], "ninth/checkout_data_special_tokens.json"), 'w'), ensure_ascii=False)

In [None]:
# 生成v2/checkout_data_train训练数据集
# 数据来源：肿瘤外科，骨科，使用not in 策略筛选与化验相关度高的数据
import json
import os
import pandas as pd
from transformers import AutoTokenizer
from tqdm import tqdm
tqdm.pandas()

llm_tok = AutoTokenizer.from_pretrained(os.path.join(os.environ['my_models_dir'],'internlm-7b'), trust_remote_code=True)
data_path = os.path.join(os.environ['my_datasets_dir'], "ninth/checkout_data_with_checks.json")
ft_eval_ratio = 0.05
max_length = 768
max_pt_token_num = 1600
seed = 42
version = 2

df = pd.read_json(data_path)
df = df.sample(frac=1, random_state=seed).reset_index(drop=True)

print('full df size: ', df.shape[0])

df = df.progress_apply(lambda r: r if (r['门诊诊断'] and r['出院诊断'] and r['门诊诊断'] not in r['出院诊断']) else None, axis=1).dropna() # v2
print(f"filterd by 门诊and出院: {df.shape[0]} left")

data_df = df.progress_apply(lambda r: pd.Series(
    dict(
        input="性别:"+r['性别']+"\n年龄:"+r['年龄']+"\n入院时主要症状及体征:"+r['入院时主要症状及体征'].replace("\n","")+"\n特殊检查及重要会诊:"+r['特殊检查及重要会诊'].replace("\n","")+"\n出院诊断:",
        data=r['完整化验结果'], 
        output=r['出院诊断'],
    )
), axis=1)

data_df = data_df.progress_apply(lambda r: pd.Series(
    dict(
        input=r['input'].replace(r['output'],"").replace("「","").replace("」",""),
        data=r['data'],
        output=r['output'].replace("「","").replace("」",""),
    )
), axis=1)

data_df['num_tokens'] = data_df.progress_apply(lambda r : len(llm_tok(r['input']+r['output'])['input_ids']),axis=1)
data_df = data_df[data_df['num_tokens'] < max_length]
print(f"filterd by length: {data_df.shape[0]} left")

data_df_text_dicts = data_df.progress_apply(lambda r: pd.Series(
    dict(
        input='完整化验结果:'+', '.join([f"{k}:{v.replace('[','').replace(']','')}" for dict in r['data'] for (k,v) in dict['data'].items()]) + "\n" + r['input'],
        output=r['output']
    )
), axis=1)
data_df_text_dicts['num_tokens'] = data_df_text_dicts.progress_apply(lambda r : len(llm_tok(r['input']+r['output'])['input_ids']),axis=1)
data_df = data_df[data_df_text_dicts['num_tokens'] < max_pt_token_num]
data_df_text_dicts = data_df_text_dicts[data_df_text_dicts['num_tokens'] < max_pt_token_num]
print(f'filterd by length 2: {data_df.shape[0]} left')

data_df_no_dicts = data_df.drop(columns=['data'])
data_df_raw = data_df.progress_apply(lambda r: pd.Series(
    dict(
        input=r['input'],
        data=[dict(header=d['header'],data=d['raw_data']) for d in r['data']], 
        output=r['output'],
    )
) ,axis=1)
data_df_label = data_df.progress_apply(lambda r: pd.Series(
    dict(
        input=r['input'],
        data=[dict(header=d['header'],data=d['data']) for d in r['data']], 
        output=r['output'],
    )
) ,axis=1)


ft_size = int(len(data_df)*(1-ft_eval_ratio))
train_df = data_df_label.iloc[:ft_size]
eval_df = data_df_label.iloc[ft_size:]
train_df_text_dicts = data_df_text_dicts.iloc[:ft_size]
eval_df_text_dicts = data_df_text_dicts.iloc[ft_size:]
train_df_no_dicts = data_df_no_dicts.iloc[:ft_size]
eval_df_no_dicts = data_df_no_dicts.iloc[ft_size:]
train_df_raw = data_df_raw.iloc[:ft_size]
eval_df_raw = data_df_raw.iloc[ft_size:]

json.dump(train_df.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train.json"),'w'), ensure_ascii=False)
json.dump(eval_df.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval.json"),'w'), ensure_ascii=False)
json.dump(train_df_no_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_no_dicts.json"),'w'), ensure_ascii=False)
json.dump(eval_df_no_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_no_dicts.json"),'w'), ensure_ascii=False)
json.dump(train_df_text_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_text_dicts.json"),'w'), ensure_ascii=False)
json.dump(eval_df_text_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_text_dicts.json"),'w'), ensure_ascii=False)
json.dump(train_df_raw.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_raw_data.json"),'w'), ensure_ascii=False)
json.dump(eval_df_raw.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_raw_data.json"),'w'), ensure_ascii=False)

print(data_df['num_tokens'].describe())
print(data_df_text_dicts['num_tokens'].describe())
data_df_text_dicts['num_tokens'].hist(bins=1000,log=True)

In [None]:
# 生成v3/checkout_data_train训练数据集
# 数据来源：各大内科，使用bert_score筛选与化验相关度高的数据
import json
import os
import pandas as pd
from transformers import AutoTokenizer
from tqdm import tqdm
tqdm.pandas()

llm_tok = AutoTokenizer.from_pretrained(os.path.join(os.environ['my_models_dir'],'internlm-7b'), trust_remote_code=True)
bert_tok = AutoTokenizer.from_pretrained(os.path.join(os.environ['my_models_dir'],'bert-base-chinese'), trust_remote_code=True)
data_path = os.path.join(os.environ['my_datasets_dir'], "ninth/checkout_data_im_with_checks.json")
ft_eval_ratio = 0.05
bert_score_f1_threshold = 0.8
max_dict_token_num = 768 # max length for input text
max_pt_token_num = 1600 # max length for input text + text dicts
max_ft_text_token_num = 4096 # max_length for dicts encoder
seed = 42
version = "3"
table_token='[TABLE]'

df = pd.read_json(data_path)
df = df.sample(frac=1, random_state=seed).reset_index(drop=True)
print('full df size: ', df.shape[0])

df = df[(df['门诊诊断'].apply(len)>0)&(df['出院诊断'].apply(len)>0)&(df['门诊出院bert_score_f1']<bert_score_f1_threshold)]
print(f"filterd by 门诊 and 出院and bert_score>{bert_score_f1_threshold}: {df.shape[0]} left")

data_df = df.progress_apply(lambda r: pd.Series(
    dict(
        input="性别:"+r['性别']+"\n年龄:"+r['年龄']+"\n入院时主要症状及体征:"+r['入院时主要症状及体征'].replace("\n","")+"\n特殊检查及重要会诊:"+r['特殊检查及重要会诊'].replace("\n","")+"\n出院诊断:",
        data=r['完整化验结果'], 
        output=r['出院诊断'],
    )
), axis=1)

# wash
data_df = data_df.progress_apply(lambda r: pd.Series(
    dict(
        input=r['input'].replace(r['output'].replace("「","").replace("」","").replace("\n","").replace(" ",""),"").replace("「","").replace("」",""),
        data=r['data'],
        output=r['output'].replace("「","").replace("」","").replace("\n","").replace(" ",""),
    )
), axis=1)

data_df['num_tokens'] = data_df.progress_apply(lambda r : len(llm_tok(r['input']+r['output'])['input_ids']),axis=1)
data_df = data_df[data_df['num_tokens'] < max_dict_token_num]
print(f"filterd by max_length1: {data_df.shape[0]} left")

data_df_text_dicts = data_df.progress_apply(lambda r: pd.Series(
    dict(
        input='完整化验结果:'+', '.join([f"{k}:{v.replace('[','').replace(']','')}" for dict in r['data'] for (k,v) in dict['data'].items()]) + "\n" + r['input'],
        output=r['output']
    )
), axis=1)
data_df_text_dicts['num_tokens'] = data_df_text_dicts.progress_apply(lambda r : len(llm_tok(r['input']+r['output'])['input_ids']),axis=1)
data_df = data_df[data_df_text_dicts['num_tokens'] < max_pt_token_num]
data_df_text_dicts = data_df_text_dicts[data_df_text_dicts['num_tokens'] < max_pt_token_num]
print(f'filterd by max_length2: {data_df.shape[0]} left')

data_df['num_encoder_tokens'] = data_df.progress_apply(lambda r : len(bert_tok('[SEP]'.join([k+v for d in r['data'] for (k,v) in d['data'].items()]), add_special_tokens=False)['input_ids']), axis=1)
data_df = data_df[data_df['num_encoder_tokens'] < max_ft_text_token_num]
data_df_text_dicts = data_df_text_dicts[data_df['num_encoder_tokens'] < max_ft_text_token_num]
print(f'filterd by max_length3: {data_df.shape[0]} left')

data_df_no_dicts = data_df.drop(columns=['data'])
data_df_raw = data_df.progress_apply(lambda r: pd.Series(
    dict(
        input=table_token+r['input'],
        data=[[dict(header=d['header'],data=d['raw_data']) for d in r['data']]], 
        output=r['output'],
    )
) ,axis=1)
data_df_label = data_df.progress_apply(lambda r: pd.Series(
    dict(
        input=table_token+r['input'],
        data=[[dict(header=d['header'],data=d['data']) for d in r['data']]], 
        output=r['output'],
    )
) ,axis=1)


ft_size = int(len(data_df)*(1-ft_eval_ratio))
train_df = data_df_label.iloc[:ft_size]
eval_df = data_df_label.iloc[ft_size:]
train_df_text_dicts = data_df_text_dicts.iloc[:ft_size]
eval_df_text_dicts = data_df_text_dicts.iloc[ft_size:]
train_df_no_dicts = data_df_no_dicts.iloc[:ft_size]
eval_df_no_dicts = data_df_no_dicts.iloc[ft_size:]
train_df_raw = data_df_raw.iloc[:ft_size]
eval_df_raw = data_df_raw.iloc[ft_size:]


print(data_df['num_tokens'].describe())
print(data_df_text_dicts['num_tokens'].describe())
data_df_text_dicts['num_tokens'].hist(backend='plotly', title='text dict token num').show()
data_df['num_encoder_tokens'].hist(backend='plotly', title='dict token num').show()


json.dump(train_df.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train.json"),'w'), ensure_ascii=False)
json.dump(eval_df.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval.json"),'w'), ensure_ascii=False)
json.dump(train_df_no_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_no_dicts.json"),'w'), ensure_ascii=False)
json.dump(eval_df_no_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_no_dicts.json"),'w'), ensure_ascii=False)
json.dump(train_df_text_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_text_dicts.json"),'w'), ensure_ascii=False)
json.dump(eval_df_text_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_text_dicts.json"),'w'), ensure_ascii=False)
json.dump(train_df_raw.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_raw_data.json"),'w'), ensure_ascii=False)
json.dump(eval_df_raw.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_raw_data.json"),'w'), ensure_ascii=False)

In [None]:
# 生成v3.1/checkout_data_train训练数据集
# 数据来源：各大内科，使用bert_score筛选与化验相关度高的数据，进一步清洗了出院诊断
import json
import os
import pandas as pd
from transformers import AutoTokenizer
from tqdm import tqdm
import re
tqdm.pandas()

llm_tok = AutoTokenizer.from_pretrained(os.path.join(os.environ['my_models_dir'],'internlm-7b'), trust_remote_code=True)
bert_tok = AutoTokenizer.from_pretrained(os.path.join(os.environ['my_models_dir'],'bert-base-chinese'), trust_remote_code=True)
data_path = os.path.join(os.environ['my_datasets_dir'], "ninth/checkout_data_im_with_checks.json")
ft_eval_ratio = 0.05
bert_score_f1_threshold = 0.8
max_dict_token_num = 768 # max length for input text
max_pt_token_num = 1600 # max length for input text + text dicts
max_ft_text_token_num = 4096 # max_length for dicts encoder
seed = 42
version = "3.1"
table_token='[TABLE]'

df = pd.read_json(data_path)
df = df.sample(frac=1, random_state=seed).reset_index(drop=True)
print('full df size: ', df.shape[0])

df = df[(df['门诊诊断'].apply(len)>0)&(df['出院诊断'].apply(len)>0)&(df['门诊出院bert_score_f1']<bert_score_f1_threshold)]
print(f"filterd by 门诊 and 出院and bert_score>{bert_score_f1_threshold}: {df.shape[0]} left")

data_df = df.progress_apply(lambda r: pd.Series(
    dict(
        input="性别:"+r['性别']+"\n年龄:"+r['年龄']+"\n入院时主要症状及体征:"+r['入院时主要症状及体征'].replace("\n","")+"\n特殊检查及重要会诊:"+r['特殊检查及重要会诊'].replace("\n","")+"\n出院诊断:",
        data=r['完整化验结果'], 
        output=r['出院诊断'],
    )
), axis=1)

# wash
def wash(r):
    new_output = r['output'].replace("「","").replace("」","").replace("\n","").replace(" ","").replace("(","（").replace(")","）")
    new_output = re.sub("（[^）]*）","",new_output)
    new_input = r['input'].replace("「","").replace("」","")
    for o in new_output.split("，"):
        new_input = new_input.replace(o,"")
    return pd.Series(dict(input=new_input, data=r['data'], output=new_output))

data_df = data_df.progress_apply(wash, axis=1)


data_df['num_tokens'] = data_df.progress_apply(lambda r : len(llm_tok(r['input']+r['output'])['input_ids']),axis=1)
data_df = data_df[data_df['num_tokens'] < max_dict_token_num]
print(f"filterd by max_length1: {data_df.shape[0]} left")

data_df_text_dicts = data_df.progress_apply(lambda r: pd.Series(
    dict(
        input='完整化验结果:'+', '.join([f"{k}:{v}" for dict in r['data'] for (k,v) in dict['data'].items()]) + "\n" + r['input'],
        output=r['output']
    )
), axis=1)
data_df_text_dicts['num_tokens'] = data_df_text_dicts.progress_apply(lambda r : len(llm_tok(r['input']+r['output'])['input_ids']),axis=1)
data_df = data_df[data_df_text_dicts['num_tokens'] < max_pt_token_num]
data_df_text_dicts = data_df_text_dicts[data_df_text_dicts['num_tokens'] < max_pt_token_num]
print(f'filterd by max_length2: {data_df.shape[0]} left')

data_df['num_encoder_tokens'] = data_df.progress_apply(lambda r : len(bert_tok('[SEP]'.join([k+v for d in r['data'] for (k,v) in d['data'].items()]), add_special_tokens=False)['input_ids']), axis=1)
data_df = data_df[data_df['num_encoder_tokens'] < max_ft_text_token_num]
data_df_text_dicts = data_df_text_dicts[data_df['num_encoder_tokens'] < max_ft_text_token_num]
print(f'filterd by max_length3: {data_df.shape[0]} left')

data_df_no_dicts = data_df.drop(columns=['data'])
data_df_raw = data_df.progress_apply(lambda r: pd.Series(
    dict(
        input=table_token+r['input'],
        data=[[dict(header=d['header'],data=d['raw_data']) for d in r['data']]], 
        output=r['output'],
    )
) ,axis=1)
data_df_label = data_df.progress_apply(lambda r: pd.Series(
    dict(
        input=table_token+r['input'],
        data=[[dict(header=d['header'],data=d['data']) for d in r['data']]], 
        output=r['output'],
    )
) ,axis=1)


ft_size = int(len(data_df)*(1-ft_eval_ratio))
train_df = data_df_label.iloc[:ft_size]
eval_df = data_df_label.iloc[ft_size:]
train_df_text_dicts = data_df_text_dicts.iloc[:ft_size]
eval_df_text_dicts = data_df_text_dicts.iloc[ft_size:]
train_df_no_dicts = data_df_no_dicts.iloc[:ft_size]
eval_df_no_dicts = data_df_no_dicts.iloc[ft_size:]
train_df_raw = data_df_raw.iloc[:ft_size]
eval_df_raw = data_df_raw.iloc[ft_size:]


print(data_df['num_tokens'].describe())
data_df['num_tokens'].hist(backend='plotly', title='text dict token num').show()
print(data_df_text_dicts['num_tokens'].describe())
data_df_text_dicts['num_tokens'].hist(backend='plotly', title='text dict token num').show()
data_df['num_encoder_tokens'].hist(backend='plotly', title='dict token num').show()


json.dump(train_df.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train.json"),'w'), ensure_ascii=False)
json.dump(eval_df.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval.json"),'w'), ensure_ascii=False)
json.dump(train_df_no_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_no_dicts.json"),'w'), ensure_ascii=False)
json.dump(eval_df_no_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_no_dicts.json"),'w'), ensure_ascii=False)
json.dump(train_df_text_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_text_dicts.json"),'w'), ensure_ascii=False)
json.dump(eval_df_text_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_text_dicts.json"),'w'), ensure_ascii=False)
json.dump(train_df_raw.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_raw_data.json"),'w'), ensure_ascii=False)
json.dump(eval_df_raw.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_raw_data.json"),'w'), ensure_ascii=False)

In [None]:
# 生成v3.2/checkout_data_train训练数据集
# 数据来源：各大内科，使用bert_score筛选与化验相关度高的数据，进一步清洗了出院诊断
# 进一步增加了stage1 Encoder预训练数据，任务：还原化验单的异常表项
import json
import os
import pandas as pd
from transformers import AutoTokenizer
from tqdm import tqdm
import re
import random
from multiprocessing import Pool, cpu_count

data_path = os.path.join(os.environ['my_datasets_dir'], "ninth/checkout_data_im_with_checks.json")
ft_eval_ratio = 0.05
pt_eval_ratio = 0.01
bert_score_f1_threshold = 0.8
max_dict_token_num = 2048  # max_length for dict token num
max_pt_token_num = 80  # max length for pretrain_input text + pt_output text
max_ft_token_num = 768  # max length for ft_input text + ft_output text
max_ft_text_token_num = 1600  # max length for ft_input text + text dicts + ft_output text
max_pt_data_size = 20000  # Pretrain 训练数据集规模
max_ft_data_size = 10000  # FT 训练数据集规模
max_num_abnormal_sample = 10  # 每个化验单最多sample的异常检验项目数
seed = 42
version = "3.2"
table_token = '[TABLE]'

pt_templates = [
    f"{table_token} 请根据化验单结果输出病人可能患有的疾病。输出:",
    f"{table_token} 请依据化验单的结果判断病人可能的疾病。输出:",
    f"{table_token} 根据化验单结果，请推测病人可能罹患的疾病。输出:",
    f"{table_token} 请分析化验单结果并指出病人可能遭受的疾病。输出:",
    f"{table_token} 基于化验单结果，预测病人可能患有哪些疾病。",
    f"{table_token} 请从化验单结果中判断出病人可能的疾病。输出:",
    f"{table_token} 根据化验单的结果，推断病人可能的健康问题。输出:",
    f"{table_token} 请查看化验单结果并识别可能的疾病。输出:",
    f"{table_token} 请解读化验单结果，判断病人可能面临的疾病。输出:",
    f"{table_token} 根据化验单结果，确定病人可能患的疾病。输出:",
    f"{table_token} 分析化验单结果并预判病人可能的疾病。输出:"
]
ft_text_template = "性别:{}\n年龄:{}\n入院时主要症状及体征:{}\n特殊检查及重要会诊:{}\n出院诊断:"
ft_template = "化验信息:{}文字信息:{}"
llm_tok = AutoTokenizer.from_pretrained(os.path.join(os.environ['my_models_dir'], 'internlm-7b'), trust_remote_code=True)
bert_tok = AutoTokenizer.from_pretrained(os.path.join(os.environ['my_models_dir'], 'bert-base-chinese'), trust_remote_code=True)
bert_tok.add_tokens([table_token])
random.seed(seed)
tqdm.pandas()

data = json.load(open(data_path))
data = random.sample(data, len(data))
print(f'full df size: {len(data)}')

def topk_index(index: pd.Series, k):
    count = 0
    for i in index.index:
        if index[i]:
            count += 1
            if count > k:
                index[i] = False
    return index

def get_abnormal_items(dict):
    abnormal_items = []
    for d in dict:
        for (k, v) in d['data'].items():
            if v in ["空", "阳性", "阳性(+)", "弱阳性", "极弱阳性"]:
                abnormal_items.append(k)
    if abnormal_items:
        num_sample = min(len(abnormal_items), max_num_abnormal_sample)
        abnormal_items = random.sample(abnormal_items, num_sample)
        return "，".join(abnormal_items)
    else:
        return "无"

def get_label_data_str(dict):
    dict_str = ""
    for d in dict:
        for (k, v) in d['data'].items():
            dict_str += f"{k}{v} "
        dict_str += "[SEP]"
    return dict_str

def prepare(r):    
    label_data = [[dict(header=d['header'], data=d['data']) for d in r['完整化验结果']]]
    raw_data = [[dict(header=d['header'], data=d['raw_data']) for d in r['完整化验结果']]]
    label_data_str = get_label_data_str(r['完整化验结果'])
    
    ft_output = r['出院诊断']
    ft_output = ft_output.replace("「", "").replace("」", "").replace("\n", "").replace(" ", "").replace("(", "（").replace(")", "）")
    ft_output = re.sub("（[^）]*）", "", ft_output)
    ft_input = ft_text_template.format(r['性别'], r['年龄'], r['入院时主要症状及体征'], r['特殊检查及重要会诊'])
    ft_input = ft_input.replace("「", "").replace("」", "")
    for o in ft_output.split("，"):
        ft_input = ft_input.replace(o, "")
    ft_text_dict_input = ft_template.format(label_data_str, ft_input)
    
    pt_input = random.choice(pt_templates)
    pt_output = ft_output
    
    num_dict_token = len(bert_tok(label_data_str, add_special_tokens=False)['input_ids'])
    num_ft_text_dict_token = len(llm_tok(ft_text_dict_input, add_special_tokens=False)['input_ids'])
    num_pt_token = len(llm_tok(pt_input+pt_output)['input_ids'])
    num_ft_token = len(llm_tok(ft_input+ft_output)['input_ids'])
    
    r['pt_input'] = pt_input
    r['pt_output'] = pt_output
    r['label_data'] = label_data
    r['raw_data'] = raw_data
    r['label_data_str'] = label_data_str
    r['ft_input'] = ft_input
    r['ft_text_dict_input'] = ft_text_dict_input
    r['ft_output'] = ft_output
    r['num_dict_token'] = num_dict_token
    r['num_ft_text_dict_token'] = num_ft_text_dict_token
    r['num_pt_token'] = num_pt_token
    r['num_ft_token'] = num_ft_token
    
    return r

print(f"generating data_df")
data_df = []

process_num = 8
pool = Pool(process_num)
print(f"using {process_num} process")
for output in tqdm(pool.imap(prepare, data), total=len(data)):  
	data_df.append(output)
pool.close()

data_df = pd.DataFrame(data_df)

ft_index = (data_df['门诊诊断'].apply(len) > 0) \
    & (data_df['出院诊断'].apply(len) > 0) \
    & (data_df['门诊出院bert_score_f1'] < bert_score_f1_threshold) \
    & (data_df['num_ft_token'] < max_ft_token_num) \
    & (data_df['num_dict_token'] < max_dict_token_num) \
    & (data_df['num_ft_text_dict_token'] < max_ft_text_token_num)
print('ft_index: ', ft_index.sum())

pt_index = (data_df['num_dict_token'] < max_dict_token_num) \
    & (data_df['num_pt_token'] < max_pt_token_num)
print('pt_index: ', pt_index.sum()) 

ft_index = topk_index(ft_index, max_ft_data_size)
intersection = ft_index & pt_index
print('intersection: ', intersection.sum())
pt_index = pt_index & (~intersection)
pt_index = topk_index(pt_index, max_pt_data_size)

pt_df = data_df[pt_index]
ft_df = data_df[ft_index]

pt_normal_df = pt_df.progress_apply(lambda r: pd.Series(dict(input=r['pt_input'],data=r['label_data'],output=r['pt_output'])), axis=1)
ft_no_dict_df = ft_df.progress_apply(lambda r: pd.Series(dict(input=r['ft_input'], output=r['ft_output'])), axis=1)
ft_text_dict_df = ft_df.progress_apply(lambda r: pd.Series(dict(input=r['ft_text_dict_input'], output=r['ft_output'])), axis=1)
ft_normal_df = ft_df.progress_apply(lambda r: pd.Series(dict(input=table_token+r['ft_input'], data=r['label_data'], output=r['ft_output'])), axis=1)
ft_raw_data_df = ft_df.progress_apply(lambda r: pd.Series(dict(input=table_token+r['ft_input'], data=r['raw_data'], output=r['ft_output'])), axis=1)

pt_size = int(len(pt_df)*(1-pt_eval_ratio))
train_df_pretrain = pt_normal_df.iloc[:pt_size]
eval_df_pretrain = pt_normal_df.iloc[pt_size:]

ft_size = int(len(ft_df)*(1-ft_eval_ratio))
train_df = ft_normal_df.iloc[:ft_size]
eval_df = ft_normal_df.iloc[ft_size:]
train_df_text_dicts = ft_text_dict_df.iloc[:ft_size]
eval_df_text_dicts = ft_text_dict_df.iloc[ft_size:]
train_df_no_dicts = ft_no_dict_df.iloc[:ft_size]
eval_df_no_dicts = ft_no_dict_df.iloc[ft_size:]
train_df_raw = ft_raw_data_df.iloc[:ft_size]
eval_df_raw = ft_raw_data_df.iloc[ft_size:]

data_df['num_dict_token'].hist(backend='plotly', title='num_dict_token').show()
pt_df['num_pt_token'].hist(backend='plotly', title='num_pt_token').show()
ft_df['num_ft_token'].hist(backend='plotly', title='num_ft_token').show()
ft_df['num_ft_text_dict_token'].hist(backend='plotly', title='num_ft_text_dict_token').show()

json.dump(train_df.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train.json"), 'w'), ensure_ascii=False)
json.dump(eval_df.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval.json"), 'w'), ensure_ascii=False, indent=4)
json.dump(train_df_no_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_no_dicts.json"), 'w'), ensure_ascii=False)
json.dump(eval_df_no_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_no_dicts.json"), 'w'), ensure_ascii=False, indent=4)
json.dump(train_df_text_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_text_dicts.json"), 'w'), ensure_ascii=False)
json.dump(eval_df_text_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_text_dicts.json"), 'w'), ensure_ascii=False, indent=4)
json.dump(train_df_raw.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_raw_data.json"), 'w'), ensure_ascii=False)
json.dump(eval_df_raw.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_raw_data.json"), 'w'), ensure_ascii=False, indent=4)
json.dump(train_df_pretrain.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_pretrain.json"), 'w'), ensure_ascii=False)
json.dump(eval_df_pretrain.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_pretrain.json"), 'w'), ensure_ascii=False, indent=4)

In [15]:
# 生成v3.3/checkout_data_train训练数据集
# 数据来源：各大内科，对出院诊断做归一化，过滤了长度过长的部分病例，删除了出现比例小于千分之一的长尾疾病，平衡了训练集和测试集中的疾病种类数量
# 进一步增加了stage1 Encoder预训练数据，任务：还原化验单的异常表项
import json
import os
import pandas as pd
from transformers import AutoTokenizer, BertModel
from tqdm import tqdm
import re
import random
from collections import Counter
import multiprocessing
import torch
import faiss


data_path = os.path.join(os.environ['my_datasets_dir'], "ninth/checkout_data_im_with_checks.json")
ft_eval_ratio = 0.1
pt_eval_ratio = 0.01
bert_score_f1_threshold = 0.8
max_dict_token_num = 4096  # 字典最长token数
max_pt_token_num = 128  # max length for pretrain_input text + pt_output text
max_ft_token_num = 1024  # max length for ft_input text + ft_output text
max_ft_data_size = 20000  # FT 训练数据集最大规模
max_pt_data_size = 20000  # Pretrain 训练数据集规模
max_num_abnormal_sample = 10  # 每个化验单最多sample的异常检验项目数
long_tail_lb_ratio = 1e-3  # 过滤掉出院诊断中包含长尾疾病的病例
seed = 42
version = "3.3"
table_token = '[TABLE]'
pt_templates = [
    f"{table_token} 请根据化验单结果输出病人可能患有的疾病。输出:",
    f"{table_token} 请依据化验单的结果判断病人可能的疾病。输出:",
    f"{table_token} 根据化验单结果，请推测病人可能罹患的疾病。输出:",
    f"{table_token} 请分析化验单结果并指出病人可能遭受的疾病。输出:",
    f"{table_token} 基于化验单结果，预测病人可能患有哪些疾病。",
    f"{table_token} 请从化验单结果中判断出病人可能的疾病。输出:",
    f"{table_token} 根据化验单的结果，推断病人可能的健康问题。输出:",
    f"{table_token} 请查看化验单结果并识别可能的疾病。输出:",
    f"{table_token} 请解读化验单结果，判断病人可能面临的疾病。输出:",
    f"{table_token} 根据化验单结果，确定病人可能患的疾病。输出:",
    f"{table_token} 分析化验单结果并预判病人可能的疾病。输出:"
]
ft_text_template = "性别:{}\n年龄:{}\n入院时主要症状及体征:{}\n特殊检查及重要会诊:{}\n出院诊断:"
ft_template = "化验信息:{}文字信息:{}"

llm_tok = AutoTokenizer.from_pretrained(os.path.join(os.environ['my_models_dir'], 'internlm-7b'), trust_remote_code=True)
bert_tok = AutoTokenizer.from_pretrained(os.path.join(os.environ['my_models_dir'], 'bert-base-chinese'), trust_remote_code=True)
bert_tok.add_tokens([table_token])
random.seed(seed)
tqdm.pandas()

# get data_df
data = json.load(open(data_path))
data = random.sample(data, len(data))
print(f'full df size: {len(data)}')

term2bterm = json.load(open(os.path.join(os.environ['my_datasets_dir'], "bios_v2.2_release/CoreData/term2bterm.json")))
print(f"original diagnosis terms num: {len(term2bterm.keys())}")
print(f"normalized diagnosis terms num: {len(set(list(term2bterm.values())))}")

def get_label_data_str(dict):
    dict_str = ""
    for d in dict:
        for (k, v) in d['data'].items():
            if v == "N":
                v = "正常"
            elif v == "空":
                v = "异常"
            dict_str += f"{k} {v} "
        dict_str += "[SEP]"
    return dict_str

def prepare(d):    
    label_data = [[dict(header=d['header'], data=d['data']) for d in d['完整化验结果']]]
    raw_data = [[dict(header=d['header'], data=d['raw_data']) for d in d['完整化验结果']]]
    label_data_str = get_label_data_str(d['完整化验结果'])
    diagnosis = d['出院诊断']
    diagnosis = diagnosis.replace("「", "").replace("」", "").replace("\n", "").replace(" ", "").replace("(", "（").replace(")", "）")
    diagnosis = re.sub("（[^）]*）", "", diagnosis)
    bterms = [term2bterm.get(t.strip(), t.strip()) for t in diagnosis.split("，")]
    ft_input = ft_text_template.format(d['性别'], d['年龄'], d['入院时主要症状及体征'], d['特殊检查及重要会诊'])
    ft_input = ft_input.replace("「", "").replace("」", "")
    for t in bterms:
        ft_input = ft_input.replace(t, "")
    ft_text_dict_input = ft_template.format(label_data_str, ft_input)
    pt_input = random.choice(pt_templates)
    ft_output = "，".join(bterms)
    pt_output = ft_output
    num_dict_token = len(bert_tok(label_data_str, add_special_tokens=False)['input_ids'])
    num_pt_token = len(llm_tok(pt_input+pt_output)['input_ids'])
    num_ft_token = len(llm_tok(ft_input+ft_output)['input_ids'])
    num_ft_text_dict_token = len(llm_tok(ft_text_dict_input)['input_ids'])
    
    d['ft_output'] = ft_output
    d['pt_output'] = pt_output
    d['num_dict_token'] = num_dict_token
    d['num_pt_token'] = num_pt_token
    d['num_ft_token'] = num_ft_token
    d['num_ft_text_dict_token'] = num_ft_text_dict_token
    d['pt_input'] = pt_input
    d['label_data'] = label_data
    d['raw_data'] = raw_data
    d['label_data_str'] = label_data_str
    d['ft_input'] = ft_input
    d['ft_text_dict_input'] = ft_text_dict_input
    d['bterms'] = bterms
    
    return d

print(f"generating data_df")
prepared_data = []
process_num = 8
pool = multiprocessing.Pool(process_num)
print(f"using {process_num} process")

for output in tqdm(pool.imap(prepare, data), total=len(data)):  
	prepared_data.append(output)

pool.close()

data_df = pd.DataFrame(prepared_data)

# filter and organize data_df
ft_index = (data_df['出院诊断'].apply(len) > 0) \
    & (data_df['num_ft_token'] < max_ft_token_num) \
    & (data_df['num_dict_token'] < max_dict_token_num)
print('ft_index_by_length: ', ft_index.sum())

print('ft_index: ', ft_index.sum())

pt_index = (data_df['num_dict_token'] < max_dict_token_num) \
    & (data_df['num_pt_token'] < max_pt_token_num)
print('pt_index: ', pt_index.sum()) 

def topk_index(index: pd.Series, k):
    count = 0
    for i in index.index:
        if index[i]:
            count += 1
            if count > k:
                index[i] = False
    return index

ft_index = topk_index(ft_index, max_ft_data_size)
intersection = ft_index & pt_index
print('intersection: ', intersection.sum())
pt_index = pt_index & (~intersection)
pt_index = topk_index(pt_index, max_pt_data_size)

pt_df = data_df[pt_index]
ft_df = data_df[ft_index]

# remove long tail
bterm_count = dict(Counter([t for d in ft_df['bterms'].to_list() for t in d]))
print('bterms class num: ', len(bterm_count))
valid_bterms = set([bterm for bterm in bterm_count if bterm_count[bterm] > long_tail_lb_ratio*ft_df.shape[0]])
print(f'vaild_bterms class > {long_tail_lb_ratio*ft_df.shape[0]} num: ', len(valid_bterms))

def remove_long_tail(d):
    bterms = d['bterms']
    bterms = list(set(bterms) & valid_bterms)
    ft_output = "，".join(bterms)
    d['bterms'] = bterms
    d['ft_output'] = ft_output
    return d

ft_df = ft_df.progress_apply(remove_long_tail, axis=1)
ft_df = ft_df[ft_df['bterms'].apply(len) > 0]

# generate train_df and eval_df
pt_normal_df = pt_df.progress_apply(lambda r: pd.Series(dict(input=r['pt_input'],data=r['label_data'],output=r['pt_output'])), axis=1)
ft_no_dict_df = ft_df.progress_apply(lambda r: pd.Series(dict(input=r['ft_input'], output=r['ft_output'])), axis=1)
ft_text_dict_df = ft_df.progress_apply(lambda r: pd.Series(dict(input=r['ft_text_dict_input'], output=r['ft_output'])), axis=1)
ft_normal_df = ft_df.progress_apply(lambda r: pd.Series(dict(input=table_token+r['ft_input'], data=r['label_data'], output=r['ft_output'])), axis=1)
ft_raw_data_df = ft_df.progress_apply(lambda r: pd.Series(dict(input=table_token+r['ft_input'], data=r['raw_data'], output=r['ft_output'])), axis=1)

pt_size = int(len(pt_df)*(1-pt_eval_ratio))
train_df_pretrain = pt_normal_df.iloc[:pt_size]
eval_df_pretrain = pt_normal_df.iloc[pt_size:]

ft_size = int(len(ft_df)*(1-ft_eval_ratio))
train_df = ft_normal_df.iloc[:ft_size]
eval_df = ft_normal_df.iloc[ft_size:]
train_df_text_dicts = ft_text_dict_df.iloc[:ft_size]
eval_df_text_dicts = ft_text_dict_df.iloc[ft_size:]
train_df_no_dicts = ft_no_dict_df.iloc[:ft_size]
eval_df_no_dicts = ft_no_dict_df.iloc[ft_size:]
train_df_raw = ft_raw_data_df.iloc[:ft_size]
eval_df_raw = ft_raw_data_df.iloc[ft_size:]

data_df['num_dict_token'].hist(backend='plotly', title='num_dict_token').show()
ft_df['num_ft_token'].hist(backend='plotly', title='num_ft_token').show()
ft_df['num_ft_text_dict_token'].hist(backend='plotly', title='num_ft_text_dict_token').show()
train_bterm_count = pd.DataFrame.from_dict(Counter(train_df['output'].apply(lambda x: x.split("，")).sum()), orient='index', columns=['train_count'])
print('train_bterm_count: ', train_bterm_count.shape)
eval_bterm_count = pd.DataFrame.from_dict(Counter(eval_df['output'].apply(lambda x: x.split("，")).sum()), orient='index', columns=['eval_count'])
print('eval_bterm_count: ', eval_bterm_count.shape)
merge_bterm_count = pd.concat([train_bterm_count, eval_bterm_count], axis=1).fillna(0).sort_values(by=['train_count'], ascending=False)
merge_bterm_count.hist(backend='plotly', title='bterm_count', x=merge_bterm_count.index, y=merge_bterm_count['train_count']+merge_bterm_count['eval_count'],log_y=True).show()
merge_bterm_count

if not os.path.exists(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/")):
    os.mkdir(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/"))
json.dump(train_df.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train.json"), 'w'), ensure_ascii=False)
json.dump(eval_df.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval.json"), 'w'), ensure_ascii=False, indent=4)
json.dump(train_df_no_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_no_dicts.json"), 'w'), ensure_ascii=False)
json.dump(eval_df_no_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_no_dicts.json"), 'w'), ensure_ascii=False, indent=4)
json.dump(train_df_text_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_text_dicts.json"), 'w'), ensure_ascii=False)
json.dump(eval_df_text_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_text_dicts.json"), 'w'), ensure_ascii=False, indent=4)
json.dump(train_df_raw.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_raw_data.json"), 'w'), ensure_ascii=False)
json.dump(eval_df_raw.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_raw_data.json"), 'w'), ensure_ascii=False, indent=4)
json.dump(train_df_pretrain.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_pretrain.json"), 'w'), ensure_ascii=False)
json.dump(eval_df_pretrain.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_pretrain.json"), 'w'), ensure_ascii=False, indent=4)

full df size: 47541
original diagnosis terms num: 20123
normalized diagnosis terms num: 11290
generating data_df
using 8 process


100%|██████████| 47541/47541 [02:47<00:00, 283.95it/s]


ft_index_by_length:  43888
ft_index:  43888
pt_index:  45731
intersection:  19966
bterms class num:  6899
vaild_bterms class > 20.0 num:  550


100%|██████████| 20000/20000 [00:01<00:00, 12287.83it/s]
100%|██████████| 20000/20000 [00:04<00:00, 4425.54it/s]
100%|██████████| 18595/18595 [00:03<00:00, 4859.17it/s]
100%|██████████| 18595/18595 [00:03<00:00, 4821.78it/s]
100%|██████████| 18595/18595 [00:12<00:00, 1433.50it/s]
100%|██████████| 18595/18595 [00:03<00:00, 4741.91it/s]


train_bterm_count:  (550, 1)
eval_bterm_count:  (540, 1)


In [74]:
# eval_result analysis
import pandas as pd
from collections import Counter
import numpy as np
import json
import plotly.express as px
from transformers import AutoTokenizer
import os
from tqdm import tqdm

tqdm.pandas()

def analyse_result(eval_data_path, csv_path):
    llm_tok = AutoTokenizer.from_pretrained(os.path.join(os.environ['my_models_dir'], 'internlm-7b'), trust_remote_code=True)
    eval_df = pd.read_json(eval_data_path, orient='records')
    df = pd.read_csv(csv_path)
    eval_df = pd.concat([eval_df, df], axis=1)
    eval_df['input_length'] = eval_df['input'].progress_apply(lambda x: len(llm_tok(x)['input_ids']))
    ref_terms = df['ref_terms'].apply(lambda x: eval(x)).sum()
    ref_terms_count = dict(Counter(ref_terms))

    ref_terms_recall = {}
    for row in df.itertuples():
        ref_terms = eval(row.ref_terms)
        ref_topk_cids = eval(row.ref_topk_cid)
        pred_terms = eval(row.pred_terms)
        pred_topk_cids = eval(row.pred_topk_cid)
        for ref_term, ref_topk_cid in zip(ref_terms, ref_topk_cids):
            if len(pred_topk_cids) == 0:
                break
            # 优先匹配最相似的项
            intersect = [len(pred_topk_cid & ref_topk_cid) for pred_topk_cid in pred_topk_cids]
            if max(intersect) == 0:
                continue
            else:
                ref_terms_recall[ref_term] = ref_terms_recall.get(ref_term, 0) + 1
                pred_topk_cids.pop(np.argmax(intersect))
    
    ref_terms_count_df = pd.DataFrame.from_dict(ref_terms_count, orient='index', columns=['count'])
    ref_terms_count_df['recall'] = pd.Series(ref_terms_recall)
    ref_terms_count_df.fillna(0, inplace=True)
    ref_terms_count_df.sort_values(by='count', ascending=False, inplace=True)
    ref_terms_count_df['recall'] = ref_terms_count_df['recall'] / ref_terms_count_df['count']
    return ref_terms_count_df, eval_df

def get_nice_terms(ct_eval_data_path, bl_eval_data_path, ct_eval_result, bl_eval_result):
    ref_terms_count_df1, eval_df1 = analyse_result(ct_eval_data_path, ct_eval_result)
    ref_terms_count_df2, eval_df2 = analyse_result(bl_eval_data_path, bl_eval_result) 
    ref_terms_count_df1.hist(backend='plotly', x=ref_terms_count_df2.index, y='count', title='疾病count的分布情况').update_layout(width=1600, height=900).show()
    ref_terms_count_df1.hist(backend='plotly', x=ref_terms_count_df2.index, y='recall', title='疾病recall的分布情况').update_layout(width=1600, height=900).show()
    
    merge_df = pd.merge(ref_terms_count_df1, ref_terms_count_df2, left_index=True, right_index=True, suffixes=('_没化验单', '_有化验单'))
    print('terms: ', merge_df.shape[0])
    merge_df['化验单带来的提升'] = merge_df['recall_有化验单'] - merge_df['recall_没化验单']
    merge_df.hist(backend='plotly', x=merge_df.index, y='化验单带来的提升', title='化验单带来的提升的分布情况').show()
    eval_df1['bin'] = pd.cut(eval_df1['input_length'], bins=20)
    result = eval_df1.groupby('bin')['bios_scores_f'].mean().reset_index()
    result.columns = ['bin', 'mean_score']
    result['bin'] = result['bin'].apply(str)
    px.line(result, x='bin', y='mean_score', title='不带化验单的性能-长度的关系').show()
    eval_df2['bin'] = pd.cut(eval_df2['input_length'], bins=20)
    result = eval_df2.groupby('bin')['bios_scores_f'].mean().reset_index()
    result.columns = ['bin', 'mean_score']
    result['bin'] = result['bin'].apply(str)
    px.line(result, x='bin', y='mean_score', title='带化验单的性能-长度的关系').show()
    merge_df.sort_values(by='化验单带来的提升', ascending=False, inplace=True)
    print('merge_df: ', merge_df.index.to_list())
    nice_terms = merge_df[(merge_df['count_没化验单'] > 5) & (merge_df['化验单带来的提升'] > 0)]
    print('nice_terms: ', nice_terms.shape)
    nice_terms
    return nice_terms.index.to_list()

ct_eval_data_path = "/mnt/petrelfs/guoyiqiu/coding/my_datasets/ninth/v3.3/checkout_data_eval_no_dicts.json"
bl_eval_data_path = "/mnt/petrelfs/guoyiqiu/coding/my_datasets/ninth/v3.3/checkout_data_eval_text_dicts.json"
ct_eval_result = "/mnt/petrelfs/guoyiqiu/coding/my_models/ct-internlm-7b-v3.3/eval_result_2024_01_15_13_12_02.csv"
bl_eval_result = "/mnt/petrelfs/guoyiqiu/coding/my_models/bl2-internlm-7b-v3.3/eval_result_2024_01_16_12_43_10.csv"
nice_bterms = get_nice_terms(ct_eval_data_path, bl_eval_data_path, ct_eval_result, bl_eval_result)

json.dump(nice_bterms, open(os.path.join(os.environ['my_datasets_dir'], "ninth/v3.3/nice_bterms.json"),'w'), ensure_ascii=False)


100%|██████████| 1860/1860 [00:02<00:00, 698.46it/s]
100%|██████████| 1860/1860 [00:10<00:00, 170.33it/s]


terms:  540


merge_df:  ['2型糖尿病肾病iii期', '出血性扩张术后', '双侧大脑中动脉狭窄', '陈旧性心尖梗塞', '右侧肾结石', '手术后左侧胸膜肺炎', '结肠良性肿瘤', '肥胖疾病状态', '右肺腺癌', '弥漫性神经精神疾病', '术前肾母细胞瘤治疗', '双侧颈动脉斑块', '急性胃肠炎', '右斜疝', '结肠病变', '慢性肾小球性尿毒症', '重叠综合征', '双侧颈内动脉狭窄', '风湿性脊柱炎', '膝关节骨性关节炎', '永久性连续性房颤', 'i2b2肥胖', '残余心脏功能', '骨关节炎疾病状态', '糖尿病周围神经病变', '静息性心绞痛症状', '亚急性心肌梗死', '频发室性期外搏动', '重度贫血', '右侧甲状腺肿瘤', '亚临床支气管炎症', '肾结石性肾结石', '术后主动脉假性动脉瘤', '结直肠多发性息肉', '间质性浸润性肺炎', '白细胞减少', '动脉粥样硬化性动脉硬化', '慢性心力衰竭急性加重', 'a427肺癌', '脑梗死后', '慢性胆囊炎伴胆石', '结直肠肿瘤靶向治疗', 'c-相关性血管炎', '心肌桥', '抑郁症状状态', '照射后头颈部鳞状细胞癌', '右肺肿瘤', '慢性尿液感染', '冠状动脉心肌桥', '心肌梗塞组织修复术后', '梗塞性脑损伤', '颈动脉硬化症', '鼻咽狭窄术后', 'avm消化道出血', '骨肾疾病', '缺血性短暂性脑缺血发作', '胃溃疡炎症', '3型呼吸衰竭', '高磷血症状态', '支气管炎性哮喘', '高脂血症s', '喉癌术后', '阵发性心室颤动', '高血压3级', '室性异位早搏', '慢性阻塞性肺疾病加重', '左侧肾脓肿', 'ns肾病综合征', '同侧甲状腺结节', '痛风性障碍', '慢性血糖控制不佳', '未特指的慢性支气管炎', '石源性胆管炎', '低钾血症状态', '陈旧性心肌梗死', 'ua不稳定型心绞痛', '未特指的贫血', '腔隙性脑梗死', 'ii级心脏功能', '椎板切除术后', '阻塞性肺炎表现', '肝硬化功能障碍', '糖尿病和外周血管疾病', '腰椎压缩性骨折', '左肾结石', '右侧甲状腺乳头状癌', '肝占位性病变', '瘰疬', '腔隙性脑桥梗塞', '原发性甲减', '呼吸性咯血', '

In [76]:
# 生成v3.4/checkout_data_train训练数据集
# 数据来源：各大内科，对出院诊断做归一化，过滤了长度过长的部分病例，根据3.3版本ct-bl结果选择了少量疾病种类。
# 进一步增加了stage1 Encoder预训练数据，任务：还原化验单的异常表项
import json
import os
import pandas as pd
from transformers import AutoTokenizer, BertModel
from tqdm import tqdm
import re
import random
from collections import Counter
import multiprocessing
import torch
import faiss


data_path = os.path.join(os.environ['my_datasets_dir'], "ninth/checkout_data_im_with_checks.json")
ft_eval_ratio = 0.1
pt_eval_ratio = 0.01
bert_score_f1_threshold = 0.8
max_dict_token_num = 4096  # 字典最长token数
max_pt_token_num = 128  # max length for pretrain_input text + pt_output text
max_ft_token_num = 1024  # max length for ft_input text + ft_output text
max_ft_data_size = 99999  # FT 训练数据集最大规模
max_pt_data_size = 20000  # Pretrain 训练数据集规模
max_num_abnormal_sample = 10  # 每个化验单最多sample的异常检验项目数
long_tail_lb_ratio = 1e-3  # 过滤掉出院诊断中包含长尾疾病的病例
seed = 42
version = "3.4"
table_token = '[TABLE]'
pt_templates = [
    f"{table_token} 请根据化验单结果输出病人可能患有的疾病。输出:",
    f"{table_token} 请依据化验单的结果判断病人可能的疾病。输出:",
    f"{table_token} 根据化验单结果，请推测病人可能罹患的疾病。输出:",
    f"{table_token} 请分析化验单结果并指出病人可能遭受的疾病。输出:",
    f"{table_token} 基于化验单结果，预测病人可能患有哪些疾病。",
    f"{table_token} 请从化验单结果中判断出病人可能的疾病。输出:",
    f"{table_token} 根据化验单的结果，推断病人可能的健康问题。输出:",
    f"{table_token} 请查看化验单结果并识别可能的疾病。输出:",
    f"{table_token} 请解读化验单结果，判断病人可能面临的疾病。输出:",
    f"{table_token} 根据化验单结果，确定病人可能患的疾病。输出:",
    f"{table_token} 分析化验单结果并预判病人可能的疾病。输出:"
]
ft_text_template = "性别:{}\n年龄:{}\n入院时主要症状及体征:{}\n特殊检查及重要会诊:{}\n出院诊断:"
ft_template = "化验信息:{}文字信息:{}"

llm_tok = AutoTokenizer.from_pretrained(os.path.join(os.environ['my_models_dir'], 'internlm-7b'), trust_remote_code=True)
bert_tok = AutoTokenizer.from_pretrained(os.path.join(os.environ['my_models_dir'], 'bert-base-chinese'), trust_remote_code=True)
bert_tok.add_tokens([table_token])
random.seed(seed)
tqdm.pandas()

# get data_df
data = json.load(open(data_path))
data = random.sample(data, len(data))
print(f'full df size: {len(data)}')

term2bterm = json.load(open(os.path.join(os.environ['my_datasets_dir'], "bios_v2.2_release/CoreData/term2bterm.json")))
print(f"original diagnosis terms num: {len(term2bterm.keys())}")
print(f"normalized diagnosis terms num: {len(set(list(term2bterm.values())))}")

# get nice bterms
nice_bterms = json.load(open(os.path.join(os.environ['my_datasets_dir'], "ninth/v3.3/nice_bterms.json")))

def get_label_data_str(dict):
    dict_str = ""
    for d in dict:
        for (k, v) in d['data'].items():
            if v == "N":
                v = "正常"
            elif v == "空":
                v = "异常"
            dict_str += f"{k} {v} "
        dict_str += "[SEP]"
    return dict_str

def prepare(d):    
    label_data = [[dict(header=d['header'], data=d['data']) for d in d['完整化验结果']]]
    raw_data = [[dict(header=d['header'], data=d['raw_data']) for d in d['完整化验结果']]]
    label_data_str = get_label_data_str(d['完整化验结果'])
    diagnosis = d['出院诊断']
    diagnosis = diagnosis.replace("「", "").replace("」", "").replace("\n", "").replace(" ", "").replace("(", "（").replace(")", "）")
    diagnosis = re.sub("（[^）]*）", "", diagnosis)
    bterms = [term2bterm.get(t.strip(), t.strip()) for t in diagnosis.split("，")]
    bterms = list(set(bterms) & set(nice_bterms))
    ft_input = ft_text_template.format(d['性别'], d['年龄'], d['入院时主要症状及体征'], d['特殊检查及重要会诊'])
    ft_input = ft_input.replace("「", "").replace("」", "")
    for t in bterms:
        ft_input = ft_input.replace(t, "")
    ft_text_dict_input = ft_template.format(label_data_str, ft_input)
    pt_input = random.choice(pt_templates)
    ft_output = "，".join(bterms)
    pt_output = ft_output
    num_dict_token = len(bert_tok(label_data_str, add_special_tokens=False)['input_ids'])
    num_pt_token = len(llm_tok(pt_input+pt_output)['input_ids'])
    num_ft_token = len(llm_tok(ft_input+ft_output)['input_ids'])
    
    d['ft_output'] = ft_output
    d['pt_output'] = pt_output
    d['num_dict_token'] = num_dict_token
    d['num_pt_token'] = num_pt_token
    d['num_ft_token'] = num_ft_token
    d['pt_input'] = pt_input
    d['label_data'] = label_data
    d['raw_data'] = raw_data
    d['label_data_str'] = label_data_str
    d['ft_input'] = ft_input
    d['ft_text_dict_input'] = ft_text_dict_input
    d['bterms'] = bterms
    
    return d

print(f"generating data_df")
prepared_data = []
process_num = 8
pool = multiprocessing.Pool(process_num)
print(f"using {process_num} process")

for output in tqdm(pool.imap(prepare, data), total=len(data)):  
	prepared_data.append(output)

pool.close()

data_df = pd.DataFrame(prepared_data)

# filter and organize data_df
ft_index = (data_df['bterms'].apply(len) > 0) \
    & (data_df['num_ft_token'] < max_ft_token_num) \
    & (data_df['num_dict_token'] < max_dict_token_num)
print('ft_index_by_length: ', ft_index.sum())

print('ft_index: ', ft_index.sum())

pt_index = (data_df['bterms'].apply(len) > 0) \
    & (data_df['num_dict_token'] < max_dict_token_num) \
    & (data_df['num_pt_token'] < max_pt_token_num)
print('pt_index: ', pt_index.sum()) 

def topk_index(index: pd.Series, k):
    count = 0
    for i in index.index:
        if index[i]:
            count += 1
            if count > k:
                index[i] = False
    return index

ft_index = topk_index(ft_index, max_ft_data_size)
intersection = ft_index & pt_index
print('intersection: ', intersection.sum())
pt_index = pt_index & (~intersection)
pt_index = topk_index(pt_index, max_pt_data_size)

pt_df = data_df[pt_index]
ft_df = data_df[ft_index]

# generate train_df and eval_df
pt_normal_df = pt_df.progress_apply(lambda r: pd.Series(dict(input=r['pt_input'],data=r['label_data'],output=r['pt_output'])), axis=1)
ft_no_dict_df = ft_df.progress_apply(lambda r: pd.Series(dict(input=r['ft_input'], output=r['ft_output'])), axis=1)
ft_text_dict_df = ft_df.progress_apply(lambda r: pd.Series(dict(input=r['ft_text_dict_input'], output=r['ft_output'])), axis=1)
ft_normal_df = ft_df.progress_apply(lambda r: pd.Series(dict(input=table_token+r['ft_input'], data=r['label_data'], output=r['ft_output'])), axis=1)
ft_raw_data_df = ft_df.progress_apply(lambda r: pd.Series(dict(input=table_token+r['ft_input'], data=r['raw_data'], output=r['ft_output'])), axis=1)

pt_size = int(len(pt_df)*(1-pt_eval_ratio))
train_df_pretrain = pt_normal_df.iloc[:pt_size]
eval_df_pretrain = pt_normal_df.iloc[pt_size:]

ft_size = int(len(ft_df)*(1-ft_eval_ratio))
train_df = ft_normal_df.iloc[:ft_size]
eval_df = ft_normal_df.iloc[ft_size:]
train_df_text_dicts = ft_text_dict_df.iloc[:ft_size]
eval_df_text_dicts = ft_text_dict_df.iloc[ft_size:]
train_df_no_dicts = ft_no_dict_df.iloc[:ft_size]
eval_df_no_dicts = ft_no_dict_df.iloc[ft_size:]
train_df_raw = ft_raw_data_df.iloc[:ft_size]
eval_df_raw = ft_raw_data_df.iloc[ft_size:]

data_df['num_dict_token'].hist(backend='plotly', title='num_dict_token').show()
ft_df['num_ft_token'].hist(backend='plotly', title='num_ft_token').show()
ft_df['num_ft_text_dict_token'].hist(backend='plotly', title='num_ft_text_dict_token').show()
train_bterm_count = pd.DataFrame.from_dict(Counter(train_df['output'].apply(lambda x: x.split("，")).sum()), orient='index', columns=['train_count'])
print('train_bterm_count: ', train_bterm_count.shape)
eval_bterm_count = pd.DataFrame.from_dict(Counter(eval_df['output'].apply(lambda x: x.split("，")).sum()), orient='index', columns=['eval_count'])
print('eval_bterm_count: ', eval_bterm_count.shape)
merge_bterm_count = pd.concat([train_bterm_count, eval_bterm_count], axis=1).fillna(0).sort_values(by=['train_count'], ascending=False)
merge_bterm_count.hist(backend='plotly', title='bterm_count', x=merge_bterm_count.index, y=merge_bterm_count['train_count']+merge_bterm_count['eval_count'],log_y=True).show()
merge_bterm_count

if not os.path.exists(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/")):
    os.mkdir(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/"))
json.dump(train_df.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train.json"), 'w'), ensure_ascii=False)
json.dump(eval_df.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval.json"), 'w'), ensure_ascii=False, indent=4)
json.dump(train_df_no_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_no_dicts.json"), 'w'), ensure_ascii=False)
json.dump(eval_df_no_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_no_dicts.json"), 'w'), ensure_ascii=False, indent=4)
json.dump(train_df_text_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_text_dicts.json"), 'w'), ensure_ascii=False)
json.dump(eval_df_text_dicts.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_text_dicts.json"), 'w'), ensure_ascii=False, indent=4)
json.dump(train_df_raw.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_raw_data.json"), 'w'), ensure_ascii=False)
json.dump(eval_df_raw.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_raw_data.json"), 'w'), ensure_ascii=False, indent=4)
json.dump(train_df_pretrain.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_train_pretrain.json"), 'w'), ensure_ascii=False)
json.dump(eval_df_pretrain.to_dict(orient="records"), open(os.path.join(os.environ['my_datasets_dir'], f"ninth/v{version}/checkout_data_eval_pretrain.json"), 'w'), ensure_ascii=False, indent=4)

full df size: 47541
original diagnosis terms num: 20123
normalized diagnosis terms num: 11290
generating data_df
using 8 process


100%|██████████| 47541/47541 [02:02<00:00, 386.95it/s]


ft_index_by_length:  21070
ft_index:  21070
pt_index:  22183
intersection:  21070


100%|██████████| 1113/1113 [00:00<00:00, 4374.93it/s]
100%|██████████| 21070/21070 [00:04<00:00, 4328.58it/s]
100%|██████████| 21070/21070 [00:04<00:00, 4620.71it/s]
100%|██████████| 21070/21070 [00:14<00:00, 1478.45it/s]
100%|██████████| 21070/21070 [00:04<00:00, 4310.89it/s]


KeyError: 'num_ft_text_dict_token'

仁济住院数据预处理

In [None]:
# 仁济统计
import pandas as pd
import os

file_name_list = [name for name in os.listdir('./') if name.endswith('.csv')]
dfs = {file_name:pd.read_csv(file_name, low_memory=False) for file_name in file_name_list}
for df in dfs.values():
    df.rename(columns={'jzzsy': '住院号'}, inplace=True)
print("文件名:行数")
[(i,len(dfs[i])) for i in dfs]
print("文件名:住院号数量")
[(i,len(set(dfs[i]['住院号']))) for i in dfs]
intersection = set(dfs['202303_medical_history_enter.csv']['住院号'])
black_list = set([
    '202303_medical_history_leave_24h.csv',
    '202303_medical_history_op_first_disease.csv',
    '202303_medical_history_operation.csv',
    '202303_medical_history_routine.csv'
])
union = set()
for i in dfs:
    union = union.union(set(dfs[i]['住院号']))

for i in set(file_name_list) - black_list:
    new_set = set(dfs[i]['住院号'])
    intersection = new_set & intersection
print(f"住院号交集数量:{len(intersection)}, 重合比例:{len(intersection)}/{len(union)} {len(intersection)/len(union)*100:.2f}%")

In [None]:
# 生成renji_data.json
import json
from collections import defaultdict
from tqdm.auto import tqdm
import pandas as pd

advice = pd.read_csv("data/202303出院/202303出院有放射报告-医嘱.csv", low_memory=False, encoding='gbk').rename(columns={'jzzsy': '住院号'})
check = pd.read_csv("data/202303出院/202303出院有放射报告-检验.csv", low_memory=False, encoding='gbk').rename(columns={'jzzsy': '住院号'})
test = pd.read_csv("data/202303出院/202303出院有放射报告-检查.csv", low_memory=False, encoding='gbk').rename(columns={'jzzsy': '住院号'})
xlsx = pd.ExcelFile('data/202303出院/202303出院有放射报告-病史相关.xlsx')
history = {sheet_name : xlsx.parse(sheet_name).rename(columns={'jzzsy': '住院号'}) for sheet_name in xlsx.sheet_names}


zids = set(history['病案首页']['住院号'])
chuyuan = []
for zid in tqdm(zids):
    d = {}
    d['住院号'] = zid
    for sheet_name in history:
        sheet = history[sheet_name]
        if "住院号" in sheet.columns:
            d[sheet_name] = sheet[sheet['住院号'] == zid].to_dict(orient='records')
    d['医嘱'] = advice[advice['住院号'] == zid].to_dict(orient='records')
    d['检验'] = check[check['住院号'] == zid].to_dict(orient='records')
    d['检查'] = test[test['住院号'] == zid].to_dict(orient='records')
    chuyuan.append(d)
json.dump(chuyuan[:20], open("data/chuyuan_data_sample.json", 'w',encoding='utf-8',),ensure_ascii=False, indent=4)
json.dump(chuyuan, open("data/chuyuan_data.json", 'w',encoding='utf-8',),ensure_ascii=False, indent=4)

check_nums = [len(d['检验']) for d in chuyuan]
check_normal = [len([c for c in d['检验'] if c['结果值异常标志'] == 'NO']) for d in chuyuan]
check_abnormal = [(i-j) for i,j in zip(check_nums, check_normal)]
print(f"average_check_num: {(sum(check_nums)/len(check_nums))}\n" \
    f"average normal: {sum(check_normal)/len(check_normal)}\n" \
    f"average abnormal: {sum(check_abnormal)/len(check_abnormal)}")

In [None]:
# 生成time_stamps
from model import Stdout2File
from datetime import datetime

int2date = lambda x: datetime.strptime(str(x), "%Y%m%d%H%M%S").strftime("%Y/%m/%d %H:%M:%S").replace(r"2023/3",r"2023/03")
date2int = lambda x: int(datetime.strptime(x, "%Y/%m/%d %H:%M:%S").strftime("%Y%m%d%H%M%S"))
chuyuan_sample = json.load(open("data/chuyuan_data_sample.json"))

def time_stamps(d):
    time_stamps = []
    time_stamps.append((int2date(str(d['出院记录'][0]['入院日期时间'])),  f"入院时间 {d['病案首页'][0]['入院诊断名称']}"))
    time_stamps.append((int2date(str(d['出院记录'][0]['出院日期时间'])[:14]),  '出院时间'))
    if d.get("手术记录"):
        for i,op in enumerate(d['手术记录']):
            time_stamps.append((int2date(str(op['手术开始日期时间'])),  f"手术{i}:{op['手术及操作编码对应名称']} 开始时间"))
            if op.get('手术结束日期时间').strip():
                time_stamps.append((int2date(str(op['手术结束日期时间'][:14])),  f"手术{i}:{op['手术及操作编码对应名称']} 结束时间"))
    if d.get("日常病程记录"):
        for i,rec in enumerate(d['日常病程记录']):
            time_stamps.append((int2date(str(rec['记录日期'])),  f"日常病程{i} 记录时间:"))
    if d.get('首次病程记录'):
        for i, rec in enumerate(d['首次病程记录']):
            time_stamps.append((int2date(str(rec['记录日期'])),  f"首次病程记录{i} 记录时间:"))
    if d.get('术后首次病程记录'):
        for i, rec in enumerate(d['术后首次病程记录']):
            time_stamps.append((int2date(str(rec['记录日期'])),  f"术后首次病程记录{i} 记录时间:"))
    if d.get('医嘱'):
        for i, rec in enumerate(d['医嘱']):
            time_stamps.append((rec['开始时间'].replace(r"2023/3",r"2023/03"),  f"医嘱{i}:{rec['医嘱名称']} 开始时间:"))
            if rec.get('结束时间'):
                time_stamps.append((rec['结束时间'].replace(r"2023/3",r"2023/03"),  f"医嘱{i}:{rec['医嘱名称']} 结束时间:"))
    if d.get('检验'):
        for i, rec in enumerate(d['检验']):
            time_stamps.append((rec['报告日期'].replace(r"2023/3",r"2023/03"),  f"检验{i}:{rec['报告名称']} {rec['检验项目']} 报告日期:"))
    if d.get('检查'):
        for i, rec in enumerate(d['检查']):
            time_stamps.append((rec['报告日期'].replace(r"2023/3",r"2023/03"), f"检查{i}:{rec['报告名称']} 报告日期:"))
    sorted_time_stamps = sorted(time_stamps, key=lambda x: date2int(x[0]))
    return sorted_time_stamps

with Stdout2File("data/chuyuan_time_stamps.txt"):
    for d in chuyuan_sample:
        for s in time_stamps(d):
            print(s)
        print("---------------------------")

In [None]:
# 生成医嘱训练数据
import json
from tqdm.auto import tqdm
from datetime import datetime

date2int = lambda x: int(datetime.strptime(x, "%Y/%m/%d %H:%M:%S").strftime("%Y%m%d%H%M%S"))
data = json.load(open("data/chuyuan/chuyuan_data.json"))

def advice_type(advice):
    if advice.get("规格") and advice['单次用量']!=1.0:
        return "药物"
    # if "会诊" in advice['医嘱名称']:
    #     return "会诊"
    keywords = ["平扫","脑电","心电","CT","检查","分析","试验","检测","MRI","静脉血","超声"]
    for k in keywords:
        if k in advice['医嘱名称']:
            return "检查"
    return False

train_data = []

for d in tqdm(data):
    if (not d.get("入院记录") and not d.get("24小时出入院记录")) or (not d.get("医嘱") or not d.get("检验")):
        continue
    text_dict =dict(
        病人年龄=d['检验'][0]['年龄'],
        病人性别=d['病案首页'][0]['性别'],
        入院诊断名称=d['病案首页'][0]['入院诊断名称'],
        主诉=d['入院记录'][0]['主诉'] if d.get('入院记录') else d['24小时出入院记录'][0]['主诉'],
    )
    
    checks = []
    for c in d['检验']:
        try:
            c['报告日期'] = date2int(c['报告日期'])
            checks.append(c)
        except:
            continue
    checks = sorted(checks, key=lambda x: x['报告日期'])
    operations = sorted(d['手术记录'], key=lambda x: x['手术开始日期时间'])
    # tests = sorted(d['检查'], key=lambda x: x['报告日期'])
    
    advices = []
    for a in d['医嘱']:
        try:
            a['开始时间'] = date2int(a['开始时间'])
            advices.append(a)
        except:
            continue
    advices = sorted(advices, key=lambda x: x['开始时间'])
    
    # 如果有手术记录，过滤掉在第一次手术之后的所有信息
    if operations:
        first_op_time = operations[0]['手术开始日期时间']
        checks = [c for c in checks if c['报告日期'] < first_op_time]
        advices = [a for a in advices if a['开始时间'] < first_op_time]
    
    if not checks or not advices:
        continue
    
    # 过滤掉在第一次检验报告出来之前的医嘱
    advices = [a for a in advices if a['开始时间'] > checks[0]['报告日期']]
    
    # 过滤掉在最后一次检验报告出来之前的医嘱
    # advices = [a for a in advices if a['开始时间'] > checks[-1]['报告日期']]
    
    # 只保留药物，会诊，检查三种医嘱
    advices = [a for a in advices if advice_type(a)]
    
    # 对医嘱按时间聚类
    new_advices = [dict(开始时间=a['开始时间'], 医嘱名称=a['医嘱名称'], 医嘱类型=advice_type(a)) for a in advices]
    check_times = sorted(list(set([c['报告日期'] for c in checks])))
    advice_groups = [[a for a in new_advices if check_times[i] < a['开始时间'] < check_times[i+1]] for i in range(len(check_times)-1)]
    advice_groups = [a for a in advice_groups if a]

    # 生成医嘱训练数据
    text = ",".join([f"{k}: {v}" for (k,v) in text_dict.items()])
    for advice_group in advice_groups:
        check_list = [dict(检验项目=c['检验项目'], 检验结果=c['检验结果'], 正常值范围=c['正常值范围'], 结果值异常标志=c['结果值异常标志'] if c['结果值异常标志']!="NO" else "正常") for c in checks if c['报告日期'] < advice_group[0]['开始时间']]
        advice_text = ",".join([f"{a['医嘱类型']}:{a['医嘱名称']}" for a in advice_group])
        train_data.append(dict(text=text,data=check_list,output=advice_text))
json.dump(train_data[:50], open("data/chuyuan/chuyuan_train_sample.json", 'w',encoding='utf-8',),ensure_ascii=False, indent=4)
json.dump(train_data, open("data/chuyuan/chuyuan_train.json", 'w',encoding='utf-8',),ensure_ascii=False, indent=4)

其他数据集预处理

In [None]:
# 读取BIOS
from tqdm.auto import tqdm
import igraph as ig
import matplotlib.pyplot as plt
from collections import defaultdict
import marisa_trie
import re
import os


print("Loading BIOS...")
# relations = open("data/bios_v2.2_release/CoreData/Relations.txt").readlines()
uni_relations = open(os.path.join(os.environ['my_datasets_dir'], "bios_v2.2_release/CoreData/UniRelations.txt")).readlines()
concept_terms = open(os.path.join(os.environ['my_datasets_dir'], "bios_v2.2_release/CoreData/ConceptTerms.txt")).readlines()
definitions = open(os.path.join(os.environ['my_datasets_dir'], "bios_v2.2_release/CoreData/Definitions.txt")).readlines()
semantic_types = open(os.path.join(os.environ['my_datasets_dir'], "bios_v2.2_release/CoreData/SemanticTypes.txt")).readlines()

print("Building concept2term...")
concept2term = defaultdict(list)
for line in tqdm(concept_terms):
    ls = line.strip().split('|')
    concept2term[ls[0]].append(ls[2])

print("Building term2concept...")
if os.path.exists('data/term2concept.marisa'):
    term2concept = marisa_trie.BytesTrie()
    term2concept.load("data/term2concept.marisa")
else:
    keys = []
    values = []
    for line in tqdm(concept_terms):
        ls = line.strip().split('|')
        keys.append(ls[2])
        values.append(bytes(ls[0], encoding='utf-8'))
    term2concept = marisa_trie.BytesTrie(zip(keys,values))
    term2concept.save('data/term2concept.marisa')

print("Building node2idx...")
node2idx = {}
for line in tqdm(uni_relations):
    rid, head, tail, relid, rel = line.strip().split("|")
    if not node2idx.get(head):
        node2idx[head] = len(node2idx)
    if not node2idx.get(tail):
        node2idx[tail] = len(node2idx)

print("Building idx2node...")
idx2node = [x[0] for x in sorted(node2idx.items(), key=lambda x: x[1])]

print("Building edges...")
edges = {}
for rl in tqdm(uni_relations):
    ls = rl.strip().split("|")
    edges[(node2idx[ls[1]],node2idx[ls[2]])]=ls[-1]

print("Building graph...")
g = ig.Graph(n=len(node2idx), edges=list(edges.keys()), directed=True)
g.simplify()

print("Adding node descriptions...")
has_chinese = lambda text:  bool(re.compile(r'[\u4e00-\u9fa5]').search(text))
for i in tqdm(range(len(g.vs))):
    g.vs[i]["cid"] = idx2node[i]
    g.vs[i]["name"] = concept2term[idx2node[i]][0]
    color = "blue"
    neighbors_num = len(g.neighbors(i))
    if neighbors_num > 10:
        color = "green"
    if neighbors_num > 50:
        color = "orange"
    if neighbors_num > 100:
        color = "red"
    if neighbors_num > 1000:
        color = "purple"
    g.vs[i]["color"] = color
    for i in concept2term[idx2node[i]]:
        if has_chinese(i):
            g.vs[i]["name"] = i
            break

print("Adding edge descriptions...")
for i in tqdm(range(len(g.es))):
    s,t = g.es[i].source, g.es[i].target
    g.es[i]["name"] = edges[(s,t)]

undi_g = g.as_undirected()

import matplotlib.pyplot as plt
plt.rcParams['figure.figsize']=(15,15)

def get_cid(str):
    if term2concept.get(str.lower()):
        return term2concept[str.lower()][0].decode('utf-8')
    return None


def get_n_hop_neighbors_subgraph(str,n=2):
    cid = get_cid(str)
    if not cid:
        print("Entity Not found in KG")
        return
    print(f'cid: {cid} idx: {node2idx[cid]} term: {g.vs[node2idx[cid]]["name"]}')

    root = node2idx[cid]
    g.vs[root]['hop'] = 0
    all_neighbors = [root]
    for i in range(1,n+1):
        i_hop_neighbors = []
        for node in all_neighbors:
            node_neighbors = g.neighbors(node)
            if len(node_neighbors) > 100:
                continue
            node_neighbors = node_neighbors[:min(len(node_neighbors), 5)]
            i_hop_neighbors.extend(node_neighbors)
        i_hop_neighbors = list(set(i_hop_neighbors))
        print(f'{i}_hop_neighbors num: {len(i_hop_neighbors)}')
        for x in i_hop_neighbors:
            if 'hop' not in g.vs[x].attributes().keys() or g.vs[x]['hop'] is None:
                g.vs[x]['hop'] = i
        all_neighbors.extend(i_hop_neighbors)
        all_neighbors = list(set(all_neighbors))
    print('all_neighbors: ', all_neighbors)
    subgraph = g.subgraph(all_neighbors)
    for edge in subgraph.es:
        if abs(subgraph.vs[edge.source]['hop'] - subgraph.vs[edge.target]['hop']) != 1:
            subgraph.delete_edges(edge)
    for node in all_neighbors:
        if 'hop' in g.vs[node].attributes():
            del g.vs[node]['hop']
    return subgraph

def get_n_str_subgraph(str_list):
    node_idxs = list(set([node2idx[get_cid(str)] for str in str_list if get_cid(str)]))
    print('node_idxs: ', node_idxs)
    old_colors = [g.vs[i]['color'] for i in node_idxs]
    for i in node_idxs:
        g.vs[i]['color'] = 'white'
    paths = []
    for i in range(len(node_idxs)):
        for j in range(i+1, len(node_idxs)):
            paths.extend(g.get_shortest_paths(node_idxs[i], node_idxs[j], mode=ig.ALL))
    all_nodes = list(set([n for path in paths for n in path]))
    print('node num: ', len(all_nodes))
    subgraph = g.subgraph(all_nodes)
    for i,color in zip(node_idxs,old_colors):
        g.vs[i]['color'] = color
    return subgraph

In [None]:
# 关于TQA和TFV的数据分析
from transformers import AutoTokenizer
import pandas as pd
import os
import json
from tqdm.auto import tqdm
tok = AutoTokenizer.from_pretrained(os.path.join(os.environ['my_models_dir'], 'llama-7b'))

tqa_dst = json.load(open(os.path.join(os.environ['my_datasets_dir'], 'table/TQA_train_data.json')))
tfv_dst = json.load(open(os.path.join(os.environ['my_datasets_dir'], 'table/TFV_train_data.json')))

def count_token_num(d):
    return len(tok(d['instruction']+d['input'])['input_ids'])

from multiprocessing import Pool

pool = Pool(16)
tqa_dst_token_num = []
for output in tqdm(pool.imap(count_token_num, tqa_dst), total=len(tqa_dst)):  
	tqa_dst_token_num.append(output)
tfv_dst_token_num = []
for output in tqdm(pool.imap(count_token_num, tfv_dst), total=len(tfv_dst)):  
	tfv_dst_token_num.append(output)
pool.close()


In [None]:
# 生成true_false_dataset
HONESTY_EXTRACTION_TEMPLATE_EXP = """USER:Pretend you are an honest person making statements about the world.\nAssistant:"""
HONESTY_EXTRACTION_TEMPLATE_REF = """USER:Pretend you are a dishonest person making statements about the world.\nAssistant:"""

knows = {topic:{True:[],False:[]}for topic in topic_true_examples}

for o in outputs:
    ks = o['query_output'].split("<sep>")
    new_ks = []
    for k in ks:
        new_ks.extend(k.split("<sep"))
    ks = new_ks
    new_ks = []
    for k in ks:
        new_ks.extend(k.split("\n"))
    ks = new_ks
    ks = [re.sub("[0-9]+. ","", k).strip() for k in ks]
    ks = [k for k in ks if k]
    for k in ks:
        if k not in knows[o['topic']][o['label']]:
            knows[o['topic']][o['label']].append(k)
    
dst = []
for topic in knows:
    for i in range(min(len(knows[topic][True]),len(knows[topic][False]))):
        true_input = HONESTY_EXTRACTION_TEMPLATE_EXP
        false_input = HONESTY_EXTRACTION_TEMPLATE_REF
        true_output = knows[topic][True][i]
        false_output = knows[topic][False][i]
        dst.append([dict(input=true_input, output=true_output, topic=topic, label=True),dict(input=false_input, output=false_output, topic=topic, label=False)])
print('True-False Dst Size: ', len(dst))
json.dump(knows, open("data/knows.json", "w"), indent=4)
json.dump(dst, open("data/true_false_dataset.json", "w"), indent=4)

In [None]:
# 生成usmle_test_inference.json
import json
import jsonlines
from model import INTERNLM_TEMPLATE


usmle_test = list(jsonlines.open("data/usmle/questions/US/test.jsonl"))
usmle_test_for_infer = []

for d in usmle_test:
    options_strs = [f"{op}: {d['options'][op]}" for op in d['options']]
    input = f"Question: {d['question']} Options: {'; '.join(options_strs)}. Output: The correct answer is option"
    input = INTERNLM_TEMPLATE.format(input)
    d['input'] = input
    d['labels'] = ['A', 'B', 'C', 'D', 'E']
    usmle_test_for_infer.append(d)

json.dump(usmle_test_for_infer, open("data/mgpu_infer/usmle_test_inference.json", "w"), indent=4)

In [None]:
# 生成truthfulqa_inference.json
import json
from model import *
import os

tqa = json.load(open("/home/cs/yangyuchen/guoyiqiu/gpt_re/data/TruthfulQA-main/data/mc_task.json"))
print(tqa[0])
tqa_dst = [dict(input=INTERNLM_TEMPLATE.format(d['question']),labels=list(d['mc1_targets'].keys()),gt=list(d['mc1_targets'].values()).index(1)) for d in tqa]
print(tqa_dst[0])
json.dump(tqa_dst, open("data/mgpu_infer/truthfulqa_inference.json", "w"), indent=4)

In [None]:
# 生成medqa_test_noquestion_inference
from model import *

medqa_test = json.load(open("data/usmle/questions/US/test.json"))
print('medqa_test: ', medqa_test[0])
medqa_test_infer = []
for d in medqa_test:
    options = d['options']
    option_text = ";".join([f"{k}: {options[k]}" for k in options])
    # input_text = INTERNLM_TEMPLATE.format("Question:"+d['question']+"; Options:"+option_text) + "The correct answer is option"
    input_text = INTERNLM_TEMPLATE.format("Options:"+option_text) + "The correct answer is option"
    medqa_test_infer.append(dict(input=input_text, labels=list(options.keys()), gt=d['answer']))
print('medqa_test_infer: ', medqa_test_infer[0])
json.dump(medqa_test_infer, open("data/mgpu_infer/medqa_test_noquestion_inference.json", "w"), indent=4)

In [None]:
# 生成true-false数据集
true_query_template = "Generate 10 statements about the topic {topic}. The statements should be true and brief and contain factual knowledge. You can use the following statements as examples: {examples}. The statements should be split by <sep>."

false_query_template = "Generate 10 false statements about the topic {topic}. The statements should be incorrect and brief and contain wrong factual knowledge. You can use the following statements as examples: {examples}. The statements should be split by <sep>."

topic_true_examples = {
    "Cities": "Oranjestad is a city in Aruba." , 
    "Inventions": "Grace Hopper invented the COBOL programming language." , 
    "Chemical Elements": "Boron is used in the production of glass and ceramics." , 
    "Animals": "The llama has a diet of herbivore." , 
    "Companies": "Meta Platforms has headquarters in United States." , 
    "Scientific Facts": "The Earth’s tides are primarily caused by the gravitational pull of the moon.",
    "Medical": "Benign tumors typically grow slowly and do not invade surrounding tissues or spread to other areas."
}
topic_false_examples = {
    "Cities": "Wellington is a name of a country." ,
    "Inventions": "David Schwarz lived in France." ,
    "Chemical Elements": "Indium is in the Lanthanide group." ,
    "Animals": "The whale has a long, tubular snout, large ears, and a powerful digging ability to locate and consume termites and ants." ,
    "Companies": "KDDI operates in the industry of Materials." , 
    "Scientific Facts": "Ice sinks in water due to its higher density.",
    "Medical": "The normal range for human body temperature is 50-55 degrees Celsius."
}
true_queries = [dict(query_input=true_query_template.format(topic=k, examples=v),topic=k,label=True) for (k,v) in topic_true_examples.items()]
false_queries = [dict(query_input=false_query_template.format(topic=k, examples=v),topic=k,label=False) for (k,v) in topic_false_examples.items()]
inputs = true_queries + false_queries

outputs = []
for i in range(10):
    outputs.extend(multithread_query_chatgpt(inputs, thread_num=8, model_name='gpt-4-1106-preview'))

In [None]:
# 生成TermDiseaseZH.json
import os
import re
import json
from tqdm.auto import tqdm


contain_chinese = lambda text: bool(re.compile(r'[\u4e00-\u9fa5]').search(text))
semantic_types = open(os.path.join(os.environ['my_datasets_dir'], "bios_v2.2_release/CoreData/SemanticTypes.txt"))
concept_terms = open(os.path.join(os.environ['my_datasets_dir'], "bios_v2.2_release/CoreData/ConceptTerms.txt"))

semantic_types = {line.split("|")[0]: line.split("|")[1] for line in tqdm(semantic_types)}
term_disease_zh = {line.split("|")[2]: line.split("|")[0] for line in tqdm(concept_terms) if semantic_types[line.split("|")[0]] in ["5","6","10","11","25"] and contain_chinese(line)}
print(len(term_disease_zh))

json.dump(term_disease_zh, open(os.path.join(os.environ['my_datasets_dir'], "bios_v2.2_release/CoreData/TermDiseaseZH.json"), "w"))

In [None]:
# 展示数据集出院诊断的特殊分布和长尾现象
import os
import pandas as pd
from collections import Counter

df_train = pd.read_json(os.path.join(os.environ['my_datasets_dir'], 'ninth/v3.1/checkout_data_train.json'))
df_eval = pd.read_json(os.path.join(os.environ['my_datasets_dir'], 'ninth/v3.1/checkout_data_eval.json'))
df = pd.concat([df_train,df_eval])
print('病例数量: ', df.shape[0])

def count_abnormal(row):
    num_item = 0
    num_abnormal = 0
    for dict in row['data']:
        for d in dict:
            num_item += len(d['data'])
            for k,v in d['data'].items():
                if v == '空':
                    num_abnormal += 1
    row['总项目数'] = num_item
    row['异常项目数'] = num_abnormal
    return row
df = df.apply(count_abnormal, axis=1)
print(f"平均异常比例：{(df['异常项目数'] / df['总项目数']).mean()}")
(df['异常项目数'] / df['总项目数']).hist(backend='plotly', title='异常项目数占比').update_layout(width=1600, height=900).show()
df['异常项目数'].hist(backend='plotly', title='异常项目数').update_layout(width=1600, height=900).show()
ninth_diseases = df['output'].apply(lambda x: x.split("，")).sum()
print('所有需要预测疾病的数量: ', len(ninth_diseases))
ninth_diseases_count = pd.DataFrame.from_dict(Counter(ninth_diseases), orient='index').sort_values(by=0, ascending=False)
print('所有需要预测疾病的种类: ', len(ninth_diseases_count))
print('疾病种类的分布情况: ', ninth_diseases_count.describe())
ninth_diseases_count.iloc[:100].hist(backend='plotly', x=ninth_diseases_count.iloc[:100].index, y=0, title='出现数量前一百的疾病的分布情况').update_layout(width=1600, height=900).show()

In [None]:
# 生成TermDiseaseZHEmbedding.npy
import json
from transformers import BertTokenizer, BertModel
import os
import torch
from tqdm.auto import tqdm
import numpy as np

term_disease_zh = json.load(open((os.path.join(os.environ['my_datasets_dir'], "bios_v2.2_release/CoreData/TermDiseaseZH.json"))))
terms = list(term_disease_zh.keys())
print('terms: ', len(terms))

bert = BertModel.from_pretrained(os.path.join(os.environ['my_models_dir'], "bert-base-chinese"))
tok = BertTokenizer.from_pretrained(os.path.join(os.environ['my_models_dir'], "bert-base-chinese"))

bert.eval()
# bert.cuda()
batch_size = 4
terms = terms[:50]

term_disease_zh_embedding = []
for i in tqdm(range(0, len(terms), batch_size)):
    inp = tok(terms[i:i+batch_size], return_tensors='pt', padding=True)
    input_ids, attention_mask = inp['input_ids'], inp['attention_mask']
    # input_ids, attention_mask = input_ids.cuda(), attention_mask.cuda()
    with torch.no_grad():
        batch_embedding = bert(input_ids=input_ids, attention_mask=attention_mask)['last_hidden_state'][:,0]
    term_disease_zh_embedding.append(batch_embedding.cpu().numpy())

term_disease_zh_embedding = np.concatenate(term_disease_zh_embedding, axis=0)
print('term_disease_zh_embedding: ', term_disease_zh_embedding.shape)
np.save(os.path.join(os.environ['my_datasets_dir'], "bios_v2.2_release/CoreData/TermDiseaseZHEmbedding_sample.npy"), term_disease_zh_embedding)

In [None]:
# 生成NinthDiseaseEmbedding.npy
import pandas as pd
import os
from transformers import BertTokenizer, BertModel
import torch
from tqdm.auto import tqdm
import numpy as np
import json

df_train = pd.read_json(os.path.join(os.environ['my_datasets_dir'], 'ninth/v3.1/checkout_data_train.json'))
df_eval = pd.read_json(os.path.join(os.environ['my_datasets_dir'], 'ninth/v3.1/checkout_data_eval.json'))
df = pd.concat([df_train,df_eval])
ninth_diseases = list(set(df['output'].apply(lambda x: x.split("，")).sum()))
json.dump(ninth_diseases, open(os.path.join(os.environ['my_datasets_dir'], "bios_v2.2_release/CoreData/NinthDisease.json"), 'w'), ensure_ascii=False)
print('所有需要预测疾病的种类: ', len(ninth_diseases))

bert = BertModel.from_pretrained(os.path.join(os.environ['my_models_dir'], "bert-base-chinese"))
tok = BertTokenizer.from_pretrained(os.path.join(os.environ['my_models_dir'], "bert-base-chinese"))

bert.eval()
# bert.cuda()
batch_size = 4
# ninth_diseases = ninth_diseases[:50]

ninth_disease_zh_embedding = []
for i in tqdm(range(0, len(ninth_diseases), batch_size)):
    inp = tok(ninth_diseases[i:i+batch_size], return_tensors='pt', padding=True)
    input_ids, attention_mask = inp['input_ids'], inp['attention_mask']
    # input_ids, attention_mask = input_ids.cuda(), attention_mask.cuda()
    with torch.no_grad():
        batch_embedding = bert(input_ids=input_ids, attention_mask=attention_mask)['last_hidden_state'][:,0]
    ninth_disease_zh_embedding.append(batch_embedding.cpu().numpy())

ninth_disease_zh_embedding = np.concatenate(ninth_disease_zh_embedding, axis=0)
print('ninth_disease_zh_embedding: ', ninth_disease_zh_embedding.shape)
np.save(os.path.join(os.environ['my_datasets_dir'], "bios_v2.2_release/CoreData/NinthDiseaseEmbedding.npy"), ninth_disease_zh_embedding)