In [1]:
import os
import pickle as pickle

import pandas as pd
import pandas.io.sql as sqlio
import psycopg2
import getpass
import csv

In [2]:
# Set up postgres connection
conn = psycopg2.connect(
    database="mimic", 
    user='postgres',
    password=getpass.getpass("Enter postgres password"), 
    host="127.0.0.1", 
    port="5432",
    options=f'-c search_path=mimiciii')

In [3]:
# Choose between the disease type of ccnstructed data
DISEASE_TYPE = 'CARDIOVASCULAR' # `CARDIOVASCULAR` or `SEPSIS` or `ARDS` (acute respiratory distress syndrome)

## 1. Identify all filtered patient ids, using first 3-digit category ICD-9 codes

In [4]:
"""
lists of ICD 9 codes (related to heart diseases):

393-398  Chronic Rheumatic Heart Disease
410-414  Ischemic Heart Disease
420-429  Other Forms Of Heart Disease
"""

heart_disease_query = """
    SELECT DISTINCT(subject_id)
    FROM diagnoses_icd
    WHERE (
        icd9_code LIKE '393%' OR
        icd9_code LIKE '394%' OR
        icd9_code LIKE '395%' OR
        icd9_code LIKE '396%' OR
        icd9_code LIKE '397%' OR
        icd9_code LIKE '398%' OR
        icd9_code LIKE '410%' OR
        icd9_code LIKE '411%' OR
        icd9_code LIKE '412%' OR
        icd9_code LIKE '413%' OR
        icd9_code LIKE '414%' OR
        icd9_code LIKE '420%' OR
        icd9_code LIKE '421%' OR
        icd9_code LIKE '422%' OR
        icd9_code LIKE '423%' OR
        icd9_code LIKE '424%' OR
        icd9_code LIKE '425%' OR
        icd9_code LIKE '426%' OR
        icd9_code LIKE '427%' OR
        icd9_code LIKE '428%' OR
        icd9_code LIKE '429%' 
    );
    """

In [5]:
"""
lists of ICD 9 codes (related to sepsis diseases):

038 - Septicaemia
054.5 - Herpetic septicemia 
670.2 - Major puerperal infection : Puerperal sepsis (?) 
785.52 - Septic shock 
995.91 - Sepsis
995.92 - Sepsis, with acute organ dysfunction/multiple organ dysfunction/severe

"""

sepsis_query = """
    SELECT DISTINCT(subject_id)
    FROM diagnoses_icd
    WHERE (
        icd9_code LIKE '038%' OR
        icd9_code LIKE '0545%' OR
        icd9_code LIKE '78552' OR
        icd9_code LIKE '99591' OR
        icd9_code LIKE '99592'
    );
    """

In [6]:
"""
lists of ICD 9 codes (related to sepsis diseases):
518.5 - ARDS 
518.81 - Respiratory failure, acute
"""

ards_query = """
    SELECT DISTINCT(subject_id)
    FROM diagnoses_icd
    WHERE (
        icd9_code LIKE '5185%' OR
        icd9_code LIKE '51881'
    );
    """

In [7]:
# Get all the patient ids with chosen disease
if DISEASE_TYPE == 'CARDIOVASCULAR':
    filter_query = heart_disease_query 
elif DISEASE_TYPE == 'SEPSIS':
    filter_query = sepsis_query
elif DISEASE_TYPE == 'ARDS':
    filter_query = ards_query
else:
    print("Error. Not implemented.")

filtered_subject_ids = pd.read_sql(filter_query, conn)

In [8]:
# Convert to a set for filtering
subject_id_set = set(filtered_subject_ids['subject_id'])

In [9]:
len(subject_id_set)

24138

## 2. Retrieve all admission ids from last 12 months since each patient's last admission

