In [None]:
! pip install -e .

In [1]:
import sys
sys.executable
from heterogt.utils.seed import set_random_seed

In [2]:
RANDOM_SEED = 123
set_random_seed(RANDOM_SEED)

[INFO] Detected Apple MPS backend (Mac with M1/M2/M3). Skipping CUDA seeds.
[INFO] Random seed set to 123


In [3]:
import pandas as pd
import numpy as np
import os
import pickle

In [4]:
data_path = "./MIMIC-III-raw/"
# 'HADM_ID': encounter id
# 'SUBJECT_ID': patient id

In [5]:
med_file = os.path.join(data_path, "PRESCRIPTIONS.csv.gz")
procedure_file = os.path.join(data_path, "PROCEDURES_ICD.csv.gz")
diag_file = os.path.join(data_path, "DIAGNOSES_ICD.csv.gz")
admission_file = os.path.join(data_path, "ADMISSIONS.csv.gz")
lab_test_file = os.path.join(data_path, "LABEVENTS.csv.gz")
patient_file = os.path.join(data_path, "PATIENTS.csv.gz")

In [6]:
# drug code mapping files from GAMENet repo
# https://github.com/sjy1203/GAMENet/tree/master/data
ndc2atc_file = './GAMENet/ndc2atc_level4.csv' 
cid_atc = './GAMENet/drug-atc.csv'
ndc2rxnorm_file = './GAMENet/ndc2rxnorm_mapping.txt'

# Data preprocessing

In [7]:
def process_med():
    med_pd = pd.read_csv(med_file, dtype={'NDC':'category'})
    # filter
    med_pd.drop(columns=['ROW_ID','DRUG_TYPE','DRUG_NAME_POE','DRUG_NAME_GENERIC',
                     'FORMULARY_DRUG_CD','GSN','PROD_STRENGTH','DOSE_VAL_RX',
                     'DOSE_UNIT_RX','FORM_VAL_DISP','FORM_UNIT_DISP','FORM_UNIT_DISP',
                      'ROUTE','ENDDATE','DRUG'], axis=1, inplace=True)
    med_pd.drop(index = med_pd[med_pd['NDC'] == '0'].index, axis=0, inplace=True)
    med_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'ICUSTAY_ID', 'STARTDATE'], inplace=True)
    med_pd.fillna(method='pad', inplace=True)
    med_pd.dropna(inplace=True)
    med_pd.drop_duplicates(inplace=True)
    med_pd['ICUSTAY_ID'] = med_pd['ICUSTAY_ID'].astype('int64')
    med_pd['STARTDATE'] = pd.to_datetime(med_pd['STARTDATE'], format='%Y-%m-%d %H:%M:%S')    
    med_pd = med_pd.reset_index(drop=True)
    
    def filter_first24hour_med(med_pd):
        med_pd_new = med_pd.drop(columns=['NDC'])
        med_pd_new = med_pd_new.groupby(by=['SUBJECT_ID','HADM_ID','ICUSTAY_ID']).head(1).reset_index(drop=True)
        med_pd_new = pd.merge(med_pd_new, med_pd, on=['SUBJECT_ID','HADM_ID','ICUSTAY_ID','STARTDATE'])
        med_pd_new = med_pd_new.drop(columns=['STARTDATE'])
        return med_pd_new
    
    med_pd = filter_first24hour_med(med_pd) 
    
    med_pd = med_pd.drop(columns=['ICUSTAY_ID'])
    med_pd = med_pd.drop_duplicates() 
    return med_pd.reset_index(drop=True)

def process_procedure():
    pro_pd = pd.read_csv(procedure_file, dtype={'ICD9_CODE':'category'})
    pro_pd.drop(columns=['ROW_ID'], inplace=True)
    pro_pd.drop_duplicates(inplace=True)
    pro_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'SEQ_NUM'], inplace=True)
    pro_pd.drop(columns=['SEQ_NUM'], inplace=True)
    pro_pd.drop_duplicates(inplace=True)
    pro_pd.reset_index(drop=True, inplace=True)
    pro_pd['ICD9_CODE'] = 'PRO_' + pro_pd['ICD9_CODE'].astype(str)
    return pro_pd

def process_diag():
    diag_pd = pd.read_csv(diag_file)
    diag_pd.dropna(inplace=True)
    diag_pd.drop(columns=['SEQ_NUM','ROW_ID'],inplace=True)
    diag_pd.drop_duplicates(inplace=True)
    diag_pd.sort_values(by=['SUBJECT_ID','HADM_ID'], inplace=True)
    diag_pd["ICD9_CODE"] = "DIAG_" + diag_pd["ICD9_CODE"].astype(str)
    return diag_pd.reset_index(drop=True)

