## mimic 疾病处理初步

根据icd10大类将mimic中的疾病分类保存

In [None]:
import pandas as pd
import numpy as np
from collections import defaultdict
from tqdm.notebook import tqdm


icd9toicd10 = pd.read_csv('icd9toicd10cmgem.csv', usecols=['icd9cm', 'icd10cm'])
map9to10 = defaultdict(str)
for i in range(len(icd9toicd10)):
    icd9 = icd9toicd10.iloc[i]['icd9cm']
    icd10 = icd9toicd10.iloc[i]['icd10cm']
    while icd9.startswith('0'):
        icd9 = icd9[1:]
    if len(map9to10[icd9]) == 0 or map9to10[icd9] == 'NoDx':
        map9to10[icd9] = icd10
np.save('icd9toicd10.npy', map9to10)

def get_icd10_type(id: str) -> str:
    if id[0] in ['A', 'B']:
        return 'Certain infectious and parasitic diseases'
    elif id[0] == 'C':
        return 'Neoplasms'
    elif id[0] == 'D':
        if id[1:3].isdecimal() and int(id[1:3]) <= 49:
            return 'Neoplasms'
        else:
            return 'Diseases of the blood and blood-forming organs and certain disorders involving the immune mechanism'
    elif id[0] == 'E':
        return 'Endocrine, nutritional and metabolic diseases'
    elif id[0] == 'F':
        return 'Mental, Behavioral and Neurodevelopmental disorders'
    elif id[0] == 'G':
        return 'Diseases of the nervous system'
    elif id[0] == 'H':
        if id[1:3].isdecimal() and int(id[1:3]) <= 59:
            return 'Diseases of the eye and adnexa'
        else:
            return 'Diseases of the ear and mastoid process'
    elif id[0] == 'I':
        return 'Diseases of the circulatory system'
    elif id[0] == 'J':
        return 'Diseases of the respiratory system'
    elif id[0] == 'K':
        return 'Diseases of the digestive system'
    elif id[0] == 'L':
        return 'Diseases of the skin and subcutaneous tissue'
    elif id[0] == 'M':
        return 'Diseases of the musculoskeletal system and connective tissue'
    elif id[0] == 'N':
        return 'Diseases of the genitourinary system'
    elif id[0] == 'O':
        return 'Pregnancy, childbirth and the puerperium'
    elif id[0] == 'P':
        return 'Certain conditions originating in the perinatal period'
    elif id[0] == 'Q':
        return 'Congenital malformations, deformations and chromosomal abnormalities'
    elif id[0] == 'R':
        return 'Symptoms, signs and abnormal clinical and laboratory findings, not elsewhere classified'
    elif id[0] in ['S', 'T']:
        return 'Injury, poisoning and certain other consequences of external causes'
    elif id[0] == 'U':
        return 'Codes for special purposes'
    elif id[0] in ['V', 'W', 'X', 'Y']:
        return 'External causes of morbidity'
    elif id[0] == 'Z':
        return 'Factors influencing health status and contact with health services'

diagnosis = pd.read_csv('diagnosis_icd.csv')
all_type = defaultdict(list)
for i in tqdm(range(len(diagnosis))):
    code, version = diagnosis.iloc[i]['icd_code'], diagnosis.iloc[i]['icd_version']
    if version == 9:
        code = icd9toicd10[code]
    if len(code) == 0 or code == 'NoDx':
        continue
    all_type[get_icd10_type(code)].append([diagnosis.iloc[i]['subject_id'], diagnosis.iloc[i]['hadm_id']])
for k in all_type.keys():
    df = pd.DataFrame(all_type[k], columns=['subject_id', 'hadm_id'])
    df.to_csv(f'diagnosis/{k}.csv', index=False)
    print(k, len(all_type[k]))

提取药物相互作用关系

In [None]:
import xml.etree.ElementTree as et
import json
from collections import defaultdict
from tqdm.notebook import tqdm


tree = et.parse("drugbank_full database.xml")
all_drug = tree.findall("drug")
proc = lambda text: text.replace("\n", "").replace("\t", "").replace("\r", "").strip()
DDI = defaultdict(dict)

for drug in tqdm(all_drug):
    for i in drug.findall("drugbank-id"):
        if i.attrib.get("primary", False) == "true":
            drug_id = proc(i.text)
    for di in drug.find("drug-interactions").findall("drug-interaction"):
        DDI[drug_id][proc(di.find("drugbank-id").text)] = proc(di.find("description").text)

DDI_json = json.dumps(DDI, indent=4)
with open("./DDI.json", "w") as f:
    f.write(DDI_json)

drugbank中所有的药物及相关信息

{drug_id: {"names":[names], "description":description}}

In [None]:
import xml.etree.ElementTree as et
import json
from tqdm.notebook import tqdm


tree = et.parse("drugbank_full database.xml")
all_drug = tree.findall("drug")
proc = lambda text: text.replace("\n", "").replace("\t", "").replace("\r", "").strip()
fail = []
ans = dict()

