In [None]:
! pip install -e .

In [1]:
import sys
sys.executable
from heterogt.utils.seed import set_random_seed

In [2]:
RANDOM_SEED = 123
set_random_seed(RANDOM_SEED)

[INFO] Detected Apple MPS backend (Mac with M1/M2/M3). Skipping CUDA seeds.
[INFO] Random seed set to 123


In [3]:
import pandas as pd
import numpy as np
import os
import pickle
from icdmappings import Mapper

In [4]:
data_path = "./MIMIC-IV-raw/"
# 'HADM_ID': encounter id
# 'SUBJECT_ID': patient id

In [5]:
med_file = os.path.join(data_path, "prescriptions.csv.gz")
procedure_file = os.path.join(data_path, "procedures_icd.csv.gz")
diag_file = os.path.join(data_path, "diagnoses_icd.csv.gz")
admission_file = os.path.join(data_path, "admissions.csv.gz")
lab_test_file = os.path.join(data_path, "labevents.csv.gz")
patient_file = os.path.join(data_path, "patients.csv.gz")

In [6]:
# drug code mapping files from GAMENet repo
# https://github.com/sjy1203/GAMENet/tree/master/data
ndc2atc_file = './GAMENet/ndc2atc_level4.csv' 
cid_atc = './GAMENet/drug-atc.csv'
ndc2rxnorm_file = './GAMENet/ndc2rxnorm_mapping.txt'

In [7]:
def process_med():
    med_pd = pd.read_csv(med_file, dtype={'ndc':'category'})
    med_pd.drop(columns=['pharmacy_id','poe_id','poe_seq','order_provider_id',
                     'stoptime','drug_type','drug','formulary_drug_cd',
                     'gsn','prod_strength','form_rx','dose_val_rx',
                      'dose_unit_rx','form_val_disp','form_unit_disp', 
                      'doses_per_24_hrs', 'route'], axis=1, inplace=True)
    med_pd.drop(index = med_pd[med_pd['ndc'] == '0'].index, axis=0, inplace=True)
    med_pd.sort_values(by=['subject_id', 'hadm_id', 'starttime'], inplace=True)
    med_pd.fillna(method='pad', inplace=True)
    med_pd.dropna(inplace=True)
    med_pd.drop_duplicates(inplace=True)
    med_pd['starttime'] = pd.to_datetime(med_pd['starttime'], format='%Y-%m-%d %H:%M:%S')    
    med_pd = med_pd.reset_index(drop=True)  
    
    def filter_first24hour_med(med_pd):
        med_pd_new = med_pd.drop(columns=['ndc'])
        med_pd_new = med_pd_new.groupby(by=['subject_id','hadm_id']).head(1).reset_index(drop=True)
        med_pd_new = pd.merge(med_pd_new, med_pd, on=['subject_id','hadm_id', 'starttime'])
        med_pd_new = med_pd_new.drop(columns=['starttime'])
        return med_pd_new
        
    med_pd = filter_first24hour_med(med_pd) 
    med_pd = med_pd.drop_duplicates() 
    med_pd = med_pd.reset_index(drop=True)
    med_pd.rename(columns={'subject_id': 'SUBJECT_ID', 'hadm_id': 'HADM_ID', 'ndc': 'NDC'}, inplace=True)
    return med_pd

In [8]:
def convert_icd_code(pd):
    mapper = Mapper()
    mask_icd10 = pd['icd_version'] == 10
    icd10_codes = pd.loc[mask_icd10, 'icd_code']
    # Vectorized mapping: returns None if mapping fails
    def safe_map(x):
        try:
            mapped = mapper.map(x, source='icd10', target='icd9')
            return mapped if mapped else None
        except:
            return None

    icd9_mapped = icd10_codes.apply(safe_map)
    pd.loc[mask_icd10, 'icd_code'] = icd9_mapped
    pd.dropna(subset=['icd_code'], inplace=True)
    return pd

