根据icd10大类将mimic中的疾病分类保存

In [None]:
import pandas as pd
import numpy as np
from collections import defaultdict
from tqdm.notebook import tqdm


icd9toicd10 = pd.read_csv('icd9toicd10cmgem.csv', usecols=['icd9cm', 'icd10cm'])
map9to10 = defaultdict(str)
for i in range(len(icd9toicd10)):
    icd9 = icd9toicd10.iloc[i]['icd9cm']
    icd10 = icd9toicd10.iloc[i]['icd10cm']
    while icd9.startswith('0'):
        icd9 = icd9[1:]
    if len(map9to10[icd9]) == 0 or map9to10[icd9] == 'NoDx':
        map9to10[icd9] = icd10
np.save('icd9toicd10.npy', map9to10)

def get_icd10_type(id: str) -> str:
    if id[0] in ['A', 'B']:
        return 'Certain infectious and parasitic diseases'
    elif id[0] == 'C':
        return 'Neoplasms'
    elif id[0] == 'D':
        if id[1:3].isdecimal() and int(id[1:3]) <= 49:
            return 'Neoplasms'
        else:
            return 'Diseases of the blood and blood-forming organs and certain disorders involving the immune mechanism'
    elif id[0] == 'E':
        return 'Endocrine, nutritional and metabolic diseases'
    elif id[0] == 'F':
        return 'Mental, Behavioral and Neurodevelopmental disorders'
    elif id[0] == 'G':
        return 'Diseases of the nervous system'
    elif id[0] == 'H':
        if id[1:3].isdecimal() and int(id[1:3]) <= 59:
            return 'Diseases of the eye and adnexa'
        else:
            return 'Diseases of the ear and mastoid process'
    elif id[0] == 'I':
        return 'Diseases of the circulatory system'
    elif id[0] == 'J':
        return 'Diseases of the respiratory system'
    elif id[0] == 'K':
        return 'Diseases of the digestive system'
    elif id[0] == 'L':
        return 'Diseases of the skin and subcutaneous tissue'
    elif id[0] == 'M':
        return 'Diseases of the musculoskeletal system and connective tissue'
    elif id[0] == 'N':
        return 'Diseases of the genitourinary system'
    elif id[0] == 'O':
        return 'Pregnancy, childbirth and the puerperium'
    elif id[0] == 'P':
        return 'Certain conditions originating in the perinatal period'
    elif id[0] == 'Q':
        return 'Congenital malformations, deformations and chromosomal abnormalities'
    elif id[0] == 'R':
        return 'Symptoms, signs and abnormal clinical and laboratory findings, not elsewhere classified'
    elif id[0] in ['S', 'T']:
        return 'Injury, poisoning and certain other consequences of external causes'
    elif id[0] == 'U':
        return 'Codes for special purposes'
    elif id[0] in ['V', 'W', 'X', 'Y']:
        return 'External causes of morbidity'
    elif id[0] == 'Z':
        return 'Factors influencing health status and contact with health services'

diagnosis = pd.read_csv('diagnosis_icd.csv')
all_type = defaultdict(list)
for i in tqdm(range(len(diagnosis))):
    code, version = diagnosis.iloc[i]['icd_code'], diagnosis.iloc[i]['icd_version']
    if version == 9:
        code = icd9toicd10[code]
    if len(code) == 0 or code == 'NoDx':
        continue
    all_type[get_icd10_type(code)].append([diagnosis.iloc[i]['subject_id'], diagnosis.iloc[i]['hadm_id']])
for k in all_type.keys():
    df = pd.DataFrame(all_type[k], columns=['subject_id', 'hadm_id'])
    df.to_csv(f'diagnosis/{k}.csv', index=False)
    print(k, len(all_type[k]))

提取药物相互作用关系

In [None]:
import xml.etree.ElementTree as et
import json
from collections import defaultdict
from tqdm.notebook import tqdm


tree = et.parse("drugbank_full database.xml")
all_drug = tree.findall("drug")
proc = lambda text: text.replace("\n", "").replace("\t", "").replace("\r", "").strip()
DDI = defaultdict(dict)

