In [63]:
import sys
sys.executable

'/gpfs/gibbs/project/ying_rex/rkt23/conda_envs/ehr_project/bin/python'

In [1]:
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn import preprocessing
import copy
import os
import random
import json
import pickle

In [11]:
data_path = "/home/anonymous/ehr_dataset/eicu/physionet.org/files/eicu-crd/2.0"
# 'HADM_ID': encounter id
# 'SUBJECT_ID': patient id

In [3]:
file_list = [
    "patient.csv.gz",
    "medication.csv.gz",
    "lab.csv.gz",
    "diagnosis.csv.gz",
    "apachePatientResult.csv.gz"
]

In [13]:
med_file = os.path.join(data_path, "medication.csv.gz")
procedure_file = os.path.join(data_path, "treatment.csv.gz")
diag_file = os.path.join(data_path, "diagnosis.csv.gz")
patient_and_admission_file = os.path.join(data_path, "patient.csv.gz")
patient_result_file = os.path.join(data_path, "apachePatientResult.csv.gz")
lab_test_file = os.path.join(data_path, "lab.csv.gz")

# Data preprocessing

In [39]:
def process_med():
    med_pd = pd.read_csv(med_file, dtype={'drughiclseqno':'category'})
    # filter
    med_pd.drop(columns=['medicationid','drugorderoffset','drugivadmixture',
                         'drugordercancelled','drugname','dosage','routeadmin',
                         'frequency','loadingdose','prn','drugstopoffset',
                         'gtc'], axis=1, inplace=True)
    med_pd.drop(index = med_pd[med_pd['drughiclseqno'] == '0'].index, axis=0, inplace=True)
    med_pd.fillna(method='pad', inplace=True)
    med_pd.dropna(inplace=True)
    med_pd.drop_duplicates(inplace=True)
    med_pd['patientunitstayid'] = med_pd['patientunitstayid'].astype('int64')
    med_pd['drugstartoffset'] = med_pd['drugstartoffset'].astype('int64')
    med_pd.sort_values(by=['patientunitstayid', 'drugstartoffset'], inplace=True)
    med_pd = med_pd.reset_index(drop=True)

    def filter_first24hour_med(med_pd):
        med_pd_new = med_pd.drop(columns=['drughiclseqno'])
        med_pd_new = med_pd_new.groupby(by=['patientunitstayid', 'drugstartoffset']).head(1).reset_index(drop=True)
        med_pd_new = pd.merge(med_pd_new, med_pd, on=['patientunitstayid', 'drugstartoffset'])
        med_pd_new = med_pd_new.drop(columns=['drugstartoffset'])
        return med_pd_new
    med_pd = filter_first24hour_med(med_pd) # or next line
    med_pd = med_pd.drop_duplicates()
 
    med_pd["drughiclseqno"] = "NDC_" + med_pd["drughiclseqno"].astype(str)
    return med_pd.reset_index(drop=True)

def process_diag():
    diag_pd = pd.read_csv(diag_file)
    diag_pd.dropna(inplace=True)
    diag_pd.drop(columns=['activeupondischarge', 'diagnosisid', 'diagnosisoffset',
                          'diagnosisstring', 'diagnosispriority'], inplace=True)
    
    diag_pd = diag_pd[diag_pd['icd9code'].notna()]
    diag_pd['icd9code'] = diag_pd['icd9code'].map(lambda x: x.split(", ")[0])

    diag_pd.drop_duplicates(inplace=True)
    diag_pd.sort_values(by=['patientunitstayid'], inplace=True)
    return diag_pd.reset_index(drop=True)

