In [1]:
import sys
sys.executable

'/usr/bin/python3'

In [1]:
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from tqdm import tqdm
import gc
from sklearn import preprocessing
import copy
import os
import random
import json
import pickle

In [5]:
data_path = "/home/anonymous/ehr_dataset"
# 'HADM_ID': encounter id
# 'SUBJECT_ID': patient id

In [34]:
file_list = [
    "ADMISSIONS.csv.gz",
    "PRESCRIPTIONS.csv",
    "LABEVENTS.csv",
    "PATIENTS.csv",
    "DIAGNOSES_ICD.csv",
    "PROCEDURES_ICD"
]

In [6]:
med_file = os.path.join(data_path, "PRESCRIPTIONS.csv.gz")
procedure_file = os.path.join(data_path, "PROCEDURES_ICD.csv.gz")
diag_file = os.path.join(data_path, "DIAGNOSES_ICD.csv.gz")
admission_file = os.path.join(data_path, "ADMISSIONS.csv.gz")
lab_test_file = os.path.join(data_path, "LABEVENTS.csv.gz")
patient_file = os.path.join(data_path, "PATIENTS.csv.gz")

In [7]:
# drug code mapping files from GAMENet repo
# https://github.com/sjy1203/GAMENet/tree/master/data
ndc2atc_file = '/home/anonymous/GAMENet/data/ndc2atc_level4.csv' 
cid_atc = '/home/anonymous/GAMENet/data/drug-atc.csv'
ndc2rxnorm_file = '/home/anonymous/GAMENet/data/ndc2rxnorm_mapping.txt'

# Data preprocessing

In [8]:
def process_med():
    med_pd = pd.read_csv(med_file, dtype={'NDC':'category'})
    # filter
    med_pd.drop(columns=['ROW_ID','DRUG_TYPE','DRUG_NAME_POE','DRUG_NAME_GENERIC',
                     'FORMULARY_DRUG_CD','GSN','PROD_STRENGTH','DOSE_VAL_RX',
                     'DOSE_UNIT_RX','FORM_VAL_DISP','FORM_UNIT_DISP','FORM_UNIT_DISP',
                      'ROUTE','ENDDATE','DRUG'], axis=1, inplace=True)
    med_pd.drop(index = med_pd[med_pd['NDC'] == '0'].index, axis=0, inplace=True)
    med_pd.fillna(method='pad', inplace=True)
    med_pd.dropna(inplace=True)
    med_pd.drop_duplicates(inplace=True)
    med_pd['ICUSTAY_ID'] = med_pd['ICUSTAY_ID'].astype('int64')
    med_pd['STARTDATE'] = pd.to_datetime(med_pd['STARTDATE'], format='%Y-%m-%d %H:%M:%S')    
    med_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'ICUSTAY_ID', 'STARTDATE'], inplace=True)
    med_pd = med_pd.reset_index(drop=True)
    
    def filter_first24hour_med(med_pd):
        med_pd_new = med_pd.drop(columns=['NDC'])
        med_pd_new = med_pd_new.groupby(by=['SUBJECT_ID','HADM_ID','ICUSTAY_ID']).head(1).reset_index(drop=True)
        med_pd_new = pd.merge(med_pd_new, med_pd, on=['SUBJECT_ID','HADM_ID','ICUSTAY_ID','STARTDATE'])
        med_pd_new = med_pd_new.drop(columns=['STARTDATE'])
        return med_pd_new
    med_pd = filter_first24hour_med(med_pd) # or next line
#     med_pd = med_pd.drop(columns=['STARTDATE'])
    
    med_pd = med_pd.drop(columns=['ICUSTAY_ID'])
    med_pd = med_pd.drop_duplicates() 
    return med_pd.reset_index(drop=True)

def process_procedure():
    pro_pd = pd.read_csv(procedure_file, dtype={'ICD9_CODE':'category'})
    pro_pd.drop(columns=['ROW_ID'], inplace=True)
    pro_pd.drop_duplicates(inplace=True)
    pro_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'SEQ_NUM'], inplace=True)
    pro_pd.drop(columns=['SEQ_NUM'], inplace=True)
    pro_pd["ICD9_CODE"] = "PRO_" + pro_pd["ICD9_CODE"].astype(str)
    pro_pd.drop_duplicates(inplace=True)
    pro_pd.reset_index(drop=True, inplace=True)
    return pro_pd