def process_admission():
    ad_pd = pd.read_csv(admission_file)
    patient_pd = pd.read_csv(patient_file)
    
    ad_pd.drop(columns=['ROW_ID', 'ADMISSION_LOCATION',
       'DISCHARGE_LOCATION', 'INSURANCE', 'LANGUAGE', 'RELIGION',
       'MARITAL_STATUS', 'ETHNICITY', 'EDREGTIME', 'EDOUTTIME', 'DIAGNOSIS',
       'HOSPITAL_EXPIRE_FLAG', 'HAS_CHARTEVENTS_DATA'], axis=1, inplace=True)
    
    patient_pd.drop(columns=['ROW_ID','DOD','DOD_HOSP','DOD_SSN',"EXPIRE_FLAG"], axis=1, inplace=True)
    
    ad_pd["ADMITTIME"] = pd.to_datetime(ad_pd['ADMITTIME'], format='%Y-%m-%d %H:%M:%S', errors='coerce')
    ad_pd["DISCHTIME"] = pd.to_datetime(ad_pd['DISCHTIME'], format='%Y-%m-%d %H:%M:%S', errors='coerce')  # time for leaving hospital
    ad_pd["STAY_DAYS"] = (ad_pd["DISCHTIME"] - ad_pd["ADMITTIME"]).dt.days
    patient_pd['DOB'] = pd.to_datetime(patient_pd['DOB'], format='%Y-%m-%d %H:%M:%S', errors='coerce')  # birthday
    
    ad_pd = ad_pd.merge(patient_pd, on=['SUBJECT_ID'], how='inner')
    
    # create features: age, death, number of days in this encounter, readmission (next visit)
    ad_pd["AGE"] = ad_pd['ADMITTIME'].dt.year - ad_pd['DOB'].dt.year
    ad_pd[ad_pd["AGE"] >= 300] = 90
    age = ad_pd["AGE"]
    bins = np.linspace(age.min(), age.max(), 20 + 1)
    ad_pd['AGE'] = pd.cut(age, bins=bins, labels=False, include_lowest=True)
    ad_pd['AGE'] = "AGE_" + ad_pd["AGE"].astype("str")
    
    ad_pd["DEATH"] = ad_pd["DEATHTIME"].notna()

    ad_pd['ADMITTIME'] = ad_pd['ADMITTIME'].astype(str)
    ad_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'ADMITTIME'], inplace=True)
    ad_pd['next_visit'] = ad_pd.groupby('SUBJECT_ID')['HADM_ID'].shift(-1)
    ad_pd['READMISSION'] = ad_pd['next_visit'].notnull().astype(int)
    ad_pd.drop('next_visit', axis=1, inplace=True)

    ad_pd.drop(columns=['DISCHTIME', 'DOB', 'DEATHTIME'], axis=1, inplace=True)
    ad_pd.drop_duplicates(inplace=True)
    return ad_pd.reset_index(drop=True)

def process_lab_test(n_bins=5):
    lab_pd = pd.read_csv(lab_test_file)
    lab_pd = lab_pd.groupby(by=['SUBJECT_ID','ITEMID']).head(1).reset_index(drop=True)  # only consider the first value
    lab_pd = lab_pd[lab_pd["VALUENUM"].notna()]
    lab_pd = lab_pd[lab_pd["HADM_ID"].notna()]
    
    lab_pd.drop(columns=['ROW_ID'], axis=1, inplace=True)
    
    def contains_text(group):
        for item in group:
            try:
                float(item)
            except ValueError:
                return True
        return False

    for itemid in lab_pd['ITEMID'].unique():
        group = lab_pd[lab_pd['ITEMID'] == itemid]['VALUE']

        # if the lab test contains text value then directly copy the value
        if contains_text(group):
            lab_pd.loc[lab_pd['ITEMID'] == itemid, 'value_bin'] = group
        else:
            # value->numeric
            values_numeric = pd.to_numeric(group, errors='coerce')

            if len(values_numeric.dropna()) < n_bins:
                lab_pd.loc[lab_pd['ITEMID'] == itemid, 'value_bin'] = group
            else:
                lab_pd.loc[lab_pd['ITEMID'] == itemid, 'value_bin'] = pd.qcut(values_numeric, q=n_bins, labels=False, duplicates='drop')
        
    lab_pd["ITEMID"] = lab_pd["ITEMID"].astype(str)
    lab_pd["value_bin"] = lab_pd["value_bin"].astype(str)
    lab_pd["LAB_TEST"] = lab_pd[["ITEMID", "value_bin"]].apply("-".join, axis=1)
    
    lab_pd.drop(columns=['CHARTTIME', 'VALUE', 'VALUENUM', 'VALUEUOM', 'FLAG', 'value_bin', 'ITEMID'], axis=1, inplace=True)
    lab_pd.drop_duplicates(inplace=True)
    lab_pd.reset_index(drop=True, inplace=True)
    lab_pd['LAB_TEST'] = "LAB_" + lab_pd["LAB_TEST"].astype(str)
    return lab_pd