In [10]:
# Get the subtraction between all admission times and the last admission by each patient; in year unit
admissions_diff = pd.read_sql(
    """
    SELECT a.subject_id, a.hadm_id,
    ROUND((cast(a.admittime as date)-cast(last_admission_time.max_admittime as date))/365.242,2) AS diff_from_last 
    FROM admissions AS a
    LEFT JOIN
        (SELECT subject_id,  MAX(admittime) AS max_admittime
        FROM admissions
        GROUP BY subject_id
        ) AS last_admission_time
    ON a.subject_id=last_admission_time.subject_id
    WHERE a.subject_id IN %(subject_id_set)s;
    """, 
    con=conn,
    params={'subject_id_set': tuple(subject_id_set)}) # add 'WHERE' to filter paient ids in SQL query

In [11]:
admissions_diff.shape

(33872, 3)

In [12]:
admissions_diff['diff_from_last'].describe()

count    33872.000000
mean        -0.536699
std          1.446672
min        -11.560000
25%         -0.100000
50%          0.000000
75%          0.000000
max          0.000000
Name: diff_from_last, dtype: float64

In [13]:
# # uncomment to sanity check, using an id from `filtered_subject_ids`
# # filtered_subject_ids.head()
# admissions_diff[admissions_diff['subject_id']==36]

In [14]:
# Get all the admissions from last 12 months since each patient's last admission ('diff_from_last' >= -1 <year>)
admissions_last_year = admissions_diff[admissions_diff['diff_from_last'] >= -1]

In [15]:
# # uncomment to sanity check
# admissions_last_year[admissions_last_year['subject_id']==36]

In [16]:
# Convert all hadm_id's into a set
hadm_id_set = set(admissions_last_year['hadm_id'])

In [17]:
len(hadm_id_set)

29059

## 3. Get all drug events and process them by remove 'stopword' events

In [18]:
# Get all the drug events from table 'inputevents_mv'
# update: add "order by" for correct ordered formats -> https://stackoverflow.com/questions/369362/postgresql-changing-returned-rows-order
drug_events = pd.read_sql(
    """
    SELECT im.subject_id, im.hadm_id, im.starttime, im.itemid, di.abbreviation
    FROM inputevents_mv as im
    JOIN d_items as di
    ON im.itemid=di.itemid
    WHERE im.subject_id IN %(subject_id_set)s
    AND im.hadm_id IN %(hadm_id_set)s
    ORDER BY im.subject_id ASC, 
            im.starttime ASC, 
            im.itemid ASC;;
    """, 
    con=conn,
    params={'subject_id_set': tuple(subject_id_set),
           'hadm_id_set': tuple(hadm_id_set)}) # add 'WHERE' to filter paient and admission ids in SQL

In [19]:
drug_events.shape

(2325288, 5)

In [20]:
len(set(drug_events['subject_id']))

10790

In [21]:
# # uncomment to sanity check
# drug_events[drug_events['subject_id'] == 36]

### 3.1 Remove 'stopword' events (too frequent counts or too rare)

In [22]:
# Count itemid values
itemid_counts = drug_events['itemid'].value_counts()

itemid_counts2 = itemid_counts.reset_index()
itemid_counts2 = itemid_counts2.rename(columns={"index": "itemid", "itemid":"counts"})

In [23]:
# itemid_counts2.head()

In [24]:
# Add 'proportion' of each itemid to the table 
itemid_counts2['proportion'] = itemid_counts2['counts']/sum(itemid_counts2['counts'])

In [25]:
# itemid_counts2.head(10)

In [26]:
# Remove these proportions that are larger than the `threshold` or the count is less than 5;
# update: threshold for heart disease: 3.7% (0.037); for sepsis: 5.3% (0.053) ('PO-intake' not in the list); 
# each threshold was chosen by consultation with medical experts based on the frequency
if DISEASE_TYPE == 'CARDIOVASCULAR':
    prop_threshold = 0.037 
elif DISEASE_TYPE == 'SEPSIS':
    prop_threshold = 0.053
elif DISEASE_TYPE == 'ARDS':
    prop_threshold = 0.073
else:
    print("Error. Not implemented.")

itemid_counts3 = itemid_counts2[(itemid_counts2['proportion'] <= prop_threshold) & (itemid_counts2['counts'] >= 5)]