def process_diag():
    diag_pd = pd.read_csv(diag_file)
    diag_pd.dropna(inplace=True)
    diag_pd.drop(columns=['SEQ_NUM','ROW_ID'],inplace=True)
    diag_pd.drop_duplicates(inplace=True)
    diag_pd.sort_values(by=['SUBJECT_ID','HADM_ID'], inplace=True)
    return diag_pd.reset_index(drop=True)

def process_admission():
    ad_pd = pd.read_csv(admission_file)
    patient_pd = pd.read_csv(patient_file)
    
    ad_pd.drop(columns=['ROW_ID', 'ADMISSION_TYPE', 'ADMISSION_LOCATION',
       'DISCHARGE_LOCATION', 'INSURANCE', 'LANGUAGE', 'RELIGION',
       'MARITAL_STATUS', 'ETHNICITY', 'EDREGTIME', 'EDOUTTIME', 'DIAGNOSIS',
       'HOSPITAL_EXPIRE_FLAG', 'HAS_CHARTEVENTS_DATA'], axis=1, inplace=True)
    
    patient_pd.drop(columns=['ROW_ID','DOD','DOD_HOSP','DOD_SSN',"EXPIRE_FLAG"], axis=1, inplace=True)
    
    ad_pd["ADMITTIME"] = pd.to_datetime(ad_pd['ADMITTIME'], format='%Y-%m-%d %H:%M:%S')
    ad_pd["DISCHTIME"] = pd.to_datetime(ad_pd['DISCHTIME'], format='%Y-%m-%d %H:%M:%S')  # time for leaving hospital
    patient_pd['DOB'] = pd.to_datetime(patient_pd['DOB'], format='%Y-%m-%d %H:%M:%S')  # birthday
    ad_pd = ad_pd.merge(patient_pd, on=['SUBJECT_ID'], how='inner')
    
    # create features: age, death, number of days in this encounter, readmission (next visit)
    ad_pd["AGE"] = ad_pd['ADMITTIME'].dt.year - ad_pd['DOB'].dt.year
    ad_pd[ad_pd["AGE"] >= 300] = 90
    age = ad_pd["AGE"]
    bins = np.linspace(age.min(), age.max(), 20 + 1)
    ad_pd['AGE'] = pd.cut(age, bins=bins, labels=False, include_lowest=True)
    ad_pd['AGE'] = "AGE_" + ad_pd["AGE"].astype("str")
    
    ad_pd["DEATH"] = ad_pd["DEATHTIME"].notna()
    ad_pd["STAY_DAYS"] = (ad_pd["DISCHTIME"] - ad_pd["ADMITTIME"]).astype('timedelta64[h]').dt.days

    ad_pd['ADMITTIME'] = ad_pd['ADMITTIME'].astype(str)
    ad_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'ADMITTIME'], inplace=True)
    ad_pd['next_visit'] = ad_pd.groupby('SUBJECT_ID')['HADM_ID'].shift(-1)
    ad_pd['READMISSION'] = ad_pd['next_visit'].notnull().astype(int)
    ad_pd.drop('next_visit', axis=1, inplace=True)

    ad_pd.drop(columns=['DISCHTIME', 'DOB', 'DEATHTIME'], axis=1, inplace=True)
    ad_pd.drop_duplicates(inplace=True)
    return ad_pd.reset_index(drop=True)

