In [1]:
import os
import dill
import random
import jsonlines
import numpy as np
import pandas as pd
from utils import *

In [2]:
class Voc(object):
    '''Define the vocabulary (token) dict'''

    def __init__(self):

        self.idx2word = {}
        self.word2idx = {}

    def add_sentence(self, sentence):
        '''add vocabulary to dict via a list of words'''
        for word in sentence:
            if word not in self.word2idx:
                self.idx2word[len(self.word2idx)] = word
                self.word2idx[word] = len(self.word2idx)

# create voc set
def create_str_token_mapping(df, vocabulary_file):
    diag_voc = Voc()
    med_voc = Voc()
    pro_voc = Voc()

    for index, row in df.iterrows():
        diag_voc.add_sentence(row["icd_code"])
        med_voc.add_sentence(row["ATC3"])
        pro_voc.add_sentence(row["pro_code"])

    dill.dump(
        obj={"diag_voc": diag_voc, "med_voc": med_voc, "pro_voc": pro_voc},
        file=open(vocabulary_file, "wb"),
    )
    return diag_voc, med_voc, pro_voc

## Step 0:

Assign the STAY_ID to precription, procedure and diagnosis tables.

## Step 1:
Preprocess the raw MIMIC-III data as the original medication recommendation works

In [3]:
base_dir = ""   # base folder

## Some auxiliary info, such as DDI, ATC and ICD
RXCUI2atc4_file = os.path.join(base_dir, "./auxiliary/RXCUI2atc4.csv")
cid2atc6_file = os.path.join(base_dir, "./auxiliary/drug-atc.csv")
ndc2RXCUI_file = os.path.join(base_dir, "./auxiliary/ndc2RXCUI.txt")
ddi_file = os.path.join(base_dir, "./auxiliary/drug-DDI.csv")
drugbankinfo = os.path.join(base_dir, "./auxiliary/drugbank_drugs_info.csv")

In [4]:
med_file = os.path.join(base_dir, "./hosp/prescriptions.csv")
diag_file = os.path.join(base_dir, "./hosp/diagnoses_icd.csv")
procedure_file = (
    os.path.join(base_dir, "./hosp/procedures_icd.csv")
)

# input auxiliary files
med_structure_file = os.path.join(base_dir, "./handled/atc32SMILES.pkl")

# output files
ddi_adjacency_file = os.path.join(base_dir, "./handled/ddi_A_final.pkl")
ehr_adjacency_file = os.path.join(base_dir, "./handled/ehr_adj_final.pkl")
ehr_sequence_file = os.path.join(base_dir, "./handled/records_final.pkl")
vocabulary_file = os.path.join(base_dir, "./handled/voc_final.pkl")
ddi_mask_H_file = os.path.join(base_dir, "./handled/ddi_mask_H.pkl")
atc3toSMILES_file = os.path.join(base_dir, "./handled/atc3toSMILES.pkl")

In [6]:
# for med
med_pd = med_process(med_file)  # process the raw file
med_pd_lg2 = process_visit_lg2(med_pd).reset_index(drop=True)   # remain the single-visit
# med_pd_lg2 = process_visit_lg2(med_pd).reset_index(drop=True)   # filter out the patient has less 2 visits
med_pd = med_pd.merge(
    med_pd_lg2[["subject_id"]], on="subject_id", how="inner"
).reset_index(drop=True)

med_pd = codeMapping2atc4(med_pd, ndc2RXCUI_file, RXCUI2atc4_file)
med_pd = filter_300_most_med(med_pd)

# med to SMILES mapping
atc3toDrug = ATC3toDrug(med_pd)
druginfo = pd.read_csv(drugbankinfo)
atc3toSMILES = atc3toSMILES(atc3toDrug, druginfo)
dill.dump(atc3toSMILES, open(atc3toSMILES_file, "wb"))
med_pd = med_pd[med_pd.ATC3.isin(atc3toSMILES.keys())]
print("complete medication processing")

# for diagnosis
diag_pd = diag_process(diag_file)

print("complete diagnosis processing")

# for procedure
pro_pd = procedure_process(procedure_file)
# pro_pd = filter_1000_most_pro(pro_pd)

print("complete procedure processing")

# combine
data = combine_process(med_pd, diag_pd, pro_pd)
print("complete combining")

