# Clean & build the datasets

### run this script after the sql extract, it will create .csv files with the data, and the data for NESDE as a .pkl file

#### it will produce the following files:
1. Vancomycin_Dosing_SI_*_hr_dataset.csv -- dataset of normalized discretized features
2. Vancomycin_Dosing_vanco_blood_*_hr_dataset.csv -- dataset of vancomycin measures over irregular time
3. Vancomycin_Dosing_vanco_input_*_hr_dataset.csv -- dataset of vancomycin interventions over irregular time
4. Vancomycin_Dosing_norm_vals.pkl -- normalization coefficients for the features, can be useful to understand the actual values.
5. Vancomycin_Dosing_train.pkl -- trainset for NESDE
6. Vancomycin_Dosing_test.pkl -- testset for NESDE

 **note that this script takes a while to run

In [None]:
import pandas as pd
import numpy as np
import tqdm
import pickle

# path to inputevents table from mimicIV, used for further cleaning (need to be set):
input_events_path = "./inputevents.csv"

# time window (hours) to average the side-information
step_size = 4

# number of hours without vanco_input that would "clean" the vancomycin from the patients' body, we use if for splitting trajectories.
vanco_input_null_step_split = int(48/step_size)

# minimal number of vancomycin blood samples:
min_vanco_blood_smp = 2

# max tajectory time:
max_trj_time = 6000

# minimal # of different values of vancomycin input
min_vanco_input_dlen = 2

# test set split ratio:
test_ratio = 0.3
 

In [None]:
# Some dicts to help working with the data:
d_itemid2name = dict()
d_itemid2name[227454] =	'Vancomycin_random'
d_itemid2name[226065] =	'Vancomycin_random'
d_itemid2name[227455] =	'Vancomycin_trough'
d_itemid2name[226064] =	'Vancomycin_trough'
d_itemid2name[227453] =	'Vancomycin_peak'
d_itemid2name[225697] =	'Vancomycin_peak'
d_itemid2name[225798] =	'Vancomycin_input'
d_itemid2name[225667] =	'ionized_Ca'
d_itemid2name[227466] =	'PTT'
d_itemid2name[220235] =	'CO2'
d_itemid2name[220045] =	'HR'
d_itemid2name[220562] =	'PTT'
d_itemid2name[227456] =	'Albumin'
d_itemid2name[220574] =	'Albumin'
d_itemid2name[220051] =	'Dia_BP'
d_itemid2name[220050] =	'Sys_BP'
d_itemid2name[225651] =	'Direct_Bili'
d_itemid2name[225690] =	'Total_Bili'
d_itemid2name[220615] =	'Creatinine'
d_itemid2name[223900] =	'GCS_Verbal'
d_itemid2name[220739] =	'GCS_Eye'
d_itemid2name[223901] =	'GCS_Motor'
d_itemid2name[220545] =	'Ht_serum'
d_itemid2name[220228] =	'Hb'
d_itemid2name[227467] =	'INR'
d_itemid2name[220561] =	'INR'
d_itemid2name[223830] =	'PH'
d_itemid2name[225678] =	'Platelet_Count'
d_itemid2name[227457] =	'Platelet_Count'
d_itemid2name[227465] =	'PT'
d_itemid2name[220560] =	'PT'
d_itemid2name[220210] =	'RR'
d_itemid2name[220227] =	'SAO2'
d_itemid2name[223762] =	'Temp_C'
d_itemid2name[223761] =	'Temp_F'
d_itemid2name[227429] =	'Troponin'
d_itemid2name[220546] =	'WBC'
d_itemid2name[227468] =	'Fibrinogen'
d_itemid2name[220541] =	'Fibrinogen'
d_itemid2name[220612] =	'CRP'
d_itemid2name[227444] =	'CRP'
d_itemid2name[226512] =	'Adm_Weight_Kg'
d_itemid2name[226531] =	'Adm_Weight_lb'
d_itemid2name[220640] = 'Potassium'
d_itemid2name[227442] = 'Potassium'
d_itemid2name[225625] = 'non_ionized_Ca'
d_itemid2name[220645] = 'Sodium'
d_itemid2name[224639] = 'Daily_Weight'
d_itemid2name[225636] = 'D_DIMER'


d_labitemid2name = dict()
d_labitemid2name[51009] = 'Vancomycin_blood'
d_labitemid2name[50861] = 'ALT'
d_labitemid2name[50878] = 'AST'
d_labitemid2name[51300] = 'WBC'
d_labitemid2name[51274] = 'PT'
d_labitemid2name[51275] = 'PTT'
d_labitemid2name[50889] = 'CRP'
d_labitemid2name[50912] = 'Creatinine'
d_labitemid2name[51006] = 'Urea_Nitrogen'
d_labitemid2name[50971] = 'Potassium'
d_labitemid2name[50893] = 'Total_Ca'
d_labitemid2name[50983] = 'Sodium'
d_labitemid2name[52618] = 'Sodium'
d_labitemid2name[50883] = 'Direct_Bili'
d_labitemid2name[50885] = 'Total_Bili'
d_labitemid2name[51221] = 'Ht'
d_labitemid2name[51222] = 'Hb'
d_labitemid2name[51237] = 'INR'
d_labitemid2name[51265] = 'Platelet_Count'
d_labitemid2name[51003] = 'Troponin'
d_labitemid2name[52111] = 'Fibrinogen'
d_labitemid2name[51214] = 'Fibrinogen'
d_labitemid2name[51196] = 'D_DIMER'
d_labitemid2name[50915] = 'D_DIMER'