In [27]:
# Convert to a set - which contains item ids that are neither too frequent nor too rare
itemid_set = set(itemid_counts3['itemid'])

In [28]:
drug_events_filtered2 = drug_events[drug_events['itemid'].isin(itemid_set)]

drug_events_filtered2.shape

(1372601, 5)

In [29]:
# # uncomment to sanity check
# drug_events_filtered2[drug_events_filtered2['subject_id']==36]

### 3.2 Drop the duplicated items (which indicates different doses in the same session)

In [30]:
# # Uncomment to view duplicated 'itemid's from the same session
# drug_events_filtered.groupby(by='subject_id').apply(lambda x: x.sort_values('itemid'))

In [31]:
# Drop the 'exact' duplicates (due to different doses) in the same input session
drug_events_filtered3 = drug_events_filtered2.drop_duplicates()

drug_events_filtered3.shape

(1182305, 5)

In [32]:
# sort by patient id and starttime (admission)
drug_events_filtered3 = drug_events_filtered3.sort_values(['subject_id', 'starttime'])

In [33]:
# remove 'starttime', in order to drop duplicates from the same admission consecutively
drug_events_filtered3 = drug_events_filtered3.drop(['starttime'], axis=1)

In [34]:
# also drop the duplicates in the same admission and next to each other (consecutively)
drug_events_filtered4 = drug_events_filtered3.loc[(drug_events_filtered3.shift() != drug_events_filtered3).any(axis=1)]

drug_events_filtered4.shape # the count of all drug events, after filtering 

(1039380, 4)

In [35]:
# # sanity check
# drug_events_filtered4[drug_events_filtered4['subject_id']==36]

In [36]:
# convert 'int' to 'str' for output
drug_events_filtered4['itemid'] = drug_events_filtered4['itemid'].astype(str)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  


In [37]:
# drug_events_filtered4.head()

In [38]:
drug_events_only = drug_events_filtered4.drop(['abbreviation'], axis=1)
drug_events_only['event_type'] = 1 # Add 'event_type'=1 for sanity check when merging tables

# drug_events_only.head()
len(set(drug_events_only['subject_id']))

10717

In [39]:
# drug_events_only[drug_events_only['subject_id'] == 36] # sanity check

## 4. Get all the diagnosis codes by filtered admission ids (and then group by patient)

In [40]:
filtered_hadm_id_set = set(drug_events_only['hadm_id'])

In [41]:
diagnosis_codes = pd.read_sql(
    sql="""
        SELECT di.*, a.admittime
        FROM diagnoses_icd AS di
        JOIN admissions AS a
        ON di.hadm_id=a.hadm_id
        WHERE di.hadm_id IN %(filtered_hadm_id_set)s;
        """, 
    con=conn,
    params={'filtered_hadm_id_set': tuple(filtered_hadm_id_set)})

In [42]:
# note: use the first 3 digits of the full icd 9 diagnosis codes, for the category
diagnosis_codes['short_code'] = diagnosis_codes['icd9_code'].apply(lambda code: code[:3])

# diagnosis_codes.head()

In [43]:
sorted_diagnosis_codes = diagnosis_codes.groupby(by=['subject_id']).apply(
    lambda x: x.sort_values(['admittime', 'seq_num'], ascending=[True, True]))[['hadm_id', 'seq_num', 'short_code', 'admittime']]

sorted_diagnosis_codes = sorted_diagnosis_codes.reset_index().drop('level_1', axis=1)

In [44]:
# # uncomment to sanity check 
# sorted_diagnosis_codes[sorted_diagnosis_codes['subject_id'] == 36]

## 5. Get all procedure codes

In [45]:
procedure_codes = pd.read_sql(
    """
    SELECT a.admittime, procedures.* 
    FROM admissions AS a
    RIGHT JOIN
        (SELECT pi.subject_id, pi.hadm_id, pi.seq_num, pi.icd9_code, dip.short_title
        FROM procedures_icd AS pi
        JOIN d_icd_procedures AS dip
        ON pi.icd9_code=dip.icd9_code) AS procedures
    ON a.hadm_id=procedures.hadm_id
    WHERE procedures.hadm_id IN %(filtered_hadm_id_set)s;
    """, 
    con=conn,
    params={'filtered_hadm_id_set': tuple(filtered_hadm_id_set)}) # add 'WHERE' to filter admission ids

