In [49]:
import pandas as pd
import pandas.io.sql as sqlio
import psycopg2
import getpass
import csv

In [50]:
# Set up postgres connection
conn = psycopg2.connect(
    database="mimic", 
    user='postgres',
    password=getpass.getpass("Enter postgres password"), 
    host="127.0.0.1", 
    port="5432",
    options=f'-c search_path=mimiciii')

## 0. Config

In [51]:
out_folder = "processed_mimic_data"

## 1. Identify all 'heart disease' patients

In [52]:
# Get all the patient ids with heart disease
"""
lists of ICD 9 codes (related to heart diseases):

393-398  Chronic Rheumatic Heart Disease
410-414  Ischemic Heart Disease
420-429  Other Forms Of Heart Disease
"""

heart_disease_subject_ids = pd.read_sql(
    """
    SELECT DISTINCT(subject_id)
    FROM diagnoses_icd
    WHERE (
        icd9_code LIKE '393%' OR
        icd9_code LIKE '394%' OR
        icd9_code LIKE '395%' OR
        icd9_code LIKE '396%' OR
        icd9_code LIKE '397%' OR
        icd9_code LIKE '398%' OR
        icd9_code LIKE '410%' OR
        icd9_code LIKE '411%' OR
        icd9_code LIKE '412%' OR
        icd9_code LIKE '413%' OR
        icd9_code LIKE '414%' OR
        icd9_code LIKE '420%' OR
        icd9_code LIKE '421%' OR
        icd9_code LIKE '422%' OR
        icd9_code LIKE '423%' OR
        icd9_code LIKE '424%' OR
        icd9_code LIKE '425%' OR
        icd9_code LIKE '426%' OR
        icd9_code LIKE '427%' OR
        icd9_code LIKE '428%' OR
        icd9_code LIKE '429%' 
    );
    """, conn)

In [53]:
heart_disease_subject_ids.shape

(24138, 1)

In [54]:
# Convert to a set for filtering
heart_disease_id_set = set(heart_disease_subject_ids['subject_id'])

## 2. Retrieve all admission ids from last 12 months since each patient's last admission

In [55]:
# Get the subtraction between all admission times and the last admission by each patient; in year unit
admissions_diff = pd.read_sql(
    """
    SELECT a.subject_id, a.hadm_id,
    ROUND((cast(a.admittime as date)-cast(last_admission_time.max_admittime as date))/365.242,2) AS diff_from_last 
    FROM admissions AS a
    LEFT JOIN
        (SELECT subject_id,  MAX(admittime) AS max_admittime
        FROM admissions
        GROUP BY subject_id
        ) AS last_admission_time
    ON a.subject_id=last_admission_time.subject_id;
    """, conn)

In [56]:
admissions_diff.head()

Unnamed: 0,subject_id,hadm_id,diff_from_last
0,22,165315,0.0
1,23,152223,-4.12
2,23,124321,0.0
3,24,161859,0.0
4,25,129635,0.0


In [57]:
# Get all the admissions from last 12 months since each patient's last admission ('diff_from_last' >= -1 <year>)
admissions_last_year = admissions_diff[admissions_diff['diff_from_last'] >= -1]

In [58]:
admissions_last_year.head()

Unnamed: 0,subject_id,hadm_id,diff_from_last
0,22,165315,0.0
2,23,124321,0.0
3,24,161859,0.0
4,25,129635,0.0
5,26,197661,0.0


In [59]:
# Convert all hadm_id's into a set
hadm_id_set = set(admissions_last_year['hadm_id'])

## 3. Get all drug events and process them by remove 'stopword' events

In [60]:
# Get all the drug events from table 'inputevents_mv'
drug_events = pd.read_sql(
    """
    SELECT im.subject_id, im.hadm_id, im.starttime, im.itemid, di.abbreviation
    FROM inputevents_mv as im
    JOIN d_items as di
    ON im.itemid=di.itemid;
    """, conn)

In [61]:
drug_events.shape

(3618991, 5)