In [8]:
def ndc2atc4(med_pd):
    with open(ndc2rxnorm_file, 'r') as f:
        ndc2rxnorm = eval(f.read())
    med_pd['RXCUI'] = med_pd['NDC'].map(ndc2rxnorm)
    med_pd.dropna(inplace=True)

    rxnorm2atc = pd.read_csv(ndc2atc_file)
    rxnorm2atc = rxnorm2atc.drop(columns=['YEAR','MONTH','NDC'])
    rxnorm2atc.drop_duplicates(subset=['RXCUI'], inplace=True)
    med_pd.drop(index = med_pd[med_pd['RXCUI'].isin([''])].index, axis=0, inplace=True)
    
    med_pd['RXCUI'] = med_pd['RXCUI'].astype('int64')
    med_pd = med_pd.reset_index(drop=True)
    med_pd = med_pd.merge(rxnorm2atc, on=['RXCUI'])
    med_pd.drop(columns=['NDC', 'RXCUI'], inplace=True)
    med_pd = med_pd.rename(columns={'ATC4':'NDC'})
    med_pd['NDC'] = med_pd['NDC'].map(lambda x: x[:4])
    med_pd = med_pd.drop_duplicates()
    med_pd["NDC"] = "MED_" + med_pd["NDC"].astype(str)    
    return med_pd.reset_index(drop=True)