In [46]:
procedure_codes.shape

(53101, 6)

In [47]:
len(set(procedure_codes['subject_id']))

9642

In [48]:
# # uncomment to sanity check
# procedure_codes[procedure_codes['subject_id']==36]

In [49]:
procedure_codes_filtered2 = procedure_codes.drop(['short_title'], axis=1)

procedure_codes_filtered2.shape

(53101, 5)

### 5.1 Group by subject_id and sort by admittime and seq_num

In [50]:
# group by subject_id and sort by admittime (first) and seq_num (second)
procedure_codes_filtered3 = procedure_codes_filtered2.groupby(by='subject_id', sort=False).apply(
    lambda x: x.sort_values(['admittime', 'seq_num']))

# procedure_codes_filtered3.head()

In [51]:
procedure_codes_filtered4 = procedure_codes_filtered3.reset_index(level=[0,1], drop=True)

# procedure_codes_filtered4.head()

In [52]:
# # sanity check
# procedure_codes_filtered4[procedure_codes_filtered4['subject_id']==36]

## 6. Merge drug events and procedures, and then output

In [53]:
# drug_events_only.head()

In [54]:
procedure_codes_filtered4['event_type'] = 2
procedure_codes_filtered4['itemid'] = procedure_codes_filtered4['icd9_code'] # rename the column name to merge the tables

# procedure_codes_filtered4.head()
procedure_codes_filtered4.shape # the count of all procedures, after filtered

(53101, 7)

In [55]:
merged = pd.concat([drug_events_only, procedure_codes_filtered4]).reset_index(level=0, drop=True)

# merged.head()

In [56]:
merged[merged['event_type'] == 1].shape

(1039380, 7)

In [57]:
merged[merged['event_type'] == 2].shape

(53101, 7)

In [58]:
# # sanity check to verify how many subjects are overlapped
# overlapped_sbj = set(drug_events_only['subject_id']).intersection(set(procedure_codes_filtered4['subject_id']))

# len(overlapped_sbj)

In [59]:
#sanity check
# drug_events_only[drug_events_only['hadm_id'] == 183791]

In [60]:
#sanity check
# procedure_codes_filtered4[procedure_codes_filtered4['hadm_id'] == 183791]

In [61]:
# note: df.groupby(by=['subject_id', 'hadm_id']) default sort=True... while `hadm_id` is randomly generated (no order with timestamp)
merged_by_adm = merged.groupby(by=['subject_id', 'hadm_id'], sort=False)['itemid'].apply(list).reset_index()


In [62]:
# merged_by_adm.head()

In [63]:
# # sanity check
# merged_by_adm[merged_by_adm['subject_id']==36]

In [64]:
# remove the duplicated event/treatment next to each other (perhaps due to different timestamps after grouping)
merged_by_adm['itemid'] = merged_by_adm['itemid'].apply(lambda x: [x[i] for i in range(len(x)) if (i==0) or x[i] != x[i-1]])


In [65]:
merged_by_sbj = merged_by_adm.groupby(by='subject_id')['itemid'].apply(list).reset_index()

# merged_by_sbj.head()

In [66]:
# add the count column for all the events
merged_by_sbj['count'] = merged_by_sbj['itemid'].apply(lambda x: len([event for adm in x for event in adm]))

# merged_by_sbj.head()

In [67]:
merged_by_sbj['count'].describe()

count    10717.000000
mean       101.805449
std        195.192724
min          1.000000
25%         16.000000
50%         41.000000
75%         92.000000
max       3596.000000
Name: count, dtype: float64

In [68]:
# filter/cut the event length to the event count threshold
if DISEASE_TYPE == 'CARDIOVASCULAR':
    event_count_threshold = 92 # 75% percentile 