In [62]:
drug_events.head()

Unnamed: 0,subject_id,hadm_id,starttime,itemid,abbreviation
0,27063,139787,2133-02-05 06:29:00,225166,Potassium Chloride - KCL
1,27063,139787,2133-02-05 05:34:00,225944,Sterile Water
2,27063,139787,2133-02-05 05:34:00,225166,Potassium Chloride - KCL
3,27063,139787,2133-02-03 12:00:00,225893,Piperacillin/Tazobactam (Zosyn)
4,27063,139787,2133-02-03 12:00:00,220949,Dextrose 5%


### 3.1 Filter drug events by the admission id set and heart disease patient ids

In [63]:
# Filter all drug events that are in the admission id set (admissions from the last 12 month of each patient's last admission)
drug_events_last_year = drug_events[drug_events['hadm_id'].isin(hadm_id_set)]

In [64]:
# Filter by heart disease patient ids
drug_events_filtered = drug_events_last_year[drug_events_last_year['subject_id'].isin(heart_disease_id_set)]

In [65]:
drug_events_filtered.head()

Unnamed: 0,subject_id,hadm_id,starttime,itemid,abbreviation
0,27063,139787,2133-02-05 06:29:00,225166,Potassium Chloride - KCL
1,27063,139787,2133-02-05 05:34:00,225944,Sterile Water
2,27063,139787,2133-02-05 05:34:00,225166,Potassium Chloride - KCL
3,27063,139787,2133-02-03 12:00:00,225893,Piperacillin/Tazobactam (Zosyn)
4,27063,139787,2133-02-03 12:00:00,220949,Dextrose 5%


### 3.2 Drop the duplicated items (which indicates different doses in the same session)

In [66]:
# # Uncomment to view duplicated 'itemid's from the same session
# drug_events_filtered.groupby(by='subject_id').apply(lambda x: x.sort_values('itemid'))

In [67]:
# Drop the duplicates (due to different doses) in the same input session
drug_events_filtered2 = drug_events_filtered.drop_duplicates()

In [68]:
# # Sanity check, uncomment to view that there are no more duplicates from the same input session
# drug_events_filtered2.groupby(by='subject_id').apply(lambda x: x.sort_values('itemid'))

### 3.3 Remove 'stopword' events (too frequent counts or too rare)

In [69]:
# Count itemid values
itemid_counts = drug_events_filtered2['itemid'].value_counts()

itemid_counts2 = itemid_counts.reset_index()
itemid_counts2 = itemid_counts2.rename(columns={"index": "itemid", "itemid":"counts"})

In [70]:
itemid_counts2.head()

Unnamed: 0,itemid,counts
0,225158,276499
1,220949,230833
2,225943,99743
3,226452,85691
4,223258,75492


In [71]:
# Add 'proportion' of each itemid to the table 
itemid_counts2['proportion'] = itemid_counts2['counts']/sum(itemid_counts2['counts'])

In [72]:
itemid_counts2.head(10)

Unnamed: 0,itemid,counts,proportion
0,225158,276499,0.142412
1,220949,230833,0.118892
2,225943,99743,0.051373
3,226452,85691,0.044136
4,223258,75492,0.038883
5,222168,66435,0.034218
6,225799,60339,0.031078
7,221749,58045,0.029896
8,221906,51334,0.02644
9,221744,45548,0.02346


In [73]:
# Remove these proportions that are larger than 4.1% or the count is less than 5
itemid_counts3 = itemid_counts2[(itemid_counts2['proportion'] <= 0.041) & (itemid_counts2['counts'] >= 5)]

In [74]:
# Convert to a set - which contains item ids that are neither too frequent nor too rare
itemid_set = set(itemid_counts3['itemid'])

In [75]:
drug_events_filtered3 = drug_events_filtered2[drug_events_filtered2['itemid'].isin(itemid_set)]

In [76]:
drug_events_filtered3.head()