In [9]:
def process_procedure():
    pro_pd = pd.read_csv(procedure_file, dtype={'icd_code': 'str'})
    pro_pd.drop_duplicates(inplace=True)
    pro_pd.sort_values(by=['subject_id', 'hadm_id', 'seq_num'], inplace=True)
    pro_pd.drop(columns=['seq_num', 'chartdate'], errors='ignore', inplace=True)

    # Map ICD-10 to ICD-9
    pro_pd = convert_icd_code(pro_pd)

    # Remove duplicates and reset index
    pro_pd.drop_duplicates(inplace=True)
    pro_pd.reset_index(drop=True, inplace=True)

    # Convert to category if needed
    pro_pd['icd_code'] = pro_pd['icd_code'].astype('category')
    pro_pd.drop(columns=['icd_version'], errors='ignore', inplace=True)
    pro_pd.rename(columns={'subject_id': 'SUBJECT_ID', 'hadm_id': 'HADM_ID', 'icd_code': 'ICD9_CODE'}, inplace=True)
    pro_pd['ICD9_CODE'] = 'PRO_' + pro_pd['ICD9_CODE'].astype(str)
    return pro_pd

In [10]:
def process_diag():
    diag_pd = pd.read_csv(diag_file)
    diag_pd.dropna(inplace=True)
    diag_pd.sort_values(by=['subject_id', 'hadm_id', 'seq_num'], inplace=True)
    diag_pd.drop(columns=['seq_num'], inplace=True)
    diag_pd = convert_icd_code(diag_pd)
    diag_pd.drop_duplicates(inplace=True)
    diag_pd = diag_pd.reset_index(drop=True)
    diag_pd.drop(columns=['icd_version'], errors='ignore', inplace=True)
    diag_pd.rename(columns={'subject_id': 'SUBJECT_ID', 'hadm_id': 'HADM_ID', 'icd_code': 'ICD9_CODE'}, inplace=True)
    diag_pd["ICD9_CODE"] = "DIAG_" + diag_pd["ICD9_CODE"].astype(str)
    return diag_pd

In [11]:
def process_admission():
    ad_pd = pd.read_csv(admission_file)
    patient_pd = pd.read_csv(patient_file)
    ad_pd.drop(columns=['admit_provider_id', 'admission_location',
       'discharge_location', 'insurance', 'language', 'marital_status',
       'race', 'edregtime', 'edouttime', 'hospital_expire_flag'], axis=1, inplace=True)
    patient_pd.drop(columns=['anchor_year','anchor_year_group','dod'], axis=1, inplace=True)
    ad_pd["admittime"] = pd.to_datetime(ad_pd['admittime'], format='%Y-%m-%d %H:%M:%S', errors='coerce')
    ad_pd["dischtime"] = pd.to_datetime(ad_pd['dischtime'], format='%Y-%m-%d %H:%M:%S', errors='coerce')  # time for leaving hospital
    ad_pd["deathtime"] = pd.to_datetime(ad_pd['deathtime'], format='%Y-%m-%d %H:%M:%S', errors='coerce')  # time for leaving hospital
    ad_pd["STAY_DAYS"] = (ad_pd["dischtime"] - ad_pd["admittime"]).dt.days
    ad_pd = ad_pd.merge(patient_pd, on=['subject_id'], how='inner')
    ad_pd.loc[ad_pd["anchor_age"] >= 300, "anchor_age"] = 90
    age = ad_pd["anchor_age"]
    bins = np.linspace(age.min(), age.max(), 20 + 1)
    ad_pd['anchor_age'] = pd.cut(age, bins=bins, labels=False, include_lowest=True)
    ad_pd['anchor_age'] = "AGE_" + ad_pd["anchor_age"].astype("str")
    ad_pd["DEATH"] = ad_pd["deathtime"].notna()
    ad_pd['admittime'] = ad_pd['admittime'].astype(str)
    ad_pd.sort_values(by=['subject_id', 'hadm_id', 'admittime'], inplace=True)
    ad_pd['next_visit'] = ad_pd.groupby('subject_id')['hadm_id'].shift(-1)
    ad_pd['READMISSION'] = ad_pd['next_visit'].notnull().astype(int)
    ad_pd.drop('next_visit', axis=1, inplace=True)
    ad_pd.drop(columns=['dischtime', 'deathtime'], axis=1, inplace=True)
    ad_pd.drop_duplicates(inplace=True)
    ad_pd.reset_index(drop=True, inplace=True)
    ad_pd.rename(columns={'subject_id': 'SUBJECT_ID', 'hadm_id': 'HADM_ID', 'admittime': 'ADMITTIME',
                          'gender': 'GENDER', 'anchor_age': 'AGE', 'admission_type': 'ADMISSION_TYPE'}, inplace=True)
    return ad_pd