def process_lab_test(n_bins=5):
    lab_pd = pd.read_csv(lab_test_file)
    lab_pd = lab_pd.groupby(by=['SUBJECT_ID','ITEMID']).head(1).reset_index(drop=True)  # only consider the first value
    lab_pd = lab_pd[lab_pd["VALUENUM"].notna()]
    lab_pd = lab_pd[lab_pd["HADM_ID"].notna()]
    
    lab_pd.drop(columns=['ROW_ID'], axis=1, inplace=True)
    
    def contains_text(group):
        for item in group:
            try:
                float(item)
            except ValueError:
                return True
        return False

    for itemid in lab_pd['ITEMID'].unique():
        group = lab_pd[lab_pd['ITEMID'] == itemid]['VALUE']

        # if the lab test contains text value then directly copy the value
        if contains_text(group):
            lab_pd.loc[lab_pd['ITEMID'] == itemid, 'value_bin'] = group
        else:
            # value->numeric
            values_numeric = pd.to_numeric(group, errors='coerce')

            if len(values_numeric.dropna()) < n_bins:
                lab_pd.loc[lab_pd['ITEMID'] == itemid, 'value_bin'] = group
            else:
                lab_pd.loc[lab_pd['ITEMID'] == itemid, 'value_bin'] = pd.qcut(values_numeric, q=n_bins, labels=False, duplicates='drop')
        
    lab_pd["ITEMID"] = lab_pd["ITEMID"].astype(str)
    lab_pd["value_bin"] = lab_pd["value_bin"].astype(str)
    lab_pd["LAB_TEST"] = lab_pd[["ITEMID", "value_bin"]].apply("-".join, axis=1)
    
    lab_pd.drop(columns=['CHARTTIME', 'VALUE', 'VALUENUM', 'VALUEUOM', 'FLAG', 'value_bin', 'ITEMID'], axis=1, inplace=True)
    lab_pd.drop_duplicates(inplace=True)
    return lab_pd.reset_index(drop=True)

In [9]:
def ndc2atc4(med_pd):
    with open(ndc2rxnorm_file, 'r') as f:
        ndc2rxnorm = eval(f.read())
    med_pd['RXCUI'] = med_pd['NDC'].map(ndc2rxnorm)
    med_pd.dropna(inplace=True)

    rxnorm2atc = pd.read_csv(ndc2atc_file)
    rxnorm2atc = rxnorm2atc.drop(columns=['YEAR','MONTH','NDC'])
    rxnorm2atc.drop_duplicates(subset=['RXCUI'], inplace=True)
    med_pd.drop(index = med_pd[med_pd['RXCUI'].isin([''])].index, axis=0, inplace=True)
    
    med_pd['RXCUI'] = med_pd['RXCUI'].astype('int64')
    med_pd = med_pd.reset_index(drop=True)
    med_pd = med_pd.merge(rxnorm2atc, on=['RXCUI'])
    med_pd.drop(columns=['NDC', 'RXCUI'], inplace=True)
    med_pd = med_pd.rename(columns={'ATC4':'NDC'})
    med_pd['NDC'] = med_pd['NDC'].map(lambda x: x[:4])
    med_pd = med_pd.drop_duplicates()    
    return med_pd.reset_index(drop=True)