Unnamed: 0,subject_id,hadm_id,starttime,itemid,abbreviation
0,27063,139787,2133-02-05 06:29:00,225166,Potassium Chloride - KCL
1,27063,139787,2133-02-05 05:34:00,225944,Sterile Water
2,27063,139787,2133-02-05 05:34:00,225166,Potassium Chloride - KCL
3,27063,139787,2133-02-03 12:00:00,225893,Piperacillin/Tazobactam (Zosyn)
7,27063,139787,2133-02-05 09:43:00,225944,Sterile Water


In [77]:
drug_events_filtered3.shape

(1248736, 5)

### 3.4 Group drug events by each patient, and sort by session time and then 'itemid'

In [78]:
# Group by 'subject_id', and sort by 'starttime' and then 'itemid'
drug_events_only = drug_events_filtered3.groupby(by='subject_id').apply(lambda x: x.sort_values('starttime'))['itemid'].reset_index(level=[1], drop=True)

In [79]:
drug_events_only.head()

subject_id
23    222051
23    226364
23    222051
23    222051
23    225798
Name: itemid, dtype: int64

In [80]:
# Convert to a sequnce of drug events for each patient
drug_events_by_patient = drug_events_only.groupby(by='subject_id').apply(list)

drug_events_by_patient2 = drug_events_by_patient.reset_index()

In [81]:
drug_events_by_patient2.head()

Unnamed: 0,subject_id,itemid
0,23,"[222051, 226364, 222051, 222051, 225798, 22587..."
1,34,"[226361, 225942, 221668, 220970]"
2,36,"[221833, 225823, 225152, 225823, 221794, 22587..."
3,85,"[226361, 225973, 221468, 225974, 225851, 22585..."
4,107,"[225910, 225168, 225168, 225942, 222168, 22174..."


In [82]:
drug_events_by_patient2.shape

(10718, 2)

### 3.5 Filter by event length (resulted sequence length is between 3 to 50)

In [83]:
# Add 'count' of itemid to the table
drug_events_by_patient2['count'] = [len(events) for events in drug_events_by_patient2['itemid']]

In [84]:
drug_events_by_patient2['count'].describe()

count    10718.000000
mean       116.508304
std        220.813025
min          1.000000
25%         16.000000
50%         47.000000
75%        109.000000
max       4011.000000
Name: count, dtype: float64

In [85]:
# Filter by sequence length
drug_events_by_patient3 = drug_events_by_patient2[drug_events_by_patient2['count'].apply(lambda x: True if x >=3 and x <=50 else False)]

In [86]:
drug_events_by_patient3.shape

(5325, 3)

## 4. Get all procedure codes

In [87]:
procedure_codes = pd.read_sql(
    """
    SELECT a.admittime, procedures.* 
    FROM admissions AS a
    RIGHT JOIN
        (SELECT pi.subject_id, pi.hadm_id, pi.seq_num, pi.icd9_code, dip.short_title
        FROM procedures_icd AS pi
        JOIN d_icd_procedures AS dip
        ON pi.icd9_code=dip.icd9_code) AS procedures
    ON a.hadm_id=procedures.hadm_id;
    """, conn)

In [88]:
procedure_codes.head()

Unnamed: 0,admittime,subject_id,hadm_id,seq_num,icd9_code,short_title
0,2143-07-23 07:15:00,62641,154460,3,3404,Insert intercostal cath
1,2183-06-05 21:02:00,2592,130856,1,9671,Cont inv mec ven <96 hrs
2,2183-06-05 21:02:00,2592,130856,2,3893,Venous cath NEC
3,2187-06-08 02:24:00,55357,119355,1,9672,Cont inv mec ven 96+ hrs
4,2187-06-08 02:24:00,55357,119355,2,331,Spinal tap


### 4.1 Filter by the admission id set and heart disease patient ids

In [89]:
# Filter by the admission id set (admissions from the last 12 month of each patient's last admission)
procedure_codes_last_year = procedure_codes[procedure_codes['hadm_id'].isin(hadm_id_set)]