FileNotFoundError: [Errno 2] No such file or directory: './hosp/prescriptions.csv'

In [6]:
# only use a part of data
np.random.seed(42)
sample_patient = np.random.choice(data["subject_id"].unique(), int(data["subject_id"].unique().shape[0]*0.15), replace=False)
data = data[data["subject_id"].isin(sample_patient)]

In [7]:
statistics(data)

#patients  (8906,)
#clinical events  23046
#diagnosis  1998
#med  125
#procedure 1001
#avg of diagnoses  8.414345222598282
#avg of medicines  7.017399982643409
#avg of procedures  2.1107784431137726
#avg of vists  2.5876936896474287
#max of diagnoses  220
#max of medicines  72
#max of procedures  49
#max of visit  48


In [8]:
# create vocab
diag_voc, med_voc, pro_voc = create_str_token_mapping(data, vocabulary_file)
print("obtain voc")

# create ehr sequence data
records = create_patient_record(data, diag_voc, med_voc, pro_voc, ehr_sequence_file)
print("obtain ehr sequence data")

# create ddi adj matrix
ddi_adj = get_ddi_matrix(records, med_voc, ddi_file, cid2atc6_file, ehr_adjacency_file, ddi_adjacency_file)
print("obtain ddi adj matrix")

# get ddi_mask_H
ddi_mask_H = get_ddi_mask(atc3toSMILES, med_voc)
dill.dump(ddi_mask_H, open(ddi_mask_H_file, "wb"))

obtain voc
obtain ehr sequence data
obtain ddi adj matrix


## Step 2: Get side info
Extract side information of patients from other csv

In [9]:
def get_side(source_df, side_df, side_columns, aligh_column):

    side_df = side_df[side_columns]
    source_df = pd.merge(source_df, side_df, how="left", on=aligh_column)

    return source_df

In [10]:
admission = pd.read_csv("./hosp/admissions.csv")
data = get_side(data, admission, 
                ["hadm_id", "insurance", "language", "admission_type", "marital_status", "race"],
                "hadm_id"
                )

In [11]:
data.fillna(value="unknown", inplace=True)

## Step 3: Map ATC to drugname
Resolve the mapping. In the original preprocessed data, the drug is represented by ATC code, but we need the drugname for LLM.


In [12]:
RXCUI2atc4 = pd.read_csv(RXCUI2atc4_file)
RXCUI2atc4["NDC"] = RXCUI2atc4["NDC"].map(lambda x: x.replace("-", ""))
with open(ndc2RXCUI_file, "r") as f:
    ndc2RXCUI = eval(f.read())

In [13]:
RXCUI2ndc = dict(zip(ndc2RXCUI.values(), ndc2RXCUI.keys()))
RXCUI2atc4["RXCUI"] = RXCUI2atc4["RXCUI"].astype("str")
RXCUI2atc4["NDC"] = RXCUI2atc4["RXCUI"].map(RXCUI2ndc)
RXCUI2atc4.dropna(axis=0, how="any", inplace=True)
RXCUI2atc4.drop_duplicates(inplace=True)

In [14]:
RXCUI2atc4.shape, RXCUI2atc4.nunique()

((32732, 5),
 YEAR       73
 MONTH      12
 NDC      2037
 RXCUI    2037
 ATC4      445
 dtype: int64)

In [15]:
RXCUI2atc4.drop_duplicates(inplace=True)
RXCUI2atc4.shape

(32732, 5)

In [16]:
med_pd.head(5)

Unnamed: 0,subject_id,hadm_id,starttime,drug,ATC3
0,10000032,22595853,2180-05-07 00:00:00,Heparin,B01A
1,10000032,22841357,2180-06-26 22:00:00,Heparin,B01A
2,10000032,25742920,2180-08-06 03:00:00,Heparin,B01A
3,10000032,29079034,2180-07-23 15:00:00,Heparin,B01A
4,10000117,22927623,2181-11-15 13:00:00,Heparin,B01A


In [17]:
med_pd = pd.read_csv(med_file, dtype={"ndc": "category"})
med_pd.columns.values

  med_pd = pd.read_csv(med_file, dtype={"ndc": "category"})