for drug in tqdm(all_drug):
    temp = {"id":None, "name":None, "desc":None}
    names = []
    # id
    for id in drug.findall("drugbank-id"):
        if "primary" in id.attrib.keys() and id.attrib["primary"]=="true":
            temp["id"] = id.text
            break
    if type(temp["id"]) is not str and len(temp["id"]) == 0:
        fail.append(drug)
        continue
    # desc
    desc = ""
    ## description
    if drug.find("description").text:
        desc += proc(drug.find("description").text)
        desc += "\n"
    ## indication
    if drug.find("indication").text:
        desc += proc(drug.find("indication").text)
        desc += "\n"
    ## products
    products = []
    for product in drug.find("products").findall("product"):
        if product.find("name").text:
            p = proc(product.find("name").text)
            if p not in names:
                names.append(p)
        else:
            continue
        if product.find("dosage-form").text:
            p += ' {}'.format(proc(product.find("dosage-form").text).replace(",", ""))
        if product.find("strength").text:
            p += ' {}'.format(proc(product.find("strength").text))
        if product.find("route").text:
            p += ' {}'.format(proc(product.find("route").text))
        products.append(p)
    if len(products) > 0:
        products = "; ".join(list(set(products)))
        products = products[:-2]
        desc += products
        desc += "\n"
    ## food-interactions
    fis = ""
    for fi in drug.find("food-interactions").findall("food-interaction"):
        if fi.text:
            fis += proc(fi.text)
            fis += "; "
    if len(fis) > 0:
        desc += fis[:-2]
        desc += "\n"
    # name
    ## synonyms
    for syn in drug.find("synonyms").findall("synonym"):
        if syn.attrib.get("language", False) == "english":
            names.append(proc(syn.text))
    ## name
    names.append(proc(drug.find("name").text))
    temp["name"] = names

    temp["desc"] = desc

    ans[temp["id"]] = {"names":temp["name"], "description":temp["desc"]}

ans = json.dumps(ans, indent=4)
with open("drugs.json", "w", encoding="utf-8") as f:
    f.write(ans)

提取mimic数据集中循环系统大类中涉及的所有药物

In [None]:
from tqdm.notebook import tqdm
import pandas as pd


pres = pd.read_csv("./data/prescriptions.csv", usecols=["subject_id", 'hadm_id', 'drug'])
# 60271
dcs = pd.read_csv("./data/Diseases of the circulatory system.csv")

all_drugs = set()

for i in tqdm(range(len(dcs))):
    subject_id = dcs.iloc[i]["subject_id"]
    hadm_id = dcs.iloc[i]["hadm_id"]
    all_drugs = all_drugs.union(set(pres[(pres["subject_id"] == subject_id) & (pres["hadm_id"] == hadm_id)]["drug"].tolist()))

all_drugs = list(all_drugs)


import json

with open("./data/mimic_drugs.json", "w") as f:
    f.write(json.dumps(all_drugs))

## 获取药物训练数据（没有ndc数据前）

爬取mimic药物描述

In [None]:
from src.retriever.bing_retriever import BingRetriever
from src.utils.arguments import ModelArguments
import json
import urllib3
from tqdm.notebook import tqdm


urllib3.disable_warnings()

with open("./data/mimic_drugs_filter.json", "r") as f:
    all_drugs: list = json.load(f)

r = BingRetriever(ModelArguments("./"))

drug_desc = dict()

for drug_i, drug in enumerate(tqdm(all_drugs)):
    drug_desc[drug] = r(drug)
    if drug_i % 100 == 0:
        with open("./data/mimic_drug_search_results.json", "w") as f:
            f.write(json.dumps(drug_desc))

with open("./data/mimic_drug_search_results.json", "w") as f:
    f.write(json.dumps(drug_desc))

针对搜到的结果太少的药物重爬一遍

In [None]:
import json
from tqdm.notebook import tqdm
from src.retriever.bing_retriever import BingRetriever
from src.utils.arguments import ModelArguments
import urllib3

urllib3.disable_warnings()

with open("./data/mimic_drug_search_results.json", "r") as f:
    drug_desc = json.load(f)

r = BingRetriever(ModelArguments("./"))
drugs = [k for k,v in drug_desc.items() if len(v) < 50]

for drug in tqdm(drugs):
    drug_desc[drug] = r(drug)

看看PDD数据里能用的有多少

In [None]:
import json


with open("./data/mimicdrug2drugbank.json", "r") as f:
    mimic2drugbank = json.load(f)

with open("./data/mimic_drugs.json", "r") as f:
    all_drugs = json.load(f)

fail = []
for drug in all_drugs:
    if mimic2drugbank.get(drug, False):
        continue
    else:
        fail.append(drug)

with open("./data/mimic_drugs_filter.json", "w") as f:
    f.write(json.dumps(fail))

len(fail), len(all_drugs), len(mimic2drugbank), len(set(mimic2drugbank.values()))

(2137, 3989, 3276, 978)

对爬到的药物取topk

In [None]:
import json
import torch
from src.model.bgem3 import M3ForScore
from src.utils.arguments import ModelArguments
from transformers import HfArgumentParser
from tqdm import tqdm


with open("./data/mimic_drug_search_results.json", "r") as f:
    drugs = json.load(f)

with torch.no_grad():
    parser = HfArgumentParser((ModelArguments, ))
    model_args = parser.parse_dict({"model_path":"./checkpoint/m3/"})[0]
    model = M3ForScore(model_args, device="cuda", batch_size=4)

    for drug, data_list in tqdm(drugs.items()):
        drugs[drug] = model(drug, data_list, 5)

with open("./data/mimic_drug_desc.json", "w") as f:
    f.write(json.dumps(drugs))