In [12]:
def process_lab_test(n_bins=5):
    lab_pd = pd.read_csv(lab_test_file)
    # 每个(subject_id, itemid)只取第一次
    lab_pd = lab_pd.groupby(by=['subject_id','itemid']).head(1).reset_index(drop=True)
    # 仅保留有数值和有住院ID的记录（否则无法对 '___' 行用 valuenum）
    lab_pd = lab_pd[lab_pd["valuenum"].notna()]
    lab_pd = lab_pd[lab_pd["hadm_id"].notna()]
    # 某些版本可能无此列，安全起见用 errors='ignore'
    lab_pd.drop(columns=['labevent_id'], axis=1, inplace=True, errors='ignore')

    # 主循环：按 itemid 批处理
    for itemid in lab_pd['itemid'].unique():
        group_idx = (lab_pd['itemid'] == itemid)

        group_value = lab_pd.loc[group_idx, 'value'].astype(str)
        group_num_vn = pd.to_numeric(lab_pd.loc[group_idx, 'valuenum'], errors='coerce')  # 来自 valuenum 列
        group_num_val = pd.to_numeric(group_value, errors='coerce')                       # 来自 value 列的“数值字符串”

        # 三类掩码（互斥并覆盖全体）
        mask_underscore = (group_value == '___')                                  # 需用 valuenum
        mask_text_nonnum = group_value.notna() & (group_num_val.isna()) & (~mask_underscore)  # 非'___'文本
        mask_numeric_in_value = group_num_val.notna() & (~mask_underscore)        # value 为数值字符串

        # 初始化：先把结果设为原 value（确保“非'___'文本”能保留）
        lab_pd.loc[group_idx, 'value_bin'] = group_value

        # A) 处理 '___'：使用 valuenum（对属于 '___' 的子集分箱）
        if mask_underscore.any():
            vn_subset = group_num_vn[mask_underscore]
            if vn_subset.notna().sum() >= n_bins:
                lab_pd.loc[group_idx & mask_underscore, 'value_bin'] = pd.qcut(
                    vn_subset, q=n_bins, labels=False, duplicates='drop'
                ).astype('Int64').astype(str)
            else:
                # 仍然“使用 valuenum”，直接写回数值字符串
                lab_pd.loc[group_idx & mask_underscore, 'value_bin'] = vn_subset.astype('Float64').astype(str)

        # B) 处理数值字符串：仅对这些行分箱，保留其他行为原值
        if mask_numeric_in_value.any():
            val_subset = group_num_val[mask_numeric_in_value]
            if val_subset.notna().sum() >= n_bins:
                lab_pd.loc[group_idx & mask_numeric_in_value, 'value_bin'] = pd.qcut(
                    val_subset, q=n_bins, labels=False, duplicates='drop'
                ).astype('Int64').astype(str)
            else:
                # 数量不足：保留原 value（已在初始化时完成），无需额外操作
                pass

        # C) 非'___'文本行 mask_text_nonnum：已在初始化中保留，无需额外处理

    # 统一与收尾
    lab_pd["itemid"] = lab_pd["itemid"].astype(str)
    lab_pd["value_bin"] = lab_pd["value_bin"].astype(str)
    lab_pd["LAB_TEST"] = lab_pd[["itemid", "value_bin"]].apply("-".join, axis=1)

    lab_pd.drop(columns=[
        'specimen_id', 'order_provider_id', 'itemid', 'charttime', 'storetime', 'value', 'valuenum', 'valueuom',
        'ref_range_lower', 'ref_range_upper', 'flag', 'priority', 'comments', 'value_bin'
    ], axis=1, inplace=True, errors='ignore')

    lab_pd.drop_duplicates(inplace=True)
    lab_pd.reset_index(drop=True, inplace=True)
    lab_pd.rename(columns={'subject_id': 'SUBJECT_ID', 'hadm_id': 'HADM_ID'}, inplace=True)
    lab_pd['HADM_ID'] = lab_pd['HADM_ID'].astype('int64')
    lab_pd['LAB_TEST'] = "LAB_" + lab_pd["LAB_TEST"].astype(str)
    return lab_pd