In [90]:
# Filter by heart disease patient ids
procedure_codes_filtered = procedure_codes_last_year[procedure_codes_last_year['subject_id'].isin(heart_disease_id_set)]

In [91]:
procedure_codes_filtered.head()

Unnamed: 0,admittime,subject_id,hadm_id,seq_num,icd9_code,short_title
79,2120-07-03 18:46:00,16052,137667,7,3893,Venous cath NEC
80,2120-07-03 18:46:00,16052,137667,8,3893,Venous cath NEC
81,2120-07-03 18:46:00,16052,137667,9,9923,Inject steroid
82,2103-08-08 19:34:00,7221,179572,1,3950,Angio oth non-coronary
83,2103-08-08 19:34:00,7221,179572,2,9910,Inject/inf thrombo agent


In [92]:
procedure_codes_filtered.shape

(135959, 6)

In [93]:
procedure_codes_filtered2 = procedure_codes_filtered.drop(['short_title', 'hadm_id'], axis=1)

### 4.2 Group by subject_id and sort by admittime and seq_num

In [94]:
# group by subject_id and sort by admittime (first) and seq_num (second)
procedure_codes_filtered3 = procedure_codes_filtered2.groupby(by='subject_id').apply(lambda x: x.sort_values(['admittime', 'seq_num']))

In [95]:
procedure_codes_filtered3.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,admittime,subject_id,seq_num,icd9_code
subject_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
3,44088,2101-10-20 19:08:00,3,1,9604
3,44089,2101-10-20 19:08:00,3,2,9962
3,44090,2101-10-20 19:08:00,3,3,8964
3,44091,2101-10-20 19:08:00,3,4,9672
3,44092,2101-10-20 19:08:00,3,5,3893


In [96]:
procedure_codes_filtered4 = procedure_codes_filtered3.reset_index(level=[0,1], drop=True)

In [97]:
procedure_codes_filtered4.head()

Unnamed: 0,admittime,subject_id,seq_num,icd9_code
0,2101-10-20 19:08:00,3,1,9604
1,2101-10-20 19:08:00,3,2,9962
2,2101-10-20 19:08:00,3,3,8964
3,2101-10-20 19:08:00,3,4,9672
4,2101-10-20 19:08:00,3,5,3893


In [98]:
# Group all the sequential events by patients
procedures_by_patient = procedure_codes_filtered4.groupby(by='subject_id', axis=0)['icd9_code'].apply(list)

procedures_by_patient2 = procedures_by_patient.reset_index()

In [99]:
procedures_by_patient2.head()

Unnamed: 0,subject_id,icd9_code
0,3,"[9604, 9962, 8964, 9672, 3893, 966]"
1,9,"[9672, 9604]"
2,12,"[5137, 5212, 5459, 5351, 9915, 3893, 9960, 541..."
3,13,"[3612, 3615, 3961, 3761, 8872]"
4,17,"[3571, 3961, 8872, 3731, 8872, 3893]"


### 4.3 Check the sequence length of procedures for each patient

In [100]:
# Add 'count' of itemid to the table
procedures_by_patient2['count'] = [len(codes) for codes in procedures_by_patient2['icd9_code']]

In [101]:
procedures_by_patient2['count'].describe()

count    22293.000000
mean         6.098731
std          5.005845
min          1.000000
25%          3.000000
50%          5.000000
75%          8.000000
max         65.000000
Name: count, dtype: float64

## 5. Merge drug events and procedures

In [102]:
# Merge drug events (filtered by event length between 3 and 50) and procedures
drug_events_procedures_merged = pd.merge(drug_events_by_patient3, procedures_by_patient2, how='inner', on='subject_id')

In [103]:
drug_events_procedures_merged.shape

(4439, 5)

In [104]:
drug_events_procedures_merged.head()