range_dic = dict()
range_dic['Weight_Kg'] = (30,300)
range_dic['Daily_Weight'] = (30,300)
range_dic['Adm_Weight_Kg'] = (30,300)
range_dic['Adm_Weight_lb'] = (30/0.453592,300/0.453592)
range_dic['Creatinine'] = (0.1,10)
range_dic['Potassium'] = (1.5,10)
range_dic['Total_Bili'] = (0,20)
range_dic['PH'] = (6.5,8.5)
range_dic['Sodium'] = (115,170)
range_dic['Total_Ca'] = (4,15)
range_dic['WBC'] = (0,50)
range_dic['non_ionized_Ca'] = (5,20)
range_dic['SAO2'] = (0,100)
range_dic['Hb'] = (5,20)
range_dic['Anti-Xa'] = (0,2)
range_dic['INR'] = (0.8,8)
range_dic['Platelet_Count'] = (0,2000)
range_dic['HR'] = (30,200)
range_dic['Direct_Bili'] = (0,20)
range_dic['Ht'] = (10,60)
range_dic['Ht_serum'] = (10,60)
range_dic['Temp_C'] = (32,42)
range_dic['Temp_F'] = (89.6,107.6)
range_dic['PTT'] = (10,120)
range_dic['RR'] = (5,60)
range_dic['Dia_BP'] = (20,200)
range_dic['Sys_BP'] = (40,280)
range_dic['Albumin'] = (0,6)
range_dic['AST'] = (10,10000)
range_dic['ALT'] = (10,8000)
range_dic['CRP'] = (0,100)
range_dic['PT'] = (10,120)
range_dic['Troponin'] = (0,60)
range_dic['Urea_Nitrogen'] = (0,300)
range_dic['D_DIMER'] = (0,6000)
range_dic['CO2'] = (0,180)
range_dic['Fibrinogen'] = (0,1000)


d_rev_name2itemid = dict()
for key in d_itemid2name.keys():
    if d_itemid2name[key] in d_rev_name2itemid.keys():
        d_rev_name2itemid[d_itemid2name[key]].append(key)
    else:
        d_rev_name2itemid[d_itemid2name[key]] = [key]


d_rev_name2labitemid = dict()
for key in d_labitemid2name.keys():
    if d_labitemid2name[key] in d_rev_name2labitemid.keys():
        d_rev_name2labitemid[d_labitemid2name[key]].append(key)
    else:
        d_rev_name2labitemid[d_labitemid2name[key]] = [key]

for key in d_rev_name2itemid.keys():
    if key not in d_rev_name2labitemid.keys():
        d_rev_name2labitemid[key] = []

for key in d_rev_name2labitemid.keys():
    if key not in d_rev_name2itemid.keys():
        d_rev_name2itemid[key] = []

In [None]:
# load data and initiate df:
df = pd.read_csv('./VD_chartevents.csv', encoding='unicode_escape').drop_duplicates()
df_labs = pd.read_csv('./VD_labevents.csv', encoding='unicode_escape').drop_duplicates()
df_diags = pd.read_csv('./VD_diagnoses.csv', encoding='unicode_escape').drop_duplicates()
df_labs['value'] = df_labs['value'].str.extract('(\d+(\.\d+)?)',expand=False).astype(float)[0]
inds_to_replace = df_labs['valuenum'].isna() & df_labs['value'].notna()
df_labs['valuenum'] = np.where(inds_to_replace,df_labs['value'],df_labs['valuenum'])
df_diags = df_diags.loc[df_diags['category'].notna()]
diag_columns = df_diags['category'].unique().tolist()
columns = set(d_rev_name2itemid.keys())
columns |= set(d_rev_name2labitemid.keys())
columns = list(columns)
new_df_template = pd.DataFrame(columns=['timestep','stay_id','gender','age','race','hospital_expire_flag','admission_type','Weight_Kg'] + diag_columns + columns)
df_cont_vanco_input_template = pd.DataFrame(columns=['stay_id','traj_ind','starttime','endtime','amount'])
df_cont_vanco_blood_template = pd.DataFrame(columns=['stay_id','traj_ind','charttime','value'])


step_interval = 3600*step_size
curr_step = 0
data_dics = dict()
df_dic = dict()
cont_df_dic = dict()

In [None]:
# compute some dicts to be used later:
for stayid in df['stay_id'].unique().tolist():
    df_dic[stayid] = new_df_template.copy(deep=True)
    cont_df_dic[stayid] = {'vanco_input':df_cont_vanco_input_template.copy(deep=True),'vanco_blood':df_cont_vanco_input_template.copy(deep=True)}
    data_dics[stayid] = dict()
    data_dics[stayid]['stay_id'] = stayid
    data_dics[stayid]['gender'] = df.loc[df['stay_id'] == stayid]['gender'].unique().tolist()[0]
    data_dics[stayid]['age'] = df.loc[df['stay_id'] == stayid]['age'].unique().tolist()[0]
    data_dics[stayid]['race'] = df.loc[df['stay_id'] == stayid]['race'].unique().tolist()[0]
    data_dics[stayid]['hospital_expire_flag'] = df.loc[df['stay_id'] == stayid]['hospital_expire_flag'].unique().tolist()[0]
    data_dics[stayid]['admission_type'] = df.loc[df['stay_id'] == stayid]['admission_type'].unique().tolist()[0]
    curr_diag = df_diags.loc[df_diags['stay_id'] == stayid]['category'].unique().tolist()
    for dcol in diag_columns:
        if dcol in curr_diag:
            data_dics[stayid][dcol] = 1
        else:
            data_dics[stayid][dcol] = 0

df_ce = df.loc[df['charttime'].notna()]
df_ie = df.loc[df['amount'].notna()]