for drug in tqdm(all_drug):
    for i in drug.findall("drugbank-id"):
        if i.attrib.get("primary", False) == "true":
            drug_id = proc(i.text)
    for di in drug.find("drug-interactions").findall("drug-interaction"):
        DDI[drug_id][proc(di.find("drugbank-id").text)] = proc(di.find("description").text)

DDI_json = json.dumps(DDI, indent=4)
with open("./DDI.json", "w") as f:
    f.write(DDI_json)

修改药物相互作用关系为能够用datasets加载的知识图谱的边数据集

In [None]:
import json
from collections import defaultdict
from tqdm.notebook import tqdm

with open("./DDI.json", "r") as f:
    ddi:dict = json.load(f)

links = defaultdict(str)
total_len = 0
for h, ts in tqdm(ddi.items()):
    for t, d in ts.items():
        if not links[tuple(sorted((h, t)))]:
            links[tuple(sorted((h, t)))] = d
    total_len += len(ts)
print(len(links)/total_len)

with open("./links.json", "w") as f:
    f.write(json.dumps([{"entity1":k[0], "entity2":k[1], "description":v} for k, v in links.items()]))

drugbank中所有的药物及相关信息

{drug_id: {"names":[names], "description":description}}

In [None]:
import xml.etree.ElementTree as et
import json
from tqdm.notebook import tqdm


tree = et.parse("drugbank_full database.xml")
all_drug = tree.findall("drug")
proc = lambda text: text.replace("\n", "").replace("\t", "").replace("\r", "").strip()
fail = []
ans = dict()

for drug in tqdm(all_drug):
    temp = {"id":None, "name":None, "desc":None}
    names = []
    # id
    for id in drug.findall("drugbank-id"):
        if "primary" in id.attrib.keys() and id.attrib["primary"]=="true":
            temp["id"] = id.text
            break
    if type(temp["id"]) is not str and len(temp["id"]) == 0:
        fail.append(drug)
        continue
    # desc
    desc = ""
    ## description
    if drug.find("description").text:
        desc += proc(drug.find("description").text)
        desc += "\n"
    ## indication
    if drug.find("indication").text:
        desc += proc(drug.find("indication").text)
        desc += "\n"
    ## products
    products = []
    for product in drug.find("products").findall("product"):
        if product.find("name").text:
            p = proc(product.find("name").text)
            if p not in names:
                names.append(p)
        else:
            continue
        if product.find("dosage-form").text:
            p += ' {}'.format(proc(product.find("dosage-form").text).replace(",", ""))
        if product.find("strength").text:
            p += ' {}'.format(proc(product.find("strength").text))
        if product.find("route").text:
            p += ' {}'.format(proc(product.find("route").text))
        products.append(p)
    if len(products) > 0:
        products = "; ".join(list(set(products)))
        products = products[:-2]
        desc += products
        desc += "\n"
    ## food-interactions
    fis = ""
    for fi in drug.find("food-interactions").findall("food-interaction"):
        if fi.text:
            fis += proc(fi.text)
            fis += "; "
    if len(fis) > 0:
        desc += fis[:-2]
        desc += "\n"
    # name
    ## synonyms
    for syn in drug.find("synonyms").findall("synonym"):
        if syn.attrib.get("language", False) == "english":
            names.append(proc(syn.text))
    ## name
    names.append(proc(drug.find("name").text))
    temp["name"] = names

    temp["desc"] = desc

    ans[temp["id"]] = {"names":temp["name"], "description":temp["desc"]}

ans = json.dumps(ans, indent=4)
with open("drugs.json", "w", encoding="utf-8") as f:
    f.write(ans)

提取mimic数据集中循环系统大类中涉及的所有药物

In [None]:
from tqdm.notebook import tqdm
import pandas as pd


pres = pd.read_csv("./data/prescriptions.csv", usecols=["subject_id", 'hadm_id', 'drug'])
# 60271
dcs = pd.read_csv("./data/Diseases of the circulatory system.csv")

all_drugs = set()