对drugbank里的药物做嵌入

In [None]:
from src.model.bgem3 import M3ForInference
from src.utils.arguments import ModelArguments
from transformers import HfArgumentParser
from tqdm.notebook import tqdm
import json
import torch

with open("./data/drugs.json") as f:
    drugs = json.load(f)

with torch.no_grad():
    parser = HfArgumentParser((ModelArguments, ))
    model_args = parser.parse_dict({"model_path":"./checkpoint/m3/"})[0]
    model = M3ForInference(model_args, device="cuda")

    drug_embed = []
    drug_embed_id = []
    for k,v in drugs.items():
        if len(v["description"])>50:
            drug_embed.append(v["description"])
            drug_embed_id.append(k)

    results = []
    batch_size = 4
    for i in tqdm(range(0, len(drug_embed), batch_size)):
        if i+batch_size>=len(drug_embed):
            results.append(model(drug_embed[i:]))
        else:
            results.append(model(drug_embed[i:i+batch_size]))
    results = torch.cat(results, dim=0).detach().cpu()

torch.save(results, "./data/drugs_embed.pt")

with open("./data/drugs_embed_id.json", "w") as f:
    f.write(json.dumps(drug_embed_id))

对爬取到的mimic的药物描述做嵌入

In [None]:
from src.model.bgem3 import M3ForInference
from src.utils.arguments import ModelArguments
from transformers import HfArgumentParser
from tqdm import tqdm
import json
import torch

with open("./data/mimic_drug_desc.json", "r") as f:
    drugs = json.load(f)

with torch.no_grad():
    parser = HfArgumentParser((ModelArguments, ))
    model_args = parser.parse_dict({"model_path":"./checkpoint/m3/"})[0]
    model = M3ForInference(model_args, device="cuda")

    results = []
    for k,v in tqdm(drugs.items()):
        results.append(model(k+" "+v))
    results = torch.stack(results, dim=0).detach().cpu()

torch.save(results, "./data/mimic_drugs_embed.pt")

with open("./data/mimic_drugs_embed_id.json", "w") as f:
    f.write(json.dumps(list(drugs.keys())))

根据嵌入结果先在mimic内部消歧

In [None]:
import faiss
import torch
import numpy as np
import json


mimic_drugs = torch.load("./data/mimic_drugs_embed.pt").numpy()
with open("./data/mimic_drugs_embed_id.json", "r") as f:
    mimic_drugs_id = np.array(json.load(f))

index = faiss.IndexFlatIP(mimic_drugs.shape[-1])
index.add(mimic_drugs)
# print(index.ntotal)



import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

D, I = index.search(mimic_drugs, 10)

sim_drug = []
all_sim_drug = []
for idx, (d, i) in enumerate(zip(D, I)):
    sim = mimic_drugs_id[i[d>0.85]][1:]
    if len(sim) > 0 and idx not in all_sim_drug:
        sim_drug.append(i[d>0.85].tolist())
        all_sim_drug.extend(i[d>0.85].tolist())
sim_drug