In [None]:
# build the dfs:
# it would take a little while...
invalid_stayids = []
while True:
    print("current time step: ",curr_step)
    ind1 = df_ce['charttime'] >= (curr_step * step_interval)
    ind2 = df_ie['endtime'] >= (curr_step * step_interval)
    ind3 = df_labs['charttime'] >= (curr_step * step_interval)
    if not any(ind1) and not any(ind2) and not any(ind3):
        break
    valid_stay = set(df_ce.loc[ind1]['stay_id'].unique().tolist())
    valid_stay |= set(df_ie.loc[ind2]['stay_id'].unique().tolist())
    valid_stay |= set(df_labs.loc[ind3]['stay_id'].unique().tolist())
    valid_stay = list(valid_stay)
    ind1 = ind1 & (df_ce['charttime'] < ((curr_step + 1) * step_interval))
    ind1 = ind1 & (df_ce['charttime'].notna())
    ind2 = ind2 & (df_ie['starttime'] < ((curr_step + 1) * step_interval))
    ind2 = ind2 & df_ie['starttime'].notna()
    ind2 = ind2 & df_ie['endtime'].notna()
    ind3 = ind3 & (df_labs['charttime'] < ((curr_step + 1) * step_interval))
    ind3 = ind3 & df_labs['charttime'].notna()
    for stayid in tqdm.tqdm(valid_stay):
        indstayid_ce = df_ce['stay_id'] == stayid
        indstayid_ie = df_ie['stay_id'] == stayid
        indstayid_labs = df_labs['stay_id'] == stayid
        ind11 = ind1 & indstayid_ce
        ind22 = ind2 & indstayid_ie
        ind33 = ind3 & indstayid_labs
        curr_df_ce = df_ce.loc[ind11]
        curr_df_ie = df_ie.loc[ind22]
        curr_df_labs = df_labs.loc[ind33]

        tmp_d = data_dics[stayid].copy()
        tmp_d['timestep'] = curr_step
        tmp_d['Weight_Kg'] = np.nan
        for key in columns: 
            tmp_key_curr_df_ce = curr_df_ce.loc[curr_df_ce['itemid'].isin(d_rev_name2itemid[key])]
            tmp_key_curr_df_ie = curr_df_ie.loc[curr_df_ie['itemid'].isin(d_rev_name2itemid[key])]
            tmp_key_curr_df_labs = curr_df_labs.loc[curr_df_labs['itemid'].isin(d_rev_name2labitemid[key])]
            if key in range_dic.keys():
                rel_inds1 = tmp_key_curr_df_ce['valuenum'] >= range_dic[key][0]
                rel_inds1 = rel_inds1 & (tmp_key_curr_df_ce['valuenum'] <= range_dic[key][1])
                rel_inds2 = tmp_key_curr_df_labs['valuenum'] >= range_dic[key][0]
                rel_inds2 = rel_inds2 & (tmp_key_curr_df_labs['valuenum'] <= range_dic[key][1])
                tmp_key_curr_df_ce = tmp_key_curr_df_ce.loc[rel_inds1]
                tmp_key_curr_df_labs = tmp_key_curr_df_labs.loc[rel_inds2]

            assert len(tmp_key_curr_df_ce) == 0 or len(tmp_key_curr_df_ie) == 0
            if key == 'Vancomycin_blood':
                vanco_blood_dic = dict()
            if len(tmp_key_curr_df_ce) > 0:
                if len(tmp_key_curr_df_labs) > 0:
                    if any(tmp_key_curr_df_ce['valuenum'].notnull()):
                        tcharts_df = tmp_key_curr_df_ce.loc[tmp_key_curr_df_ce['valuenum'].notnull()][['charttime','valuenum']]
                        charts_times = tcharts_df['charttime'].tolist()
                        charts_vals = tcharts_df['valuenum'].tolist()
                        if key == 'Vancomycin_blood':
                            for jj in range(len(charts_vals)):
                                vanco_blood_dic[charts_times[jj]] = charts_vals[jj]
                    if any(tmp_key_curr_df_labs['valuenum'].notnull()):
                        tlabs_df = tmp_key_curr_df_labs.loc[tmp_key_curr_df_labs['valuenum'].notnull()][['charttime','valuenum']]
                        labs_times = tlabs_df['charttime'].tolist()
                        labs_vals = tlabs_df['valuenum'].tolist()
                        assert len(labs_vals) == len(labs_times)
                        for jj in range(len(labs_vals)):
                            if labs_times[jj] not in charts_times:
                                charts_vals.append(labs_vals[jj])
                                if key == 'Vancomycin_blood':
                                    vanco_blood_dic[labs_times[jj]] = labs_vals[jj]

                    if key == 'Vancomycin_blood':
                        vanco_blood_times = list(vanco_blood_dic.keys()).copy()
                        vanco_blood_times.sort()
                        for ttime in vanco_blood_times:
                            cont_df_dic[stayid]['vanco_blood'] = cont_df_dic[stayid]['vanco_blood'].append({'stay_id':stayid,'traj_ind':0,'charttime':ttime,'value':vanco_blood_dic[ttime]}, ignore_index=True)
                    tmp_d[key] = np.asarray(charts_vals).mean()
                else:
                    if any(tmp_key_curr_df_ce['valuenum'].notnull()):
                        tcharts_df = tmp_key_curr_df_ce.loc[tmp_key_curr_df_ce['valuenum'].notnull()][['charttime','valuenum']]
                        charts_vals = tcharts_df['valuenum'].tolist()
                        charts_times = tcharts_df['charttime'].tolist()
                        tmp_d[key] = np.asarray(charts_vals).mean()
                        if key == 'Vancomycin_blood':
                            for jj in range(len(charts_vals)):
                                vanco_blood_dic[charts_times[jj]] = charts_vals[jj]
                            vanvanco_bloodco_blood_times = list(vanco_blood_dic.keys()).copy()
                            vanco_blood_times.sort()
                            for ttime in vanco_blood_times:
                                cont_df_dic[stayid]['vanco_blood'] = cont_df_dic[stayid]['vanco_blood'].append({'stay_id': stayid, 'traj_ind': 0, 'charttime': ttime, 'value': vanco_blood_dic[ttime]}, ignore_index=True)

                    else:
                        tmp_d[key] = np.nan
            elif len(tmp_key_curr_df_ie) > 0:
                val = 0
                rel_w_inds = tmp_key_curr_df_ie['patientweight'] >= range_dic['Weight_Kg'][0]
                rel_w_inds = rel_w_inds & (tmp_key_curr_df_ie['patientweight'] <= range_dic['Weight_Kg'][1])
                tmp_d['Weight_Kg'] = tmp_key_curr_df_ie.loc[rel_w_inds]['patientweight'].mean()
                for ind in tmp_key_curr_df_ie.index.tolist():

                    term1 = tmp_key_curr_df_ie.at[ind,'starttime'] >= (curr_step * step_interval)
                    term2 = tmp_key_curr_df_ie.at[ind,'endtime'] < ((curr_step + 1) * step_interval)
                    if key == 'Vancomycin_input':
                        if tmp_key_curr_df_ie.at[ind,'amountuom'] != 'dose':
                            invalid_stayids.append(stayid)
                        if tmp_key_curr_df_ie.at[ind,'endtime'] == tmp_key_curr_df_ie.at[ind,'starttime']:
                            rate = tmp_key_curr_df_ie.at[ind,'amount']
                            cont_df_dic[stayid]['vanco_input'] = cont_df_dic[stayid]['vanco_input'].append({'stay_id': stayid, 'traj_ind': 0, 'starttime': tmp_key_curr_df_ie.at[ind, 'starttime'],'endtime': tmp_key_curr_df_ie.at[ind, 'endtime'] + 1, 'rate': rate}, ignore_index=True)
                        else:
                            rate = tmp_key_curr_df_ie.at[ind,'amount'] / (tmp_key_curr_df_ie.at[ind,'endtime'] - tmp_key_curr_df_ie.at[ind,'starttime'])
                            cont_df_dic[stayid]['vanco_input'] = cont_df_dic[stayid]['vanco_input'].append({'stay_id':stayid,'traj_ind':0,'starttime':tmp_key_curr_df_ie.at[ind,'starttime'],'endtime':tmp_key_curr_df_ie.at[ind,'endtime'],'rate':rate}, ignore_index=True)
                    if term1 and term2:
                        val += tmp_key_curr_df_ie.at[ind,'amount']
                    elif term1:
                        val += (tmp_key_curr_df_ie.at[ind,'amount'] * ((curr_step + 1) * step_interval - tmp_key_curr_df_ie.at[ind,'starttime']))/(tmp_key_curr_df_ie.at[ind,'endtime'] - tmp_key_curr_df_ie.at[ind,'starttime'])
                    elif term2:
                        val += (tmp_key_curr_df_ie.at[ind,'amount'] * (tmp_key_curr_df_ie.at[ind,'endtime'] - curr_step * step_interval)) / (tmp_key_curr_df_ie.at[ind,'endtime'] - tmp_key_curr_df_ie.at[ind,'starttime'])
                    else:
                        val += (tmp_key_curr_df_ie.at[ind,'amount'] * step_interval)/(tmp_key_curr_df_ie.at[ind,'endtime'] - tmp_key_curr_df_ie.at[ind,'starttime'])
                tmp_d[key] = val
            elif len(tmp_key_curr_df_labs) > 0:
                if any(tmp_key_curr_df_labs['valuenum'].notnull()):
                    tlabs_df = tmp_key_curr_df_labs.loc[tmp_key_curr_df_labs['valuenum'].notnull()][['charttime', 'valuenum']]
                    labs_times = tlabs_df['charttime'].tolist()
                    labs_vals = tlabs_df['valuenum'].tolist()
                    tmp_d[key] = np.asarray(labs_vals).mean()
                    if key == 'Vancomycin_blood':
                        for jj in range(len(labs_vals)):
                            vanco_blood_dic[labs_times[jj]] = labs_vals[jj]
                        vanco_blood_times = list(vanco_blood_dic.keys()).copy()
                        vanco_blood_times.sort()
                        for ttime in vanco_blood_times:
                            cont_df_dic[stayid]['vanco_blood'] = cont_df_dic[stayid]['vanco_blood'].append({'stay_id': stayid, 'traj_ind': 0, 'charttime': ttime, 'value': vanco_blood_dic[ttime]}, ignore_index=True)
                else:
                    tmp_d[key] = np.nan
            else:
                tmp_d[key] = np.nan


        df_dic[stayid] = df_dic[stayid].append(tmp_d,ignore_index=True)

    curr_step += 1