for i in tqdm(range(len(dcs))):
    subject_id = dcs.iloc[i]["subject_id"]
    hadm_id = dcs.iloc[i]["hadm_id"]
    all_drugs = all_drugs.union(set(pres[(pres["subject_id"] == subject_id) & (pres["hadm_id"] == hadm_id)]["drug"].tolist()))

all_drugs = list(all_drugs)


import json

with open("./data/mimic_drugs.json", "w") as f:
    f.write(json.dumps(all_drugs))

爬取mimic药物描述

In [None]:
from src.retriever.bing_retriever import BingRetriever
from src.utils.arguments import ModelArguments
import json
import urllib3
from tqdm.notebook import tqdm


urllib3.disable_warnings()

with open("./data/mimic_drugs_filter.json", "r") as f:
    all_drugs: list = json.load(f)

r = BingRetriever(ModelArguments("./"))

drug_desc = dict()

for drug_i, drug in enumerate(tqdm(all_drugs)):
    drug_desc[drug] = r(drug)
    if drug_i % 100 == 0:
        with open("./data/mimic_drug_search_results.json", "w") as f:
            f.write(json.dumps(drug_desc))

with open("./data/mimic_drug_search_results.json", "w") as f:
    f.write(json.dumps(drug_desc))

针对搜到的结果太少的药物重爬一遍

In [None]:
import json
from tqdm.notebook import tqdm
from src.retriever.bing_retriever import BingRetriever
from src.utils.arguments import ModelArguments
import urllib3

urllib3.disable_warnings()

with open("./data/mimic_drug_search_results.json", "r") as f:
    drug_desc = json.load(f)

r = BingRetriever(ModelArguments("./"))
drugs = [k for k,v in drug_desc.items() if len(v) < 50]

for drug in tqdm(drugs):
    drug_desc[drug] = r(drug)

看看PDD数据里能用的有多少

In [None]:
import json


with open("./data/mimicdrug2drugbank.json", "r") as f:
    mimic2drugbank = json.load(f)

with open("./data/mimic_drugs.json", "r") as f:
    all_drugs = json.load(f)

fail = []
for drug in all_drugs:
    if mimic2drugbank.get(drug, False):
        continue
    else:
        fail.append(drug)

with open("./data/mimic_drugs_filter.json", "w") as f:
    f.write(json.dumps(fail))

len(fail), len(all_drugs), len(mimic2drugbank), len(set(mimic2drugbank.values()))

(2137, 3989, 3276, 978)

对爬到的药物取topk

In [None]:
import json
import torch
from src.model.bgem3 import M3ForScore
from src.utils.arguments import ModelArguments
from transformers import HfArgumentParser
from tqdm import tqdm


with open("./data/mimic_drug_search_results.json", "r") as f:
    drugs = json.load(f)

with torch.no_grad():
    parser = HfArgumentParser((ModelArguments, ))
    model_args = parser.parse_dict({"model_path":"./checkpoint/m3/"})[0]
    model = M3ForScore(model_args, device="cuda", batch_size=4)

    for drug, data_list in tqdm(drugs.items()):
        drugs[drug] = model(drug, data_list, 5)

with open("./data/mimic_drug_desc.json", "w") as f:
    f.write(json.dumps(drugs))

对drugbank里的药物做嵌入

In [None]:
from src.model.bgem3 import M3ForInference
from src.utils.arguments import ModelArguments
from transformers import HfArgumentParser
from tqdm.notebook import tqdm
import json
import torch

with open("./data/drugs.json") as f:
    drugs = json.load(f)

with torch.no_grad():
    parser = HfArgumentParser((ModelArguments, ))
    model_args = parser.parse_dict({"model_path":"./checkpoint/m3/"})[0]
    model = M3ForInference(model_args, device="cuda")

    drug_embed = []
    drug_embed_id = []
    for k,v in drugs.items():
        if len(v["description"])>50:
            drug_embed.append(v["description"])
            drug_embed_id.append(k)

    results = []
    batch_size = 4
    for i in tqdm(range(0, len(drug_embed), batch_size)):
        if i+batch_size>=len(drug_embed):
            results.append(model(drug_embed[i:]))
        else:
            results.append(model(drug_embed[i:i+batch_size]))
    results = torch.cat(results, dim=0).detach().cpu()