def process_admission():
    patient_and_ad_pd = pd.read_csv(patient_and_admission_file)
    result_pd = pd.read_csv(patient_result_file)
    
    patient_and_ad_pd.drop(columns=['ethnicity', 'hospitalid', 'wardid', 'apacheadmissiondx',
                                    'admissionheight', 'hospitaladmittime24', 'hospitaladmitoffset',
                                    'hospitaladmitsource', 'hospitaldischargeyear', 'hospitaldischargetime24',
                                    'hospitaldischargeoffset', 'hospitaldischargelocation', 'hospitaldischargestatus',
                                    'unittype', 'unitadmittime24', 'unitadmitsource', 'unitvisitnumber', 'unitstaytype',
                                    'admissionweight', 'dischargeweight', 'unitdischargetime24', 'unitdischargeoffset',
                                    'unitdischargelocation', 'unitdischargestatus'], axis=1, inplace=True)
    
    result_pd.drop(columns=['apachepatientresultsid', 'physicianspeciality', 'physicianinterventioncategory',
                            'acutephysiologyscore', 'apachescore', 'apacheversion', 'predictedicumortality',
                            'actualicumortality', 'predictediculos', 'actualiculos', 'predictedhospitalmortality',
                            'predictedhospitallos', 'actualhospitallos', 'preopmi', 'preopcardiaccath',
                            'ptcawithin24h', 'unabridgedunitlos', 'actualventdays', 'predventdays',
                            'unabridgedactualventdays'], axis=1, inplace=True)
    
    patient_and_ad_pd = patient_and_ad_pd.merge(result_pd, on=['patientunitstayid'], how='inner')
    patient_and_ad_pd.dropna(inplace=True)
    
    # create features: age, death, number of days in this encounter, readmission (next visit)
    # ad_pd["AGE"] = ad_pd['ADMITTIME'].dt.year - ad_pd['DOB'].dt.year
    patient_and_ad_pd['age'] = patient_and_ad_pd['age'].map(lambda x: '90' if x == '> 89' else x)
    patient_and_ad_pd['age'] = patient_and_ad_pd['age'].astype('int64')
    age = patient_and_ad_pd['age']
    bins = np.linspace(age.min(), age.max(), 20 + 1)
    patient_and_ad_pd['age'] = pd.cut(age, bins=bins, labels=False, include_lowest=True)
    patient_and_ad_pd['age'] = "AGE_" + patient_and_ad_pd['age'].astype("str")
    
    patient_and_ad_pd["death"] = (patient_and_ad_pd['actualhospitalmortality'] == "EXPIRED")
    # ad_pd["STAY_DAYS"] = (ad_pd["DISCHTIME"] - ad_pd["ADMITTIME"]).astype('timedelta64[h]').dt.days
    
    # ad_pd['ADMITTIME'] = ad_pd['ADMITTIME'].astype(str)
    patient_and_ad_pd.sort_values(by=['uniquepid', 'patienthealthsystemstayid', 'patientunitstayid'], inplace=True)
    
    patient_and_ad_pd.drop(columns=['actualhospitalmortality'], axis=1, inplace=True)
    patient_and_ad_pd.drop_duplicates(inplace=True)
    return patient_and_ad_pd.reset_index(drop=True)

def process_lab_test(n_bins=5):
    lab_pd = pd.read_csv(lab_test_file)
    lab_pd['itemid'] = lab_pd.groupby(by=['labname']).ngroup()
    
    # only consider the first lab test value for each lab test in a given stay
    lab_pd = lab_pd.groupby(by=['patientunitstayid', 'itemid']).head(1).reset_index(drop=True)

    lab_pd = lab_pd[lab_pd["patientunitstayid"].notna()]
    lab_pd = lab_pd[lab_pd["labresult"].notna()]
    
    lab_pd.drop(columns=['labresult'], axis=1, inplace=True)
    
    def contains_text(group):
        for item in group:
            try:
                float(item)
            except ValueError:
                return True
        return False

    for itemid in lab_pd['itemid'].unique():
        group = lab_pd[lab_pd['itemid'] == itemid]['labresulttext']

        # if the lab test contains text value then directly copy the value
        if contains_text(group):
            lab_pd.loc[lab_pd['itemid'] == itemid, 'value_bin'] = group
        else:
            # value->numeric
            values_numeric = pd.to_numeric(group, errors='coerce')
            
            if len(values_numeric.dropna()) < n_bins:
                lab_pd.loc[lab_pd['itemid'] == itemid, 'value_bin'] = group
            else:
                # cut
                # bins = np.linspace(values_numeric.min(), values_numeric.max(), n_bins + 1)
                #  lab_pd.loc[data['ITEMID'] == itemid, 'value_bin'] = pd.cut(values_numeric, bins=bins, labels=False, include_lowest=True)
                lab_pd.loc[lab_pd['itemid'] == itemid, 'value_bin'] = pd.qcut(values_numeric, q=n_bins, labels=False, duplicates='drop')
    
    lab_pd["itemid"] = lab_pd["itemid"].astype(str)
    lab_pd["value_bin"] = lab_pd["value_bin"].astype(str)
    lab_pd["lab_test"] = lab_pd[["itemid", "value_bin"]].apply("-".join, axis=1)
    
    lab_pd.drop(columns=['labresulttext', 'value_bin', 'itemid'], axis=1, inplace=True)
    lab_pd.drop_duplicates(inplace=True)
    return lab_pd.reset_index(drop=True)