Unnamed: 0,subject_id,itemid,count_x,icd9_code,count_y
0,23,"[222051, 226364, 222051, 222051, 225798, 22587...",18,[0151],1
1,34,"[226361, 225942, 221668, 220970]",4,"[3772, 3783, 8945]",3
2,85,"[226361, 225973, 221468, 225974, 225851, 22585...",22,"[3142, 3897]",2
3,107,"[225910, 225168, 225168, 225942, 222168, 22174...",27,"[3995, 4443, 4513, 3995]",4
4,111,"[221749, 222315, 221906, 227533, 220952, 22231...",21,"[9671, 9604, 3893, 3891]",4


In [105]:
drug_events_procedures_merged['total_count'] = drug_events_procedures_merged['count_x'] + drug_events_procedures_merged['count_y']

In [106]:
drug_events_procedures_merged['total_count'].describe()

count    4439.000000
mean       26.663212
std        14.715137
min         4.000000
25%        14.000000
50%        24.000000
75%        39.000000
max        74.000000
Name: total_count, dtype: float64

In [107]:
# Remove NA values
drug_events_procedures_merged2 = drug_events_procedures_merged[drug_events_procedures_merged['icd9_code'].notna()]

# Remove columns of counts
drug_events_procedures_merged3 = drug_events_procedures_merged2.drop(['count_x', 'count_y', 'total_count'], axis=1)

# Rename columns
drug_events_procedures_merged4 = drug_events_procedures_merged3.rename(columns={"icd9_code": "procedure_codes", "itemid":"drug_events"})

In [108]:
drug_events_procedures_merged4.head()

Unnamed: 0,subject_id,drug_events,procedure_codes
0,23,"[222051, 226364, 222051, 222051, 225798, 22587...",[0151]
1,34,"[226361, 225942, 221668, 220970]","[3772, 3783, 8945]"
2,85,"[226361, 225973, 221468, 225974, 225851, 22585...","[3142, 3897]"
3,107,"[225910, 225168, 225168, 225942, 222168, 22174...","[3995, 4443, 4513, 3995]"
4,111,"[221749, 222315, 221906, 227533, 220952, 22231...","[9671, 9604, 3893, 3891]"


In [109]:
drug_events_procedures_merged4.shape

(4439, 3)

## 6.1 Add survival flag (1: survival, 0: death)

In [110]:
#expire_flag: 1 indicates death in the hospital, and 0 indicates survival to hospital discharge.
survival_subject_ids = pd.read_sql(
    """
    SELECT subject_id FROM patients
    WHERE expire_flag=0;
    """, conn)

In [111]:
# Convert to a set of survival ids
survival_id_set = set(survival_subject_ids['subject_id'])

In [112]:
drug_events_procedures_merged4['survival'] = [1 if idx in survival_id_set else 0 for idx in drug_events_procedures_merged4['subject_id']]

In [113]:
drug_events_procedures_merged4.head()

Unnamed: 0,subject_id,drug_events,procedure_codes,survival
0,23,"[222051, 226364, 222051, 222051, 225798, 22587...",[0151],1
1,34,"[226361, 225942, 221668, 220970]","[3772, 3783, 8945]",0
2,85,"[226361, 225973, 221468, 225974, 225851, 22585...","[3142, 3897]",0
3,107,"[225910, 225168, 225168, 225942, 222168, 22174...","[3995, 4443, 4513, 3995]",1
4,111,"[221749, 222315, 221906, 227533, 220952, 22231...","[9671, 9604, 3893, 3891]",0


In [114]:
drug_events_procedures_merged4['survival'].value_counts()

1    3018
0    1421
Name: survival, dtype: int64

## 6.2 Add static patient features.

In [115]:
patients = pd.read_sql(
    """
    SELECT subject_id, gender,dob FROM patients;
    """, conn)

## 7. Export as the input format for DRG framework

In [116]:
final_merged = drug_events_procedures_merged4.copy().merge(patients, on="subject_id", how='left')

In [117]:
drug_events_procedures_merged4.shape

(4439, 4)

In [118]:
final_merged.shape

(4439, 6)