# remove patgients with invalid units of vancomycin
for stayid in invalid_stayids:
    del cont_df_dic[stayid]
    del df_dic[stayid]
# backup, to save time in case things go wrong:
with open('./tmp_BU.pkl','wb') as f:
    pickle.dump({'df_dic':df_dic,'cont_df_dic':cont_df_dic},f)

In [None]:
# load backup (uncomment and run):
# with open('./tmp_BU.pkl','rb') as f:
#     dd = pickle.load(f)
# df_dic = dd['df_dic']
# cont_df_dic = dd['cont_df_dic']

In [None]:
# process the irregular data:
new_df_dic = dict()
key_list = list(df_dic.keys()).copy()
patients_skipped = 0
for stayid in tqdm.tqdm(key_list):
    ts_ar = np.arange(len(df_dic[stayid]))
    valid_vanco_input_inds = df_dic[stayid].index[df_dic[stayid]['Vancomycin_input'].notna()]
    valid_arr = ts_ar[valid_vanco_input_inds]
    if (not any(df_dic[stayid]['Vancomycin_input'].notna())) or len(valid_arr) < min_vanco_input_dlen:
#         print("Skipping patient: ", stayid)
        patients_skipped += 1
        del cont_df_dic[stayid]
        continue

    new_df_dic[stayid] = []
    start_ind = 0
    end_ind = ts_ar[-1]
    if valid_arr[0] > vanco_input_null_step_split:
        start_ind = valid_arr[0] - vanco_input_null_step_split
    if valid_arr[-1] < ts_ar[-1] - vanco_input_null_step_split:
        end_ind = valid_arr[-1] + vanco_input_null_step_split

    split_inds = []
    for i in range(1,len(valid_arr)):
        if valid_arr[i - 1] + vanco_input_null_step_split >= end_ind:
            break
        if valid_arr[i] > valid_arr[i-1] + vanco_input_null_step_split:
            split_inds.append(valid_arr[i-1] + np.ceil((valid_arr[i] - valid_arr[i-1])/2.0))
    # crop beginning and end:
    curr_start = start_ind - 1
    curr_st = 0
    if len(split_inds) > 0:
        for ind in split_inds:
            cropped_ind = ts_ar > curr_start
            cropped_ind = cropped_ind & (ts_ar <= ind)
            minidf = df_dic[stayid].loc[cropped_ind].copy(deep=True)
            minidf = minidf.reset_index(drop=True)
            new_df_dic[stayid].append(minidf.sort_values(by=['timestep']))
            new_df_dic[stayid][-1]['timestep'] = np.arange(len(new_df_dic[stayid][-1]))
            new_df_dic[stayid][-1]['traj_ind'] = curr_st*np.ones(len(new_df_dic[stayid][-1]))
            ind_cont_vanco_input = cont_df_dic[stayid]['vanco_input']['starttime'] >= ((curr_start + 1) * step_interval)
            ind_cont_vanco_input = ind_cont_vanco_input & (cont_df_dic[stayid]['vanco_input']['starttime'] < ind * step_interval)
            ind_cont_vanco_input = cont_df_dic[stayid]['vanco_input'].index[ind_cont_vanco_input]
            ind_cont_vanco_blood = cont_df_dic[stayid]['vanco_blood']['charttime'] >= ((curr_start + 1) * step_interval)
            ind_cont_vanco_blood = ind_cont_vanco_blood & (cont_df_dic[stayid]['vanco_blood']['charttime'] < ind * step_interval)
            ind_cont_vanco_blood = cont_df_dic[stayid]['vanco_blood'].index[ind_cont_vanco_blood]
            for iind in ind_cont_vanco_input:
                cont_df_dic[stayid]['vanco_input'].at[iind,'starttime'] = cont_df_dic[stayid]['vanco_input'].at[iind,'starttime'] - ((curr_start + 1) * step_interval)
                cont_df_dic[stayid]['v'].at[iind,'endtime'] = cont_df_dic[stayid]['vanco_input'].at[iind,'endtime'] - ((curr_start + 1) * step_interval)
                cont_df_dic[stayid]['vanco_input'].at[iind,'traj_ind'] = curr_st
            for iind in ind_cont_vanco_blood:
                cont_df_dic[stayid]['vanco_blood'].at[iind,'charttime'] = cont_df_dic[stayid]['vanco_blood'].at[iind,'charttime'] - ((curr_start + 1) * step_interval)
                cont_df_dic[stayid]['vanco_blood'].at[iind,'traj_ind'] = curr_st

            curr_start = ind
            curr_st += 1
    cropped_ind = ts_ar > curr_start
    cropped_ind = cropped_ind & (ts_ar <= end_ind)
    minidf = df_dic[stayid].loc[cropped_ind].copy(deep=True)
    minidf = minidf.reset_index(drop=True)
    new_df_dic[stayid].append(minidf.sort_values(by=['timestep']))
    new_df_dic[stayid][-1]['timestep'] = np.arange(len(new_df_dic[stayid][-1]))
    new_df_dic[stayid][-1]['traj_ind'] = curr_st*np.ones(len(new_df_dic[stayid][-1]))
    del df_dic[stayid]
    ind_cont_vanco_input = cont_df_dic[stayid]['vanco_input']['starttime'] >= ((curr_start + 1) * step_interval)
    ind_cont_vanco_input = ind_cont_vanco_input & (cont_df_dic[stayid]['vanco_input']['starttime'] < end_ind * step_interval)
    ind_cont_vanco_input = cont_df_dic[stayid]['vanco_input'].index[ind_cont_vanco_input]
    ind_cont_vanco_blood = cont_df_dic[stayid]['vanco_blood']['charttime'] >= ((curr_start + 1) * step_interval)
    ind_cont_vanco_blood = ind_cont_vanco_blood & (cont_df_dic[stayid]['vanco_blood']['charttime'] < end_ind * step_interval)
    ind_cont_vanco_blood = cont_df_dic[stayid]['vanco_blood'].index[ind_cont_vanco_blood]

    for iind in ind_cont_vanco_input:
        cont_df_dic[stayid]['vanco_input'].at[iind, 'starttime'] = cont_df_dic[stayid]['vanco_input'].at[iind, 'starttime'] - ((curr_start + 1) * step_interval)
        cont_df_dic[stayid]['vanco_input'].at[iind, 'endtime'] = cont_df_dic[stayid]['vanco_input'].at[iind, 'endtime'] - ((curr_start + 1) * step_interval)
        cont_df_dic[stayid]['vanco_input'].at[iind, 'traj_ind'] = curr_st
    for iind in ind_cont_vanco_blood:
        cont_df_dic[stayid]['vanco_blood'].at[iind, 'charttime'] = cont_df_dic[stayid]['vanco_blood'].at[iind, 'charttime'] - ((curr_start + 1) * step_interval)
        cont_df_dic[stayid]['vanco_blood'].at[iind, 'traj_ind'] = curr_st