torch.save(results, "./data/drugs_embed.pt")

with open("./data/drugs_embed_id.json", "w") as f:
    f.write(json.dumps(drug_embed_id))

对爬取到的mimic的药物描述做嵌入

In [None]:
from src.model.bgem3 import M3ForInference
from src.utils.arguments import ModelArguments
from transformers import HfArgumentParser
from tqdm import tqdm
import json
import torch

with open("./data/mimic_drug_desc.json", "r") as f:
    drugs = json.load(f)

with torch.no_grad():
    parser = HfArgumentParser((ModelArguments, ))
    model_args = parser.parse_dict({"model_path":"./checkpoint/m3/"})[0]
    model = M3ForInference(model_args, device="cuda")

    results = []
    for k,v in tqdm(drugs.items()):
        results.append(model(k+" "+v))
    results = torch.stack(results, dim=0).detach().cpu()

torch.save(results, "./data/mimic_drugs_embed.pt")

with open("./data/mimic_drugs_embed_id.json", "w") as f:
    f.write(json.dumps(list(drugs.keys())))

根据嵌入结果先在mimic内部消歧

In [None]:
import faiss
import torch
import numpy as np
import json


mimic_drugs = torch.load("./data/mimic_drugs_embed.pt").numpy()
with open("./data/mimic_drugs_embed_id.json", "r") as f:
    mimic_drugs_id = np.array(json.load(f))

index = faiss.IndexFlatIP(mimic_drugs.shape[-1])
index.add(mimic_drugs)
# print(index.ntotal)



import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

D, I = index.search(mimic_drugs, 10)

sim_drug = []
all_sim_drug = []
for idx, (d, i) in enumerate(zip(D, I)):
    sim = mimic_drugs_id[i[d>0.85]][1:]
    if len(sim) > 0 and idx not in all_sim_drug:
        sim_drug.append(i[d>0.85].tolist())
        all_sim_drug.extend(i[d>0.85].tolist())
sim_drug