def filter_most_pro(pro_pd, threshold=800):
    pro_count = pro_pd.groupby(by=['ICD9_CODE']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
    pro_pd = pro_pd[pro_pd['ICD9_CODE'].isin(pro_count.loc[:threshold, 'ICD9_CODE'])]
    return pro_pd.reset_index(drop=True)

def filter_most_diag(diag_pd, threshold=2000):
    diag_count = diag_pd.groupby(by=['ICD9_CODE']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
    diag_pd = diag_pd[diag_pd['ICD9_CODE'].isin(diag_count.loc[:threshold, 'ICD9_CODE'])]
    return diag_pd.reset_index(drop=True)

def filter_most_lab(lab_pd, threshold=1500):
    lab_count = lab_pd.groupby(by=['LAB_TEST']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
    lab_pd = lab_pd[lab_pd['LAB_TEST'].isin(lab_count.loc[:threshold, 'LAB_TEST'])]
    return lab_pd.reset_index(drop=True)

In [10]:
from datetime import timedelta

def process_all():
    print('process_med')
    med_pd = process_med()
    med_pd = ndc2atc4(med_pd)

    print('process_diag')
    diag_pd = process_diag()
    diag_pd = filter_most_diag(diag_pd)

    print('process_pro')
    pro_pd = process_procedure()
    pro_pd = filter_most_pro(pro_pd)

    print('process_ad')
    ad_pd = process_admission()
    
    print('process_lab')
    lab_pd = process_lab_test()
    lab_pd = filter_most_lab(lab_pd)
    
    med_pd_key = med_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    diag_pd_key = diag_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    pro_pd_key = pro_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    lab_pd_key = lab_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    ad_pd_key = ad_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    
    # filter key
    combined_key = med_pd_key.merge(diag_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    combined_key = combined_key.merge(pro_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    combined_key = combined_key.merge(lab_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    combined_key = combined_key.merge(ad_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    diag_pd = diag_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    med_pd = med_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    pro_pd = pro_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    lab_pd = lab_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    ad_pd = ad_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')

    # flatten and merge
    diag_pd = diag_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index()  
    med_pd = med_pd.groupby(by=['SUBJECT_ID', 'HADM_ID'])['NDC'].unique().reset_index()
    pro_pd = pro_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index().rename(columns={'ICD9_CODE':'PRO_CODE'})  
    lab_pd = lab_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['LAB_TEST'].unique().reset_index()
    
    med_pd['NDC'] = med_pd['NDC'].map(lambda x: list(x))
    pro_pd['PRO_CODE'] = pro_pd['PRO_CODE'].map(lambda x: list(x))
    lab_pd['LAB_TEST'] = lab_pd['LAB_TEST'].map(lambda x: list(x))
    
    data = diag_pd.merge(med_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    data = data.merge(pro_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    data = data.merge(lab_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    data = data.merge(ad_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
#     data['ICD9_CODE_Len'] = data['ICD9_CODE'].map(lambda x: len(x))
#     data['NDC_Len'] = data['NDC'].map(lambda x: len(x))

    data = data.sort_values(by=['SUBJECT_ID', 'ADMITTIME'])
    
    # create feature: readmission within 30/90 days
    data['ADMITTIME'] = pd.to_datetime(data['ADMITTIME'])
    data['READMISSION_1M'] = data.groupby('SUBJECT_ID')['ADMITTIME'].shift(-1) - data['ADMITTIME']
    data['READMISSION_3M'] = data['READMISSION_1M'].apply(lambda x: 1 if x <= timedelta(days=90) else 0)
    data['READMISSION_1M'] = data['READMISSION_1M'].apply(lambda x: 1 if x <= timedelta(days=30) else 0)
    
    # create feature: diease in next 6/12 months    
    data['NEXT_DIAG_6M'] = data.apply(lambda x: data[(data['SUBJECT_ID'] == x['SUBJECT_ID']) & 
                                              (data['ADMITTIME'] > x['ADMITTIME']) & 
                                              (data['ADMITTIME'] <= x['ADMITTIME'] + timedelta(days=180))]['ICD9_CODE'].tolist(), axis=1)
    data['NEXT_DIAG_12M'] = data.apply(lambda x: data[(data['SUBJECT_ID'] == x['SUBJECT_ID']) & 
                                                   (data['ADMITTIME'] > x['ADMITTIME']) & 
                                                   (data['ADMITTIME'] <= x['ADMITTIME'] + timedelta(days=365))]['ICD9_CODE'].tolist(), axis=1)
    data['NEXT_DIAG_6M'] = data['NEXT_DIAG_6M'].apply(lambda x: x[0] if x else float('nan'))
    data['NEXT_DIAG_12M'] = data['NEXT_DIAG_12M'].apply(lambda x: x[0] if x else float('nan'))
    
    data.drop(columns=['ADMITTIME'], axis=1, inplace=True)
    return data.reset_index(drop=True)

In [11]:
def statistics(data):
    print('#patients ', data['SUBJECT_ID'].unique().shape)
    print('#clinical events ', len(data))
    
    diag = data['ICD9_CODE'].values
    med = data['NDC'].values
    pro = data['PRO_CODE'].values
    lab_test = data['LAB_TEST'].values
    
    unique_diag = set([j for i in diag for j in list(i)])
    unique_med = set([j for i in med for j in list(i)])
    unique_pro = set([j for i in pro for j in list(i)])
    unique_lab = set([j for i in lab_test for j in list(i)])
    
    print('#diagnosis ', len(unique_diag))
    print('#med ', len(unique_med))
    print('#procedure', len(unique_pro))
    print('#lab', len(unique_lab))
    
    avg_diag = avg_med = avg_pro = avg_lab = 0
    max_diag = max_med = max_pro = max_lab = 0
    cnt = max_visit = avg_visit = 0

    for subject_id in data['SUBJECT_ID'].unique():
        item_data = data[data['SUBJECT_ID'] == subject_id]
        x, y, z, k = [], [], [], []
        visit_cnt = 0
        for index, row in item_data.iterrows():
            visit_cnt += 1
            cnt += 1
            x.extend(list(row['ICD9_CODE']))
            y.extend(list(row['NDC']))
            z.extend(list(row['PRO_CODE']))
            k.extend(list(row['LAB_TEST']))
        x, y, z, k = set(x), set(y), set(z), set(k)
        avg_diag += len(x)
        avg_med += len(y)
        avg_pro += len(z)
        avg_lab += len(k)
        avg_visit += visit_cnt
        if len(x) > max_diag:
            max_diag = len(x)
        if len(y) > max_med:
            max_med = len(y) 
        if len(z) > max_pro:
            max_pro = len(z)
        if len(k) > max_lab:
            max_lab = len(k)
        if visit_cnt > max_visit:
            max_visit = visit_cnt

    print('#avg of diagnoses ', avg_diag/ cnt)
    print('#avg of medicines ', avg_med/ cnt)
    print('#avg of procedures ', avg_pro/ cnt)
    print('#avg of lab_test ', avg_lab/ cnt)
    print('#avg of vists ', avg_visit/ len(data['SUBJECT_ID'].unique()))

    print('#max of diagnoses ', max_diag)
    print('#max of medicines ', max_med)
    print('#max of procedures ', max_pro)
    print('#max of lab_test ', max_lab)
    print('#max of visit ', max_visit)

In [13]:
data = process_all()

process_med


  exec(code_obj, self.user_global_ns, self.user_ns)


process_diag
process_pro
process_ad
process_lab


In [14]:
statistics(data)

#patients  (33582,)
#clinical events  40770
#diagnosis  1998
#med  149
#procedure 801
#lab 1500
#avg of diagnoses  10.765979887171941
#avg of medicines  11.248099092469953
#avg of procedures  4.459234731420162
#avg of lab_test  41.59244542555801
#avg of vists  1.2140432374486332
#max of diagnoses  122
#max of medicines  53
#max of procedures  48
#max of lab_test  169
#max of visit  15


In [15]:
with open("./dataset/mimic.pkl", "wb") as outfile:
    pickle.dump(data, outfile)

# Dataset split

In [17]:
mimic_data = pickle.load(open("./dataset/mimic.pkl", 'rb'))

In [18]:
mimic_data.columns

Index(['SUBJECT_ID', 'HADM_ID', 'ICD9_CODE', 'NDC', 'PRO_CODE', 'LAB_TEST',
       'GENDER', 'AGE', 'DEATH', 'STAY_DAYS', 'READMISSION', 'READMISSION_1M',
       'READMISSION_3M', 'NEXT_DIAG_6M', 'NEXT_DIAG_12M'],
      dtype='object')

In [84]:
mimic_data.head(10)

Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE,NDC,PRO_CODE,LAB_TEST,GENDER,AGE,DEATH,STAY_DAYS,READMISSION,READMISSION_1M,READMISSION_3M,NEXT_DIAG_6M,NEXT_DIAG_12M
0,2,163353,"[V3001, V053, V290]",[J01C],[PRO_9955],"[51143-0.0, 51144-0.0, 51146-0.0, 51200-0.0, 5...",M,AGE_0,False,3,0,0,0,,
1,4,185777,"[042, 1363, 7994, 2763, 7907, 5715, 04111, V09...","[A10A, R05D, P01B, A12B, N03A, A07E, D10A, H03...","[PRO_3893, PRO_8872, PRO_3323]","[50810-1.0, 50811-1.0, 50815-2.0, 50816-0.0, 5...",F,AGE_10,False,7,0,0,0,,
2,6,107064,"[40391, 4440, 9972, 2766, 2767, 2859, 2753, V1...","[A06A, L04A, A12C, C07A, C08C, B01A, D11A, A03...","[PRO_5569, PRO_0091, PRO_3957, PRO_3806, PRO_9...","[50971-4.0, 50983-0.0, 50986-4.0, 51006-4.0, 5...",F,AGE_14,False,16,0,0,0,,
3,8,159514,"[V3001, 7706, 7746, V290, V502, V053]","[J01C, D06A]","[PRO_9390, PRO_640, PRO_9955]","[51143-0.0, 51144-0.0, 51146-0.0, 51200-4.0, 5...",M,AGE_0,False,4,0,0,0,,
4,9,150750,"[431, 5070, 4280, 5849, 2765, 4019]","[B05C, A12C, J01M, A01A, C02D, N01A, A02B, A12A]","[PRO_9672, PRO_9604]","[50819-0.0, 50820-2.0, 50821-3.0, 50826-2.0, 5...",M,AGE_9,True,4,0,0,0,,
5,10,184167,"[V3000, 7742, 76525, 76515, V290]",[J01C],"[PRO_9983, PRO_9915, PRO_966]","[51143-0.0, 51144-0.0, 51146-0.0, 51200-4.0, 5...",F,AGE_0,False,8,0,0,0,,
6,11,194540,[1913],"[A06A, A02B, N02B, A01A, A04A, N03A]","[PRO_0159, PRO_0113, PRO_9229, PRO_9925]","[50868-1.0, 50882-2.0, 50902-2.0, 50912-0.0, 5...",F,AGE_11,False,25,0,0,0,,
7,12,112213,"[1570, 57410, 9971, 4275, 99811, 4019, 5680, 5...","[B01A, A12C, B05C, N01A, A02B, A12A, A01A, C07...","[PRO_5137, PRO_5459, PRO_5351, PRO_9915, PRO_3...","[50802-0.0, 50804-0.0, 50806-4.0, 50808-4.0, 5...",M,AGE_15,True,12,0,0,0,,
8,13,143045,"[41401, 4111, 25000, 4019, 2720]","[A12B, N02B, A06A, A02B, C07A, B05C, C02D, N01...","[PRO_3612, PRO_3615, PRO_3961, PRO_3761, PRO_8...","[50809-3.0, 50818-0.0, 50820-1.0, 50821-3.0, 5...",F,AGE_8,False,6,0,0,0,,
9,17,194023,"[7455, 45829, V1259, 2724]","[A06A, A02B, N02B, B05C, C02D, A12C, A12A, A10...","[PRO_3571, PRO_3961, PRO_8872]","[50960-3.0, 50970-0.0, 50971-4.0, 50983-0.0, 5...",F,AGE_10,False,4,0,0,0,"[4239, 5119, 78551, 4589, 311, 7220, 71946, 2724]","[4239, 5119, 78551, 4589, 311, 7220, 71946, 2724]"


In [4]:
len(mimic_data["SUBJECT_ID"].unique())

33582

In [104]:
# get the patient id with the labels
pat_readmission = set(mimic_data[mimic_data["READMISSION_1M"] == 1]["SUBJECT_ID"].values.tolist())
print(len(pat_readmission))
pat_nextdiag_6m = set(mimic_data[mimic_data["NEXT_DIAG_6M"].notna()]["SUBJECT_ID"].values.tolist())
print(len(pat_nextdiag_6m))
pat_nextdiag_12m = set(mimic_data[mimic_data["NEXT_DIAG_12M"].notna()]["SUBJECT_ID"].values.tolist())
print(len(pat_nextdiag_12m))
pat_death = set(mimic_data[mimic_data["DEATH"]]["SUBJECT_ID"].values.tolist())
print(len(pat_death))
pat_all_label = list(pat_readmission | pat_nextdiag_6m | pat_nextdiag_12m | pat_death)

1121
2967
3501
4420


In [22]:
pat_all = mimic_data["SUBJECT_ID"].unique().tolist()

In [23]:
n_pretrain_patient = int(len(pat_all) * 0.7)
pretrain_patient = np.random.choice(list(set(pat_all) - set(pat_all_label)), n_pretrain_patient, replace=False).tolist()
downstream_patient = list(set(pat_all) - set(pretrain_patient))
print(len(pretrain_patient), len(downstream_patient))

23507 10075


In [24]:
pretrain_dataset = mimic_data[mimic_data["SUBJECT_ID"].isin(set(pretrain_patient))]

In [37]:
# downstream task: PLOS, death prediction, readmission
train_ratio, val_ratio = 0.2, 0.4
n_finetune_pat, n_val_pat = int(len(downstream_patient) * train_ratio), int(len(downstream_patient) * val_ratio)
np.random.shuffle(downstream_patient)
finetune_pat, val_pat, test_pat = downstream_patient[:n_finetune_pat], \
                                    downstream_patient[n_finetune_pat:n_finetune_pat+n_val_pat], \
                                    downstream_patient[n_finetune_pat+n_val_pat:]
finetune_dataset = mimic_data[mimic_data["SUBJECT_ID"].isin(set(finetune_pat))]
val_dataset = mimic_data[mimic_data["SUBJECT_ID"].isin(set(val_pat))]
test_dataset = mimic_data[mimic_data["SUBJECT_ID"].isin(set(test_pat))]

In [38]:
# downstream task: next diagnosis prediction
# only use the patient that has multiple visits
# 6m
train_ratio, val_ratio = 0.4, 0.3
n_finetune_pat, n_val_pat = int(len(pat_nextdiag_6m) * train_ratio), int(len(pat_nextdiag_6m) * val_ratio)
pat_nextdiag_6m = list(pat_nextdiag_6m)
np.random.shuffle(pat_nextdiag_6m)
finetune_pat, val_pat, test_pat = pat_nextdiag_6m[:n_finetune_pat], \
                                    pat_nextdiag_6m[n_finetune_pat:n_finetune_pat+n_val_pat], \
                                    pat_nextdiag_6m[n_finetune_pat+n_val_pat:]
finetune_dataset6m = mimic_data[mimic_data["SUBJECT_ID"].isin(set(finetune_pat))]
val_dataset6m = mimic_data[mimic_data["SUBJECT_ID"].isin(set(val_pat))]
test_dataset6m = mimic_data[mimic_data["SUBJECT_ID"].isin(set(test_pat))]
# 12m
n_finetune_pat, n_val_pat = int(len(pat_nextdiag_12m) * train_ratio), int(len(pat_nextdiag_12m) * val_ratio)
pat_nextdiag_12m = list(pat_nextdiag_12m)
np.random.shuffle(pat_nextdiag_12m)
finetune_pat, val_pat, test_pat = pat_nextdiag_12m[:n_finetune_pat], \
                                    pat_nextdiag_12m[n_finetune_pat:n_finetune_pat+n_val_pat], \
                                    pat_nextdiag_12m[n_finetune_pat+n_val_pat:]
finetune_dataset12m = mimic_data[mimic_data["SUBJECT_ID"].isin(set(finetune_pat))]
val_dataset12m = mimic_data[mimic_data["SUBJECT_ID"].isin(set(val_pat))]
test_dataset12m = mimic_data[mimic_data["SUBJECT_ID"].isin(set(test_pat))]

In [39]:
with open("./dataset/mimic_pretrain.pkl", "wb") as outfile:
    pickle.dump(pretrain_dataset, outfile)
with open("./dataset/mimic_downstream.pkl", "wb") as outfile:
    pickle.dump([finetune_dataset, val_dataset, test_dataset], outfile)
with open("./dataset/mimic_nextdiag_6m.pkl", "wb") as outfile:
    pickle.dump([finetune_dataset6m, val_dataset6m, test_dataset6m], outfile)
with open("./dataset/mimic_nextdiag_12m.pkl", "wb") as outfile:
    pickle.dump([finetune_dataset12m, val_dataset12m, test_dataset12m], outfile)