In [None]:
# reorganize the columns:
list_df = []
for stayid in new_df_dic.keys():
    list_df = list_df + new_df_dic[stayid]
final_df = pd.concat(list_df,ignore_index=True)
curr_ind = 0
new_columns = []
for col in final_df.columns:
    if curr_ind == 1:
        new_columns.append('traj_ind')
        curr_ind += 1
    if col != 'traj_ind':
        new_columns.append(col)
        curr_ind += 1
final_df = final_df[new_columns]


vanco_input_list_df = []
vanco_blood_list_df = []
for stayid in cont_df_dic.keys():
    vanco_input_list_df.append(cont_df_dic[stayid]['vanco_input'])
    vanco_blood_list_df.append(cont_df_dic[stayid]['vanco_blood'])
final_vanco_input_df = pd.concat(vanco_input_list_df,ignore_index=True)
final_vanco_blood_df = pd.concat(vanco_blood_list_df,ignore_index=True)

In [None]:
# merge temp:
final_df["Temp_F"] = (5.0/9.0)*(final_df["Temp_F"] - 32)
final_df.Temp_C.fillna(final_df.Temp_F,inplace=True)
final_df = final_df.drop(columns="Temp_F")

# merge Ht:
final_df.Ht.fillna(final_df.Ht_serum,inplace=True)
final_df = final_df.drop(columns="Ht_serum")