array(['subject_id', 'hadm_id', 'pharmacy_id', 'poe_id', 'poe_seq',
       'order_provider_id', 'starttime', 'stoptime', 'drug_type', 'drug',
       'formulary_drug_cd', 'gsn', 'ndc', 'prod_strength', 'form_rx',
       'dose_val_rx', 'dose_unit_rx', 'form_val_disp', 'form_unit_disp',
       'doses_per_24_hrs', 'route'], dtype=object)

In [18]:
RXCUI2atc4 = RXCUI2atc4.rename({"NDC": "ndc"}, axis=1)

In [19]:
med_pd["ndc"].astype("str")
med_pd = pd.merge(med_pd, RXCUI2atc4, how="left", on="ndc")

In [20]:
atc2drug = pd.read_csv("./auxiliary/WHO ATC-DDD 2021-12-03.csv")
atc2drug["code_len"] = atc2drug["atc_code"].map(lambda x: len(x))
atc2drug = atc2drug[atc2drug["code_len"]==4]    # all levels are included. We only need the 4th level, i.e., ATC4
atc2drug.rename(columns={"atc_code": "ATC4"}, inplace=True)
atc2drug.drop(columns=["ddd", "uom", "adm_r", "note", "code_len"], axis=1, inplace=True)

In [21]:
atc2drug.head(3)

Unnamed: 0,ATC4,atc_name
2,A01A,STOMATOLOGICAL PREPARATIONS
46,A02A,ANTACIDS
79,A02B,DRUGS FOR PEPTIC ULCER AND GASTRO-OESOPHAGEAL ...


In [22]:
RXCUI2atc4["ATC4"] = RXCUI2atc4["ATC4"].map(lambda x: x[:4])

In [23]:
# all atc code in original data can be mapped to drugname by atc2drug.
# means that we use the same data as the traditional medication recommendation models.
pd.merge(RXCUI2atc4, atc2drug, on="ATC4", how="left")["atc_name"].isna().sum()

0

In [24]:
atc2drug["atc_name"] = atc2drug["atc_name"].map(lambda x: x.lower())

In [25]:
# get the atc2drug and drug2atc mapping dict
atc2drug_dict = dict(zip(atc2drug["ATC4"].values, atc2drug["atc_name"].values))
drug2atc_dict = dict(zip(atc2drug["atc_name"].values, atc2drug["ATC4"].values))

In [26]:
import json
json.dump({"atc2drug": atc2drug_dict, "drug2atc": drug2atc_dict}, open("./handled/atc2drug.json", "w"))

In [27]:
# get the diagnosis and procedure mapping dict, which both use the ICD. these mappings are in raw MIMIC dataset
icd2diag = pd.read_csv("./hosp/d_icd_diagnoses.csv")
icd2diag_dict = dict(zip(icd2diag["icd_code"].astype(str).values, icd2diag["long_title"].values))

In [28]:
icd2proc = pd.read_csv("./hosp/d_icd_procedures.csv")
icd2proc_dict = dict(zip(icd2proc["icd_code"].astype(str).values, icd2proc["long_title"].values))

In [29]:
def decode(code_list, decoder):
    # decode a list of code into corresponding names
    miss_match = 0
    target_list = []
    for code in code_list:
        try:
            target_list.append(decoder[code])
        except:
            miss_match += 1
    
    #print(miss_match)

    return target_list

In [30]:
data["drug"] = data["ATC3"].map(lambda x: decode(x, atc2drug_dict))
data["diagnosis"] = data["icd_code"].map(lambda x: decode(x, icd2diag_dict))
data["procedure"] = data["pro_code"].map(lambda x: decode(x, icd2proc_dict))

some miss matches occurs in diagnosis and procedures, but no for drug

In [31]:
data.iloc[1]["pro_code"]

['8669']

In [32]:
def profile_tokenization(df, profile_columns):
    prof_dict = {"word2idx":{}, "idx2word": {}}
    for prof in profile_columns:
        prof_dict["idx2word"][prof] = dict(zip(range(df[prof].nunique()), df[prof].unique()))
        prof_dict["word2idx"][prof] = dict(zip(df[prof].unique(), range(df[prof].nunique())))
    return prof_dict

In [33]:
profile_dict = profile_tokenization(data, ["insurance", "language", "admission_type", "marital_status", "race"])
json.dump(profile_dict, open("./handled/profile_dict.json", "w"))

## Step 4: Construct Prompt
Design the prompt templates and construct the prompt