elif DISEASE_TYPE == 'SEPSIS':
    event_count_threshold = 92 # cut at 92 (~1400 samples); not 75% percentile, to reduce computational costs (len=281)
elif DISEASE_TYPE == 'ARDS':
    event_count_threshold = 92 # cut at 92 (~1782 samples); not 75% percentile, to reduce computational costs (len=319) (~3177 samples)

filtered_events = merged_by_sbj[merged_by_sbj['count'] <= event_count_threshold]

len(filtered_events)

8059

In [69]:
# filtered_events['itemid'].apply(lambda x: len(x)).hist() # check how many admission visits for each patient, after filtering

In [70]:
# add 'Admission_in' and 'Admission_out' stamps in the grouped events
filtered_events['final_seq'] = filtered_events['itemid'].apply(
    lambda x: [" ".join(adm) for adm in x]).apply(
    lambda x: "admin " + " admout admin ".join(x) + " admout")

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  after removing the cwd from sys.path.


In [71]:
filtered_events['final_count'] = filtered_events['final_seq'].apply(lambda x: len(x.split()))

# filtered_events.head()
filtered_events['final_count'].describe()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  """Entry point for launching an IPython kernel.


count    8059.000000
mean       34.741903
std        24.284513
min         3.000000
25%        14.000000
50%        29.000000
75%        52.000000
max       102.000000
Name: final_count, dtype: float64

In [72]:
# # sanity check
# filtered_events['final_count'].idxmax()
# filtered_events.loc[282]

In [73]:
final_subject_id_set = set(filtered_events['subject_id'])

### 6.1 Update: Get the `final_subject_id_set` and then add the diagnosis columns

In [74]:
filtered_diagnosis_codes = sorted_diagnosis_codes[sorted_diagnosis_codes['subject_id'].isin(final_subject_id_set)]

# filtered_diagnosis_codes.head()

In [75]:
len(set(filtered_diagnosis_codes['subject_id']))

8059

In [76]:
filtered_diagnosis_codes2 = filtered_diagnosis_codes.groupby(by=['subject_id', 'hadm_id'], sort=False)['short_code'].apply(list).reset_index()

# filtered_diagnosis_codes2.head()

In [77]:
diagnosis_by_sbj = filtered_diagnosis_codes2.groupby(by='subject_id')['short_code'].apply(list).reset_index()

# diagnosis_by_sbj.head()

In [78]:
filtered_diagnosis_codes3 = diagnosis_by_sbj.copy()

In [79]:
# add 'Admission_in' and 'Admission_out' stamps in the grouped events
filtered_diagnosis_codes3['final_diag'] = diagnosis_by_sbj['short_code'].apply(
    lambda x: [" ".join(adm) for adm in x]).apply(
    lambda x: "admin " + " admout admin ".join(x) + " admout")

# filtered_diagnosis_codes3.head()

In [80]:
filtered_diagnosis_codes3['final_count_diag'] = filtered_diagnosis_codes3['final_diag'].apply(lambda x: len(x.split()))

filtered_diagnosis_codes3['final_count_diag'].max()

146

In [81]:
# # sanity check
# filtered_diagnosis_codes3[filtered_diagnosis_codes3['subject_id']==689]

### 6.2 Update: Get the `final_subject_id_set` and then find the drug/procedure coexistence

In [82]:
## check the coexistance of drug event and procedures
coexist_table = merged.groupby(by=['subject_id', 'hadm_id'], sort=False)['itemid'].apply(set).reset_index()

# coexist_table.head()

In [83]:
coexist_table_filtered = coexist_table[coexist_table['subject_id'].isin(final_subject_id_set)]

In [84]:
# get rule-based dictionary of co-existing drugs/procedures
coexist_dict = dict()

for item_set in coexist_table_filtered['itemid']:
    for item in item_set:
        temp_dict = {item: item_set}
        coexist_dict.update(temp_dict)

In [85]:
# export the event coexistence files
for output_idx in [1, 2, 3, 4, 5]:
    output_datapath = '../mimic_data_' + DISEASE_TYPE.lower() + str(output_idx) + '/'

    os.makedirs(output_datapath, exist_ok=True)
    with open(output_datapath + 'coexist_dict.pkl', 'wb') as f:
        pickle.dump(coexist_dict, f, protocol=pickle.HIGHEST_PROTOCOL)

### 6.3 Add survival flag (1: survival, 0: death)

In [86]:
# expire_flag: 1 indicates death in the hospital, and 0 indicates survival to hospital discharge.
survival_subject_ids = pd.read_sql(
    """
    SELECT subject_id FROM patients
    WHERE expire_flag=0;
    """, conn)

In [87]:
# Convert to a set of survival ids
survival_id_set = set(survival_subject_ids['subject_id'])

In [88]:
filtered_events['survival'] = [1 if idx in survival_id_set else 0 for idx in filtered_events['subject_id']]

filtered_events.shape

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  """Entry point for launching an IPython kernel.