def filter_most_pro(pro_pd, threshold=800):
    pro_count = pro_pd.groupby(by=['ICD9_CODE']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
    pro_pd = pro_pd[pro_pd['ICD9_CODE'].isin(pro_count.loc[:threshold, 'ICD9_CODE'])]
    return pro_pd.reset_index(drop=True)

def filter_most_diag(diag_pd, threshold=2000):
    diag_count = diag_pd.groupby(by=['ICD9_CODE']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
    diag_pd = diag_pd[diag_pd['ICD9_CODE'].isin(diag_count.loc[:threshold, 'ICD9_CODE'])]
    return diag_pd.reset_index(drop=True)

def filter_most_lab(lab_pd, threshold=1500):
    lab_count = lab_pd.groupby(by=['LAB_TEST']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
    lab_pd = lab_pd[lab_pd['LAB_TEST'].isin(lab_count.loc[:threshold, 'LAB_TEST'])]
    return lab_pd.reset_index(drop=True)

In [9]:
def remove_high_visit_patients(df: pd.DataFrame, visit_threshold: int = 8) -> pd.DataFrame:
    """
    Remove all rows corresponding to patients with more than `visit_threshold` unique visits.

    Parameters:
        df (pd.DataFrame): The input EHR dataframe. Must contain 'SUBJECT_ID' and 'HADM_ID'.
        visit_threshold (int): The maximum number of visits allowed per patient.

    Returns:
        pd.DataFrame: A filtered dataframe containing only patients with ≤ visit_threshold visits.
    """
    # Step 1: Count visits per patient
    visit_counts = df.groupby("SUBJECT_ID")["HADM_ID"].nunique()

    # Step 2: Identify valid SUBJECT_IDs
    valid_subjects = visit_counts[visit_counts <= visit_threshold].index

    # Step 3: Filter the dataframe
    filtered_df = df[df["SUBJECT_ID"].isin(valid_subjects)].copy()

    # Optional: logging
    num_removed_patients = df["SUBJECT_ID"].nunique() - filtered_df["SUBJECT_ID"].nunique()
    num_removed_rows = len(df) - len(filtered_df)
    print(f"Removed {num_removed_patients} patients and {num_removed_rows} rows exceeding {visit_threshold} visits.")

    return filtered_df

In [10]:
from datetime import timedelta

def process_all():
    print('process_med')
    med_pd = process_med()
    med_pd = ndc2atc4(med_pd)

    print('process_diag')
    diag_pd = process_diag()
    diag_pd = filter_most_diag(diag_pd)

    print('process_pro')
    pro_pd = process_procedure()
    pro_pd = filter_most_pro(pro_pd)

    print('process_ad')
    ad_pd = process_admission()
    
    print('process_lab')
    lab_pd = process_lab_test()
    lab_pd = filter_most_lab(lab_pd)
    
    med_pd_key = med_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    diag_pd_key = diag_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    pro_pd_key = pro_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    lab_pd_key = lab_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    ad_pd_key = ad_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    
    # filter key
    combined_key = med_pd_key.merge(diag_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    combined_key = combined_key.merge(pro_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    combined_key = combined_key.merge(lab_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    combined_key = combined_key.merge(ad_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    diag_pd = diag_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    med_pd = med_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    pro_pd = pro_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    lab_pd = lab_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    ad_pd = ad_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')

    # flatten and merge
    diag_pd = diag_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index()  
    med_pd = med_pd.groupby(by=['SUBJECT_ID', 'HADM_ID'])['NDC'].unique().reset_index()
    pro_pd = pro_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index().rename(columns={'ICD9_CODE':'PRO_CODE'})  
    lab_pd = lab_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['LAB_TEST'].unique().reset_index()
    
    med_pd['NDC'] = med_pd['NDC'].map(lambda x: list(x))
    pro_pd['PRO_CODE'] = pro_pd['PRO_CODE'].map(lambda x: list(x))
    lab_pd['LAB_TEST'] = lab_pd['LAB_TEST'].map(lambda x: list(x))
    
    data = diag_pd.merge(med_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    data = data.merge(pro_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    data = data.merge(lab_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    data = data.merge(ad_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')

    data = data.sort_values(by=['SUBJECT_ID', 'ADMITTIME'])
    
    # filter rows without disease codes
    data = data[data['ICD9_CODE'].notnull()]

    # create feature: readmission within 30/90 days
    data['ADMITTIME'] = pd.to_datetime(data['ADMITTIME'])
    data['READMISSION_1M'] = data.groupby('SUBJECT_ID')['ADMITTIME'].shift(-1) - data['ADMITTIME']
    data['READMISSION_3M'] = data['READMISSION_1M'].apply(lambda x: 1 if x <= timedelta(days=90) else 0)
    data['READMISSION_1M'] = data['READMISSION_1M'].apply(lambda x: 1 if x <= timedelta(days=30) else 0)
    
    # create feature: diease in next 6/12 months    
    data['NEXT_DIAG_6M'] = data.apply(lambda x: data[(data['SUBJECT_ID'] == x['SUBJECT_ID']) & 
                                              (data['ADMITTIME'] > x['ADMITTIME']) & 
                                              (data['ADMITTIME'] <= x['ADMITTIME'] + timedelta(days=180))]['ICD9_CODE'].tolist(), axis=1)
    data['NEXT_DIAG_12M'] = data.apply(lambda x: data[(data['SUBJECT_ID'] == x['SUBJECT_ID']) & 
                                                   (data['ADMITTIME'] > x['ADMITTIME']) & 
                                                   (data['ADMITTIME'] <= x['ADMITTIME'] + timedelta(days=365))]['ICD9_CODE'].tolist(), axis=1)
    data['NEXT_DIAG_6M'] = data['NEXT_DIAG_6M'].apply(lambda x: x[0] if x else float('nan'))
    data['NEXT_DIAG_12M'] = data['NEXT_DIAG_12M'].apply(lambda x: x[0] if x else float('nan'))
    
    data.drop(columns=['ADMITTIME'], axis=1, inplace=True)
    data = remove_high_visit_patients(data, visit_threshold=8)
    return data.reset_index(drop=True)

In [11]:
def statistics(data):
    print('#patients ', data['SUBJECT_ID'].unique().shape)
    print('#clinical events ', len(data))
    
    diag = data['ICD9_CODE'].values
    med = data['NDC'].values
    pro = data['PRO_CODE'].values
    lab_test = data['LAB_TEST'].values
    
    unique_diag = set([j for i in diag for j in list(i)])
    unique_med = set([j for i in med for j in list(i)])
    unique_pro = set([j for i in pro for j in list(i)])
    unique_lab = set([j for i in lab_test for j in list(i)])
    
    print('#diagnosis ', len(unique_diag))
    print('#med ', len(unique_med))
    print('#procedure', len(unique_pro))
    print('#lab', len(unique_lab))
    
    avg_diag = avg_med = avg_pro = avg_lab = 0
    max_diag = max_med = max_pro = max_lab = 0
    cnt = max_visit = avg_visit = 0

    for subject_id in data['SUBJECT_ID'].unique():
        item_data = data[data['SUBJECT_ID'] == subject_id]
        x, y, z, k = [], [], [], []
        visit_cnt = 0
        for index, row in item_data.iterrows():
            visit_cnt += 1
            cnt += 1
            x.extend(list(row['ICD9_CODE']))
            y.extend(list(row['NDC']))
            z.extend(list(row['PRO_CODE']))
            k.extend(list(row['LAB_TEST']))
        x, y, z, k = set(x), set(y), set(z), set(k)
        avg_diag += len(x)
        avg_med += len(y)
        avg_pro += len(z)
        avg_lab += len(k)
        avg_visit += visit_cnt
        if len(x) > max_diag:
            max_diag = len(x)
        if len(y) > max_med:
            max_med = len(y) 
        if len(z) > max_pro:
            max_pro = len(z)
        if len(k) > max_lab:
            max_lab = len(k)
        if visit_cnt > max_visit:
            max_visit = visit_cnt

    print('#avg of diagnoses ', avg_diag/ cnt)
    print('#avg of medicines ', avg_med/ cnt)
    print('#avg of procedures ', avg_pro/ cnt)
    print('#avg of lab_test ', avg_lab/ cnt)
    print('#avg of vists ', avg_visit/ len(data['SUBJECT_ID'].unique()))

    print('#max of diagnoses ', max_diag)
    print('#max of medicines ', max_med)
    print('#max of procedures ', max_pro)
    print('#max of lab_test ', max_lab)
    print('#max of visit ', max_visit)

In [12]:
data = process_all()

process_med


  med_pd = pd.read_csv(med_file, dtype={'NDC':'category'})
  med_pd.fillna(method='pad', inplace=True)


process_diag
process_pro
process_ad


  ad_pd[ad_pd["AGE"] >= 300] = 90
  ad_pd[ad_pd["AGE"] >= 300] = 90
  ad_pd[ad_pd["AGE"] >= 300] = 90


process_lab


  lab_pd.loc[lab_pd['ITEMID'] == itemid, 'value_bin'] = group


Removed 20 patients and 212 rows exceeding 8 visits.


In [13]:
statistics(data)

#patients  (33067,)
#clinical events  39976
#diagnosis  1998
#med  145
#procedure 801
#lab 1500
#avg of diagnoses  10.831999199519712
#avg of medicines  7.820442265359215
#avg of procedures  4.481113668200921
#avg of lab_test  41.873649189513706
#avg of vists  1.2089394260138506
#max of diagnoses  105
#max of medicines  50
#max of procedures  48
#max of lab_test  169
#max of visit  8


In [14]:
with open("./MIMIC-III-processed/mimic.pkl", "wb") as outfile:
    pickle.dump(data, outfile)

# Dataset split

In [15]:
mimic_data = pickle.load(open("./MIMIC-III-processed/mimic.pkl", 'rb'))

In [16]:
len(mimic_data["SUBJECT_ID"].unique())

33067

In [17]:
# get the patient id with the labels
pat_readmission = set(mimic_data[mimic_data["READMISSION_1M"] == 1]["SUBJECT_ID"].values.tolist())
print(len(pat_readmission))
pat_nextdiag_6m = set(mimic_data[mimic_data["NEXT_DIAG_6M"].notna()]["SUBJECT_ID"].values.tolist())
print(len(pat_nextdiag_6m))
# note that here we extract at the patient level, not the encounter level
# so patients with at least one encounter with a next diagnosis in 6 / 12 months will still have
# some encounters without a next diagnosis. This part is addressed in the HBERTFinetuneEHRDataset.
pat_nextdiag_12m = set(mimic_data[mimic_data["NEXT_DIAG_12M"].notna()]["SUBJECT_ID"].values.tolist())
print(len(pat_nextdiag_12m))
pat_death = set(mimic_data[mimic_data["DEATH"]]["SUBJECT_ID"].values.tolist())
print(len(pat_death))
pat_all_label = list(pat_readmission | pat_nextdiag_6m | pat_nextdiag_12m | pat_death)

1086
2904
3438
4367


In [18]:
pat_all = mimic_data["SUBJECT_ID"].unique().tolist()

In [19]:
n_pretrain_patient = int(len(pat_all) * 0.7)
np.random.seed(RANDOM_SEED)
pretrain_patient = np.random.choice(list(set(pat_all) - set(pat_all_label)), n_pretrain_patient, replace=False).tolist()
downstream_patient = list(set(pat_all) - set(pretrain_patient))
print(len(pretrain_patient), len(downstream_patient))

23146 9921


In [20]:
pretrain_dataset = mimic_data[mimic_data["SUBJECT_ID"].isin(set(pretrain_patient))]

In [21]:
# downstream task: PLOS, death prediction, readmission
train_ratio, val_ratio = 0.2, 0.4
n_finetune_pat, n_val_pat = int(len(downstream_patient) * train_ratio), int(len(downstream_patient) * val_ratio)
downstream_patient = sorted(downstream_patient) 
np.random.seed(RANDOM_SEED + 1)
np.random.shuffle(downstream_patient)
finetune_pat, val_pat, test_pat = downstream_patient[:n_finetune_pat], \
                                    downstream_patient[n_finetune_pat:n_finetune_pat+n_val_pat], \
                                    downstream_patient[n_finetune_pat+n_val_pat:]
finetune_dataset = mimic_data[mimic_data["SUBJECT_ID"].isin(set(finetune_pat))]
val_dataset = mimic_data[mimic_data["SUBJECT_ID"].isin(set(val_pat))]
test_dataset = mimic_data[mimic_data["SUBJECT_ID"].isin(set(test_pat))]

In [None]:
# downstream task: next diagnosis prediction
# only use the patient that has multiple visits
# 6m
train_ratio, val_ratio = 0.4, 0.3
n_finetune_pat, n_val_pat = int(len(pat_nextdiag_6m) * train_ratio), int(len(pat_nextdiag_6m) * val_ratio)
pat_nextdiag_6m = list(pat_nextdiag_6m)
pat_nextdiag_6m = sorted(list(pat_nextdiag_6m)) 
np.random.seed(RANDOM_SEED + 2) 
np.random.shuffle(pat_nextdiag_6m)
finetune_pat, val_pat, test_pat = pat_nextdiag_6m[:n_finetune_pat], \
                                    pat_nextdiag_6m[n_finetune_pat:n_finetune_pat+n_val_pat], \
                                    pat_nextdiag_6m[n_finetune_pat+n_val_pat:]
finetune_dataset6m = mimic_data[mimic_data["SUBJECT_ID"].isin(set(finetune_pat))]
val_dataset6m = mimic_data[mimic_data["SUBJECT_ID"].isin(set(val_pat))]
test_dataset6m = mimic_data[mimic_data["SUBJECT_ID"].isin(set(test_pat))]
# 12m
n_finetune_pat, n_val_pat = int(len(pat_nextdiag_12m) * train_ratio), int(len(pat_nextdiag_12m) * val_ratio)
pat_nextdiag_12m = list(pat_nextdiag_12m)
pat_nextdiag_12m = sorted(list(pat_nextdiag_12m))
np.random.seed(RANDOM_SEED + 3) 
np.random.shuffle(pat_nextdiag_12m)
finetune_pat, val_pat, test_pat = pat_nextdiag_12m[:n_finetune_pat], \
                                    pat_nextdiag_12m[n_finetune_pat:n_finetune_pat+n_val_pat], \
                                    pat_nextdiag_12m[n_finetune_pat+n_val_pat:]
finetune_dataset12m = mimic_data[mimic_data["SUBJECT_ID"].isin(set(finetune_pat))]
val_dataset12m = mimic_data[mimic_data["SUBJECT_ID"].isin(set(val_pat))]
test_dataset12m = mimic_data[mimic_data["SUBJECT_ID"].isin(set(test_pat))]

In [23]:
with open("./MIMIC-III-processed/mimic_pretrain.pkl", "wb") as outfile:
    pickle.dump(pretrain_dataset, outfile)
with open("./MIMIC-III-processed/mimic_downstream.pkl", "wb") as outfile:
    pickle.dump([finetune_dataset, val_dataset, test_dataset], outfile)
with open("./MIMIC-III-processed/mimic_nextdiag_6m.pkl", "wb") as outfile:
    pickle.dump([finetune_dataset6m, val_dataset6m, test_dataset6m], outfile)
with open("./MIMIC-III-processed/mimic_nextdiag_12m.pkl", "wb") as outfile:
    pickle.dump([finetune_dataset12m, val_dataset12m, test_dataset12m], outfile)