In [13]:
def ndc2atc4(med_pd):
    with open(ndc2rxnorm_file, 'r') as f:
        ndc2rxnorm = eval(f.read())
    med_pd['RXCUI'] = med_pd['NDC'].map(ndc2rxnorm)
    med_pd.dropna(inplace=True)

    rxnorm2atc = pd.read_csv(ndc2atc_file)
    rxnorm2atc = rxnorm2atc.drop(columns=['YEAR','MONTH','NDC'])
    rxnorm2atc.drop_duplicates(subset=['RXCUI'], inplace=True)
    med_pd.drop(index = med_pd[med_pd['RXCUI'].isin([''])].index, axis=0, inplace=True)
    
    med_pd['RXCUI'] = med_pd['RXCUI'].astype('int64')
    med_pd = med_pd.reset_index(drop=True)
    med_pd = med_pd.merge(rxnorm2atc, on=['RXCUI'])
    med_pd.drop(columns=['NDC', 'RXCUI'], inplace=True)
    med_pd = med_pd.rename(columns={'ATC4':'NDC'})
    med_pd['NDC'] = med_pd['NDC'].map(lambda x: x[:4])
    med_pd = med_pd.drop_duplicates()
    med_pd["NDC"] = "MED_" + med_pd["NDC"].astype(str)    
    return med_pd.reset_index(drop=True)