(8059, 6)

In [89]:
filtered_events['survival'].value_counts()

1    5525
0    2534
Name: survival, dtype: int64

### 6.4 Export as the input format for DRG framework

In [90]:
filtered_diagnosis_codes3.shape

(8059, 4)

In [91]:
final_merged = filtered_events.merge(filtered_diagnosis_codes3, on='subject_id', how='inner')

final_merged.shape

(8059, 9)

In [92]:
# final_merged.head()

In [93]:
final_merged['final_count_diag'] = final_merged['final_diag'].apply(lambda x: len(x.split()))

In [94]:
final_merged['final_count_diag'].describe() # max: 166 (updated)

count    8059.000000
mean       18.434297
std        10.406458
min         3.000000
25%        12.000000
50%        16.000000
75%        22.000000
max       146.000000
Name: final_count_diag, dtype: float64

In [95]:
# # update: check if there is some patient with different admission in diagnosis codes and drug events, should be no false
# compared = final_merged['itemid'].apply(len) == final_merged['short_code'].apply(len)

# compared[compared==False] 

In [96]:
# # sanity check 
# final_merged[final_merged['subject_id'] == 689]

In [97]:
# Split into positive/negative data
neg_data = final_merged[final_merged['survival'] == 0]['final_seq']
pos_data = final_merged[final_merged['survival'] == 1]['final_seq']

# update: diagnosis code split-ups
neg_diag = final_merged[final_merged['survival'] == 0]['final_diag']
pos_diag = final_merged[final_merged['survival'] == 1]['final_diag']

#### Export negative samples

In [98]:
if DISEASE_TYPE == 'CARDIOVASCULAR':
    val_data_size = 200
elif DISEASE_TYPE == 'SEPSIS' or DISEASE_TYPE == 'ARDS':
    val_data_size = 100
else:
    print("Error. Not implemented.")

In [99]:
# Split into train/validation; 200 validation samples and the rest are train
validation_neg = neg_data.sample(n=val_data_size, random_state=3)
train_neg = neg_data.drop(validation_neg.index)

In [100]:
# Write as txt files into '../mimic_data/'
train_neg.to_csv(path_or_buf=output_datapath+'train_neg.txt', index=False, header=False, sep=' ', quoting=csv.QUOTE_NONE, escapechar=' ')
validation_neg.to_csv(path_or_buf=output_datapath+'validation_neg.txt', index=False, header=False, sep=' ', quoting=csv.QUOTE_NONE, escapechar=' ')

In [101]:
# Write as txt files into '../mimic_data/', update: DIAGNOSIS CODES
neg_diag.loc[train_neg.index].to_csv(path_or_buf=output_datapath+'train_neg_diag.txt', index=False, header=False, sep=' ', quoting=csv.QUOTE_NONE, escapechar=' ')
neg_diag.loc[validation_neg.index].to_csv(path_or_buf=output_datapath+'validation_neg_diag.txt', index=False, header=False, sep=' ', quoting=csv.QUOTE_NONE, escapechar=' ')

#### Export positive samples