In [None]:
# impute weight:
for patid in tqdm.tqdm(np.unique(final_df['stay_id'].to_numpy())):
    pat_idx = final_df['stay_id'] == patid
    for traj in np.unique(final_df[pat_idx]['traj_ind'].to_numpy()):
        traj_idx = final_df['traj_ind'] == traj
        traj_idx = traj_idx & pat_idx
        w_kg = None
        w_kg_adm = None
        w_kg_day = None
        w_lb_adm = None
        for tstep in np.sort(final_df[traj_idx]['timestep'].to_numpy()):
            idxx = final_df['timestep'] == tstep
            idxx = idxx & traj_idx
            if final_df[idxx]['Weight_Kg'].isna().to_numpy()[0]:
                if w_kg is not None:
                    final_df.loc[idxx,'Weight_Kg'] = w_kg
                else:
                    f_idx = final_df['Weight_Kg'].notna()
                    f_idx = f_idx & traj_idx
                    if any(f_idx):
                        w_kg = final_df[f_idx].sort_values(by=['timestep'])['Weight_Kg'].to_numpy()[0]   
                        final_df.loc[idxx,'Weight_Kg'] = w_kg
            else:
                w_kg = final_df[idxx]['Weight_Kg'].to_numpy()[0]
            if final_df[idxx]['Daily_Weight'].isna().to_numpy()[0]:
                if w_kg_day is not None:
                    final_df.loc[idxx,'Daily_Weight'] = w_kg_day
                else:
                    f_idx = final_df['Daily_Weight'].notna()
                    f_idx = f_idx & traj_idx
                    if any(f_idx):
                        w_kg_day = final_df[f_idx].sort_values(by=['timestep'])['Daily_Weight'].to_numpy()[0]
                        final_df.loc[idxx,'Daily_Weight'] = w_kg_day
            else:
                w_kg_day = final_df[idxx]['Daily_Weight'].to_numpy()[0]
            if final_df[idxx]['Adm_Weight_Kg'].isna().to_numpy()[0]:
                if w_kg_adm is not None:
                    final_df.loc[idxx,'Adm_Weight_Kg'] = w_kg_adm
                else:
                    f_idx = final_df['Adm_Weight_Kg'].notna()
                    f_idx = f_idx & traj_idx
                    if any(f_idx):
                        w_kg_adm = final_df[f_idx].sort_values(by=['timestep'])['Adm_Weight_Kg'].to_numpy()[0]
                        final_df.loc[idxx,'Adm_Weight_Kg'] = w_kg_adm
            else:
                w_kg_adm = final_df[idxx]['Adm_Weight_Kg'].to_numpy()[0]
            if final_df[idxx]['Adm_Weight_lb'].isna().to_numpy()[0]:
                if w_lb_adm is not None:
                    final_df.loc[idxx,'Adm_Weight_lb'] = w_lb_adm
                else:
                    f_idx = final_df['Adm_Weight_lb'].notna()
                    f_idx = f_idx & traj_idx
                    if any(f_idx):
                        w_lb_adm = final_df[f_idx].sort_values(by=['timestep'])['Adm_Weight_lb'].to_numpy()[0]
                        final_df.loc[idxx,'Adm_Weight_lb'] = w_lb_adm
            else:
                w_lb_adm = final_df[idxx]['Adm_Weight_lb'].to_numpy()[0]

In [None]:
# filter patients that recieved other anticoagulants:
# input_events = pd.read_csv(input_events_path)
# drug_itemids = [225906,225908,225148,229781,225147] + [225975,229597,230044]
# invalid_stays = input_events[input_events['itemid'].isin(drug_itemids)]['stay_id']
# invalid_stays = np.unique(invalid_stays.to_numpy())
# input_events = None
# final_vanco_input_df = final_vanco_input_df[~final_vanco_input_df['stay_id'].isin(invalid_stays)]
# final_vanco_blood_df = final_vanco_blood_df[~final_vanco_blood_df['stay_id'].isin(invalid_stays)]
# final_df = final_df[~final_df['stay_id'].isin(invalid_stays)]

# # make sure that there are no nan values for weight:
# w_not_nan = final_df['Weight_Kg'].notna()
# w_not_nan = w_not_nan | final_df['Daily_Weight'].notna()
# w_not_nan = w_not_nan | final_df['Adm_Weight_Kg'].notna()
# w_not_nan = w_not_nan | final_df['Adm_Weight_lb'].notna()
# final_df = final_df[w_not_nan]