[[0, 936, 15],
 [2, 1503, 1209],
 [8, 2108, 1674],
 [9, 2017],
 [14, 1186],
 [17, 1093, 548],
 [18, 1433, 36],
 [19, 855],
 [22, 2090, 1150, 2022],
 [27, 157],
 [29, 485, 1756, 198],
 [31, 2072],
 [34, 72, 125],
 [42, 401],
 [44, 1461],
 [48, 317],
 [51, 1193],
 [52, 183, 677, 1376],
 [55, 1181],
 [59, 1860, 185],
 [61, 1515],
 [66, 1243, 734],
 [71, 1455],
 [79, 280],
 [80, 948, 267],
 [83, 970, 1833, 1749, 1961],
 [84, 380, 1934, 465],
 [86, 1041],
 [91, 1699],
 [94, 525],
 [95, 496],
 [96, 1386],
 [99, 1593, 1261],
 [107, 221, 245],
 [108, 1783],
 [109, 662],
 [111, 861],
 [112, 878, 226, 959],
 [113, 1114],
 [114, 1779, 1341, 854],
 [120, 1949, 371],
 [128, 259, 1413],
 [131, 714],
 [132, 1887],
 [136, 254, 1780],
 [137, 1814],
 [139, 1322, 638, 414, 1459, 1830],
 [140, 2097, 1992],
 [144, 421],
 [148, 543],
 [150, 448, 1154],
 [152, 1197],
 [154, 408],
 [155, 1538, 1405, 1358, 804],
 [164, 1208],
 [165, 1778, 1384, 708, 1144, 1942, 1567],
 [166, 1836],
 [170, 972, 1912],
 [172, 16

In [None]:
mimic_drugs_id[[186, 717, 1130, 271, 1391, 1204, 760, 469, 1297]]

array(['Aspiri', 'aspi', 'Aspirin 81mg TAB', 'Aspirin 81mg',
       'Aspirin 81mg (1cap)', 'Aspirin 325 mg', 'Aspirin 325mg', 'Aspir',
       'Aspirin 325'], dtype='<U84')

mimic-drugbank药物对齐

In [None]:
import faiss
import torch
import numpy as np
import json


drugbank_drugs = torch.load("./data/drugs_embed.pt").numpy()
with open("./data/drugs_embed_id.json", "r") as f:
    drugbank_drugs_id = json.load(f)

index = faiss.IndexFlatIP(drugbank_drugs.shape[-1])
index.add(drugbank_drugs)
# print(index.ntotal)



import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

D, I = index.search(mimic_drugs, 1)
D = D.squeeze()
I = I.squeeze()

for sd in sim_drug:
    if 142 in I[sd]:
        print(sd, I[sd], D[sd])
    I[sd] = I[sd][np.argmax(D[sd])]
    D[sd] = np.max(D[sd])

In [None]:
print(f"\
0.9 - 1.0   {np.sum(D>0.9)/len(D):.2f}\n\
0.8 - 0.9   {np.sum((D>0.8) & (D<0.9))/len(D):.2f}\n\
0.7 - 0.8   {np.sum((D>0.7) & (D<0.8))/len(D):.2f}\n\
0.6 - 0.7   {np.sum((D>0.6) & (D<0.7))/len(D):.2f}\n\
0.0 - 0.6   {np.sum(D<0.6)/len(D):.2f}")

0.9 - 1.0   0.03
0.8 - 0.9   0.22
0.7 - 0.8   0.41
0.6 - 0.7   0.31
0.0 - 0.6   0.03


mimic to drugbank

In [None]:
import json


with open("./data/mimicdrug2drugbank.json", "r") as f:
    PDD_map = json.load(f)
with open("./data/mimic_drugs.json", "r") as f:
    all_drugs = json.load(f)

mimic2drugbank = dict()
for drug in all_drugs:
    if PDD_map.get(drug, False):
        mimic2drugbank[drug] = PDD_map[drug]


threshold = lambda x: x>0.7

for idx, (d,i) in enumerate(zip(D, I)):
    if threshold(d):
        mimic2drugbank[mimic_drugs_id[idx]] = drugbank_drugs_id[i]

with open("./data/mimicdrug2drugbank.json", "w") as f:
    f.write(json.dumps(mimic2drugbank))

len(mimic2drugbank) / len(all_drugs)

0.8019553772875407

PDD里面有5个旧drugbank数据库中现已弃用的数据，手动查线上drugbank对应一下

In [None]:
with open("./data/mimicdrug2drugbank.json", "r") as f:
    drugs = json.load(f)
for k,v in drugs.items():
    if v == "DB11122":
        drugs[k] = "DB09255"
    if v == "DB09396":
        drugs[k] = "DB00647"
    if v == "DB09323":
        drugs[k] = "DB01053"
    if v == "DB00021":
        drugs[k] = "DB09532"
    if v == "DB11280":
        drugs[k] = "DB01914"
with open("./data/mimicdrug2drugbank.json", "w") as f:
    f.write(json.dumps(drugs))

把drugbank中要用的药物筛出来

In [19]:
import json

with open("./mimicdrug2drugbank.json", "r") as f:
    drugs = list(set(json.load(f).values()))

with open("./drugs.json", "r") as f:
    drugbank_drugs = json.load(f)

drugbank_drugs_filtered = {}
for d in drugs:
    if drugbank_drugs[d]["description"]:
        drugbank_drugs_filtered[d] = drugbank_drugs[d]

with open("./drugs_filtered.json", "w") as f:
    f.write(json.dumps(dict(sorted(drugbank_drugs_filtered.items(), key=lambda obj: obj[0][2:]))))

## 获取药物训练数据（使用ndc数据）

按照其代码处理prescription

In [None]:
import numpy as np
import pandas as pd


med = pres = pd.read_csv("./raw_data/prescriptions.csv", usecols=['subject_id', 'hadm_id', 'drug', 'ndc'])

def ndc_meds(med:pd.DataFrame, mapping:str, use_map_cols=None) -> pd.DataFrame:
    med.ndc = med.ndc.fillna(-1)
    med.ndc = med.ndc.astype("Int64")

    def to_str(ndc):
        if ndc < 0:
            return np.nan
        ndc = str(ndc)
        return (("0"*(11 - len(ndc))) + ndc)[0:-2]

    def format_ndc_table(ndc):
        parts = ndc.split("-")
        return ("0"*(5 - len(parts[0])) + parts[0]) + ("0"*(4 - len(parts[1])) + parts[1])

    def read_ndc_mapping2(map_path):
        ndc_map = pd.read_csv(map_path, encoding = 'latin1')
        ndc_map.columns = list(map(str.lower, ndc_map.columns))
        return ndc_map

    ndc_map = read_ndc_mapping2(mapping)
    if use_map_cols:
        ndc_map = ndc_map[use_map_cols]

    ndc_map['new_ndc'] = ndc_map.productndc.apply(format_ndc_table)
    ndc_map.drop_duplicates(subset=['new_ndc'], inplace=True)
    med['new_ndc'] = med.ndc.apply(to_str)

    med = med.merge(ndc_map, how='inner', left_on='new_ndc', right_on='new_ndc')
    return med

med = ndc_meds(med, "./MIMIC-IV-Data-Pipeline/utils/mappings/NDC_product_table.csv",
               ['productndc', 'proprietaryname', 'nonproprietaryname', 'substancename']
               )[["subject_id", "hadm_id", 'proprietaryname', 'nonproprietaryname', 'substancename']]

In [None]:
med.to_csv("./mimic_drugs.csv", index=False, encoding="utf-8")

取出指定疾病的所有能用于搜索的药物名称

In [None]:
import json
import pandas as pd
from tqdm.notebook import tqdm


mimic_drugs = pd.read_csv("./mimic_drugs.csv")
with open("./drugs.json", "r") as f:
    drugs = json.load(f)
for k,v in drugs.items():
    drugs[k] = [d.strip().lower() for d in v["names"]]

filtered_diagnosis = pd.read_csv("./raw_data/Diseases of the circulatory system.csv")
filtered_mimic_drugs = pd.merge(filtered_diagnosis, mimic_drugs, how='left', on=["subject_id", "hadm_id"])


fail = []
results = dict()

for row_i in tqdm(range(len(filtered_mimic_drugs))):
    row = filtered_mimic_drugs.iloc[row_i]
    names = []
    for k in ["proprietaryname", "nonproprietaryname", "substancename"]:
        if pd.notna(row[k]):
            names.extend(row[k].strip().lower().split(";"))
    names = list(set(names))
    ans = None
    for name in names:
        if ans:
            break
        for k,v in drugs.items():
            if name in v:
                ans = k
                break
    if not ans:
        fail.append(row_i)
    else:
        results[row_i] = ans

with open("mimic2drugbank.json", "w") as f:
    f.write(json.dumps(results))
with open("mimic2drugbank_fail.json", "w") as f:
    f.write(json.dumps(fail))

在循环系统疾病数据中有4739条数据(0.3%)没有处方数据，可能就是单纯没开药

将结果合并到疾病的表中

In [None]:
import json
import pandas as pd
import numpy as np


with open("mimic2drugbank.json", "r") as f:
    results:dict = json.load(f)

mimic_drugs = pd.read_csv("./mimic_drugs.csv")
filtered_diagnosis = pd.read_csv("./raw_data/Diseases of the circulatory system.csv")
filtered_mimic_drugs = pd.merge(filtered_diagnosis, mimic_drugs, how='left', on=["subject_id", "hadm_id"])

for i in range(len(filtered_mimic_drugs)):
    if not results.get(str(i), False):
        results[str(i)] = np.nan
for i in (set(range(len(filtered_mimic_drugs))) - set(map(int, results.keys()))):
    results[i] = np.nan
k,v = zip(*sorted({int(k):v for k,v in results.items()}.items()))
filtered_mimic_drugs["drug_with_drugbank_id"] = v
filtered_mimic_drugs.to_csv("./diseases_circulatory_system_with_drugbank_id.csv",
                            index=False,
                            columns=["subject_id", "hadm_id", "drug_with_drugbank_id"])

把要用的药物描述择出来保存下

In [None]:
import pandas as pd
import numpy as np
import json

with open("./drugs.json", "r") as f:
    drugs = json.load(f)
mimic_drugs = list(set(pd.read_csv("./diseases_circulatory_system_with_drugbank_id.csv")["drug_with_drugbank_id"].tolist()))
drugs_filtered = dict()

for d in mimic_drugs:
    if not pd.isna(d):
        drugs_filtered[d] = drugs[d]

drugs_filtered = dict(sorted(drugs_filtered.items()))
with open("./drugs_filtered.json", "w") as f:
    f.write(json.dumps(drugs_filtered))

### 药物知识图谱-边训练数据集

修改药物相互作用关系为能够用datasets加载的知识图谱的边数据集

In [1]:
import json
from collections import defaultdict
from typing import Dict, List
from tqdm.notebook import tqdm


with open("./DDI.json", "r") as f:
    ddi:Dict[str, Dict[str, str]] = json.load(f)
with open("./drugs_filtered.json", "r") as f:
    drugs:List[str] = json.load(f).keys()

links = defaultdict(str)
fails = []
for head in tqdm(drugs):
    if ddi.get(head, False):
        for tail, desc in ddi[head].items():
            if (tail in drugs) and (not links[tuple(sorted((head, tail)))]):
                links[tuple(sorted((head, tail)))] = desc
    else:
        fails.append(head)
print(len(fails)/len(drugs))

with open("./links.json", "w") as f:
    f.write(json.dumps([{"entity1":k[0], "entity2":k[1], "description":v} for k, v in links.items()]))

  0%|          | 0/759 [00:00<?, ?it/s]

0.03820816864295125


运行utils/hn_mine.py生成neg

最终

In [21]:
import datasets
import json

def _load_json(file_path) -> dict:
        with open(file_path, "r") as f:
            return json.load(f)

drug_data = _load_json("./drugs_filtered.json")
pos2neg = _load_json("./drugs_neg.json")
link_data = datasets.load_dataset("json", data_files="./links.json", split="train")

In [22]:
drug_data["DB00001"]

{'names': ['Refludan',
  'Hirudin variant-1',
  'Lepirudin recombinant',
  '[Leu1, Thr2]-63-desulfohirudin',
  'Lepirudin',
  'Desulfatohirudin',
  'R-hirudin'],
 'description': 'Lepirudin is a recombinant hirudin formed by 65 amino acids that acts as a highly specific and direct thrombin inhibitor.[L41539,L41569] Natural hirudin is an endogenous anticoagulant found in _Hirudo medicinalis_ leeches.[L41539] Lepirudin is produced in yeast cells and is identical to natural hirudin except for the absence of sulfate on the tyrosine residue at position 63 and the substitution of leucine for isoleucine at position 1 (N-terminal end).[A246609] Lepirudin is used as an anticoagulant in patients with heparin-induced thrombocytopenia (HIT), an immune reaction associated with a high risk of thromboembolic complications.[A3, L41539] HIT is caused by the expression of immunoglobulin G (IgG) antibodies that bind to the complex formed by heparin and platelet factor 4. This activates endothelial cells a

In [23]:
pos2neg["DB00001"]

['DB06812',
 'DB01363',
 'DB01211',
 'DB00478',
 'DB00184',
 'DB00140',
 'DB09512',
 'DB00736',
 'DB00796',
 'DB01001',
 'DB01075',
 'DB00669',
 'DB00216',
 'DB10317',
 'DB00271']

In [24]:
link_data[0]

{'entity2': 'DB06605',
 'entity1': 'DB00001',
 'description': 'Apixaban may increase the anticoagulant activities of Lepirudin.'}

## mimic 训练数据 疾病-药物

In [None]:
import pandas as pd

# 加载带drugbank药物数据的mimic数据 去除空值
drugs = pd.read_csv("./data/diseases_circulatory_system_with_drugbank_id.csv")
drugs = drugs[~drugs.isna()]
drugs["drug_with_drugbank_id"] = drugs["drug_with_drugbank_id"].astype('str')

# 每次就诊时医生开出的所有药物数据聚集在一行里
mimic_data = drugs.groupby(["subject_id", "hadm_id"])["drug_with_drugbank_id"].agg(lambda x: ','.join(x))
discharge = pd.read_csv("./data/mimic-note/discharge.csv", usecols=["subject_id", "hadm_id", "text"])
# 将症状描述表和开药数据表通过就诊id连接
mimic_with_text = pd.merge(discharge, right=mimic_data, how='right', on=["subject_id", "hadm_id"])
# 去除空值
traindata = mimic_with_text[(~mimic_with_text.text.isna()) & (~mimic_with_text.drug_with_drugbank_id.isna())].reset_index(drop=True)
traindata.head()

Unnamed: 0,subject_id,hadm_id,text,drug_with_drugbank_id
0,11000566,28693615,\nName: ___ ___ No: ___\n \nAdm...,"DB00316,DB06605,DB01015,DB00332,DB00653,DB0065..."
1,11000680,27035960,\nName: ___ Unit No: ...,"DB01390,DB09020,DB09153,DB00327,DB00818,DB0092..."
2,11000722,24717411,\nName: ___ Unit No: ___\n...,"DB01234,DB00213,DB01109,DB09020,DB00264,DB0123..."
3,11000743,24317015,\nName: ___ Unit No: ___\n ...,"DB01164,DB01606,DB09153,DB09413,DB00332,DB0050..."
4,11000743,29132287,\nName: ___ Unit No: ___\n ...,"DB09153,DB00512,DB00451,DB09153,DB09020,DB0064..."


In [None]:
import re
import numpy as np

def extract_data(text:str):
    ChiefComplaint = re.findall(r"[ \n]+(Chief|___) Complaint:\n(.*)?[ \n]+", text, re.IGNORECASE)
    if len(ChiefComplaint) != 1:
        ChiefComplaint = np.nan
    else:
        ChiefComplaint = ChiefComplaint[0][1].strip()
    HistoryOfPresentIllness = re.findall(r"History of Present Illness:([\s\S]*)?Past Medical History", text, re.IGNORECASE)
    if len(HistoryOfPresentIllness) != 1:
        HistoryOfPresentIllness = np.nan
    else:
        HistoryOfPresentIllness = HistoryOfPresentIllness[0].replace("\n", "").strip()
        HistoryOfPresentIllness = re.sub(r"\s+", " ", HistoryOfPresentIllness)
    return ChiefComplaint, HistoryOfPresentIllness

# 提取主诉和现病史
results = traindata.text.apply(extract_data)
ChiefComplaint, HistoryOfPresentIllness = zip(*results.tolist())
print("ChiefComplaint空值占比", sum(list(map(lambda x:type(x)!=str, ChiefComplaint)))/len(ChiefComplaint))
print("HistoryOfPresentIllness空值占比", sum(list(map(lambda x:type(x)!=str, HistoryOfPresentIllness)))/len(HistoryOfPresentIllness))
traindata["ChiefComplaint"] = ChiefComplaint
traindata["HistoryOfPresentIllness"] = HistoryOfPresentIllness
traindata = traindata[(~traindata.ChiefComplaint.isna()) & (~traindata.HistoryOfPresentIllness.isna())].reset_index(drop=True)
traindata.drop(columns=["text"]).head()

ChiefComplaint空值占比 0.012735589424336174
HistoryOfPresentIllness空值占比 0.019017974073300813


Unnamed: 0,subject_id,hadm_id,drug_with_drugbank_id,ChiefComplaint,HistoryOfPresentIllness
0,11000566,28693615,"DB00316,DB06605,DB01015,DB00332,DB00653,DB0065...","Cold symptoms, shortness of breath",Ms. ___ is a ___ year old female with hx of NP...
1,11000680,27035960,"DB01390,DB09020,DB09153,DB00327,DB00818,DB0092...",found down,HPI: Patient is a ___ with complex medical his...
2,11000722,24717411,"DB01234,DB00213,DB01109,DB09020,DB00264,DB0123...",___ with right ventricular mass,Mr. ___ is a ___ y/o male with no prior past m...
3,11000743,24317015,"DB01164,DB01606,DB09153,DB09413,DB00332,DB0050...",respiratory distress,"___ yo M w/ h/o Down's syndrome, non-verbal at..."
4,11000743,29132287,"DB09153,DB00512,DB00451,DB09153,DB09020,DB0064...","cough, respiratory distress","___ yo M w/ h/o Down's syndrome, non-verbal at..."


In [None]:
traindata.to_csv("./data/mimic_train_data.csv", index=False)

In [None]:
# M3ForInference(
#   (model): XLMRobertaModel(
#     (embeddings): XLMRobertaEmbeddings(
#       (word_embeddings): Embedding(250002, 1024, padding_idx=1)
#       (position_embeddings): Embedding(8194, 1024, padding_idx=1)
#       (token_type_embeddings): Embedding(1, 1024)
#       (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
#       (dropout): Dropout(p=0.1, inplace=False)
#     )
#     (encoder): XLMRobertaEncoder(
#       (layer): ModuleList(
#         (0-23): 24 x XLMRobertaLayer(
#           (attention): XLMRobertaAttention(
#             (self): XLMRobertaSelfAttention(
#               (query): Linear(in_features=1024, out_features=1024, bias=True)
#               (key): Linear(in_features=1024, out_features=1024, bias=True)
#               (value): Linear(in_features=1024, out_features=1024, bias=True)
#               (dropout): Dropout(p=0.1, inplace=False)
#             )
#             (output): XLMRobertaSelfOutput(
#               (dense): Linear(in_features=1024, out_features=1024, bias=True)
#               (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
#               (dropout): Dropout(p=0.1, inplace=False)
#             )
#           )
#           (intermediate): XLMRobertaIntermediate(
#             (dense): Linear(in_features=1024, out_features=4096, bias=True)
#             (intermediate_act_fn): GELUActivation()
#           )
#           (output): XLMRobertaOutput(
#             (dense): Linear(in_features=4096, out_features=1024, bias=True)
#             (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
#             (dropout): Dropout(p=0.1, inplace=False)
#           )
#         )
#       )
#     )
#     (pooler): XLMRobertaPooler(
#       (dense): Linear(in_features=1024, out_features=1024, bias=True)
#       (activation): Tanh()
#     )
#   )
#   (cross_entropy): CrossEntropyLoss()
# )

## 疾病知识图谱

### 获取mimic iv中疾病所需数据

In [None]:
import pandas as pd

all_disease = pd.read_csv("./raw_data/Diseases of the circulatory system.csv", index_col=0)
icd2name = pd.read_csv("./raw_data/d_icd_diagnoses.csv")
# 疾病当前的seq_num和最大的seq_num 算占比作为权重 计算嵌入向量
id2icd = pd.read_csv("./raw_data/diagnosis_icd.csv")

循环系统疾病

In [None]:
used_disease = pd.merge(all_disease, id2icd, how='left')
used_disease.head()

Unnamed: 0,subject_id,hadm_id,icd_code,icd_version,seq_num
0,11000566,28693615,B974,10,1
1,11000566,28693615,I2699,10,2
2,11000566,28693615,D89813,10,3
3,11000566,28693615,C9202,10,4
4,11000566,28693615,T865,10,5


In [None]:
used_disease = pd.merge(used_disease, used_disease.groupby(["subject_id", "hadm_id"])[["seq_num"]].max(), on=["subject_id", "hadm_id"])

有81种疾病暂时没找到对应的icd10

In [None]:
icd9to10 = pd.read_table("./raw_data/ICD9_to_ICD10_mapping.txt", sep='\t')
icd9to10 = {a.replace('.', ''):b for a, b in zip(icd9to10['diagnosis_code'].tolist(), icd9to10['icd10cm'].tolist())}

# res = used_disease[used_disease['icd_version']==9]['icd_code'].map(icd9to10)
# len(set(used_disease.iloc[res[res.isna()].index]['icd_code'].tolist())), len(set(res.tolist()))

转换能找到对应icd10编码的

In [None]:
origin = used_disease[used_disease['icd_version']==9]['icd_code'].tolist()
res = used_disease[used_disease['icd_version']==9]['icd_code'].map(icd9to10).tolist()

used_disease.loc[used_disease['icd_version']==9, 'icd_code'] = [b if b!='NoDx' and pd.notna(b) else a for a,b in zip(origin, res)]
used_disease.loc[used_disease['icd_version']==9, 'icd_version'] = [10 if i!='NoDx' and pd.notna(i) else 9 for i in res]

如果没找到icd9对应的icd10也先保存原样了

icd10中循环系统疾病为编号I开头的icd_code, 这里只筛选了icd-10的疾病数据

In [None]:
id2name = pd.merge(used_disease, icd2name, how='left')
id2name = id2name[id2name['icd_version']==10][id2name[id2name['icd_version']==10]['icd_code'].map(lambda x: x.startswith('I'))].drop_duplicates()
id2name = id2name[["subject_id", "hadm_id", "icd_code", "long_title", "seq_num_x", "seq_num_y"]].rename(columns={"seq_num_x": "seq_num", "seq_num_y":"total_seq_num"})
id2name

Unnamed: 0,subject_id,hadm_id,icd_code,long_title,seq_num,total_seq_num
1,11000566,28693615,I2699,Other pulmonary embolism without acute cor pul...,2,11
57,11000680,27035960,I120,Hypertensive chronic kidney disease with stage...,3,8
96,11000722,24717411,I619,"Nontraumatic intracerebral hemorrhage, unspeci...",2,5
97,11000722,24717411,I97811,Intraoperative cerebrovascular infarction duri...,3,5
126,11000743,24317015,I959,"Hypotension, unspecified",7,23
...,...,...,...,...,...,...
4138192,13180748,28208689,I10,Essential (primary) hypertension,3,10
4138240,13180830,20263113,I2510,Atherosclerotic heart disease of native corona...,1,10
4138242,13180830,20263113,I129,Hypertensive chronic kidney disease with stage...,3,10
4138295,13181123,20624663,I10,Essential (primary) hypertension,6,11


对于60247就诊记录的162875条数据

In [None]:
len(id2name), len(id2name[['subject_id', 'hadm_id']].drop_duplicates())

(162875, 60247)

In [None]:
id2name.to_csv("./processed_data/subject_id_to_icd&name.csv", index=False)

所有疾病列表

In [None]:
import pandas as pd

data = pd.read_csv("./processed_data/subject_id_to_icd&name.csv")[['icd_code', 'long_title']].drop_duplicates()
data = data.sort_values('icd_code')
data.to_csv("./processed_data/all_disease.csv", index=False)

### 构建为所需数据文件格式

获取icd对应的主诉

In [None]:
import pandas as pd
import json
from collections import defaultdict
from itertools import combinations
from tqdm.auto import tqdm

mimic_data = pd.read_csv("./data/mimic_train_data.csv", usecols=['subject_id', 'hadm_id', 'ChiefComplaint'])
icd_data = pd.read_csv("data/subject_id_to_icd&name.csv", usecols=['subject_id', 'hadm_id', 'icd_code', 'seq_num', 'total_seq_num', 'long_title'])

tmp = pd.merge(icd_data, mimic_data, on=['subject_id', 'hadm_id'])
# tmp.head()

res = defaultdict(list)
for name, group in tqdm(tmp.groupby('icd_code')):
    group = group[group['ChiefComplaint'].notna()]
    for i in range(len(group)):
        res[name].append({"ChiefComplaint": group.iloc[i]['ChiefComplaint'], 'weight': f"{group.iloc[i]['seq_num']}/{group.iloc[i]['total_seq_num']}"})

with open("./data/icd2text.json", 'w', encoding='utf-8') as f:
    f.write(json.dumps(res))

In [None]:
tmp[tmp["ChiefComplaint"].notna()][['icd_code', 'long_title']].drop_duplicates().sort_values('icd_code').to_csv("./data/all_disease.csv", index=False)

保存kg

In [None]:
nodes = tmp["icd_code"].drop_duplicates().tolist()
links = defaultdict(int)

tmp = tmp[tmp["ChiefComplaint"].notna()]
for name, group in tmp.groupby(['subject_id', 'hadm_id']):
    for comb in combinations(group['icd_code'].tolist(), 2):
        links[tuple(sorted(comb))] += 1
nodes = sorted(nodes)
graph_data = {'nodes': [{'id': i, 'label': node} for i, node in enumerate(nodes)],
              'links': [{'source': nodes.index(k[0]), 'target': nodes.index(k[1]), 'weight': v} for k, v in links.items()]}

print(len(nodes), len(links.keys()))

with open("./cache/disease_graph.json", 'w', encoding='utf-8') as f:
    f.write(json.dumps(graph_data))

按seq_num加权计算embedding vector

In [None]:
import json
import torch
from torch.nn.functional import softmax
from tqdm.auto import tqdm
from src.model.bgem3 import M3ForInference
from src.utils.arguments import ModelArguments


with open("./data/icd2text.json", encoding="utf-8") as f:
    data = json.load(f)

args = ModelArguments(encode_sub_batch_size=-1)
model = M3ForInference(args)

with torch.no_grad():
    res = []
    for icd_code, v in tqdm(data.items()):
        if len(v) > 1:
            # (batch_size, 1)
            weight = softmax(torch.tensor([1-eval(d["weight"]) for d in v], device="cuda"), dim=0).unsqueeze_(1)
            # (batch_size, embed_dim)
            embed = model([d["ChiefComplaint"] for d in v])
            res.append(torch.sum(embed * weight, dim=0))
        else:
            embed = model(v[0]["ChiefComplaint"])
            res.append(embed)
    res = torch.stack(res)
    print(f"embedding matrix shape: {res.shape}, node num: {len(data)}")
    torch.save(res, "./data/disease_weight_embed.pt")

按搜索结果计算embedding vector

In [None]:
from src.model.bgem3 import M3ForInference
from src.utils.arguments import ModelArguments
from tqdm.auto import tqdm
import torch
import json


args = ModelArguments(encode_sub_batch_size=-1)
model = M3ForInference(args)
with torch.no_grad():
    with open("./data/icd2text.json", encoding='utf-8') as f:
        data:dict = json.load(f)
    res = []
    for k,v in tqdm(data.items()):
        res.append(model(",".join(v)).detach().to('cpu'))
    res = torch.stack(v)
    torch.save(res, "./data/embeddings/disease_embed.pt")