In [102]:
# Split into train/validation; 200 validation samples and the rest are train
validation_pos = pos_data.sample(n=val_data_size, random_state=3)
train_pos = pos_data.drop(validation_pos.index)

In [103]:
# Write as txt files into '../mimic_data/'
train_pos.to_csv(path_or_buf=output_datapath+'train_pos.txt', index=False, header=False, sep=' ', quoting=csv.QUOTE_NONE, escapechar=' ')
validation_pos.to_csv(path_or_buf=output_datapath+'validation_pos.txt', index=False, header=False, sep=' ', quoting=csv.QUOTE_NONE, escapechar=' ')

In [104]:
# Write as txt files into '../mimic_data/', update: DIAGNOSIS CODES
pos_diag.loc[train_pos.index].to_csv(path_or_buf=output_datapath+'train_pos_diag.txt', index=False, header=False, sep=' ', quoting=csv.QUOTE_NONE, escapechar=' ')
pos_diag.loc[validation_pos.index].to_csv(path_or_buf=output_datapath+'validation_pos_diag.txt', index=False, header=False, sep=' ', quoting=csv.QUOTE_NONE, escapechar=' ')

### 6.5 Export samples in 5 random train/test splits

In [107]:
# use a for loop to creat 5 different train-test splits
RANDOM_SEEDS = [3, 6, 9, 12, 15]

output_idx = 1

if not os.path.exists("../full_experiments1"):
    os.mkdir("../full_experiments1")
for seed in RANDOM_SEEDS:
    output_datapath = '../full_experiments1/mimic_data_' + DISEASE_TYPE.lower() + str(output_idx) + '/'
    
    os.makedirs(output_datapath, exist_ok=True)

    # Split into train/validation; 200 validation samples and the rest are train
    validation_neg = neg_data.sample(n=val_data_size, random_state=seed)
    train_neg = neg_data.drop(validation_neg.index)

    # Write as txt files into '../mimic_data/'
    train_neg.to_csv(path_or_buf=output_datapath+'train_neg.txt', index=False, header=False, sep=' ', quoting=csv.QUOTE_NONE, escapechar=' ')
    validation_neg.to_csv(path_or_buf=output_datapath+'validation_neg.txt', index=False, header=False, sep=' ', quoting=csv.QUOTE_NONE, escapechar=' ')

    # Write as txt files into '../mimic_data/', update: DIAGNOSIS CODES
    neg_diag.loc[train_neg.index].to_csv(path_or_buf=output_datapath+'train_neg_diag.txt', index=False, header=False, sep=' ', quoting=csv.QUOTE_NONE, escapechar=' ')
    neg_diag.loc[validation_neg.index].to_csv(path_or_buf=output_datapath+'validation_neg_diag.txt', index=False, header=False, sep=' ', quoting=csv.QUOTE_NONE, escapechar=' ')

    # Split into train/validation; 200 validation samples and the rest are train
    validation_pos = pos_data.sample(n=val_data_size, random_state=seed)
    train_pos = pos_data.drop(validation_pos.index)

    # Write as txt files into '../mimic_data/'
    train_pos.to_csv(path_or_buf=output_datapath+'train_poms.txt', index=False, header=False, sep=' ', quoting=csv.QUOTE_NONE, escapechar=' ')
    validation_pos.to_csv(path_or_buf=output_datapath+'validation_pos.txt', index=False, header=False, sep=' ', quoting=csv.QUOTE_NONE, escapechar=' ')

    # Write as txt files into '../mimic_data/', update: DIAGNOSIS CODES
    pos_diag.loc[train_pos.index].to_csv(path_or_buf=output_datapath+'train_pos_diag.txt', index=False, header=False, sep=' ', quoting=csv.QUOTE_NONE, escapechar=' ')
    pos_diag.loc[validation_pos.index].to_csv(path_or_buf=output_datapath+'validation_pos_diag.txt', index=False, header=False, sep=' ', quoting=csv.QUOTE_NONE, escapechar=' ')
    
    output_idx += 1