[[0, 936, 15],
 [2, 1503, 1209],
 [8, 2108, 1674],
 [9, 2017],
 [14, 1186],
 [17, 1093, 548],
 [18, 1433, 36],
 [19, 855],
 [22, 2090, 1150, 2022],
 [27, 157],
 [29, 485, 1756, 198],
 [31, 2072],
 [34, 72, 125],
 [42, 401],
 [44, 1461],
 [48, 317],
 [51, 1193],
 [52, 183, 677, 1376],
 [55, 1181],
 [59, 1860, 185],
 [61, 1515],
 [66, 1243, 734],
 [71, 1455],
 [79, 280],
 [80, 948, 267],
 [83, 970, 1833, 1749, 1961],
 [84, 380, 1934, 465],
 [86, 1041],
 [91, 1699],
 [94, 525],
 [95, 496],
 [96, 1386],
 [99, 1593, 1261],
 [107, 221, 245],
 [108, 1783],
 [109, 662],
 [111, 861],
 [112, 878, 226, 959],
 [113, 1114],
 [114, 1779, 1341, 854],
 [120, 1949, 371],
 [128, 259, 1413],
 [131, 714],
 [132, 1887],
 [136, 254, 1780],
 [137, 1814],
 [139, 1322, 638, 414, 1459, 1830],
 [140, 2097, 1992],
 [144, 421],
 [148, 543],
 [150, 448, 1154],
 [152, 1197],
 [154, 408],
 [155, 1538, 1405, 1358, 804],
 [164, 1208],
 [165, 1778, 1384, 708, 1144, 1942, 1567],
 [166, 1836],
 [170, 972, 1912],
 [172, 16

In [None]:
mimic_drugs_id[[186, 717, 1130, 271, 1391, 1204, 760, 469, 1297]]

array(['Aspiri', 'aspi', 'Aspirin 81mg TAB', 'Aspirin 81mg',
       'Aspirin 81mg (1cap)', 'Aspirin 325 mg', 'Aspirin 325mg', 'Aspir',
       'Aspirin 325'], dtype='<U84')

mimic-drugbank药物对齐

In [None]:
import faiss
import torch
import numpy as np
import json


drugbank_drugs = torch.load("./data/drugs_embed.pt").numpy()
with open("./data/drugs_embed_id.json", "r") as f:
    drugbank_drugs_id = json.load(f)

index = faiss.IndexFlatIP(drugbank_drugs.shape[-1])
index.add(drugbank_drugs)
# print(index.ntotal)



import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

D, I = index.search(mimic_drugs, 1)
D = D.squeeze()
I = I.squeeze()

for sd in sim_drug:
    if 142 in I[sd]:
        print(sd, I[sd], D[sd])
    I[sd] = I[sd][np.argmax(D[sd])]
    D[sd] = np.max(D[sd])

In [None]:
print(f"\
0.9 - 1.0   {np.sum(D>0.9)/len(D):.2f}\n\
0.8 - 0.9   {np.sum((D>0.8) & (D<0.9))/len(D):.2f}\n\
0.7 - 0.8   {np.sum((D>0.7) & (D<0.8))/len(D):.2f}\n\
0.6 - 0.7   {np.sum((D>0.6) & (D<0.7))/len(D):.2f}\n\
0.0 - 0.6   {np.sum(D<0.6)/len(D):.2f}")

0.9 - 1.0   0.03
0.8 - 0.9   0.22
0.7 - 0.8   0.41
0.6 - 0.7   0.31
0.0 - 0.6   0.03


mimic to drugbank

In [None]:
import json


with open("./data/mimicdrug2drugbank.json", "r") as f:
    PDD_map = json.load(f)
with open("./data/mimic_drugs.json", "r") as f:
    all_drugs = json.load(f)

mimic2drugbank = dict()
for drug in all_drugs:
    if PDD_map.get(drug, False):
        mimic2drugbank[drug] = PDD_map[drug]


threshold = lambda x: x>0.7

for idx, (d,i) in enumerate(zip(D, I)):
    if threshold(d):
        mimic2drugbank[mimic_drugs_id[idx]] = drugbank_drugs_id[i]

with open("./data/mimicdrug2drugbank.json", "w") as f:
    f.write(json.dumps(mimic2drugbank))

len(mimic2drugbank) / len(all_drugs)

0.8019553772875407

PDD里面有5个旧drugbank数据库中现已弃用的数据，手动查线上drugbank对应一下

In [None]:
with open("./data/mimicdrug2drugbank.json", "r") as f:
    drugs = json.load(f)
for k,v in drugs.items():
    if v == "DB11122":
        drugs[k] = "DB09255"
    if v == "DB09396":
        drugs[k] = "DB00647"
    if v == "DB09323":
        drugs[k] = "DB01053"
    if v == "DB00021":
        drugs[k] = "DB09532"
    if v == "DB11280":
        drugs[k] = "DB01914"
with open("./data/mimicdrug2drugbank.json", "w") as f:
    f.write(json.dumps(drugs))

把drugbank中要用的药物筛出来

In [None]:
import json

with open("./data/mimicdrug2drugbank.json", "r") as f:
    drugs = list(set(json.load(f).values()))

with open("./data/drugs.json", "r") as f:
    drugbank_drugs = json.load(f)

drugbank_drugs_filtered = {}
for d in drugs:
    drugbank_drugs_filtered[d] = drugbank_drugs[d]

with open("./data/drugs_filtered.json", "w") as f:
    f.write(json.dumps(dict(sorted(drugbank_drugs_filtered.items(), key=lambda obj: obj[0][2:]))))

组织成符合要求的格式

In [None]:
import json

with open("./data/drugs_filtered.json", "r") as f:
    drugs:dict = json.load(f)

with open("./data/drugs_dataset.json", "w") as f:
    for k,v in drugs.items():
        f.write(json.dumps({"query":v["names"][0], "pos":[v["description"].replace("\n", "")]}) + "\n")

运行utils/hn_mine.py生成neg