In [34]:
# prompt templates
main_template = "The patient has <VISIT_NUM> times ICU visits. \n <HISTORY> In this visit, he has diagnosis: <DIAGNOSIS>; procedures: <PROCEDURE>. Then, the patient should be prescribed: "
hist_template = "In <VISIT_NO> visit, the patient had diagnosis: <DIAGNOSIS>; procedures: <PROCEDURE>. The patient was prescribed drugs: <MEDICATION>. \n"

In [35]:
# add some patient's profiles
# main_template = "The patient's insurance type is <INSU>, language is <LANG>, admission type is <ADMTYPE>, marital status is <MARITAL>, race is <RACE>. The patient has <VISIT_NUM> times ICU visits. \n <HISTORY> In this visit, he has diagnosis: <DIAGNOSIS>; procedures: <PROCEDURE>. Then, the patient should be prescribed: "

In [36]:
def concat_str(str_list):
    # concat a list of drug / diagnosis / procedures
    target_str = ""
    for meta_str in str_list:
        target_str = target_str + meta_str + ", "
    target_str = target_str[:-2]    # remove the last comma

    return target_str

In [37]:
llm_data = []

for subject_id in data["subject_id"].unique():
    item_df = data[data["subject_id"] == subject_id]
    visit_num = item_df.shape[0] - 1
    patient = []

    profile = item_df.iloc[0]
    patient_str = main_template.replace("<INSU>", profile["insurance"].lower())\
                               .replace("<LANG>", profile["language"].lower())\
                               .replace("<ADMTYPE>", profile["admission_type"].lower())\
                               .replace("<MARITAL>", profile["marital_status"].lower())\
                               .replace("<RACE>", profile["race"].lower())

    patient_profile = {"insurance": profile["insurance"], "language": profile["language"],
                       "admission_type": profile["admission_type"], "marital_status": profile["marital_status"],
                       "race": profile["race"]}

    # get each historical visit string
    for visit_no, (_, row) in enumerate(item_df.iterrows()):
        drug, diag, proc = concat_str(row["drug"]), concat_str(row["diagnosis"]), concat_str(row["procedure"])
        patient.append(hist_template.replace("<VISIT_NO>", str(visit_no+1))\
                                    .replace("<DIGNOSIS>", diag)\
                                    .replace("<PROCEDURE>", proc)\
                                    .replace("<MEDICATION>", drug))
    patient.pop()   # remove the ground truth record

    # filter out the patients with more than 3 times visits
    if len(patient) > 3:
        patient = patient[-3:]

    # concat all historical visit strings and get hist strings
    hist_str = ""
    for meta_hist in patient:
        hist_str += meta_hist
    
    patient_str = patient_str.replace("<VISIT_NUM>", str(visit_num))\
                             .replace("<HISTORY>", hist_str)\
                             .replace("<DIAGNOSIS>", diag)\
                             .replace("<PROCEDURE>", proc)
    
    drug_code = [str(x) for x in row["ATC3"]]

    hist = {"diagnosis": [], "procedure": [], "medication": []}
    for _, row in item_df.iterrows():
        hist["diagnosis"].append([str(x) for x in row["icd_code"]])
        hist["procedure"].append([str(x) for x in row["pro_code"]])
        hist["medication"].append([str(x) for x in row["ATC3"]])
        
    llm_data.append({"input": patient_str, "target": drug, 
                     "subject_id": int(subject_id), "drug_code": drug_code,
                     "records": hist, "profile": patient_profile})
        

In [38]:
file_path = "./handled/"

def read_data(data_path):
    '''read data from jsonlines file'''
    data = []

    with jsonlines.open(file_path + data_path, "r") as f:
        for meta_data in f:
            data.append(meta_data)

    return data


def save_data(data_path, data):
    '''write all_data list to a new jsonl'''
    with jsonlines.open(file_path + data_path, "w") as w:
        for meta_data in data:
            w.write(meta_data)

In [39]:
# split the dataset: 8:1:1
train_split = int(len(llm_data) * 0.8)
val_split = int(len(llm_data) * 0.1)
train = llm_data[:train_split]
val = llm_data[train_split:train_split+val_split]
test = llm_data[train_split+val_split:]

In [None]:
save_data("train_0104.json", train)
save_data("val_0104.json", val)
save_data("test_0104.json", test)