def filter_most_pro(pro_pd, threshold=800):
    pro_count = pro_pd.groupby(by=['ICD9_CODE']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
    pro_pd = pro_pd[pro_pd['ICD9_CODE'].isin(pro_count.loc[:threshold, 'ICD9_CODE'])]
    return pro_pd.reset_index(drop=True)

def filter_most_diag(diag_pd, threshold=2000):
    diag_count = diag_pd.groupby(by=['ICD9_CODE']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
    diag_pd = diag_pd[diag_pd['ICD9_CODE'].isin(diag_count.loc[:threshold, 'ICD9_CODE'])]
    return diag_pd.reset_index(drop=True)

def filter_most_lab(lab_pd, threshold=1500):
    lab_count = lab_pd.groupby(by=['LAB_TEST']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
    lab_pd = lab_pd[lab_pd['LAB_TEST'].isin(lab_count.loc[:threshold, 'LAB_TEST'])]
    return lab_pd.reset_index(drop=True)

In [14]:
def remove_high_visit_patients(df: pd.DataFrame, visit_threshold: int = 8) -> pd.DataFrame:
    """
    Remove all rows corresponding to patients with more than `visit_threshold` unique visits.

    Parameters:
        df (pd.DataFrame): The input EHR dataframe. Must contain 'SUBJECT_ID' and 'HADM_ID'.
        visit_threshold (int): The maximum number of visits allowed per patient.

    Returns:
        pd.DataFrame: A filtered dataframe containing only patients with ≤ visit_threshold visits.
    """
    # Step 1: Count visits per patient
    visit_counts = df.groupby("SUBJECT_ID")["HADM_ID"].nunique()

    # Step 2: Identify valid SUBJECT_IDs
    valid_subjects = visit_counts[visit_counts <= visit_threshold].index

    # Step 3: Filter the dataframe
    filtered_df = df[df["SUBJECT_ID"].isin(valid_subjects)].copy()

    # Optional: logging
    num_removed_patients = df["SUBJECT_ID"].nunique() - filtered_df["SUBJECT_ID"].nunique()
    num_removed_rows = len(df) - len(filtered_df)
    print(f"Removed {num_removed_patients} patients and {num_removed_rows} rows exceeding {visit_threshold} visits.")

    return filtered_df

In [15]:
from datetime import timedelta

def process_all():
    print('process_med')
    med_pd = process_med()
    med_pd = ndc2atc4(med_pd)

    print('process_diag')
    diag_pd = process_diag()
    diag_pd = filter_most_diag(diag_pd)

    print('process_pro')
    pro_pd = process_procedure()
    pro_pd = filter_most_pro(pro_pd)

    print('process_ad')
    ad_pd = process_admission()
    
    print('process_lab')
    lab_pd = process_lab_test()
    lab_pd = filter_most_lab(lab_pd)
    
    med_pd_key = med_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    diag_pd_key = diag_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    pro_pd_key = pro_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    lab_pd_key = lab_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    ad_pd_key = ad_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    
    # filter key
    combined_key = med_pd_key.merge(diag_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    combined_key = combined_key.merge(pro_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    combined_key = combined_key.merge(lab_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    combined_key = combined_key.merge(ad_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    diag_pd = diag_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    med_pd = med_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    pro_pd = pro_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    lab_pd = lab_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    ad_pd = ad_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')

    # flatten and merge
    diag_pd = diag_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index()  
    med_pd = med_pd.groupby(by=['SUBJECT_ID', 'HADM_ID'])['NDC'].unique().reset_index()
    pro_pd = pro_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index().rename(columns={'ICD9_CODE':'PRO_CODE'})  
    lab_pd = lab_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['LAB_TEST'].unique().reset_index()
    
    med_pd['NDC'] = med_pd['NDC'].map(lambda x: list(x))
    pro_pd['PRO_CODE'] = pro_pd['PRO_CODE'].map(lambda x: list(x))
    lab_pd['LAB_TEST'] = lab_pd['LAB_TEST'].map(lambda x: list(x))
    
    data = diag_pd.merge(med_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    data = data.merge(pro_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    data = data.merge(lab_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    data = data.merge(ad_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')

    data = data.sort_values(by=['SUBJECT_ID', 'ADMITTIME'])
    
    # create feature: readmission within 30/90 days
    data['ADMITTIME'] = pd.to_datetime(data['ADMITTIME'])
    data['READMISSION_1M'] = data.groupby('SUBJECT_ID')['ADMITTIME'].shift(-1) - data['ADMITTIME']
    data['READMISSION_3M'] = data['READMISSION_1M'].apply(lambda x: 1 if x <= timedelta(days=90) else 0)
    data['READMISSION_1M'] = data['READMISSION_1M'].apply(lambda x: 1 if x <= timedelta(days=30) else 0)
    
    # create feature: diease in next 6/12 months    
    data['NEXT_DIAG_6M'] = data.apply(lambda x: data[(data['SUBJECT_ID'] == x['SUBJECT_ID']) & 
                                              (data['ADMITTIME'] > x['ADMITTIME']) & 
                                              (data['ADMITTIME'] <= x['ADMITTIME'] + timedelta(days=180))]['ICD9_CODE'].tolist(), axis=1)
    data['NEXT_DIAG_12M'] = data.apply(lambda x: data[(data['SUBJECT_ID'] == x['SUBJECT_ID']) & 
                                                   (data['ADMITTIME'] > x['ADMITTIME']) & 
                                                   (data['ADMITTIME'] <= x['ADMITTIME'] + timedelta(days=365))]['ICD9_CODE'].tolist(), axis=1)
    data['NEXT_DIAG_6M'] = data['NEXT_DIAG_6M'].apply(lambda x: x[0] if x else float('nan'))
    data['NEXT_DIAG_12M'] = data['NEXT_DIAG_12M'].apply(lambda x: x[0] if x else float('nan'))
    
    data.drop(columns=['ADMITTIME'], axis=1, inplace=True)
    data = remove_high_visit_patients(data, visit_threshold=8)
    return data.reset_index(drop=True)

In [16]:
from datetime import timedelta

def process_all():
    print('process_med')
    med_pd = process_med()
    med_pd = ndc2atc4(med_pd)

    print('process_diag')
    diag_pd = process_diag()
    diag_pd = filter_most_diag(diag_pd)

    print('process_pro')
    pro_pd = process_procedure()
    pro_pd = filter_most_pro(pro_pd)

    print('process_ad')
    ad_pd = process_admission()
    
    print('process_lab')
    lab_pd = process_lab_test()
    lab_pd = filter_most_lab(lab_pd)
    
    med_pd_key = med_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    diag_pd_key = diag_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    pro_pd_key = pro_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    lab_pd_key = lab_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    ad_pd_key = ad_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    
    # filter key
    combined_key = med_pd_key.merge(diag_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    combined_key = combined_key.merge(pro_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    combined_key = combined_key.merge(lab_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    combined_key = combined_key.merge(ad_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    diag_pd = diag_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    med_pd = med_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    pro_pd = pro_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    lab_pd = lab_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    ad_pd = ad_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')

    # flatten and merge
    diag_pd = diag_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index()  
    med_pd = med_pd.groupby(by=['SUBJECT_ID', 'HADM_ID'])['NDC'].unique().reset_index()
    pro_pd = pro_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index().rename(columns={'ICD9_CODE':'PRO_CODE'})  
    lab_pd = lab_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['LAB_TEST'].unique().reset_index()
    
    med_pd['NDC'] = med_pd['NDC'].map(lambda x: list(x))
    pro_pd['PRO_CODE'] = pro_pd['PRO_CODE'].map(lambda x: list(x))
    lab_pd['LAB_TEST'] = lab_pd['LAB_TEST'].map(lambda x: list(x))
    
    data = diag_pd.merge(med_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    data = data.merge(pro_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    data = data.merge(lab_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    data = data.merge(ad_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')

    data = data.sort_values(by=['SUBJECT_ID', 'ADMITTIME'])
    
    # filter rows without disease codes
    data = data[data['ICD9_CODE'].notnull()]
    
    # create feature: readmission within 30/90 days
    data['ADMITTIME'] = pd.to_datetime(data['ADMITTIME'])
    data['READMISSION_1M'] = data.groupby('SUBJECT_ID')['ADMITTIME'].shift(-1) - data['ADMITTIME']
    data['READMISSION_3M'] = data['READMISSION_1M'].apply(lambda x: 1 if x <= timedelta(days=90) else 0)
    data['READMISSION_1M'] = data['READMISSION_1M'].apply(lambda x: 1 if x <= timedelta(days=30) else 0)
    
    # create feature: diease in next 6/12 months    
    data['NEXT_DIAG_6M'] = data.apply(lambda x: data[(data['SUBJECT_ID'] == x['SUBJECT_ID']) & 
                                              (data['ADMITTIME'] > x['ADMITTIME']) & 
                                              (data['ADMITTIME'] <= x['ADMITTIME'] + timedelta(days=180))]['ICD9_CODE'].tolist(), axis=1)
    data['NEXT_DIAG_12M'] = data.apply(lambda x: data[(data['SUBJECT_ID'] == x['SUBJECT_ID']) & 
                                                   (data['ADMITTIME'] > x['ADMITTIME']) & 
                                                   (data['ADMITTIME'] <= x['ADMITTIME'] + timedelta(days=365))]['ICD9_CODE'].tolist(), axis=1)
    data['NEXT_DIAG_6M'] = data['NEXT_DIAG_6M'].apply(lambda x: x[0] if x else float('nan'))
    data['NEXT_DIAG_12M'] = data['NEXT_DIAG_12M'].apply(lambda x: x[0] if x else float('nan'))
    
    data.drop(columns=['ADMITTIME'], axis=1, inplace=True)
    data = remove_high_visit_patients(data, visit_threshold=8)
    return data.reset_index(drop=True)

In [17]:
def statistics(data):
    print('#patients ', data['SUBJECT_ID'].unique().shape)
    print('#clinical events ', len(data))
    
    diag = data['ICD9_CODE'].values
    med = data['NDC'].values
    pro = data['PRO_CODE'].values
    lab_test = data['LAB_TEST'].values
    
    unique_diag = set([j for i in diag for j in list(i)])
    unique_med = set([j for i in med for j in list(i)])
    unique_pro = set([j for i in pro for j in list(i)])
    unique_lab = set([j for i in lab_test for j in list(i)])
    
    print('#diagnosis ', len(unique_diag))
    print('#med ', len(unique_med))
    print('#procedure', len(unique_pro))
    print('#lab', len(unique_lab))
    
    avg_diag = avg_med = avg_pro = avg_lab = 0
    max_diag = max_med = max_pro = max_lab = 0
    cnt = max_visit = avg_visit = 0

    for subject_id in data['SUBJECT_ID'].unique():
        item_data = data[data['SUBJECT_ID'] == subject_id]
        x, y, z, k = [], [], [], []
        visit_cnt = 0
        for index, row in item_data.iterrows():
            visit_cnt += 1
            cnt += 1
            x.extend(list(row['ICD9_CODE']))
            y.extend(list(row['NDC']))
            z.extend(list(row['PRO_CODE']))
            k.extend(list(row['LAB_TEST']))
        x, y, z, k = set(x), set(y), set(z), set(k)
        avg_diag += len(x)
        avg_med += len(y)
        avg_pro += len(z)
        avg_lab += len(k)
        avg_visit += visit_cnt
        if len(x) > max_diag:
            max_diag = len(x)
        if len(y) > max_med:
            max_med = len(y) 
        if len(z) > max_pro:
            max_pro = len(z)
        if len(k) > max_lab:
            max_lab = len(k)
        if visit_cnt > max_visit:
            max_visit = visit_cnt

    print('#avg of diagnoses ', avg_diag/ cnt)
    print('#avg of medicines ', avg_med/ cnt)
    print('#avg of procedures ', avg_pro/ cnt)
    print('#avg of lab_test ', avg_lab/ cnt)
    print('#avg of vists ', avg_visit/ len(data['SUBJECT_ID'].unique()))

    print('#max of diagnoses ', max_diag)
    print('#max of medicines ', max_med)
    print('#max of procedures ', max_pro)
    print('#max of lab_test ', max_lab)
    print('#max of visit ', max_visit)

In [18]:
data = process_all()

process_med


  med_pd = pd.read_csv(med_file, dtype={'ndc':'category'})
  med_pd.fillna(method='pad', inplace=True)


process_diag
process_pro
process_ad
process_lab
Removed 65 patients and 647 rows exceeding 8 visits.


In [19]:
statistics(data)

#patients  (60709,)
#clinical events  84495
#diagnosis  1983
#med  140
#procedure 801
#lab 1281
#avg of diagnoses  10.262867625303272
#avg of medicines  2.976105094976034
#avg of procedures  2.867246582638026
#avg of lab_test  15.083472394816262
#avg of vists  1.3918035217183613
#max of diagnoses  101
#max of medicines  30
#max of procedures  41
#max of lab_test  139
#max of visit  8


In [20]:
with open("./MIMIC-IV-processed/mimic.pkl", "wb") as outfile:
    pickle.dump(data, outfile)

# Data Splits

In [21]:
mimic_data = pickle.load(open("./MIMIC-IV-processed/mimic.pkl", 'rb'))

In [22]:
mimic_data.columns

Index(['SUBJECT_ID', 'HADM_ID', 'ICD9_CODE', 'NDC', 'PRO_CODE', 'LAB_TEST',
       'ADMISSION_TYPE', 'STAY_DAYS', 'GENDER', 'AGE', 'DEATH', 'READMISSION',
       'READMISSION_1M', 'READMISSION_3M', 'NEXT_DIAG_6M', 'NEXT_DIAG_12M'],
      dtype='object')

In [23]:
len(mimic_data["SUBJECT_ID"].unique())

60709

In [24]:
# get the patient id with the labels
pat_readmission = set(mimic_data[mimic_data["READMISSION_1M"] == 1]["SUBJECT_ID"].values.tolist())
print(len(pat_readmission))
pat_nextdiag_6m = set(mimic_data[mimic_data["NEXT_DIAG_6M"].notna()]["SUBJECT_ID"].values.tolist())
print(len(pat_nextdiag_6m))
# note that here we extract at the patient level, not the encounter level
# so patients with at least one encounter with a next diagnosis in 6 / 12 months will still have
# some encounters without a next diagnosis. This part is addressed in the HBERTFinetuneEHRDataset.
pat_nextdiag_12m = set(mimic_data[mimic_data["NEXT_DIAG_12M"].notna()]["SUBJECT_ID"].values.tolist())
print(len(pat_nextdiag_12m))
pat_death = set(mimic_data[mimic_data["DEATH"]]["SUBJECT_ID"].values.tolist())
print(len(pat_death))
pat_all_label = list(pat_readmission | pat_nextdiag_6m | pat_nextdiag_12m | pat_death)

4604
9667
11197
3577


In [25]:
pat_all = mimic_data["SUBJECT_ID"].unique().tolist()

In [26]:
n_pretrain_patient = int(len(pat_all) * 0.7)
np.random.seed(RANDOM_SEED)
pretrain_patient = np.random.choice(list(set(pat_all) - set(pat_all_label)), n_pretrain_patient, replace=False).tolist()
downstream_patient = list(set(pat_all) - set(pretrain_patient))
print(len(pretrain_patient), len(downstream_patient))

42496 18213


In [27]:
pretrain_dataset = mimic_data[mimic_data["SUBJECT_ID"].isin(set(pretrain_patient))]

In [28]:
# downstream task: PLOS, death prediction, readmission
train_ratio, val_ratio = 0.2, 0.4
n_finetune_pat, n_val_pat = int(len(downstream_patient) * train_ratio), int(len(downstream_patient) * val_ratio)
downstream_patient = sorted(downstream_patient) 
np.random.seed(RANDOM_SEED + 1)
np.random.shuffle(downstream_patient)
finetune_pat, val_pat, test_pat = downstream_patient[:n_finetune_pat], \
                                    downstream_patient[n_finetune_pat:n_finetune_pat+n_val_pat], \
                                    downstream_patient[n_finetune_pat+n_val_pat:]
finetune_dataset = mimic_data[mimic_data["SUBJECT_ID"].isin(set(finetune_pat))]
val_dataset = mimic_data[mimic_data["SUBJECT_ID"].isin(set(val_pat))]
test_dataset = mimic_data[mimic_data["SUBJECT_ID"].isin(set(test_pat))]

In [29]:
# downstream task: next diagnosis prediction
# only use the patient that has multiple visits
# 6m
train_ratio, val_ratio = 0.4, 0.3
n_finetune_pat, n_val_pat = int(len(pat_nextdiag_6m) * train_ratio), int(len(pat_nextdiag_6m) * val_ratio)
pat_nextdiag_6m = list(pat_nextdiag_6m)
pat_nextdiag_6m = sorted(list(pat_nextdiag_6m))  # 先排序
np.random.seed(RANDOM_SEED + 2)  # 设置种子
np.random.shuffle(pat_nextdiag_6m)
finetune_pat, val_pat, test_pat = pat_nextdiag_6m[:n_finetune_pat], \
                                    pat_nextdiag_6m[n_finetune_pat:n_finetune_pat+n_val_pat], \
                                    pat_nextdiag_6m[n_finetune_pat+n_val_pat:]
finetune_dataset6m = mimic_data[mimic_data["SUBJECT_ID"].isin(set(finetune_pat))]
val_dataset6m = mimic_data[mimic_data["SUBJECT_ID"].isin(set(val_pat))]
test_dataset6m = mimic_data[mimic_data["SUBJECT_ID"].isin(set(test_pat))]
# 12m
n_finetune_pat, n_val_pat = int(len(pat_nextdiag_12m) * train_ratio), int(len(pat_nextdiag_12m) * val_ratio)
pat_nextdiag_12m = list(pat_nextdiag_12m)
pat_nextdiag_12m = sorted(list(pat_nextdiag_12m))  # 先排序
np.random.seed(RANDOM_SEED + 3)  # 设置种子
np.random.shuffle(pat_nextdiag_12m)
finetune_pat, val_pat, test_pat = pat_nextdiag_12m[:n_finetune_pat], \
                                    pat_nextdiag_12m[n_finetune_pat:n_finetune_pat+n_val_pat], \
                                    pat_nextdiag_12m[n_finetune_pat+n_val_pat:]
finetune_dataset12m = mimic_data[mimic_data["SUBJECT_ID"].isin(set(finetune_pat))]
val_dataset12m = mimic_data[mimic_data["SUBJECT_ID"].isin(set(val_pat))]
test_dataset12m = mimic_data[mimic_data["SUBJECT_ID"].isin(set(test_pat))]

In [30]:
with open("./MIMIC-IV-processed/mimic_pretrain.pkl", "wb") as outfile:
    pickle.dump(pretrain_dataset, outfile)
with open("./MIMIC-IV-processed/mimic_downstream.pkl", "wb") as outfile:
    pickle.dump([finetune_dataset, val_dataset, test_dataset], outfile)
with open("./MIMIC-IV-processed/mimic_nextdiag_6m.pkl", "wb") as outfile:
    pickle.dump([finetune_dataset6m, val_dataset6m, test_dataset6m], outfile)
with open("./MIMIC-IV-processed/mimic_nextdiag_12m.pkl", "wb") as outfile:
    pickle.dump([finetune_dataset12m, val_dataset12m, test_dataset12m], outfile)