In [119]:
# Split into positive/negative data
neg_data = final_merged[final_merged['survival'] == 0].drop(columns=['survival'])
pos_data = final_merged[final_merged['survival'] == 1].drop(columns=['survival'])

### 7.1 Export negative samples

In [120]:
neg_data2 = neg_data.copy()
neg_data2.drop(['drug_events', 'procedure_codes'], axis=1)

neg_data2['drug_events'] = neg_data['drug_events'].apply(lambda x: ' '.join(str(i) for i in x))
neg_data2['procedure_codes'] = neg_data['procedure_codes'].apply(lambda x: ' '.join(x))

In [121]:
neg_data

Unnamed: 0,subject_id,drug_events,procedure_codes,gender,dob
1,34,"[226361, 225942, 221668, 220970]","[3772, 3783, 8945]",M,1886-07-18
2,85,"[226361, 225973, 221468, 225974, 225851, 22585...","[3142, 3897]",M,2090-09-18
4,111,"[221749, 222315, 221906, 227533, 220952, 22231...","[9671, 9604, 3893, 3891]",F,2075-07-16
8,222,"[226361, 220995, 225168, 225168, 226361, 22636...","[3239, 403, 4523, 3893]",F,2073-07-25
9,236,"[226363, 225159, 221794, 225154, 222011, 22179...","[0066, 3722, 8856, 3606, 0046, 0040, 4443, 332...",M,2081-12-05
...,...,...,...,...,...
4428,99814,"[225828, 222168, 221749, 225168, 223258, 22205...","[3812, 0040, 9904, 9604, 9671]",F,2055-07-17
4430,99847,"[225154, 221794, 225154, 221833, 225154, 22179...","[3324, 9390]",F,2115-08-27
4433,99935,"[222168, 221828, 225798, 225166, 226364, 22636...",[0131],M,2064-05-19
4434,99944,"[226361, 225975, 221794, 222056, 221794, 22201...","[3723, 8856]",F,2075-08-25


In [122]:
# Split into train/validation; 200 validation samples and the rest are train
validation_neg = neg_data2.sample(n=200, random_state=3)
train_neg = neg_data2.drop(validation_neg.index)

In [123]:
# Write as txt files into '../mimic_data/'
train_neg.to_csv(path_or_buf=f'../{out_folder}/train_neg.txt', index=False)
validation_neg.to_csv(path_or_buf=f'../{out_folder}/validation_neg.txt', index=False)

### 7.2 Export positive samples

In [124]:
pos_data2 = pos_data.copy()
pos_data2.drop(['drug_events', 'procedure_codes'], axis=1)

pos_data2['drug_events'] = pos_data['drug_events'].apply(lambda x: ' '.join(str(i) for i in x))
pos_data2['procedure_codes'] = pos_data['procedure_codes'].apply(lambda x: ' '.join(x))

In [125]:
pos_data2.head()

Unnamed: 0,subject_id,drug_events,procedure_codes,gender,dob
0,23,222051 226364 222051 222051 225798 225879 2232...,0151,M,2082-07-17
3,107,225910 225168 225168 225942 222168 221744 2217...,3995 4443 4513 3995,M,2052-04-02
5,154,225151 225152 226361 225159 225154 225152 2251...,8856,M,2073-07-26
6,165,225975 221794 225845 225855,9390,M,2084-04-09
7,209,225152 221828 221828 225152 225975 221828 2232...,3995,M,2054-01-13


In [126]:
# Split into train/validation; 200 validation samples and the rest are train
validation_pos = pos_data2.sample(n=200, random_state=3)
train_pos = pos_data2.drop(validation_pos.index)

In [127]:
# Write as txt files into '../mimic_data/'
train_pos.to_csv(path_or_buf=f'../{out_folder}/train_pos.txt', index=False)
validation_pos.to_csv(path_or_buf=f'../{out_folder}/validation_pos.txt', index=False)

## 7.3 export all samples

In [128]:
train_all = pd.concat([train_pos,train_neg])
train_all.to_csv(path_or_buf=f'../{out_folder}/train_all.txt', index=False)