In [None]:
# leave only one weight column and clean patients that have no record of weight:
final_df = final_df[final_df['Weight_Kg'].notna()]

# only 2 patients are problematic (others have fully observed Weight_Kg at this point)
final_df = final_df[~final_df['stay_id'].isin([34303520,37530120])]
final_df = final_df.drop(columns="Daily_Weight")
final_df = final_df.drop(columns="Adm_Weight_lb")
final_df = final_df.drop(columns="Adm_Weight_Kg")

In [None]:
# merge parallel dosing:
final_vanco_input_df = final_vanco_input_df.drop_duplicates()
merged_final_vanco_input_df = pd.DataFrame(columns = final_vanco_input_df.columns)
for stayid in tqdm.tqdm(final_vanco_input_df['stay_id'].unique()):
    stayid_idx = final_vanco_input_df['stay_id'] == stayid
    for trajid in final_vanco_input_df.loc[stayid_idx]['traj_ind'].unique():
        traj_idx = stayid_idx & (final_vanco_input_df['traj_ind'] == trajid)
        mini_df = final_vanco_input_df.loc[traj_idx].sort_values(by=['starttime','endtime']).reset_index(drop=True)
        mini_times = np.sort(np.unique(mini_df[['starttime', 'endtime']].to_numpy().reshape(-1)),axis=None)
        starttimes = mini_times[:-1]
        endtimes = mini_times[1:]
        vanco_input_vals = np.zeros_like(starttimes)
        for l in range(len(mini_df)):
            line_s_time = mini_df['starttime'][l]
            line_e_time = mini_df['endtime'][l]
            line_vanco_input_val = mini_df['rate'][l]
            locs = starttimes < line_e_time
            locs = locs & (endtimes > line_s_time)
            vanco_input_vals[locs]  = vanco_input_vals[locs] + line_vanco_input_val

        starttimes = starttimes.reshape(-1,1)
        endtimes = endtimes.reshape(-1,1)
        vanco_input_vals = vanco_input_vals.reshape(-1,1)
        merged_mini_df = pd.DataFrame(np.concatenate((stayid*np.ones_like(starttimes),trajid*np.ones_like(starttimes),
                                                      starttimes, endtimes, vanco_input_vals),axis=1),columns=final_vanco_input_df.columns)
        merged_final_vanco_input_df = merged_final_vanco_input_df.append(merged_mini_df)


final_vanco_input_df = merged_final_vanco_input_df

In [None]:
# categorize values, Normalize SI & vanco_input, change timesteps to minutes

# irregular data:
final_vanco_input_df['starttime'] = (1.0 / 60.0) * final_vanco_input_df['starttime']
final_vanco_input_df['endtime'] = (1.0 / 60.0) * final_vanco_input_df['endtime']
final_vanco_blood_df['charttime'] = (1.0 / 60.0) * final_vanco_blood_df['charttime']
final_vanco_blood_df.rename(columns={'charttime': 'time'}, inplace=True)


# discretized data:
stay_ids = final_df['stay_id']
traj_inds = final_df['traj_ind']
final_df['timestep'] = step_size * 60 * final_df['timestep']
time = final_df['timestep']

final_df.race = pd.Categorical(final_df.race)
final_df.gender = pd.Categorical(final_df.gender)
final_df.admission_type = pd.Categorical(final_df.admission_type)
final_df["race"] = final_df.race.cat.codes
final_df["gender"] = final_df.gender.cat.codes
final_df["admission_type"] = final_df.admission_type.cat.codes
final_df.rename(columns={'timestep': 'time'}, inplace=True)

# relevant columns
dcols = ['time','stay_id','traj_ind','gender','age','race','admission_type','Weight_Kg','Renal','Infectious','Pulmonary','CVS', \
                 'Hematological','Met','Smoking','GI','Endocrine','Psych','Obes','GCS_Motor','Creatinine','Dia_BP', \
                 'Total_Ca','PT','Ht','GCS_Eye','Total_Bili','Potassium','RR','Troponin','Sys_BP','Urea_Nitrogen', \
                 'ALT','CO2','AST','Sodium','PH','WBC','Platelet_Count','Temp_C','INR','HR','GCS_Verbal','Hb', \
                 'non_ionized_Ca']
final_df = final_df[dcols]

# divide the vanco_input rate by the patient's weight:
final_vanco_input_df['org_rate'] = final_vanco_input_df['rate']
for stayid in tqdm.tqdm(final_df['stay_id'].unique()):
    stayid_idx = final_df['stay_id'] == stayid
    stayid_idx_vanco_input = final_vanco_input_df['stay_id'] == stayid
    for trajid in final_df.loc[stayid_idx]['traj_ind'].unique():
        traj_idx = stayid_idx & (final_df['traj_ind'] == trajid)
        traj_idx_vanco_input = stayid_idx_vanco_input & (final_vanco_input_df['traj_ind'] == trajid)
        indices = final_df.index[traj_idx].to_list()
#         TODO
#         for id,time in enumerate(final_df.loc[traj_idx]['time']):
#             weight = final_df.loc[traj_idx]['Weight_Kg'][indices[id]]
#             rel_vanco_inputs = traj_idx_vanco_input & (final_vanco_input_df['starttime'] >= time)
#             rel_vanco_inputs = rel_vanco_inputs & (final_vanco_input_df['starttime'] < time + 360)
#             final_vanco_input_df.loc[rel_vanco_inputs,'rate'] = (1.0 / weight) * final_vanco_input_df.loc[rel_vanco_inputs]['rate']

# normalize SI:
min_vals = final_df.min()
max_vals = final_df.max()
norm_vals_dic = {'min':min_vals.to_numpy()[3:],'max':max_vals.to_numpy()[3:]}

# save normalization values, for inverse transformation;
with open('./Vancomycin_Dosing_norm_vals.pkl','wb') as f:
    pickle.dump(norm_vals_dic,f)