In [6]:
def filter_most_diag(diag_pd, threshold=2000):
    diag_count = diag_pd.groupby(by=['icd9code']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
    diag_pd = diag_pd[diag_pd['icd9code'].isin(diag_count.loc[:threshold, 'icd9code'])]
    return diag_pd.reset_index(drop=True)

def filter_most_lab(lab_pd, threshold=1500):
    lab_count = lab_pd.groupby(by=['lab_test']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
    lab_pd = lab_pd[lab_pd['lab_test'].isin(lab_count.loc[:threshold, 'lab_test'])]
    return lab_pd.reset_index(drop=True)

In [40]:
# from datetime import timedelta

def process_all():
    print('process_med')
    med_pd = process_med()
    # med_pd = ndc2atc4(med_pd)

    print('process_diag')
    diag_pd = process_diag()
    diag_pd = filter_most_diag(diag_pd)

    print('process_ad')
    ad_pd = process_admission()
    
    print('process_lab')
    lab_pd = process_lab_test()
    lab_pd = filter_most_lab(lab_pd)
    
    med_pd_key = med_pd[['patientunitstayid']].drop_duplicates()
    diag_pd_key = diag_pd[['patientunitstayid']].drop_duplicates()
    # pro_pd_key = pro_pd[['patientunitstayid']].drop_duplicates()
    lab_pd_key = lab_pd[['patientunitstayid']].drop_duplicates()
    ad_pd_key = ad_pd[['uniquepid', 'patienthealthsystemstayid', 'patientunitstayid']].drop_duplicates()
    
    # filter key
    combined_key = med_pd_key.merge(diag_pd_key, on=['patientunitstayid'], how='inner')
    # combined_key = combined_key.merge(pro_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    combined_key = combined_key.merge(lab_pd_key, on=['patientunitstayid'], how='inner')
    combined_key = combined_key.merge(ad_pd_key, on=['patientunitstayid'], how='inner')

    diag_pd = diag_pd.merge(combined_key, on=['patientunitstayid'], how='inner')
    med_pd = med_pd.merge(combined_key, on=['patientunitstayid'], how='inner')
    # pro_pd = pro_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    lab_pd = lab_pd.merge(combined_key, on=['patientunitstayid'], how='inner')
    ad_pd = ad_pd.merge(combined_key, on=['uniquepid', 'patienthealthsystemstayid', 'patientunitstayid'], how='inner')

    # flatten and merge
    diag_pd = diag_pd.groupby(by=['uniquepid', 'patienthealthsystemstayid'])['icd9code'].unique().reset_index()
    med_pd = med_pd.groupby(by=['uniquepid', 'patienthealthsystemstayid'])['drughiclseqno'].unique().reset_index()
    # pro_pd = pro_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index().rename(columns={'ICD9_CODE':'PRO_CODE'})
    lab_pd = lab_pd.groupby(by=['uniquepid', 'patienthealthsystemstayid'])['lab_test'].unique().reset_index()
    
    med_pd['drughiclseqno'] = med_pd['drughiclseqno'].map(lambda x: list(x))
    # pro_pd['PRO_CODE'] = pro_pd['PRO_CODE'].map(lambda x: list(x))
    lab_pd['lab_test'] = lab_pd['lab_test'].map(lambda x: list(x))
    
    data = diag_pd.merge(med_pd, on=['uniquepid', 'patienthealthsystemstayid'], how='inner')
    # data = data.merge(pro_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    data = data.merge(lab_pd, on=['uniquepid', 'patienthealthsystemstayid'], how='inner')
    data = data.merge(ad_pd, on=['uniquepid', 'patienthealthsystemstayid'], how='inner')
#     data['ICD9_CODE_Len'] = data['ICD9_CODE'].map(lambda x: len(x))
#     data['NDC_Len'] = data['NDC'].map(lambda x: len(x))

    data.drop(columns=['patientunitstayid'], axis=1, inplace=True)
    data.drop_duplicates(subset=['uniquepid', 'patienthealthsystemstayid'], inplace=True)

    data = data.sort_values(by=['uniquepid'])
    data = data.rename(columns={'uniquepid':'SUBJECT_ID', 'patienthealthsystemstayid':'HADM_ID', 'icd9code':'ICD9_CODE',
                                'drughiclseqno':'NDC', 'lab_test':'LAB_TEST', 'gender':'GENDER', 'age':'AGE',
                                'death':'DEATH', 'unabridgedhosplos':'STAY_DAYS'})
    return data.reset_index(drop=True)

In [20]:
def statistics(data):
    print('#patients ', data['SUBJECT_ID'].unique().shape)
    print('#clinical events ', len(data))
    
    diag = data['ICD9_CODE'].values
    med = data['NDC'].values
    # pro = data['PRO_CODE'].values
    lab_test = data['LAB_TEST'].values
    
    unique_diag = set([j for i in diag for j in list(i)])
    unique_med = set([j for i in med for j in list(i)])
    # unique_pro = set([j for i in pro for j in list(i)])
    unique_lab = set([j for i in lab_test for j in list(i)])
    
    print('#diagnosis ', len(unique_diag))
    print('#med ', len(unique_med))
    # print('#procedure', len(unique_pro))
    print('#lab', len(unique_lab))
    
    avg_diag = avg_med = avg_pro = avg_lab = 0
    max_diag = max_med = max_pro = max_lab = 0
    cnt = max_visit = avg_visit = 0

    for subject_id in data['SUBJECT_ID'].unique():
        item_data = data[data['SUBJECT_ID'] == subject_id]
        x, y, z, k = [], [], [], []
        visit_cnt = 0
        for index, row in item_data.iterrows():
            visit_cnt += 1
            cnt += 1
            x.extend(list(row['ICD9_CODE']))
            y.extend(list(row['NDC']))
            # z.extend(list(row['PRO_CODE']))
            k.extend(list(row['LAB_TEST']))
        x, y, z, k = set(x), set(y), set(z), set(k)
        avg_diag += len(x)
        avg_med += len(y)
        avg_pro += len(z)
        avg_lab += len(k)
        avg_visit += visit_cnt
        if len(x) > max_diag:
            max_diag = len(x)
        if len(y) > max_med:
            max_med = len(y) 
        if len(z) > max_pro:
            max_pro = len(z)
        if len(k) > max_lab:
            max_lab = len(k)
        if visit_cnt > max_visit:
            max_visit = visit_cnt

    print('#avg of diagnoses ', avg_diag/ cnt)
    print('#avg of medicines ', avg_med/ cnt)
    print('#avg of procedures ', avg_pro/ cnt)
    print('#avg of lab_test ', avg_lab/ cnt)
    print('#avg of vists ', avg_visit/ len(data['SUBJECT_ID'].unique()))

    print('#max of diagnoses ', max_diag)
    print('#max of medicines ', max_med)
    print('#max of procedures ', max_pro)
    print('#max of lab_test ', max_lab)
    print('#max of visit ', max_visit)

In [41]:
data = process_all()

process_med


  med_pd = process_med()


process_diag
process_ad
process_lab


In [42]:
data.head(10)

Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE,NDC,LAB_TEST,GENDER,AGE,STAY_DAYS,DEATH
0,002-10018,178200,"[586, 995.90, 038.9, 458.9]","[NDC_5175, NDC_1777, NDC_525, NDC_1874, NDC_35...","[2-1.0, 4-0.0, 52-4.0, 145-3.0, 44-1.0, 134-4....",Female,AGE_6,16.5083,False
1,002-10034,141169,[573.9],"[NDC_21772, NDC_1744, NDC_1447, NDC_3937, NDC_...","[110-0.0, 145-3.0, 3-3.0, 106-4.0, 107-1.0, 11...",Female,AGE_5,5.2104,False
2,002-10052,137239,"[518.82, 486, 491.20, 584.9, 790.6, 289.9, 785...","[NDC_12383, NDC_8738, NDC_8255, NDC_36346, NDC...","[157-4.0, 110-3.0, 93-2.0, 82-2.0, 148-2.0, 11...",Female,AGE_14,4.4292,False
3,002-10066,185872,[585.6],"[NDC_19078, NDC_926, NDC_585, NDC_13209, NDC_2...","[1-0.0, 110-0.0, 134-1.0, 33-4.0, 82-0.0, 133-...",Female,AGE_9,5.8722,False
4,002-10079,136669,"[275.41, 458.9, 008.45, 785.52, 197.0, 276.2, ...","[NDC_1866, NDC_8255, NDC_6306, NDC_926, NDC_12...","[116-2.0, 110-0.0, 125-0.0, 146-1.0, 15-0.0, 1...",Female,AGE_13,6.4396,True
5,002-1010,154941,"[287.5, 578.9]","[NDC_1628, NDC_11249, NDC_6306, NDC_33598, NDC...","[128-0.0, 110-0.0, 97-0.0, 43-1.0, 112-1.0, 72...",Female,AGE_14,5.0549,False
6,002-1012,162659,"[560.81, 427.31, 401.9, 560.9, 511.9]","[NDC_1730, NDC_549, NDC_1866, NDC_21772, NDC_1...","[135-0.0, 110-2.0, 63-0.0, 150-3.0, 128-2.0, 1...",Female,AGE_17,27.2063,False
7,002-10122,140376,"[518.81, 780.57, 401.9]","[NDC_8255, NDC_2102, NDC_1866, NDC_18979, NDC_...","[44-2.0, 111-3.0, 43-3.0, 62-0.0, 148-2.0, 63-...",Female,AGE_12,3.0403,False
8,002-10148,172762,"[162.9, 427.31, 414.00, 584.9, 276.2, 518.81]","[NDC_926, NDC_807, NDC_19078, NDC_540, NDC_585...","[110-3.0, 130-1.0, 132-2.0, 131-3.0, 17-3.0, 3...",Female,AGE_15,14.0,False
9,002-1015,176710,"[458.9, 571.5, 287.5, 851.80, 852.20, 852.00]","[NDC_1866, NDC_915, NDC_549, NDC_19078, NDC_80...","[43-3.0, 146-4.0, 110-2.0, 64-2.0, 63-3.0, 44-...",Female,AGE_17,7.7938,False


In [21]:
statistics(data)

#patients  (86290,)
#clinical events  100839
#diagnosis  903
#med  2044
#lab 756
#avg of diagnoses  3.3438848064736857
#avg of medicines  24.03591864258868
#avg of procedures  0.0
#avg of lab_test  39.778617400013886
#avg of vists  1.168605863947155
#max of diagnoses  85
#max of medicines  194
#max of procedures  0
#max of lab_test  236
#max of visit  24


In [43]:
with open("./dataset/eicu.pkl", "wb") as outfile:
    pickle.dump(data, outfile)

# Dataset split

In [44]:
eicu_data = pickle.load(open("./dataset/eicu.pkl", 'rb'))

In [45]:
eicu_data.columns

Index(['SUBJECT_ID', 'HADM_ID', 'ICD9_CODE', 'NDC', 'LAB_TEST', 'GENDER',
       'AGE', 'STAY_DAYS', 'DEATH'],
      dtype='object')

In [46]:
eicu_data.head(10)

Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE,NDC,LAB_TEST,GENDER,AGE,STAY_DAYS,DEATH
0,002-10018,178200,"[586, 995.90, 038.9, 458.9]","[NDC_5175, NDC_1777, NDC_525, NDC_1874, NDC_35...","[2-1.0, 4-0.0, 52-4.0, 145-3.0, 44-1.0, 134-4....",Female,AGE_6,16.5083,False
1,002-10034,141169,[573.9],"[NDC_21772, NDC_1744, NDC_1447, NDC_3937, NDC_...","[110-0.0, 145-3.0, 3-3.0, 106-4.0, 107-1.0, 11...",Female,AGE_5,5.2104,False
2,002-10052,137239,"[518.82, 486, 491.20, 584.9, 790.6, 289.9, 785...","[NDC_12383, NDC_8738, NDC_8255, NDC_36346, NDC...","[157-4.0, 110-3.0, 93-2.0, 82-2.0, 148-2.0, 11...",Female,AGE_14,4.4292,False
3,002-10066,185872,[585.6],"[NDC_19078, NDC_926, NDC_585, NDC_13209, NDC_2...","[1-0.0, 110-0.0, 134-1.0, 33-4.0, 82-0.0, 133-...",Female,AGE_9,5.8722,False
4,002-10079,136669,"[275.41, 458.9, 008.45, 785.52, 197.0, 276.2, ...","[NDC_1866, NDC_8255, NDC_6306, NDC_926, NDC_12...","[116-2.0, 110-0.0, 125-0.0, 146-1.0, 15-0.0, 1...",Female,AGE_13,6.4396,True
5,002-1010,154941,"[287.5, 578.9]","[NDC_1628, NDC_11249, NDC_6306, NDC_33598, NDC...","[128-0.0, 110-0.0, 97-0.0, 43-1.0, 112-1.0, 72...",Female,AGE_14,5.0549,False
6,002-1012,162659,"[560.81, 427.31, 401.9, 560.9, 511.9]","[NDC_1730, NDC_549, NDC_1866, NDC_21772, NDC_1...","[135-0.0, 110-2.0, 63-0.0, 150-3.0, 128-2.0, 1...",Female,AGE_17,27.2063,False
7,002-10122,140376,"[518.81, 780.57, 401.9]","[NDC_8255, NDC_2102, NDC_1866, NDC_18979, NDC_...","[44-2.0, 111-3.0, 43-3.0, 62-0.0, 148-2.0, 63-...",Female,AGE_12,3.0403,False
8,002-10148,172762,"[162.9, 427.31, 414.00, 584.9, 276.2, 518.81]","[NDC_926, NDC_807, NDC_19078, NDC_540, NDC_585...","[110-3.0, 130-1.0, 132-2.0, 131-3.0, 17-3.0, 3...",Female,AGE_15,14.0,False
9,002-1015,176710,"[458.9, 571.5, 287.5, 851.80, 852.20, 852.00]","[NDC_1866, NDC_915, NDC_549, NDC_19078, NDC_80...","[43-3.0, 146-4.0, 110-2.0, 64-2.0, 63-3.0, 44-...",Female,AGE_17,7.7938,False


In [47]:
len(eicu_data["SUBJECT_ID"].unique())

86290

In [48]:
pat_death = set(eicu_data[eicu_data["DEATH"]]["SUBJECT_ID"].values.tolist())
print(len(pat_death))
# pat_all_label = list(pat_readmission | pat_nextdiag_6m | pat_nextdiag_12m | pat_death)
pat_all_label = list(pat_death)

9835


In [49]:
pat_all = eicu_data["SUBJECT_ID"].unique().tolist()

In [50]:
n_pretrain_patient = int(len(pat_all) * 0.7)
pretrain_patient = np.random.choice(list(set(pat_all) - set(pat_all_label)), n_pretrain_patient, replace=False).tolist()
downstream_patient = list(set(pat_all) - set(pretrain_patient))
print(len(pretrain_patient), len(downstream_patient))

60402 25888


In [51]:
pretrain_dataset = eicu_data[eicu_data["SUBJECT_ID"].isin(set(pretrain_patient))]

In [52]:
# downstream task: PLOS, death prediction
n_finetune_pat, n_eval_pat = int(len(downstream_patient) * 0.2), int(len(downstream_patient) * 0.4)
np.random.shuffle(downstream_patient)
finetune_pat, eval_pat, test_pat = downstream_patient[:n_finetune_pat], \
                                    downstream_patient[n_finetune_pat:n_finetune_pat+n_eval_pat], \
                                    downstream_patient[n_finetune_pat+n_eval_pat:]
finetune_dataset = eicu_data[eicu_data["SUBJECT_ID"].isin(set(finetune_pat))]
eval_dataset = eicu_data[eicu_data["SUBJECT_ID"].isin(set(eval_pat))]
test_dataset = eicu_data[eicu_data["SUBJECT_ID"].isin(set(test_pat))]

In [53]:
with open("./dataset/eicu_pretrain.pkl", "wb") as outfile:
    pickle.dump(pretrain_dataset, outfile)
with open("./dataset/eicu_downstream.pkl", "wb") as outfile:
    pickle.dump([finetune_dataset, eval_dataset, test_dataset], outfile)