final_df = (final_df - final_df.min()) / (final_df.max() - final_df.min())
final_df['stay_id'] = stay_ids
final_df['traj_ind'] = traj_inds
final_df['time'] = time


In [None]:
# drop patients according to vanco_blood times:
vanco_blood_drop_idx = final_vanco_blood_df['stay_id'].isna() # all false, no nans
vanco_input_drop_idx = final_vanco_input_df['stay_id'].isna()
drop_idx = final_df['stay_id'].isna()

for stayid in final_vanco_blood_df['stay_id'].unique():
    vanco_blood_stayid_idx = final_vanco_blood_df['stay_id'] == stayid
    vanco_input_stayid_idx = final_vanco_input_df['stay_id'] == stayid
    stayid_idx = final_df['stay_id'] == stayid
    for trajid in final_vanco_blood_df.loc[vanco_blood_stayid_idx]['traj_ind'].unique():
        vanco_blood_traj_idx = vanco_blood_stayid_idx & (final_vanco_blood_df['traj_ind'] == trajid)
        vanco_input_traj_idx = vanco_input_stayid_idx & (final_vanco_input_df['traj_ind'] == trajid)
        traj_idx = stayid_idx & (final_df['traj_ind'] == trajid)
        if (final_vanco_blood_df.loc[vanco_blood_traj_idx]['time'].max() > max_trj_time) or (len(final_vanco_blood_df.loc[vanco_blood_traj_idx]['time']) < min_vanco_blood_smp):
            vanco_blood_drop_idx = (vanco_blood_drop_idx | vanco_blood_traj_idx)
            vanco_input_drop_idx = (vanco_input_drop_idx | vanco_input_traj_idx)
            drop_idx = (drop_idx | traj_idx)
            
final_df.drop(final_df[drop_idx].index, inplace=True)
final_vanco_blood_df.drop(final_vanco_blood_df[vanco_blood_drop_idx].index, inplace=True)
final_vanco_input_df.drop(final_vanco_input_df[vanco_input_drop_idx].index, inplace=True)
        


In [None]:
# save final csvs:
final_df.to_csv('./Vancomycin_Dosing_SI_' + str(step_size) + 'hr_dataset.csv',index=False)
final_vanco_blood_df.to_csv('./Vancomycin_Dosing_vanco_blood_' + str(step_size) + 'hr_dataset.csv',index=False)
final_vanco_input_df.to_csv('./Vancomycin_Dosing_vanco_input_' + str(step_size) + 'hr_dataset.csv',index=False)

In [None]:
# bulid pickle file for the Dataloader:
import torch
import random

out_path = './Vancomycin_Dosing_train.pkl'
test_out_path = './Vancomycin_Dosing_test.pkl'
si_path = './Vancomycin_Dosing_SI_' + str(step_size) + 'hr_dataset.csv'
vanco_input_path = './Vancomycin_Dosing_vanco_input_' + str(step_size) + 'hr_dataset.csv'
vanco_blood_path = './Vancomycin_Dosing_vanco_blood_' + str(step_size) + 'hr_dataset.csv'

SI = pd.read_csv(si_path, encoding='unicode_escape')
vanco_input = pd.read_csv(vanco_input_path, encoding='unicode_escape')
vanco_blood = pd.read_csv(vanco_blood_path, encoding='unicode_escape')

data = []
# only running on vanco_blood indices because without vanco_blood impossible to train/eval
for stayid in tqdm(vanco_blood['stay_id'].unique()):
    SI_stay_idx = SI['stay_id'] == stayid
    vanco_input_stay_idx = vanco_input['stay_id'] == stayid
    vanco_blood_stay_idx = vanco_blood['stay_id'] == stayid
    for trajid in vanco_blood[vanco_blood_stay_idx]['traj_ind'].unique():
        SI_traj_idx = SI_stay_idx & (SI['traj_ind'] == trajid)
        vanco_input_traj_idx = vanco_input_stay_idx & (vanco_input['traj_ind'] == trajid)
        vanco_blood_traj_idx = vanco_blood_stay_idx & (vanco_blood['traj_ind'] == trajid)
        SI_vals = SI[SI_traj_idx].drop(columns=['stay_id','traj_ind']).sort_values(by=['time']).to_numpy()
        vanco_input_vals = vanco_input[vanco_input_traj_idx].drop(columns=['stay_id','traj_ind','org_rate']).sort_values(by=['starttime','endtime']).to_numpy()
        vanco_input_org_vals = vanco_input[vanco_input_traj_idx].drop(columns=['stay_id','traj_ind','rate']).sort_values(by=['starttime','endtime']).to_numpy()
        vanco_blood_dat = vanco_blood[vanco_blood_traj_idx].drop(columns=['stay_id','traj_ind']).sort_values(by=['time']).to_numpy()
        vanco_blood_times = torch.Tensor(vanco_blood_dat[:,0]).view(-1,1)
        vanco_blood_vals = torch.Tensor(vanco_blood_dat[:,1]).view(-1,1)
        if SI_vals.shape[0] == 0:
            # skip
            continue
        SI_vals = torch.nan_to_num(torch.Tensor(SI_vals),nan=-1.0)
        SI_vals[:,1:] = SI_vals[:,1:] + 1.0
        data.append({'SI':SI_vals,'times':vanco_blood_times,'obs':vanco_blood_vals,'mask':torch.ones_like(vanco_blood_vals),'U':torch.Tensor(vanco_input_vals),'U_org':torch.Tensor(vanco_input_org_vals)})


        
# save datasets:
test_size = int(test_ratio * len(data))

random.shuffle(data)


with open(test_out_path,'wb') as f:
    pickle.dump({'data':data[:test_size]},f)

with open(out_path,'wb') as f:
    pickle.dump({'data':data[test_size:]},f)