In [None]:
import os
import re
import logging
import pandas as pd
import numpy as np
from sklearn.impute import IterativeImputer

#### Step 0: Load Raw Data

In [2]:
def get_data_file_path(base_path, site_name):
    input_path = os.path.join(base_path, site_name) + '/'
    output_path = os.path.join(base_path, site_name, 'processed_data') + '/'
    aux_path = os.path.join(base_path, 'aux_files') + '/'

    os.makedirs(output_path, exist_ok=True)
    
    return [input_path, output_path, aux_path]

In [3]:
def load_onset_data(file_paths):
    xxx = pd.read_pickle(file_paths[0] + 'AKI_LAB_SCR'+'.pkl')
    yyy = pd.read_pickle(file_paths[0] + 'AKI_ONSETS'+'.pkl') 
    yyy = yyy[['ENCOUNTERID', 'PATID', 'ADMIT_DATE', 'DISCHARGE_DATE']]
    xxx = xxx[['ENCOUNTERID', 'PATID', 'SPECIMEN_DATE',  'RESULT_NUM']] 
    xxx = xxx.merge(yyy, on = ['ENCOUNTERID', 'PATID'], how='left')
    xxx = xxx.dropna()
    xxx['DAYS_SINCE_ADMIT'] = (xxx['SPECIMEN_DATE']-xxx['ADMIT_DATE']).dt.days
    # take daily average
    xxx = xxx[['PATID', 'ENCOUNTERID', 'SPECIMEN_DATE', 'DAYS_SINCE_ADMIT', 'RESULT_NUM', 'ADMIT_DATE']].groupby(['PATID', 'ENCOUNTERID', 'SPECIMEN_DATE', 'DAYS_SINCE_ADMIT', 'ADMIT_DATE']).mean()
    xxx = xxx.sort_values(['PATID', 'ENCOUNTERID', 'SPECIMEN_DATE'])
    xxx = xxx.reset_index()
    return xxx, yyy.drop_duplicates()

#### Step 1: Get AKI Onset Cohort

In [None]:
# Define auxiliary functions for baseline caculation
KDIGO_baseline = np.array([
    [1.5, 1.3, 1.2, 1.0],
    [1.5, 1.2, 1.1, 1.0],
    [1.4, 1.2, 1.1, 0.9],
    [1.3, 1.1, 1.0, 0.9],
    [1.3, 1.1, 1.0, 0.8],
    [1.2, 1.0, 0.9, 0.8]
])
KDIGO_baseline = pd.DataFrame(KDIGO_baseline, columns = ["Black males", "Other males",
                                                        "Black females", "Other females"],
                             index = ["20-24", "25-29", "30-39", "40-54", "55-65", ">65"])

def inverse_MDRD(row, KDIGO_baseline):
    age = row["AGE"]
    is_male = True if row["MALE"]  else False
    is_black = True if row["RACE_BLACK"] else False
        
    if is_male and is_black:
        if age <= 24:
            return KDIGO_baseline.loc["20-24", "Black males"]
        elif 25 <= age <= 29:
            return KDIGO_baseline.loc["25-29", "Black males"]
        elif 30 <= age <= 39:
            return KDIGO_baseline.loc["30-39", "Black males"]
        elif 40 <= age <= 54:
            return KDIGO_baseline.loc["40-54", "Black males"]
        elif 55 <= age <= 65:
            return KDIGO_baseline.loc["55-65", "Black males"]
        elif age > 65:
            return KDIGO_baseline.loc[">65", "Black males"]
    
    if is_male and not is_black:
        if age <= 24:
            return KDIGO_baseline.loc["20-24", "Other males"]
        elif 25 <= age <= 29:
            return KDIGO_baseline.loc["25-29", "Other males"]
        elif 30 <= age <= 39:
            return KDIGO_baseline.loc["30-39", "Other males"]
        elif 40 <= age <= 54:
            return KDIGO_baseline.loc["40-54", "Other males"]
        elif 55 <= age <= 65:
            return KDIGO_baseline.loc["55-65", "Other males"]
        elif age > 65:
            return KDIGO_baseline.loc[">65", "Other males"]

    if not is_male and is_black:
        if age <= 24:
            return KDIGO_baseline.loc["20-24", "Black females"]
        elif 25 <= age <= 29:
            return KDIGO_baseline.loc["25-29", "Black females"]
        elif 30 <= age <= 39:
            return KDIGO_baseline.loc["30-39", "Black females"]
        elif 40 <= age <= 54:
            return KDIGO_baseline.loc["40-54", "Black females"]
        elif 55 <= age <= 65:
            return KDIGO_baseline.loc["55-65", "Black females"]
        elif age > 65:
            return KDIGO_baseline.loc[">65", "Black females"]
    
    if not is_male and not is_black:
        if age <= 24:
            return KDIGO_baseline.loc["20-24", "Other females"]
        elif 25 <= age <= 29:
            return KDIGO_baseline.loc["25-29", "Other females"]
        elif 30 <= age <= 39:
            return KDIGO_baseline.loc["30-39", "Other females"]
        elif 40 <= age <= 54:
            return KDIGO_baseline.loc["40-54", "Other females"]
        elif 55 <= age <= 65:
            return KDIGO_baseline.loc["55-65", "Other females"]
        elif age > 65:
            return KDIGO_baseline.loc[">65", "Other females"]
        
def inverse_MDRD_raw(row):
    eGFR = row['DFLT_eGFR']  
    male = row['MALE']
    black = row['RACE_BLACK']
    age = row['AGE']
    
    if male:
        gender_factor = 1.0
    else:
        gender_factor = 0.742

    if black:
        race_factor = 1.212
    else:
        race_factor = 1.0

    Scr = (eGFR / (175 * (age ** -0.203) * gender_factor * race_factor)) ** (1 / -1.154)
    return Scr

In [4]:
def get_scr_baseline_new(df_scr, df_admit, file_paths,  aggfunc_7d = 'last', aggfunc_1y = 'mean', keep_ckd = False):
    cohort_table = dict()
    
    # load & process dx data
    dx = pd.read_pickle(file_paths[0]+'AKI_DX.pkl') 
    
    existing_columns = [col for col in ['PATID', 'ENCOUNTERID', 'DX', 'DX_DATE', 'DX_TYPE', 'DAYS_SINCE_ADMIT']
                        if col in dx.columns]
    dx = dx[existing_columns]
    dx = df_admit[['PATID', 'ENCOUNTERID', 'ADMIT_DATE']].merge(dx, on = ['PATID', 'ENCOUNTERID'], how = 'inner')

    if 'DAYS_SINCE_ADMIT' not in dx.columns:
        dx['DAYS_SINCE_ADMIT'] = (dx['DX_DATE']-dx['ADMIT_DATE']).dt.days
        
    # calculate DX_DATE when it is missing
    dx.loc[dx.DX_DATE.isna(), 'DX_DATE'] = \
            dx.loc[dx.DX_DATE.isna(), 'ADMIT_DATE'] + \
            pd.to_timedelta(dx.loc[dx.DX_DATE.isna(), 'DAYS_SINCE_ADMIT'], unit='D')  # Use admit date of the index encounter, not the dx recording encounter.

    dx['DX'] = dx['DX'].astype(str)
    dx['DX_TYPE'] = dx['DX_TYPE'].astype(str)
    dx['DX_TYPE'] = dx['DX_TYPE'].replace('09', '9')
    
    # load & process demo data
    demo = pd.read_pickle(file_paths[0]+'AKI_DEMO'+'.pkl')  
    demo['MALE'] = demo['SEX'] == 'M'

    demo['RACE_WHITE'] = demo['RACE'] == '05'
    demo['RACE_BLACK'] = demo['RACE'] == '03'
    demo = demo[['PATID', 'ENCOUNTERID', 'AGE', 'MALE', 'RACE_WHITE', 'RACE_BLACK']]
    demo = demo.drop_duplicates()
    
    pat_id_cols = ['PATID', 'ENCOUNTERID']
    complete_df = df_scr[['ENCOUNTERID', 'PATID', 'ADMIT_DATE', 'SPECIMEN_DATE', 'RESULT_NUM']]
 
    admission_SCr = complete_df[(complete_df.SPECIMEN_DATE >= complete_df.ADMIT_DATE) & \
                                (complete_df.SPECIMEN_DATE <= (complete_df.ADMIT_DATE + pd.Timedelta(days=1)))].copy()

    admission_SCr = admission_SCr.groupby(pat_id_cols)['RESULT_NUM'].mean().reset_index()

    admission_SCr.rename(columns = {'RESULT_NUM': 'ADMISSION_SCR'}, inplace = True)

    complete_df = complete_df.merge(admission_SCr, 
                                    on = pat_id_cols,
                                    how = 'left')

    one_week_prior_admission = complete_df[(complete_df.SPECIMEN_DATE >= complete_df.ADMIT_DATE - pd.Timedelta(days=7)) & \
                                           (complete_df.SPECIMEN_DATE < complete_df.ADMIT_DATE)].copy()
    one_week_prior_admission = one_week_prior_admission.sort_values(by = ['PATID', 'ENCOUNTERID','SPECIMEN_DATE'])
    
    scr_1w_time = one_week_prior_admission.copy() #.groupby(pat_id_cols).last().reset_index()
    scr_1w_time['days_before_admit']  = (scr_1w_time['ADMIT_DATE']  -  scr_1w_time['SPECIMEN_DATE']).dt.days
    cohort_table['scr_1w_df'] = scr_1w_time
    
    one_week_prior_admission = one_week_prior_admission.groupby(pat_id_cols)['RESULT_NUM'].agg(aggfunc_7d).reset_index()

        
    one_week_prior_admission = one_week_prior_admission.rename(columns = {'RESULT_NUM': 'ONE_WEEK_SCR'})

    complete_df = complete_df.merge(one_week_prior_admission, 
                                    on = pat_id_cols,
                                    how = 'left')

    complete_df.loc[complete_df.ONE_WEEK_SCR.notna(), 'BASELINE_EST_1'] = \
                np.nanmin(complete_df.loc[complete_df.ONE_WEEK_SCR.notna(), ['ONE_WEEK_SCR','ADMISSION_SCR']], axis = 1)

    complete_dfe = complete_df.drop(['SPECIMEN_DATE', 'RESULT_NUM'],axis=1).drop_duplicates()
    cohort_table['ALL_ENCOUNTERS'] = len(complete_dfe[['PATID','ENCOUNTERID']].drop_duplicates())
    cohort_table['ALL_PATIENTS'] = complete_dfe.PATID.nunique()
    cohort_table['ADMISSION_SCR_YES'] = complete_dfe.ADMISSION_SCR.notna().sum()
    cohort_table['ADMISSION_SCR_NO'] = complete_dfe.ADMISSION_SCR.isna().sum()
    cohort_table['ONE_WEEK_SCR_YES'] = complete_dfe.ONE_WEEK_SCR.notna().sum()
    cohort_table['ONE_WEEK_SCR_NO'] = complete_dfe.ONE_WEEK_SCR.isna().sum()    
    cohort_table['ADMISSION_AND_1W_SCR'] = (complete_dfe.ADMISSION_SCR.notna() & complete_dfe.ONE_WEEK_SCR.notna()).sum()
    cohort_table['ADMISSION_AND_1W_SCR_MIN'] = (complete_dfe.BASELINE_EST_1.notna()).sum()
    cohort_table['ADMISSION_OR_1W_SCR'] = (complete_dfe.ADMISSION_SCR.notna() | complete_dfe.ONE_WEEK_SCR.notna()).sum()
    cohort_table['ONE_WEEK_SCR_ENC'] = complete_dfe[(complete_dfe.ONE_WEEK_SCR.notna() & (complete_dfe['ONE_WEEK_SCR']==complete_dfe['BASELINE_EST_1']))]['ENCOUNTERID'].unique()
    cohort_table['ADMISSION_SCR_1W_ENC'] = complete_dfe[(complete_dfe.ONE_WEEK_SCR.notna() & (complete_dfe['ONE_WEEK_SCR']!=complete_dfe['BASELINE_EST_1']))]['ENCOUNTERID'].unique()
        
    #ori_num_unique_combinations = df_scr.groupby(['PATID', 'ENCOUNTERID']).ngroups
    # criterion1_no_missing = complete_df.loc[complete_df.ONE_WEEK_SCR.notna(), :].groupby(pat_id_cols).ngroups
    # criterion1_missing_rate = 1 - criterion1_no_missing / ori_num_unique_combinations

    one_year_prior_admission = complete_df[(complete_df.SPECIMEN_DATE < (complete_df.ADMIT_DATE - pd.Timedelta(days=7))) & \
                                     (complete_df.SPECIMEN_DATE >= (complete_df.ADMIT_DATE - pd.Timedelta(days=365.25)))].copy()
    one_year_prior_admission = one_year_prior_admission.sort_values(by = ['PATID', 'ENCOUNTERID','SPECIMEN_DATE'])
    
    scr_1y_time = one_year_prior_admission[one_year_prior_admission.ENCOUNTERID.isin(complete_dfe[complete_dfe.ONE_WEEK_SCR.isna()].ENCOUNTERID.unique())] 
    scr_1y_time['days_before_admit']  = (scr_1y_time['ADMIT_DATE']  -  scr_1y_time['SPECIMEN_DATE']).dt.days
    cohort_table['scr_1y_df'] = scr_1y_time
    one_year_prior_admission = one_year_prior_admission.loc[:, pat_id_cols + ['RESULT_NUM']]
    one_year_prior_admission = one_year_prior_admission.groupby(pat_id_cols)['RESULT_NUM'].agg(aggfunc_1y).reset_index()
    one_year_prior_admission.rename(columns = {'RESULT_NUM': 'ONE_YEAR_SCR'}, inplace = True)
    
    complete_df = complete_df.merge(one_year_prior_admission, 
                                    on = pat_id_cols,
                                    how = 'left')
    
    complete_df.loc[complete_df.ONE_YEAR_SCR.notna(), 'BASELINE_EST_2'] = \
                np.nanmin(complete_df.loc[complete_df.ONE_YEAR_SCR.notna(), ['ONE_YEAR_SCR', 'ADMISSION_SCR']], axis = 1)

    complete_df['BASELINE_NO_INVERT'] = \
                np.where(complete_df['BASELINE_EST_1'].isna(), complete_df['BASELINE_EST_2'], complete_df['BASELINE_EST_1'])

    complete_dfe = complete_df.drop(['SPECIMEN_DATE', 'RESULT_NUM'],axis=1).drop_duplicates()
    cohort_table['ONE_YEAR_SCR_YES'] = (complete_dfe.ONE_WEEK_SCR.isna() & complete_dfe.ONE_YEAR_SCR.notna()).sum()
    cohort_table['ONE_YEAR_SCR_NO'] = (complete_dfe.ONE_WEEK_SCR.isna() & complete_dfe.ONE_YEAR_SCR.isna()).sum()
    
    cohort_table['ADMISSION_AND_1Y_SCR'] = (complete_dfe.ADMISSION_SCR.notna() & (complete_dfe.ONE_WEEK_SCR.isna() & complete_dfe.ONE_YEAR_SCR.notna())).sum()
    cohort_table['ADMISSION_AND_1Y_SCR_MIN'] = (complete_dfe.ONE_WEEK_SCR.isna() & complete_dfe.BASELINE_EST_2.notna()).sum()
    
    cohort_table['ADMISSION_OR_1Y_SCR'] = (complete_dfe.ADMISSION_SCR.notna() | (complete_dfe.ONE_WEEK_SCR.isna() & complete_dfe.ONE_YEAR_SCR.notna())).sum()
    
    cohort_table['ONE_YEAR_SCR_ENC'] = complete_dfe[(complete_dfe.ONE_WEEK_SCR.isna() & complete_dfe.ONE_YEAR_SCR.notna() & (complete_dfe['ONE_YEAR_SCR']==complete_dfe['BASELINE_EST_2']))]['ENCOUNTERID'].unique()
    cohort_table['ADMISSION_SCR_1Y_ENC'] = complete_dfe[(complete_dfe.ONE_WEEK_SCR.isna() & complete_dfe.ONE_YEAR_SCR.notna() & (complete_dfe['ONE_YEAR_SCR']!=complete_dfe['BASELINE_EST_2']))]['ENCOUNTERID'].unique()

    pat_to_invert = complete_df.loc[complete_df.BASELINE_NO_INVERT.isna(), pat_id_cols+['ADMIT_DATE', 'ADMISSION_SCR']]
    
    cohort_table['MDRD_TO_INVERT'] = pat_to_invert['ENCOUNTERID'].nunique()
    cohort_table['MDRD_TO_INVERT_ENC'] = pat_to_invert['ENCOUNTERID'].unique()
    pat_to_invert.drop_duplicates(subset=pat_id_cols, keep='first', inplace = True)
    pat_dx = pat_to_invert.merge(dx.drop(['ENCOUNTERID', 'ADMIT_DATE'], axis = 1), 
                                              on = 'PATID', 
                                              how = 'left')
    pat_dx = pat_dx[pat_dx.DX_DATE <= pat_dx.ADMIT_DATE]   
    pat_dx = pat_dx.merge(pat_to_invert[['PATID', 'ENCOUNTERID']], 
                          on = ['PATID', 'ENCOUNTERID'], 
                          how = 'outer')
    pat_dx['DFLT_eGFR'] = 75
    
    pat_dx.loc[pat_dx['DX'].isin(['585.3', 'N18.3']), 'DFLT_eGFR'] = 90/2
    pat_dx.loc[pat_dx['DX'].isin(['585.4', 'N18.4']), 'DFLT_eGFR'] = 45/2
    pat_dx.loc[pat_dx['DX'].isin(['585.5', 'N18.5']), 'DFLT_eGFR'] = 15/2
    pat_def_egfr = pat_dx.groupby(pat_id_cols)['DFLT_eGFR'].min().reset_index()
    
    
    cohort_table['ALL_CKD3_ENC'] = dx[(dx['DX'].isin(['585.3', 'N18.3']) )  & (dx['ENCOUNTERID'].isin(df_admit['ENCOUNTERID'].unique()))]['ENCOUNTERID'].unique()
    cohort_table['ALL_CKD4_ENC'] = dx[(dx['DX'].isin(['585.4', 'N18.4']) )  & (dx['ENCOUNTERID'].isin(df_admit['ENCOUNTERID'].unique()))]['ENCOUNTERID'].unique()
    cohort_table['ALL_CKD5_ENC'] = dx[(dx['DX'].isin(['585.5', 'N18.5']) )  & (dx['ENCOUNTERID'].isin(df_admit['ENCOUNTERID'].unique()))]['ENCOUNTERID'].unique()

    cohort_table['MDRD_NOCKD'] = (pat_def_egfr['DFLT_eGFR'] == 75).sum()
    
    cohort_table['ADMISSION_OR_MDRD_NOCKD'] = (complete_dfe.ADMISSION_SCR.notna() | complete_dfe['ENCOUNTERID'].isin(pat_def_egfr[pat_def_egfr['DFLT_eGFR'] == 75]['ENCOUNTERID'].unique())).sum()
    
    cohort_table['MDRD_CKD3']  = (pat_def_egfr['DFLT_eGFR'] == 90/2).sum()
    cohort_table['MDRD_CKD4']  = (pat_def_egfr['DFLT_eGFR'] == 45/2).sum()
    cohort_table['MDRD_CKD5']  = (pat_def_egfr['DFLT_eGFR'] == 15/2).sum()
        
    pat_to_invert= pat_to_invert.merge(pat_def_egfr, on = pat_id_cols, how = 'left')
    pat_to_invert['DFLT_eGFR'] = pat_to_invert['DFLT_eGFR'].fillna(75)

    pat_to_invert['CKD345'] = pat_to_invert['DFLT_eGFR'] != 75
    cohort_table['CKD345_ENC'] = pat_to_invert[pat_to_invert['CKD345']].ENCOUNTERID.unique()
    pat_to_invert = pat_to_invert.merge(demo, on = pat_id_cols, how = 'left')
    
    KDIGO_baseline = np.array([
        [1.5, 1.3, 1.2, 1.0],
        [1.5, 1.2, 1.1, 1.0],
        [1.4, 1.2, 1.1, 0.9],
        [1.3, 1.1, 1.0, 0.9],
        [1.3, 1.1, 1.0, 0.8],
        [1.2, 1.0, 0.9, 0.8]
    ])
    KDIGO_baseline = pd.DataFrame(KDIGO_baseline, columns = ["Black males", "Other males",
                                                            "Black females", "Other females"],
                                 index = ["20-24", "25-29", "30-39", "40-54", "55-65", ">65"])    
    
    pat_to_invert.loc[~pat_to_invert['CKD345'], 'BASELINE_INVERT'] = pat_to_invert.loc[~pat_to_invert['CKD345'], :].apply(inverse_MDRD, args = (KDIGO_baseline,), axis = 1) 
    pat_to_invert.loc[pat_to_invert['CKD345'], 'BASELINE_INVERT'] = pat_to_invert.loc[pat_to_invert['CKD345'], :].apply(inverse_MDRD_raw, axis = 1) 

    pat_to_invert['BASELINE_EST_3'] = np.min(pat_to_invert[['ADMISSION_SCR', 'BASELINE_INVERT']], axis = 1)

    cohort_table['MDRD_ENC'] = pat_to_invert[(~pat_to_invert['CKD345']) & (pat_to_invert['BASELINE_EST_3'] == pat_to_invert['BASELINE_INVERT'])]['ENCOUNTERID'].unique()
    cohort_table['ADMISSION_SCR_MDRD_ENC'] = pat_to_invert[(~pat_to_invert['CKD345']) & (pat_to_invert['BASELINE_EST_3'] != pat_to_invert['BASELINE_INVERT'])]['ENCOUNTERID'].unique()
        
    complete_df = complete_df.merge(pat_to_invert[pat_id_cols + ['BASELINE_EST_3', 'CKD345']], 
                                    on = pat_id_cols,
                                    how = 'left')

    complete_df['SERUM_CREAT_BASE'] = np.min(complete_df[['BASELINE_NO_INVERT', 'BASELINE_EST_3']], axis = 1)

    if not keep_ckd:
        complete_df = complete_df[~(complete_df['CKD345'] & complete_df['BASELINE_NO_INVERT'].isna())]

    complete_df = complete_df.drop('CKD345', axis=1)
        
    # drop those still cannot find baseline
    complete_df = complete_df.dropna(subset=['SERUM_CREAT_BASE'])

    return complete_df.drop_duplicates(), cohort_table

In [8]:
def eGFR_MDRD(df, scr_label):
    # Adjust Scr for units; assuming Scr is given in mg/dL
    Scr = df[scr_label]

    # Coefficients for gender and race
    gender_coeff = np.where(df['MALE'], 1, 0.742)
    race_coeff = np.where(df['RACE_BLACK'], 1.212, 1)

    # MDRD equation components
    Scr_component = Scr ** -1.154
    age_component = df['AGE'] ** -0.203

    # eGFR calculation
    eGFR = 175 * Scr_component * age_component * gender_coeff * race_coeff
    return eGFR


def eGFR_CKDEPI09(df, scr_label):
    # Determine kappa and alpha based on 'MALE' column
    kappa = np.where(df['MALE'], 0.9, 0.7)
    alpha = np.where(df['MALE'],  -0.411, -0.329)

    # Coefficients for gender
    gender_coeff = np.where(df['MALE'], 1, 1.018)
    race_coeff = np.where(df['RACE_BLACK'], 1.159, 1)
    
    # Calculate eGFR
    Scr_over_kappa = df[scr_label] / kappa
    min_term = np.where(Scr_over_kappa <= 1, Scr_over_kappa**alpha, 1)
    max_term = np.where(Scr_over_kappa > 1, Scr_over_kappa**(-1.209), 1)
    age_term = 0.993 ** df['AGE']

    # eGFR calculation
    eGFR = 141 * min_term * max_term * age_term * gender_coeff * race_coeff
    return eGFR
    

def eGFR_CKDEPI21(df, scr_label):
    # Determine kappa and alpha based on 'MALE' column
    kappa = np.where(df['MALE'], 0.9, 0.7)
    alpha = np.where(df['MALE'], -0.302, -0.241)

    # Coefficients for gender
    gender_coeff = np.where(df['MALE'], 1, 1.012)

    # Calculate eGFR
    Scr_over_kappa = df[scr_label] / kappa
    min_term = np.where(Scr_over_kappa <= 1, Scr_over_kappa**alpha, 1)
    max_term = np.where(Scr_over_kappa > 1, Scr_over_kappa**(-1.200), 1)
    age_term = 0.9938 ** df['AGE']

    # eGFR calculation
    eGFR = 142 * min_term * max_term * age_term * gender_coeff
    return eGFR
    
def ckd_staging(df, egfr_label):
    conditions = [
        (df[egfr_label] >= 90),
        (df[egfr_label] >= 60) & (df[egfr_label] < 90),
        (df[egfr_label] >= 45) & (df[egfr_label] < 60),
        (df[egfr_label] >= 30) & (df[egfr_label] < 45),
        (df[egfr_label] >= 15) & (df[egfr_label] < 30),
        (df[egfr_label] < 15)
    ]

    choices = [0, 1, 2, 3, 4, 5]
    ckd_stage = np.select(conditions, choices, default=np.nan) 
    return ckd_stage

In [9]:
def get_rrt(df_admit, file_paths):
    px = pd.read_pickle(file_paths[0]+'AKI_PX.pkl')   

    idx_transplant = np.logical_or(np.logical_or(
                           np.logical_and(px['PX_TYPE']=='CH',px['PX'].isin(['50300','50320','50323','50325','50327','50328','50329','50340','50360','50365','50370','50380'])),
                           np.logical_and(px['PX_TYPE']=='09',px['PX'].isin(['55.51','55.52','55.53','55.54','55.61','55.69']))),np.logical_or(
                           np.logical_and(px['PX_TYPE']=='9',px['PX'].isin(['55.51','55.52','55.53','55.54','55.61','55.69'])),                       
                           np.logical_and(px['PX_TYPE']=='10',px['PX'].isin(['0TY00Z0','0TY00Z1','0TY00Z2','0TY10Z0','0TY10Z1','0TY10Z2','0TB00ZZ','0TB10ZZ','0TT00ZZ','0TT10ZZ','0TT20ZZ']))))

    idx_dialysis =(((px['PX_TYPE']=='CH') & (px['PX'].isin(['90935', '90937']))) |  
                  ((px['PX_TYPE']=='CH') & (pd.to_numeric(px['PX'], errors='coerce').between(90940, 90999))) |   
                  ((px['PX_TYPE']=='9') & (px['PX'].isin(['39.93','39.95','54.98', 'V45.11']))) | 
                  ((px['PX_TYPE']=='09') & (px['PX'].isin(['39.93','39.95','54.98', 'V45.11']))) |  
                  ((px['PX_TYPE']=='10') & (px['PX'].isin(['5A1D00Z','5A1D60Z','5A1D70Z','5A1D80Z','5A1D90Z', 'Z99.2'])))) 
 
    rrt_stage =  px[idx_transplant | idx_dialysis] 

    rrt_stage = rrt_stage[['PATID','ENCOUNTERID','PX_DATE']]
    rrt_stage.columns = ['PATID','ENCOUNTERID','RRT_ONSET_DATE']

    rrt_stage = rrt_stage.merge(df_admit, on=['PATID', 'ENCOUNTERID'], how='inner')
    rrt_stage['RRT_SINCE_ADMIT'] = (rrt_stage['RRT_ONSET_DATE']-rrt_stage['ADMIT_DATE']).dt.total_seconds()/(3600*24)
    rrt_stage = rrt_stage.loc[rrt_stage[['ENCOUNTERID', 'RRT_SINCE_ADMIT']].groupby('ENCOUNTERID').idxmin().reset_index()['RRT_SINCE_ADMIT']]
    rrt_stage.drop('ADMIT_DATE', axis = 1, inplace = True)
    return rrt_stage

In [10]:
def determine_initial_aki_stage(row):
    # Extract the AKI onset days
    aki_days = {
        1: row['AKI1_SINCE_ADMIT'],
        2: row['AKI2_SINCE_ADMIT'],
        3: row['AKI3_SINCE_ADMIT']
    }
    
    # Remove NaN values
    aki_days = {key: val for key, val in aki_days.items() if not pd.isnull(val)}
    
    if not aki_days:
        return np.nan
    
    # Find the minimum value and handle ties by prioritizing higher stages
    min_value = min(aki_days.values())
    highest_stage = max(stage for stage, day in aki_days.items() if day == min_value)
    
    return highest_stage

def get_aki_onset(df_scr, df_admit, df_rrt, df_baseline, aki_criteria = 'either'):
    xxx = df_scr.copy()
    yyy = df_admit.copy()
    
    # Filter out the CKD patients that does not have baseline 
    valid_combinations = df_baseline[['ENCOUNTERID', 'PATID']].drop_duplicates()
    xxx = xxx.merge(valid_combinations, on=['ENCOUNTERID', 'PATID'], how='inner')
    yyy = yyy.merge(valid_combinations, on=['ENCOUNTERID', 'PATID'], how='inner')
    
    #
    zzz = df_baseline[['PATID', 'ENCOUNTERID', 'SERUM_CREAT_BASE']].drop_duplicates()
    zzz.columns= ['PATID', 'ENCOUNTERID',  'RESULT_NUM_BASE_7d']
    xxx = xxx.merge(zzz, on = ['PATID', 'ENCOUNTERID'], how='left')

    zzz2 = xxx[['PATID', 'ENCOUNTERID', 'SPECIMEN_DATE', 'RESULT_NUM']].groupby(['PATID', 'ENCOUNTERID']).rolling('2d', on='SPECIMEN_DATE').min().reset_index()
    zzz2 = zzz2[['PATID', 'ENCOUNTERID', 'SPECIMEN_DATE', 'RESULT_NUM']]
    zzz2.columns= ['PATID', 'ENCOUNTERID', 'SPECIMEN_DATE', 'RESULT_NUM_BASE_2d']
    xxx = xxx.merge(zzz2, on = ['PATID', 'ENCOUNTERID', 'SPECIMEN_DATE'], how='left')

    # Check condition for AKI1
    #1.5 increase in 7 days
    xxx['AKI1.5'] = (xxx['RESULT_NUM']>=1.5*xxx['RESULT_NUM_BASE_7d']) & (xxx['DAYS_SINCE_ADMIT']>=0) 
    #0.3 increase in 48 hours
    xxx['AKI0.3'] = (xxx['RESULT_NUM']-xxx['RESULT_NUM_BASE_2d']>=0.3) & (xxx['DAYS_SINCE_ADMIT']>=0)      
    
    if aki_criteria == '2d':
        xxx = xxx[xxx['AKI0.3']]
        xxx = xxx.sort_values(['SPECIMEN_DATE', 'RESULT_NUM'], ascending=[True, False])
        xxx_backup = xxx.copy()
        xxx = xxx.groupby(['PATID', 'ENCOUNTERID']).first().reset_index()
        xxx['RESULT_NUM_BASE'] = xxx['RESULT_NUM_BASE_7d']
 
    elif aki_criteria == '7d':
        xxx = xxx[xxx['AKI1.5']]
        xxx = xxx.sort_values(['SPECIMEN_DATE', 'RESULT_NUM'], ascending=[True, False])
        xxx_backup = xxx.copy()
        xxx = xxx.groupby(['PATID', 'ENCOUNTERID']).first().reset_index()
        xxx['RESULT_NUM_BASE'] = xxx['RESULT_NUM_BASE_7d']

    elif aki_criteria == 'both':
        xxx = xxx[xxx['AKI0.3'] & xxx['AKI1.5']]
        xxx = xxx.sort_values(['SPECIMEN_DATE', 'RESULT_NUM'], ascending=[True, False])
        xxx_backup = xxx.copy()
        xxx = xxx.groupby(['PATID', 'ENCOUNTERID']).first().reset_index()
        xxx['RESULT_NUM_BASE'] = xxx['RESULT_NUM_BASE_7d']

    else:
        xxx = xxx[xxx['AKI0.3'] | xxx['AKI1.5']]
        xxx = xxx.sort_values(['SPECIMEN_DATE', 'RESULT_NUM'], ascending=[True, False])
        xxx_backup = xxx.copy()
        xxx = xxx.groupby(['PATID', 'ENCOUNTERID']).first().reset_index()
        xxx['RESULT_NUM_BASE'] = xxx['RESULT_NUM_BASE_7d']


    xxx['AKI1_SINCE_ADMIT'] = xxx['DAYS_SINCE_ADMIT'].copy()
    xxx['AKI1_DATE'] = xxx['SPECIMEN_DATE'].copy()
    xxx['AKI1_SCR'] = xxx['RESULT_NUM'].copy()
    xxx['SCR_BASELINE'] = xxx['RESULT_NUM_BASE'].copy()
    xxx['SCR_REFERENCE'] = xxx['RESULT_NUM_BASE_2d'].copy()
    xxx['AKI1_7D'] = xxx['AKI1.5'].copy()
    xxx['AKI1_2D'] = xxx['AKI0.3'].copy()
    xxx = xxx[['PATID', 'ENCOUNTERID', 'SCR_BASELINE', 'SCR_REFERENCE',  'AKI1_DATE', 'AKI1_SCR', 'AKI1_SINCE_ADMIT', 'AKI1_7D', 'AKI1_2D']]

    # Check condition for AKI2: 2.0x - <3.0x
    aki2 = xxx.merge(xxx_backup, on=['PATID', 'ENCOUNTERID'], how='left')
    aki2 = aki2[aki2['SPECIMEN_DATE']>=aki2['AKI1_DATE']]
    aki2 = aki2[aki2['RESULT_NUM']>=2*aki2['SCR_BASELINE']]
    aki2 = aki2.groupby(['PATID', 'ENCOUNTERID']).first().reset_index()
    aki2['AKI2_SINCE_ADMIT'] = aki2['DAYS_SINCE_ADMIT'].copy()
    aki2['AKI2_DATE'] = aki2['SPECIMEN_DATE'].copy()
    aki2['AKI2_SCR'] = aki2['RESULT_NUM'].copy()
    aki2 = aki2[['PATID', 'ENCOUNTERID', 'AKI2_DATE', 'AKI2_SCR', 'AKI2_SINCE_ADMIT']]
    
    # Check condition for AKI3: SCR >= 3.0x Baseline
    aki3 = xxx.merge(xxx_backup, on=['PATID', 'ENCOUNTERID'], how='left')
    aki3 = aki3[aki3['SPECIMEN_DATE']>=aki3['AKI1_DATE']]
    aki3 = aki3[(aki3['RESULT_NUM']>=3*aki3['SCR_BASELINE']) | (aki3['RESULT_NUM']>=4)]
    aki3 = aki3.groupby(['PATID', 'ENCOUNTERID']).first().reset_index()
    aki3['AKI3_SINCE_ADMIT'] = aki3['DAYS_SINCE_ADMIT'].copy()
    aki3['AKI3_DATE'] = aki3['SPECIMEN_DATE'].copy()
    aki3['AKI3_SCR'] = aki3['RESULT_NUM'].copy()
    aki3 = aki3[['PATID', 'ENCOUNTERID', 'AKI3_DATE', 'AKI3_SINCE_ADMIT', 'AKI3_SCR']]
    
    # Check condition for AKI3: initiation of RRT
    #df_rrt = get_rrt(path, ext, sep, yyy)
    rrt = df_rrt.merge(xxx[['PATID', 'ENCOUNTERID', 'AKI1_DATE']], on=['PATID', 'ENCOUNTERID'], how='left')
    rrt = rrt[rrt['RRT_ONSET_DATE'] >= rrt['AKI1_DATE']]
    aki3b =  aki3.merge(rrt, on = ['PATID', 'ENCOUNTERID'], how = 'outer')
    cond_rrt = (aki3b['RRT_SINCE_ADMIT'] < aki3b['AKI3_SINCE_ADMIT']) | (aki3b['AKI3_SINCE_ADMIT'].isna() & aki3b['RRT_SINCE_ADMIT'].notna())
    aki3b.loc[cond_rrt, 'AKI3_SINCE_ADMIT'] = aki3b.loc[cond_rrt, 'RRT_SINCE_ADMIT']
    aki3b.loc[cond_rrt, 'AKI3_DATE'] = aki3b.loc[cond_rrt, 'RRT_ONSET_DATE']
    
    aki3_all = aki3b[['PATID', 'ENCOUNTERID', 'AKI3_DATE', 'AKI3_SINCE_ADMIT', 'AKI3_SCR']]
    
    # Merge AKI staging information
    onset = xxx.merge(aki2, on=['PATID', 'ENCOUNTERID'], how='outer').merge(aki3_all, on=['PATID', 'ENCOUNTERID'], how='outer')
    onset = onset.merge(yyy, on=['PATID', 'ENCOUNTERID'], how='left')

    onset.columns = onset.columns.str.upper()
    onset['ONSET_DATE'] = onset['AKI1_DATE'].copy()  
    onset['SCR_ONSET'] = onset['AKI1_SCR'].copy() 
    
    onset['DISCHARGE_SINCE_ONSET'] = (onset['DISCHARGE_DATE'] - onset['ONSET_DATE']).dt.days
    
    onset = onset[['PATID','ENCOUNTERID', 'ADMIT_DATE', 'DISCHARGE_DATE', 
                   'ONSET_DATE', 'AKI1_SINCE_ADMIT', 'AKI2_SINCE_ADMIT', 
                   'AKI3_SINCE_ADMIT',  'DISCHARGE_SINCE_ONSET','SCR_ONSET', 
                   'SCR_BASELINE',  'SCR_REFERENCE', 'AKI1_7D', 'AKI1_2D']]

    onset['FLAG'] = (onset['AKI2_SINCE_ADMIT'].notna()) | (onset['AKI3_SINCE_ADMIT'].notna())
    onset['ONSET_SINCE_ADMIT'] = onset['AKI1_SINCE_ADMIT'].copy()  #onset[['AKI1_SINCE_ADMIT', 'AKI2_SINCE_ADMIT', 'AKI3_SINCE_ADMIT']].min(axis=1)
    
    #Generate onset staging by taking the first stage
    onset['AKI_STAGE'] = 0
    filter_aki3 = onset['AKI3_SINCE_ADMIT'].notna()
    filter_aki2 = onset['AKI2_SINCE_ADMIT'].notna() & onset['AKI3_SINCE_ADMIT'].isna()
    filter_aki1 = onset['AKI1_SINCE_ADMIT'].notna() & onset['AKI2_SINCE_ADMIT'].isna() & onset['AKI3_SINCE_ADMIT'].isna()
    
    onset.loc[filter_aki3, 'AKI_STAGE'] = 3
    onset.loc[filter_aki2, 'AKI_STAGE'] = 2
    onset.loc[filter_aki1, 'AKI_STAGE'] = 1
    
    
    # Determine the initial AKI stage by finding the column with the smallest onset day
    onset['AKI_INIT_STG'] = onset.apply(determine_initial_aki_stage, axis=1)
    
    return onset.drop_duplicates()

### Step 2: Get AKI Recovery Status

In [11]:
def load_and_filter_scr(onset, file_paths):
    xxx = pd.read_pickle(file_paths[0]+'AKI_LAB_SCR.pkl') 
    xxx = xxx[['PATID', 'ENCOUNTERID', 'SPECIMEN_DATE',  'RESULT_NUM']] 
    
    xxx = xxx.groupby(['PATID', 'ENCOUNTERID', 'SPECIMEN_DATE']).mean()
    xxx = xxx.sort_values(['PATID', 'ENCOUNTERID', 'SPECIMEN_DATE'])
    xxx = xxx.reset_index()
    # merge with onset data
    xxx = onset.merge(xxx, on = ['ENCOUNTERID', 'PATID'], how='inner')
    xxx['DAYS_SINCE_ONSET'] = (xxx['SPECIMEN_DATE']-xxx['ONSET_DATE']).dt.days
    return xxx

def akirecovery(onset, scr):
    onset2 = onset[['PATID', 'ENCOUNTERID', 'DISCHARGE_DATE', 'DISCHARGE_SINCE_ONSET']]
    df = onset2.merge(scr, on = ['PATID', 'ENCOUNTERID'], how = 'right')
    df2 = df[(df['ONSET_DATE'] < df['SPECIMEN_DATE']) & (df['SPECIMEN_DATE'] <= df['DISCHARGE_DATE'])].sort_values('SPECIMEN_DATE')
    df3 = df2.groupby(['PATID', 'ENCOUNTERID']).last().reset_index()
    df3['AKI_RCV'] = (((df3['RESULT_NUM'] < (1.5 * df3['SCR_BASELINE']) ) &  df3['AKI1_7D']) 
                     |((df3['RESULT_NUM'] <  (0.3 + df3['SCR_REFERENCE']) ) & df3['AKI1_2D']))

    return df3[['PATID', 'ENCOUNTERID', 'AKI_RCV']]

# Generate AKI resolving Status (Kellum Definition)
def akidisappearing(dall):
    dall['AKI_disappearing']=False

    dall['ddays']= (dall['charttime']-dall['onsettime'])
    dall['dvalues_rat'] = dall['value']/dall['baseline']
    dall['dvalues_lvl'] = dall['value']-dall['reference']
    filter_rat = (dall['ddays']>pd.Timedelta(value=0, unit='s')) & (dall['dvalues_rat']<1.5) & dall['aki1_7d']
    filter_lvl = (dall['ddays']>pd.Timedelta(value=0, unit='s')) & (dall['dvalues_lvl']<0.3) & dall['aki1_2d']
    
    dall.loc[(filter_rat | filter_lvl),'AKI_disappearing']=True

    dall['disappeared']=dall['AKI_disappearing'].copy()

    dmaxRelv = dall[['subject_id', 'hadm_id', 'AKI_disappearing']].groupby(['subject_id', 'hadm_id']).max().reset_index()
    dmaxRelv.columns = ['subject_id', 'hadm_id', 'max_AKI_disappearing']
    dall = dall.merge(dmaxRelv, left_on=['subject_id', 'hadm_id'], right_on=['subject_id', 'hadm_id'], how='left')
    dall['AKI_disappearing'] = dall['max_AKI_disappearing']
    
    # Filter where AKI_disappearing is True and then get the first occurrence for each group
    first_resolv = dall[dall['disappeared'] == True][['subject_id', 'hadm_id','charttime']].sort_values('charttime').groupby(['subject_id', 'hadm_id']).first().reset_index()
    dall2 = dall[['subject_id', 'hadm_id', 'AKI_disappearing']].groupby(['subject_id', 'hadm_id'])['AKI_disappearing'].max().reset_index()
    df_resolv =  dall2.merge(first_resolv, on=['subject_id', 'hadm_id'], how='left')
    
    df_resolv.columns = ['PATID', 'ENCOUNTERID', 'AKI_DISP', 'FIRST_DISP_TIME']
    return dall.drop('max_AKI_disappearing',axis = 1), df_resolv

def akireverting(dall):
    result = dall.sort_values(['subject_id', 'hadm_id','charttime']).groupby(['subject_id', 'hadm_id']).apply(check_reverted_and_relapsed).reset_index()
    
    result.columns = ['PATID', 'ENCOUNTERID', 'AKI_RVRT',  'FIRST_RVRT_TIME', 'AKI_RELP', 'FIRST_RELP_TIME']
    return result

def interp_and_check_next_day(filtered_group, row):
    group_nexttime = filtered_group.iloc[0]
    group_nextday = group_nexttime 

    nxttime = group_nexttime['charttime']  
    curtime = row['charttime']
    group_nextday['dvalues_rat'] = row['dvalues_rat'] + (group_nexttime['dvalues_rat'] - row['dvalues_rat']) / ((nxttime - curtime).total_seconds() / (60 * 60 * 24) )
    group_nextday['dvalues_lvl'] = row['dvalues_lvl'] + (group_nexttime['dvalues_lvl'] - row['dvalues_lvl']) / ((nxttime - curtime).total_seconds() / (60 * 60 * 24) )

    group_nextday['ddays'] = row['ddays'] + pd.Timedelta(days=1)

    rat_condition = (group_nextday['ddays']>pd.Timedelta(value=0, unit='s')) & (group_nextday['dvalues_rat']<1.5) & row['aki1_7d']
    lvl_condition = (group_nextday['ddays']>pd.Timedelta(value=0, unit='s')) & (group_nextday['dvalues_lvl']<0.3) & row['aki1_2d']
    next_disappeared = (rat_condition | lvl_condition)
    
    return next_disappeared

        
def check_reverted_and_relapsed(group):
    group['reverted'] = False
    for i, row in group.iterrows():
        if row['disappeared']:
            next_time = group[group['charttime'] > row['charttime']]['charttime'].min()
            next_day = row['charttime'] + pd.Timedelta(days=1)
            if pd.isna(next_time) :
                next_disappeared = True
                
                if next_disappeared:
                    group.loc[i, 'reverted'] = True
            elif next_time > next_day :
                filter_nexttime = (group['charttime'] == next_time)
                #next_disappeared = group[filter_nexttime]['disappeared'].iloc[0] if not pd.isna(next_time) else 
                next_disappeared = interp_and_check_next_day(group[filter_nexttime], row)
                
                if next_disappeared:
                    group.loc[i, 'reverted'] = True
            else:
                mask = (group['charttime'] > row['charttime']) & (group['charttime'] <= next_day)
                if group.loc[mask, 'disappeared'].all():
                    group.loc[i, 'reverted'] = True

    first_reverted_time = group.loc[group['reverted'], 'charttime'].min() if group['reverted'].any() else None
    group_reverted = group['reverted'].any()

    # Initialize relapsed as False
    relapsed = False
    first_relapsed_time = None

    # If group is reverted, check for relapse
    if group_reverted and first_reverted_time is not None:
        # Filter for records after first_reverted_time
        post_reverted_records = group[group['charttime'] > first_reverted_time]
        # Check if any of these records has 'disappeared' as False
        if not post_reverted_records.empty and (post_reverted_records['disappeared'] == False).any():
            relapsed = True
            first_relapsed_time = post_reverted_records[post_reverted_records['disappeared'] == False]['charttime'].min()

    return pd.Series({
        'reverted': group_reverted,
        'first_revert_time': first_reverted_time,
        'relap': relapsed,
        'first_relap_time': first_relapsed_time
    })


def get_aki_reverting(scr, onset):
    scr = scr[scr['DAYS_SINCE_ONSET']>=0]
    scr = scr[['PATID', 'ENCOUNTERID', 'SPECIMEN_DATE', 'RESULT_NUM', 'DAYS_SINCE_ONSET', 
               'ONSET_DATE', 'SCR_BASELINE', 'SCR_REFERENCE', 'AKI1_2D', 'AKI1_7D']]
    
    aki_recov = akirecovery(onset, scr)
    
    scr.columns = ['subject_id', 'hadm_id', 'charttime', 'value', 'onset_day',
                   'onsettime', 'baseline', 'reference', 'aki1_2d', 'aki1_7d']

    xxx, df_resolv = akidisappearing(dall = scr)
    xxx = akireverting(dall = xxx)
    

    # Merge the results on PATID and ENCOUNTERID
    df_revert = xxx.merge(df_resolv,  
                   on=['PATID', 'ENCOUNTERID'],
                   how = 'left')

    onset = onset.merge(df_revert,
                        on=['PATID', 'ENCOUNTERID'], 
                        how='left')
    
    onset = onset.merge(aki_recov, on = ['PATID', 'ENCOUNTERID'], how = 'left')
    onset['AKI_RCV'].fillna(False, inplace = True)
    
    onset['DISP_SINCE_ONSET'] = (onset['FIRST_DISP_TIME'] - onset['ONSET_DATE']).dt.days
    onset['RVRT_SINCE_ONSET'] = (onset['FIRST_RVRT_TIME'] - onset['ONSET_DATE']).dt.days
    onset['RELP_SINCE_ONSET'] = (onset['FIRST_RELP_TIME'] - onset['ONSET_DATE']).dt.days
    
    onset['AKI_ERVRT']   =  onset['AKI_RVRT'] & (onset['FIRST_RVRT_TIME'] <= (onset['ONSET_DATE'] + pd.Timedelta(days=7)))
    onset['AKI_ESRVRT']  = (onset['AKI_ERVRT'])  & (~onset['AKI_RELP'])
    onset['AKI_LSRVRT']  = (~onset['AKI_ERVRT']) & onset['AKI_RVRT']  & (~onset['AKI_RELP'])
    onset['AKI_NRVRT']   =  ~onset['AKI_RVRT']
    onset['AKI_RLPRCV']  = (onset['AKI_RELP']) & (onset['AKI_RCV'])
    onset['AKI_RLPNRCV'] = (onset['AKI_RELP']) & (~onset['AKI_RCV'])
    onset['AKI_STATUS']  = (0*onset['AKI_NRVRT'] + 
                                         1*onset['AKI_ESRVRT'] + 
                                         2*onset['AKI_LSRVRT'] +  
                                         3*onset['AKI_RLPRCV'] + 
                                         4*onset['AKI_RLPNRCV'])
    
    return onset.drop_duplicates()

def akiresolving(dall, time, dlevel = 0.3, ratio = 0.75):
    time_window = (dall['onset_day']<=3) & (dall['onset_day']>=time) 
    dall = dall[time_window].reset_index().drop('index',axis=1)
    dall2 = dall.sort_values(['value', 'charttime'], ascending=[False, True])[['subject_id', 'hadm_id', 'charttime', 'value']]
    max_idx = dall2.groupby(['subject_id', 'hadm_id'])['value'].idxmax()
    dmax_sCr = dall2.loc[max_idx]
    dmax_sCr.columns = ['subject_id', 'hadm_id', 'charttime_max', 'value_max']
    dall = dall.merge(dmax_sCr, left_on=['subject_id', 'hadm_id'], right_on=['subject_id', 'hadm_id'], how='left')
    # redefine time windows
    time_window = (dall['onset_day']<=3) & (dall['onset_day']>=time)
    dall['AKI_resolving']=0
   
    dall['ddays']= (dall['charttime']-dall['charttime_max'])
    dall['dvalues'] = dall['value_max']-dall['value']
    dall['dvalues2'] = dall['value']/dall['value_max']
    # within 72 hrs, which means onset_day equals 1, 2, or 3
    filter_lvl = (dall['ddays']>pd.Timedelta(value=0, unit='s')) & time_window & (dall['dvalues']>=dlevel)
    filter_rat = (dall['ddays']>pd.Timedelta(value=0, unit='s')) & time_window & (dall['dvalues2']<=ratio)
    dall.loc[filter_lvl,'AKI_resolving']=1
    dall.loc[filter_rat,'AKI_resolving']=1
    dall=dall.drop(['ddays', 'dvalues', 'dvalues2'],axis=1)
    dall['AKI_resolving2']=dall['AKI_resolving'].copy()

    dmaxRelv = dall[['subject_id', 'hadm_id', 'AKI_resolving']].groupby(['subject_id', 'hadm_id']).max().reset_index()
    dmaxRelv.columns = ['subject_id', 'hadm_id', 'max_AKI_resolving']
    dall = dall.merge(dmaxRelv, left_on=['subject_id', 'hadm_id'], right_on=['subject_id', 'hadm_id'], how='left')
    dall['AKI_resolving'] = dall['max_AKI_resolving']
    return dall.drop('max_AKI_resolving',axis = 1)

def akiresolvingsustain(dall, time):
    
    dresolv = dall[dall['AKI_resolving2']==1]
    dresolv = dresolv[['subject_id', 'hadm_id', 'charttime']].sort_values('charttime').groupby(['subject_id', 'hadm_id']).first().reset_index()
    dresolv.columns = ['subject_id', 'hadm_id', 'charttime_resolv']

    dall = dall.merge(dresolv, left_on=['subject_id', 'hadm_id'], right_on=['subject_id', 'hadm_id'], how='left')
    time_window = (dall['onset_day']<=3) & (dall['onset_day']>=time) 
    dall['after_ddays'] = ((dall['charttime']-dall['charttime_resolv']) > pd.Timedelta(value=0, unit='s')) & time_window 

    dresolv2 = dall[dall['after_ddays']][['subject_id', 'hadm_id', 'AKI_resolving2']].groupby(['subject_id', 'hadm_id']).min().reset_index()
    dresolv2.columns = ['subject_id', 'hadm_id', 'sustain']

    dall = dall.merge(dresolv2, left_on=['subject_id', 'hadm_id'], right_on=['subject_id', 'hadm_id'], how='left')
    dall['sustain'].fillna(True, inplace = True)
    return dall

def get_aki_resolving(scr, onset, time = 0, dlevel = 0.3, ratio = 0.75):
    scr = scr[scr['DAYS_SINCE_ONSET']>=0]
    scr = scr[['PATID', 'ENCOUNTERID', 'SPECIMEN_DATE', 'RESULT_NUM', 'DAYS_SINCE_ONSET']]
    scr.columns = ['subject_id', 'hadm_id', 'charttime', 'value', 'onset_day']

    scr = akiresolving(dall = scr, time = time, dlevel = dlevel, ratio = ratio)
    scr = akiresolvingsustain(dall = scr, time = time)
    scr['AKI_sustain'] = scr['AKI_resolving']*scr['sustain']

    scr['PATID'] = scr['subject_id']
    scr['ENCOUNTERID'] = scr['hadm_id']

    scr = scr[['PATID', 'ENCOUNTERID', 'AKI_resolving', 'AKI_sustain']].groupby(['PATID', 'ENCOUNTERID']).sum().reset_index()
    scr['AKI_TRIGG'] = scr['AKI_resolving']>0
    scr['AKI_RESOL'] = scr['AKI_sustain']>0
    onset = onset.merge(scr[['PATID', 'ENCOUNTERID',  'AKI_TRIGG', 'AKI_RESOL']],on=['PATID', 'ENCOUNTERID'], how='left')
    return onset.drop_duplicates()

In [13]:
def process_onset_and_recovery(file_paths, aggfunc_7d, aggfunc_1y, keep_ckd):
    file_paths = get_data_file_path(base_path, site_name)
    df_scr, df_admit = load_onset_data(file_paths)
    df_baseline, cohort_table = get_scr_baseline_new(df_scr, df_admit, file_paths, aggfunc_7d, aggfunc_1y, keep_ckd)
    df_rrt = get_rrt(df_admit, file_paths)

    onset = get_aki_onset(df_scr, df_admit, df_rrt, df_baseline)
    onset.to_pickle(file_paths[1] + 'onset.pkl')

    df_scr = load_and_filter_scr(onset, file_paths)

    onset_resol = get_aki_resolving(scr = df_scr, onset = onset)
    onset_revert = get_aki_reverting(df_scr, onset)

    onset_combined = onset_revert.merge(onset_resol[['PATID', 'ENCOUNTERID', 'AKI_TRIGG', 'AKI_RESOL']], 
                                     on = ['PATID', 'ENCOUNTERID'], 
                                     how = 'left').drop_duplicates()

    onset_combined.to_pickle(file_paths[1]+'outcome.pkl')
    print(f"Finish generating AKI onset and recovery for {site_name}.", flush = True)
    return onset, cohort_table

#### Step 3: Extract and Process Features

In [15]:
def load_and_create_timeline(file_paths, filename, date_var, event_type, lapsed_day):
    input_path = file_paths[0]
    output_path = file_paths[1]
    df = pd.read_pickle(input_path + filename +'.pkl')
    onset = pd.read_pickle(output_path + 'onset.pkl')
    
    if event_type.upper() == 'PROGNOSIS':
        onset = onset[onset['AKI_INIT_STG'] > 0]
    
    df_merged = onset.merge(df, on = ['PATID', 'ENCOUNTERID'], how = 'inner')

    if  ('DAYS_SINCE_ADMIT' not in df_merged.columns) or (df_merged[date_var].isna().mean() <= df_merged.get('DAYS_SINCE_ADMIT', pd.Series()).isna().mean()):
        df_merged['DAYS_SINCE_ADMIT'] = (df_merged[date_var]-df_merged['ADMIT_DATE']).dt.days
        
        
    if (event_type.upper() == 'ONSET') | (event_type.upper() == 'PROGNOSIS'):
        df_merged['MIN_ONSET_DISCHARGE_DATE'] = np.where(df_merged['ONSET_DATE'].notna(), 
                                                         df_merged['ONSET_DATE'], 
                                                         df_merged['DISCHARGE_DATE'])
        
        if df_merged[date_var].isna().mean() <= df_merged['DAYS_SINCE_ADMIT'].isna().mean():
            df_merged['DAYS_SINCE_EVENT'] = (df_merged[date_var]-df_merged['MIN_ONSET_DISCHARGE_DATE']).dt.days
        else: 
            print(f"No date variable for {filename} from site {file_paths[0].split('/')[-2]}.", flush = True)
            df_merged['DAYS_SINCE_EVENT'] = df_merged['DAYS_SINCE_ADMIT'] - (df_merged['MIN_ONSET_DISCHARGE_DATE'] - df_merged['ADMIT_DATE']).dt.days
            

    if filename == 'AKI_DX':
        df_merged = df_merged[(df_merged['DAYS_SINCE_ADMIT']>=(-365.25)/2) & (df_merged['DAYS_SINCE_ADMIT'] < 0)]
    else:
        df_merged = df_merged[(df_merged['DAYS_SINCE_ADMIT']>=0) & (df_merged['DAYS_SINCE_EVENT'] < lapsed_day)]
        
    return df_merged, onset
    
def replace_outliers(df, var_id, col_to_filter, q_l = 0.01, q_u = 0.99):
    def replace_group_outliers(group):
        lower = group[col_to_filter].quantile(q_l)
        upper = group[col_to_filter].quantile(q_u)
        group[col_to_filter] = group[col_to_filter].apply(lambda x: lower if x < lower else (upper if x > upper else x))
        return group

    if var_id is None:
        df = replace_group_outliers(df)
    else:
        df = df.groupby(var_id).apply(replace_group_outliers).reset_index(drop=True)

    return df

def filter_sparse_features(df, df_onset, var_id, trim=0.05):
    table_count = df[['PATID', 'ENCOUNTERID', var_id]].drop_duplicates()
    df_count = table_count.groupby(var_id).count()
    df_count['percentage'] = df_count['ENCOUNTERID'] / df_onset.shape[0]
    df_count = df_count.loc[(df_count['percentage'] >= trim)]
    df_count = df_count.reset_index()
    df_count = df_count.drop(['PATID', 'ENCOUNTERID', 'percentage'], axis=1)
    df_filtered = df_count.merge(df, on=[var_id], how='left')
    return df_filtered

In [16]:
def process_demo(file_paths):
    demo = pd.read_pickle(file_paths[0]+'AKI_DEMO.pkl')
    demo['MALE'] = demo['SEX'] == 'M'
    demo['HISPANIC'] = demo['HISPANIC'] == 'Y'
    demo['RACE_WHITE'] = demo['RACE'] == '05'
    demo['RACE_BLACK'] = demo['RACE'] == '03'
    demo = demo[['PATID', 'ENCOUNTERID', 'AGE', 'MALE', 'RACE_WHITE', 'RACE_BLACK', 'HISPANIC']]
    demo = demo.drop_duplicates()
    demo.to_pickle(file_paths[1]+'demo.pkl')
    print(f"Finished processing demo data from {file_paths[0].split('/')[-2]}.", flush = True)
    return demo

In [17]:
def process_vital(file_paths, event_type, lapsed_day):
    # Load data and calculate timelines
    vital, onset = load_and_create_timeline(file_paths, 'AKI_VITAL', 'MEASURE_DATE', event_type, lapsed_day)
                        
    vital = vital[(vital['ORIGINAL_BMI'].notna()) | (vital['HT'].notna()) | (vital['WT'].notna())
                  | (vital['SYSTOLIC'].notna()) | (vital['DIASTOLIC'].notna())
                  | (vital['SMOKING'].notna())]

    vital_r = vital[['PATID', 'ENCOUNTERID', 'MEASURE_DATE', 
                     'ORIGINAL_BMI', 'HT', 'WT', 'SMOKING', 'SYSTOLIC', 'DIASTOLIC',
                     'DAYS_SINCE_ADMIT', 'DAYS_SINCE_EVENT']]

    # Replace the outliers
    vital_r = replace_outliers(vital_r, None, 'ORIGINAL_BMI')
    vital_r = replace_outliers(vital_r, None, 'HT')
    vital_r = replace_outliers(vital_r, None, 'WT')
    vital_r = replace_outliers(vital_r, None, 'SYSTOLIC')
    vital_r = replace_outliers(vital_r, None, 'DIASTOLIC')

    vital_mean = vital_r[['PATID', 'ENCOUNTERID', 'DAYS_SINCE_EVENT',
                          'ORIGINAL_BMI', 'HT', 'WT', 
                          'SYSTOLIC', 'DIASTOLIC']].groupby(['PATID', 'ENCOUNTERID', 'DAYS_SINCE_EVENT']).mean().reset_index()
    
    vital_bmi = vital_mean[['PATID', 'ENCOUNTERID', 'DAYS_SINCE_EVENT', 'ORIGINAL_BMI']].dropna().sort_values(['PATID', 'ENCOUNTERID', 'DAYS_SINCE_EVENT'])
    vital_wt = vital_mean[['PATID', 'ENCOUNTERID', 'DAYS_SINCE_EVENT', 'WT']].dropna().sort_values(['PATID', 'ENCOUNTERID', 'DAYS_SINCE_EVENT'])
    vital_ht = vital_mean[['PATID', 'ENCOUNTERID', 'DAYS_SINCE_EVENT', 'HT']].dropna().sort_values(['PATID', 'ENCOUNTERID', 'DAYS_SINCE_EVENT'])
    vital_smoking = vital_r[['PATID', 'ENCOUNTERID', 'DAYS_SINCE_EVENT', 'SMOKING']].dropna().sort_values(['PATID', 'ENCOUNTERID', 'DAYS_SINCE_EVENT'])
    
    vital_sys = vital_mean[['PATID', 'ENCOUNTERID', 'DAYS_SINCE_EVENT', 'SYSTOLIC']].dropna().sort_values(['PATID', 'ENCOUNTERID', 'DAYS_SINCE_EVENT'])
    vital_dias = vital_mean[['PATID', 'ENCOUNTERID', 'DAYS_SINCE_EVENT', 'DIASTOLIC']].dropna().sort_values(['PATID', 'ENCOUNTERID', 'DAYS_SINCE_EVENT'])
    
    vital_bmi_p = vital_bmi.groupby(['PATID', 'ENCOUNTERID']).agg({'ORIGINAL_BMI':'last'}).reset_index()
    vital_wt_p  =  vital_wt.groupby(['PATID', 'ENCOUNTERID']).agg({'WT':'last'}).reset_index()
    vital_ht_p  =  vital_ht.groupby(['PATID', 'ENCOUNTERID']).agg({'HT':'last'}).reset_index()
    vital_sys_p = vital_sys.groupby(['PATID', 'ENCOUNTERID'])['SYSTOLIC'].last().reset_index(name = 'SYSTOLIC_MEAN')
    vital_dias_p = vital_dias.groupby(['PATID', 'ENCOUNTERID'])['DIASTOLIC'].last().reset_index(name = 'DIASTOLIC_MEAN')
    vital_smoking_p  =  vital_smoking.groupby(['PATID', 'ENCOUNTERID']).agg({'SMOKING':'last'}).reset_index()

    vital_wt_ht = vital_wt_p.merge(vital_ht_p, on = ['PATID', 'ENCOUNTERID'], how = 'outer')
    vital_bmi_wt_ht = vital_wt_ht.merge(vital_bmi_p, on = ['PATID', 'ENCOUNTERID'], how = 'outer')

    calculated_bmi = (vital_bmi_wt_ht['WT'] / (vital_bmi_wt_ht['HT'] ** 2)) * 703

    vital_bmi_wt_ht['BMI'] = vital_bmi_wt_ht['ORIGINAL_BMI'].fillna(calculated_bmi)

    most_recent_sys = vital_sys.groupby(['PATID', 'ENCOUNTERID']).tail(1).reset_index(drop=True)
    most_recent_dias = vital_dias.groupby(['PATID', 'ENCOUNTERID']).tail(1).reset_index(drop=True)

    second_most_recent_sys = vital_sys.groupby(['PATID', 'ENCOUNTERID']).nth(-2).reset_index(drop=True)
    second_most_recent_dias = vital_dias.groupby(['PATID', 'ENCOUNTERID']).nth(-2).reset_index(drop=True)

    sys_slope = most_recent_sys.merge(second_most_recent_sys, on=['PATID', 'ENCOUNTERID'], suffixes=('_recent', '_prev'))
    sys_slope['SYSTOLIC_FD'] = (sys_slope['SYSTOLIC_recent'] - sys_slope['SYSTOLIC_prev']) / (sys_slope['DAYS_SINCE_EVENT_recent'] - sys_slope['DAYS_SINCE_EVENT_prev'])
    
    dias_slope = most_recent_dias.merge(second_most_recent_dias, on=['PATID', 'ENCOUNTERID'], suffixes=('_recent', '_prev'))
    dias_slope['DIASTOLIC_FD'] = (dias_slope['DIASTOLIC_recent'] - dias_slope['DIASTOLIC_prev']) / (dias_slope['DAYS_SINCE_EVENT_recent'] - dias_slope['DAYS_SINCE_EVENT_prev'])

    vital_bp_fd = pd.merge(sys_slope[['PATID', 'ENCOUNTERID', 'SYSTOLIC_FD']], dias_slope[['PATID', 'ENCOUNTERID', 'DIASTOLIC_FD']], 
                           on=['PATID', 'ENCOUNTERID'],
                           how = 'outer')
    vital_t = pd.merge(vital_bmi_wt_ht.drop('ORIGINAL_BMI', axis =1), vital_smoking_p, on=['PATID', 'ENCOUNTERID'], how='outer')
    vital_t = pd.merge(vital_t, vital_sys_p, on=['PATID', 'ENCOUNTERID'], how='outer')
    vital_t = pd.merge(vital_t, vital_dias_p, on=['PATID', 'ENCOUNTERID'], how='outer')
    vital_t = pd.merge(vital_t, vital_bp_fd, on=['PATID', 'ENCOUNTERID'], how='outer')
    vital_t['SMOKING_01'] = vital_t['SMOKING'] == '01'
    vital_t['SMOKING_02'] = vital_t['SMOKING'] == '02'
    vital_t['SMOKING_03'] = vital_t['SMOKING'] == '03'
    vital_t['SMOKING_05'] = vital_t['SMOKING'] == '05'
    vital_t['SMOKING_06'] = vital_t['SMOKING'] == '06'
    vital_t['SMOKING_07'] = vital_t['SMOKING'] == '07'
    vital_t.drop('SMOKING', axis = 1, inplace = True)
    vital_t.to_pickle(file_paths[1] +'vital'+ '_d' + str(lapsed_day) + '.pkl')
    print(f"Finished processing vital data from site {file_paths[0].split('/')[-2]} for day {lapsed_day}.", flush = True)
    return vital_t

In [18]:
def get_ckd_from_diagnosis(dx):
    dx0 = dx.copy()
    dx0 = dx0[dx0['DAYS_SINCE_ADMIT'] < 0]
    CKD_code = ['585.1','585.2','585.3','585.4','585.5','585.9',
                'N18.1','N18.2','N18.3','N18.4','N18.5','N18.6','N18.9']
    dx0['PREADM_CKD_FLAG'] = dx0['DX'].isin(CKD_code)
    dx0['PREADM_CKD_STAGE'] = 0
    dx0.loc[dx0['PREADM_CKD_FLAG'], 'PREADM_CKD_STAGE'] = 1
    dx0.loc[dx0['DX'].isin(['585.1', 'N18.1']), 'PREADM_CKD_STAGE'] = 1
    dx0.loc[dx0['DX'].isin(['585.2', 'N18.2']), 'PREADM_CKD_STAGE'] = 2
    dx0.loc[dx0['DX'].isin(['585.3', 'N18.3']), 'PREADM_CKD_STAGE'] = 3
    dx0.loc[dx0['DX'].isin(['585.4', 'N18.4']), 'PREADM_CKD_STAGE'] = 4
    dx0.loc[dx0['DX'].isin(['585.5', 'N18.5']), 'PREADM_CKD_STAGE'] = 5
    dx0.loc[dx0['DX'].isin(['585.6', 'N18.6']), 'PREADM_CKD_STAGE'] = 6
    dx1 = dx0.groupby(['PATID','ENCOUNTERID'])['PREADM_CKD_FLAG'].max().reset_index()
    dx2 = dx0.groupby(['PATID','ENCOUNTERID'])['PREADM_CKD_STAGE'].max().reset_index()
    ckd_all = dx1.merge(dx2, on =['PATID','ENCOUNTERID'], how = 'left')
    return ckd_all.drop_duplicates()

In [19]:
def process_dx(file_paths, trim = 0.05, event_type = 'PROGNOSIS'):
    output_path = file_paths[1]
    aux_path = file_paths[2]
    dx, onset = load_and_create_timeline(file_paths, 'AKI_DX', 'DX_DATE', event_type, 999)
    
    dx['DX_TYPE'] = dx['DX_TYPE'].astype(str)
    dx['DX_TYPE'] = dx['DX_TYPE'].where(dx['DX_TYPE'] != '9', '09')
    
    # Step 1: Get pre-existing CKD flag and stage
    dx_ckd = get_ckd_from_diagnosis(dx) 

    # Step 2: Get all dx features
    icd10toicd09 = pd.read_csv(aux_path + '2018_I10gem.csv', sep=',')
    icd10toicd09.columns = ['DX', 'DX09']
    dx10 = dx[dx['DX_TYPE'] == '10']
    dx10['DX'] = dx10['DX'].str.replace('.', '')
    dx10 = pd.merge(dx10, icd10toicd09, on='DX', how='left')
    dx10['DX_TYPE'] = np.where(dx10['DX09'].notna(), '09', dx10['DX_TYPE'])
    dx10['DX'] = np.where(dx10['DX09'].notna(), dx10['DX09'], dx10['DX'])
    dx10 = dx10.drop('DX09', axis=1)
    dx = pd.concat([dx[dx['DX_TYPE'] != '10'], dx10], axis=0)
    dx['DX'] = dx['DX'].where(dx['DX_TYPE'] != '09', dx['DX'].map(lambda x: x[0:3]))
    dx['DX'] = dx['DX'].where(dx['DX_TYPE'] != '10', dx['DX'].map(lambda x: x[0:3]))
    dx['DX'] = dx['DX_TYPE'] + '_' + dx['DX']
    dx_r = filter_sparse_features(dx, onset, 'DX', trim)
    dx_dummies = pd.get_dummies(dx_r[['PATID','ENCOUNTERID', 'DX']], columns=['DX'], prefix='DX')
    dx_pivot = dx_dummies.groupby(['PATID', 'ENCOUNTERID']).max().reset_index()
    dx_final = dx_pivot.merge(dx_ckd, on = ['PATID', 'ENCOUNTERID'], how = 'outer')
    dx_final.to_pickle(output_path + 'dx.pkl')
    print(f"Finished processing dx data from {file_paths[0].split('/')[-2]}.", flush = True)
    return dx_final

In [20]:
def process_px(file_paths,  trim = 0.05, event_type = 'PROGNOSIS', lapsed_day = 1):
    output_path = file_paths[1]
    px, onset = load_and_create_timeline(file_paths, 'AKI_PX', 'PX_DATE', event_type, lapsed_day)
    px['PX_TYPE'] = px['PX_TYPE'].where(px['PX_TYPE'] != '9', '09')
    px['PX_NUMERIC'] = np.where(px['PX'].str.isnumeric(), px['PX'], 0)
    px['PX_NUMERIC'] = px['PX_NUMERIC'].astype(int)
    mask_admin = (px['PX_TYPE'] == 'CH') & np.logical_or(np.logical_and(px['PX_NUMERIC'] >= 99202, px['PX_NUMERIC'] <= 99499),
                                                         np.logical_and(px['PX_NUMERIC'] >= 80047, px['PX_NUMERIC'] <= 89398))
    
    px = px[~mask_admin].drop('PX_NUMERIC', axis = 1)
    px['PX'] = px['PX_TYPE'] + '_' + px['PX']
    px_r = filter_sparse_features(px, onset, 'PX', trim)
    px_dummies = pd.get_dummies(px_r[['PATID','ENCOUNTERID', 'PX']], columns=['PX'], prefix='PX') 
    px_final = px_dummies.groupby(['PATID', 'ENCOUNTERID']).max().reset_index()
    px_final.to_pickle(output_path + 'px'+'_d'+str(lapsed_day)+'.pkl')
    
    print(f"Finished processing px data from {file_paths[0].split('/')[-2]} for day {lapsed_day}.", flush = True)
    return px_final

In [21]:
def convert_rxnorm2atc_pmed(med, file_paths):
    aux_path = file_paths[2]+'map/'
    rxcui2atc = pd.read_parquet(aux_path + 'med_unified_conversion_rx2atc.parquet') 
    rxcui2atc = rxcui2atc.rename(columns={'RX': 'RXNORM_CUI', 
                                          'ATC': 'ATC4th'})
    rxcui2atc = rxcui2atc[['RXNORM_CUI', 'ATC4th']].dropna().drop_duplicates()
    med['RXNORM_CUI'] = med['RXNORM_CUI'].astype('str')
    rxcui2atc['RXNORM_CUI'] = rxcui2atc['RXNORM_CUI'].astype('str')
    rxcui2atc['ATC4th'] = rxcui2atc['ATC4th'].astype('str')
    med = med.merge(rxcui2atc[['RXNORM_CUI', 'ATC4th']], on='RXNORM_CUI', how='left')
    med['RX_TYPE'] = 'RXN'
    med['RX_CODE'] = med['RXNORM_CUI'].copy()
    med['RX_TYPE'] = med['RX_TYPE'].where(med['ATC4th'].isna(), 'ATC')
    med['RX_CODE'] = med['RX_CODE'].where(med['ATC4th'].isna(), med['ATC4th'])
    med['RX_CODE'] = med['RX_TYPE'] + '_' + med['RX_CODE']
    med = med[['PATID', 'ENCOUNTERID', 'RX_TYPE', 'RX_CODE']]
    return med

def process_pmed(file_paths, trim = 0.05, event_type = 'PROGNOSIS', lapsed_day = 1):
    output_path = file_paths[1]
    med, onset = load_and_create_timeline(file_paths, 'AKI_PMED', 'RX_START_DATE', event_type, lapsed_day)
    med_atc = convert_rxnorm2atc_pmed(med, file_paths)
    med_r = filter_sparse_features(med_atc, onset, 'RX_CODE', trim)
    med_dummies = pd.get_dummies(med_r[['PATID', 'ENCOUNTERID', 'RX_CODE']], 
                                 columns=['RX_CODE'], 
                                 prefix='RX')
    med_final = med_dummies.groupby(['PATID', 'ENCOUNTERID']).max().reset_index()
    med_final.to_pickle(output_path + 'pmed'+'_d'+str(lapsed_day)+'.pkl')
    print(f"Finished processing pmed data from {file_paths[0].split('/')[-2]} for day {lapsed_day}.", flush = True)
    return med_final

In [22]:
def convert_rxnorm2atc(med, file_paths):
    aux_path = file_paths[2]+'map/'
    rxcui2atc = pd.read_parquet(aux_path+'med_unified_conversion_rx2atc.parquet')  
    rxcui2atc = rxcui2atc.rename(columns={'RX': 'MEDADMIN_CODE', 'ATC': 'ATC4th'})
    rxcui2atc = rxcui2atc[['MEDADMIN_CODE', 'ATC4th']].dropna().drop_duplicates()
    med['MEDADMIN_CODE'] = med['MEDADMIN_CODE'].astype('str')
    rxcui2atc['MEDADMIN_CODE'] = rxcui2atc['MEDADMIN_CODE'].astype('str')
    rxcui2atc['ATC4th'] = rxcui2atc['ATC4th'].astype('str')
    med = med.merge(rxcui2atc[['MEDADMIN_CODE', 'ATC4th']], on='MEDADMIN_CODE', how='left')
    med['MEDADMIN_TYPE'] = 'RXN'
    med['MEDADMIN_TYPE'] = med['MEDADMIN_TYPE'].where(med['ATC4th'].isna(), 'ATC')
    med['MEDADMIN_CODE'] = med['MEDADMIN_CODE'].where(med['ATC4th'].isna(), med['ATC4th'])
    med['MEDADMIN_CODE'] = med['MEDADMIN_TYPE'] + '_' + med['MEDADMIN_CODE']
    med = med[['PATID', 'ENCOUNTERID', 'MEDADMIN_TYPE', 'MEDADMIN_CODE']]
    return med

def convert_ndc2atc(med, file_paths):
    aux_path = file_paths[2]+'map/'
    ndc2rx = pd.read_parquet(aux_path+'med_unified_conversion_nd2rx.parquet')  
    rx2atc = pd.read_parquet(aux_path+'med_unified_conversion_rx2atc.parquet')
    ndc2atc = ndc2rx.merge(rx2atc, on = 'RX', how = 'inner')
    ndc2atc = ndc2atc.rename(columns={'ND': 'MEDADMIN_CODE', 'ATC': 'ATC4th'})
    ndc2atc = ndc2atc[['MEDADMIN_CODE', 'ATC4th']].dropna().drop_duplicates()
    med['MEDADMIN_CODE'] = med['MEDADMIN_CODE'].astype('str')
    ndc2atc['MEDADMIN_CODE'] = ndc2atc['MEDADMIN_CODE'].astype('str')
    ndc2atc['ATC4th'] = ndc2atc['ATC4th'].astype('str')
    med = med.merge(ndc2atc[['MEDADMIN_CODE', 'ATC4th']], on='MEDADMIN_CODE', how='left')
    med['MEDADMIN_TYPE'] = 'ND'
    med['MEDADMIN_TYPE'] = med['MEDADMIN_TYPE'].where(med['ATC4th'].isna(), 'ATC')
    med['MEDADMIN_CODE'] = med['MEDADMIN_CODE'].where(med['ATC4th'].isna(), med['ATC4th'])
    med['MEDADMIN_CODE'] = med['MEDADMIN_TYPE'] + '_' + med['MEDADMIN_CODE']
    med = med[['PATID', 'ENCOUNTERID', 'MEDADMIN_TYPE', 'MEDADMIN_CODE']]
    return med

def process_amed(file_paths, trim = 0.05, event_type = 'PROGNOSIS', lapsed_day = 1):
    output_path = file_paths[1]
    med, onset = load_and_create_timeline(file_paths, 'AKI_AMED', 'MEDADMIN_START_DATE', event_type, lapsed_day)

    med_rx = med[med['MEDADMIN_TYPE'] == 'RX'][['PATID', 'ENCOUNTERID', 'MEDADMIN_CODE']]
    med_nd = med[med['MEDADMIN_TYPE'] == 'ND'][['PATID', 'ENCOUNTERID', 'MEDADMIN_CODE']]

    if not med_rx.empty:
        med_rx = convert_rxnorm2atc(med_rx, file_paths)

    if not med_nd.empty:
        med_nd = convert_ndc2atc(med_nd, file_paths)

    med_atc = pd.concat([med_rx, med_nd], axis=0, ignore_index=True)
    med_r = filter_sparse_features(med_atc, onset, 'MEDADMIN_CODE', trim)

    med_dummies = pd.get_dummies(med_r[['PATID', 'ENCOUNTERID', 'MEDADMIN_CODE']], 
                                 columns=['MEDADMIN_CODE'], 
                                 prefix='RX')
    med_final = med_dummies.groupby(['PATID', 'ENCOUNTERID']).max().reset_index()
    med_final.to_pickle(output_path + 'amed'+'_d'+str(lapsed_day)+'.pkl')
    print(f"Finished processing amed data from {file_paths[0].split('/')[-2]} for day {lapsed_day}.", flush = True)
    return med_final

In [23]:
def process_lab_scr(file_paths, event_type, lapsed_day):
    scr, onset = load_and_create_timeline(file_paths, 'AKI_LAB_SCR', 'SPECIMEN_DATE', event_type, lapsed_day)
    scr = replace_outliers(scr, None, 'RESULT_NUM')

    scr_mean = scr[['PATID', 'ENCOUNTERID', 'DAYS_SINCE_EVENT', 'RESULT_NUM']].groupby(['PATID', 'ENCOUNTERID', 'DAYS_SINCE_EVENT'])['RESULT_NUM'].mean().reset_index(name='SCR_MEAN')
    scr_mean_sorted = scr_mean.dropna().sort_values(['PATID', 'ENCOUNTERID', 'DAYS_SINCE_EVENT'])
    scr_mean_mrv = scr_mean_sorted.groupby(['PATID', 'ENCOUNTERID']).agg({'SCR_MEAN': 'last'}).reset_index()

    scr_mean_sorted['ROWID'] = range(1, len(scr_mean_sorted) + 1)
    most_recent_scr = scr_mean_sorted.groupby(['PATID', 'ENCOUNTERID']).agg({'SCR_MEAN': 'last', 'DAYS_SINCE_EVENT': 'last', 'ROWID': 'last'}).reset_index()
    mrv_scr = most_recent_scr['ROWID'].unique()
    most_recent_scr = most_recent_scr[['PATID', 'ENCOUNTERID', 'SCR_MEAN', 'DAYS_SINCE_EVENT']].rename(columns={'SCR_MEAN': 'MOST_RECENT_VALUE', 'DAYS_SINCE_EVENT': 'MOST_RECENT_DAY'})

    second_most_recent_scr = scr_mean_sorted[~scr_mean_sorted['ROWID'].isin(mrv_scr)].groupby(['PATID', 'ENCOUNTERID']).agg({'SCR_MEAN': 'last', 'DAYS_SINCE_EVENT': 'last'}).reset_index()
    second_most_recent_scr = second_most_recent_scr.rename(columns={'SCR_MEAN': 'SECOND_MOST_RECENT_VALUE', 'DAYS_SINCE_EVENT': 'SECOND_MOST_RECENT_DAY'})

    scr_diff = pd.merge(most_recent_scr, second_most_recent_scr, on=['PATID', 'ENCOUNTERID'], how='inner')
    scr_diff['DAYS_DIFF'] = scr_diff['MOST_RECENT_DAY'] - scr_diff['SECOND_MOST_RECENT_DAY']
    scr_diff['SCR_CHANGE'] = scr_diff['MOST_RECENT_VALUE'] - scr_diff['SECOND_MOST_RECENT_VALUE']
    scr_diff['SCR_FD'] = scr_diff['SCR_CHANGE'] / scr_diff['DAYS_DIFF']

    scr_final = scr_mean_mrv[['PATID', 'ENCOUNTERID', 'SCR_MEAN']].merge(scr_diff[['PATID', 'ENCOUNTERID', 'SCR_FD']], on=['PATID', 'ENCOUNTERID'], how='outer')

    scr_final.to_pickle(file_paths[1] + 'scr' + '_d' + str(lapsed_day) + '.pkl')
    print(f"Finished processing lab scr data from {file_paths[0].split('/')[-2]} for day {lapsed_day}.", flush=True)
    return scr_final

In [24]:
def unify_lab_units(file_paths, event_type = 'PROGNOSIS'):
    aux_path = file_paths[2]+'map/'
    output_path = file_paths[1]
    site_name = file_paths[0].split('/')[-2]
    labtest, onset = load_and_create_timeline(file_paths, 'AKI_LAB', 'SPECIMEN_DATE', event_type, 8)
    labtest['site']= site_name
    
    UCUMunitX = pd.read_csv(aux_path + 'UCUMunitX.csv')
    UCUMunitX.factor_final = UCUMunitX.factor_final.fillna(1)
    UCUMqualX = pd.read_csv(aux_path + 'UCUMqualX.csv')
    local_custom_convert = pd.read_csv(aux_path + 'local_custom_convert.csv')
    loincmap3 = pd.read_csv(aux_path + 'GroupLoincTerms.csv')

    labtest2 = labtest.merge(local_custom_convert, on = ['LAB_LOINC', 'site'], how='left')
    labtest2['NEW_UNIT'] = np.where(labtest2['TARGET_UNIT'].notna(), labtest2['TARGET_UNIT'], labtest2['RESULT_UNIT'])
    labtest2['NEW_RESULT_NUM'] = np.where(labtest2['TARGET_UNIT'].notna(), labtest2['Multipliyer']*labtest2['RESULT_NUM'], labtest2['RESULT_NUM'])

    labtest3 = labtest2.copy()
    labtest3['RESULT_UNIT'] = labtest3['NEW_UNIT']
    labtest3['RESULT_NUM'] = labtest3['NEW_RESULT_NUM']
    labtest3 = labtest3.drop(['NEW_UNIT', 'NEW_RESULT_NUM', 'SOURCE_UNIT', 'TARGET_UNIT', 'LONG_COMMON_NAME', 'Multipliyer'], axis=1)
    labtest4 = labtest3.merge(UCUMunitX, on = ['LAB_LOINC', 'RESULT_UNIT'], how='left').copy()

    filter1 = labtest4['FINAL_UNIT'].notna() 

    labtest4['NEW_UNIT'] = np.where(filter1, labtest4['FINAL_UNIT'], labtest4['RESULT_UNIT'])
    labtest4['NEW_RESULT_NUM'] = np.where(filter1, labtest4['factor_final']*labtest4['RESULT_NUM'], labtest4['RESULT_NUM'])
    labtest4['NEW_LAB_LOINC'] = np.where(filter1, labtest4['GroupId'], labtest4['LAB_LOINC'])
    labtest4['RESULT_UNIT'] = labtest4['NEW_UNIT']
    labtest4['RESULT_NUM'] = labtest4['NEW_RESULT_NUM']
    labtest4['LAB_LOINC'] = labtest4['NEW_LAB_LOINC']
    labtest4 = labtest4.drop(['GroupId', 'EXAMPLE_UCUM_UNITS',
                              'EXAMPLE_UCUM_UNITS_FINAL', 'RESULT_UNIT_CONSENSUS', 'FINAL_UNIT',
                              'FINAL_Multiplyer', 'RESULT_UNIT_API', 'FINAL_UNIT_API', 'factor_final',
                              'NEW_UNIT', 'NEW_RESULT_NUM', 'NEW_LAB_LOINC'], axis=1)
    
    labtest5 = labtest4.copy()
    labtest5 = labtest5.merge(UCUMqualX[['LAB_LOINC', 'GroupId']].drop_duplicates(), on='LAB_LOINC', how='left')
    filter2 = labtest5['GroupId'].notna() #& labtest5['RESULT_NUM'].isna()
    labtest5['NEW_LAB_LOINC'] = np.where(filter2, labtest5['GroupId'], labtest5['LAB_LOINC'])
    labtest5['LAB_LOINC'] = labtest5['NEW_LAB_LOINC']
    labtest5 = labtest5.drop(['GroupId','NEW_LAB_LOINC'],axis=1)
    labtest5 = labtest5.drop('site',axis=1)
    labtest5 = labtest5.drop_duplicates()

    labtest5.to_pickle(output_path + 'lab_unified.pkl')
    print(f"Finished unifying lab units from {file_paths[0].split('/')[-2]}.", flush = True)
    return labtest5

def manual_update_lab_units(file_paths):
    aux_path = file_paths[2]+'map/'
    output_path = file_paths[1]
    lab_test = pd.read_pickle(output_path + 'lab_unified.pkl')
    group_convert_tbl = pd.read_csv(aux_path + 'group_conversion_custom.csv')
    group_convert_tbl.replace('np.nan', np.nan, inplace=True)
    lab_test2 = lab_test.merge(group_convert_tbl[group_convert_tbl['site'] == site_name], 
                               left_on = ['LAB_LOINC', 'RESULT_UNIT'], 
                               right_on = ['LAB_LOINC', 'SOURCE_UNIT'], 
                               how = 'left')
    lab_test2.loc[lab_test2['GroupId'].notna(), 'LAB_LOINC'] = lab_test2.loc[lab_test2['GroupId'].notna(), 'GroupId']
    lab_test2.loc[lab_test2['GroupId'].notna(), 'RESULT_UNIT'] = lab_test2.loc[lab_test2['GroupId'].notna(), 'TARGET_UNIT']
    lab_test2.loc[lab_test2['GroupId'].notna(), 'RESULT_NUM'] = (lab_test2.loc[lab_test2['GroupId'].notna(), 'RESULT_NUM'] * 
                                                                 lab_test2.loc[lab_test2['GroupId'].notna(), 'factor'])
    lab_test2.to_pickle(output_path + 'lab_unified_updated.pkl')
    return lab_test2

In [25]:
def process_numeric_lab(file_paths, trim = 0.05, lapsed_day = 1):
    output_path = file_paths[1]
    
    onset = pd.read_pickle(output_path + 'onset.pkl')
    lab = pd.read_pickle(output_path +'lab_unified_updated.pkl')
    
    lab = lab[lab['DAYS_SINCE_EVENT'] < lapsed_day]
    lab = lab[(lab['LAB_LOINC'].notna()) & (lab['RESULT_NUM'].notna())]
    lab = replace_outliers(lab, 'LAB_LOINC', 'RESULT_NUM')
    lab = filter_sparse_features(lab, onset, 'LAB_LOINC', trim)

    lab_r = lab[['LAB_LOINC','PATID', 'ENCOUNTERID', 'RESULT_NUM', 'DAYS_SINCE_EVENT']].groupby(['LAB_LOINC', 'PATID', 'ENCOUNTERID', 'DAYS_SINCE_EVENT']).mean().reset_index()
    lab_f = lab_r.sort_values(['PATID', 'ENCOUNTERID', 'LAB_LOINC', 'DAYS_SINCE_EVENT']).groupby(['PATID', 'ENCOUNTERID', 'LAB_LOINC']).agg({'RESULT_NUM':'last'}).reset_index()
    df_l = lab_f.pivot(index=['PATID', 'ENCOUNTERID'], columns='LAB_LOINC', values='RESULT_NUM').reset_index()
    df_l.columns = ['LAB_' + col if col not in ['PATID','ENCOUNTERID'] else col for col in df_l.columns]
    df_lab = df_l.groupby(['PATID', 'ENCOUNTERID']).first().reset_index()

    df_lab.to_pickle(output_path + 'lab_num' + '_d' + str(lapsed_day) +'.pkl')
    print(f"Finished processing numeric lab tests from {file_paths[0].split('/')[-2]} for day {lapsed_day}.", flush = True)
    return df_lab


def process_categorical_lab(file_paths, trim = 0.05, lapsed_day = 1):
    output_path = file_paths[1]
    onset = pd.read_pickle(output_path + 'onset.pkl')
    lab = pd.read_pickle(output_path +'lab_unified_updated.pkl')
    lab = lab[lab['DAYS_SINCE_EVENT'] < lapsed_day]
    lab = lab[(lab['LAB_LOINC'].notna()) & (lab['RESULT_QUAL'].notna()) &  ~(lab['RESULT_QUAL'].str.contains('OT', case=False)
                                                                             | lab['RESULT_QUAL'].str.contains('NI', case=False)
                                                                             | lab['RESULT_QUAL'].str.contains('INVALID', case=False))]

    lab = filter_sparse_features(lab, onset, 'LAB_LOINC', trim)

    labcat = lab[['PATID', 'ENCOUNTERID', 'LAB_LOINC', 'DAYS_SINCE_EVENT', 'RESULT_QUAL']]
    
    if labcat.empty:
        labcat[['PATID', 'ENCOUNTERID']].to_pickle(output_path + 'lab_qual' + '_d' + str(lapsed_day) +'.pkl')
        return pd.DataFrame(columns=['PATID', 'ENCOUNTERID'])
    
    lab_mode = labcat.groupby(['PATID', 'ENCOUNTERID', 'LAB_LOINC', 'DAYS_SINCE_EVENT']).agg(pd.Series.mode).reset_index()
    lab_mode_nnd = lab_mode.loc[lab_mode['RESULT_QUAL'].apply(type) == str].copy()
    lab_mode_nd = lab_mode.loc[lab_mode['RESULT_QUAL'].apply(type) != str].copy()
    pattern = '[\[\]\']'
    lab_mode_nd['RESULT_QUAL'] = lab_mode_nd['RESULT_QUAL'].apply(lambda x: re.sub(pattern, "", np.array2string(x, separator='-')))
    lab_mode = pd.concat([lab_mode_nd, lab_mode_nnd], ignore_index=True)

    labcat_t = lab_mode.sort_values(['PATID', 'ENCOUNTERID', 'LAB_LOINC', 'DAYS_SINCE_EVENT']).groupby(['PATID', 'ENCOUNTERID', 'LAB_LOINC']).agg({'RESULT_QUAL': 'last'}).reset_index()

    labcat_t['LAB_LOINC'] = 'LABCAT_' + labcat_t['LAB_LOINC'] + "(" + labcat_t['RESULT_QUAL'] + ")"
    labcat_t['dummy'] = True
    labcat_t = labcat_t[['PATID', 'ENCOUNTERID', 'LAB_LOINC', 'dummy']]
    labcat_t = labcat_t.pivot_table(index=['PATID', 'ENCOUNTERID'],
                                    columns='LAB_LOINC',
                                    values='dummy',
                                    fill_value=False).reset_index()

    labcat_t.to_pickle(output_path + 'lab_qual' + '_d' + str(lapsed_day) +'.pkl')
    print(f"Finished processing categorical lab tests from {file_paths[0].split('/')[-2]} for day {lapsed_day}.", flush=True)
    return labcat_t

In [26]:
def merge_feature_data(file_paths, lapsed_day):
    input_path = file_paths[0]
    output_path = file_paths[1]
    site_name = file_paths[0].split('/')[-2]
    onset = pd.read_pickle(output_path+'onset.pkl')
    covid = pd.read_pickle(input_path+'/AKI_ONSETS_2'+'.pkl')
    onset = onset.merge(covid[['ENCOUNTERID', 'PATID', 'BCCOVID']], on = ['ENCOUNTERID', 'PATID'], how = 'left')
    demo = pd.read_pickle(output_path+'demo.pkl')
    dx = pd.read_pickle(output_path+'dx.pkl')
    px = pd.read_pickle(output_path+'px_'+ 'd'+str(lapsed_day) + '.pkl')
    rx = pd.read_pickle(output_path+'amed_'+ 'd'+str(lapsed_day) + '.pkl') if site_name not in ['UTSW', 'UofU' ,'UPITT','UNMC'] else pd.read_pickle(output_path+'pmed_'+ 'd'+str(lapsed_day) + '.pkl')
    
    scr = pd.read_pickle(output_path+'scr_'+ 'd'+str(lapsed_day) + '.pkl')
    vital = pd.read_pickle(output_path+'vital_'+ 'd'+str(lapsed_day) + '.pkl')
    labnum = pd.read_pickle(output_path+'lab_num_'+ 'd'+str(lapsed_day) + '.pkl')
    labcat= pd.read_pickle(output_path+'lab_qual_'+ 'd'+str(lapsed_day) + '.pkl')
    
    if not labcat.empty:
        labcat_mean = labcat.drop(['PATID', 'ENCOUNTERID'], axis = 1).sum()/onset.shape[0]
        labcat = labcat.drop(list((labcat_mean[labcat_mean < 0.05]).index), axis =1)
    
    df1 = onset.merge(demo, on = ['PATID', 'ENCOUNTERID'], how = 'left')
    df2 = df1.merge(dx, on = ['PATID', 'ENCOUNTERID'], how = 'left')
    df3 = df2.merge(px, on= ['PATID', 'ENCOUNTERID'], how='left')
    df4 = df3.merge(rx, on = ['PATID', 'ENCOUNTERID'], how = 'left')
    df5 = df4.merge(scr, on = ['PATID', 'ENCOUNTERID'], how = 'left')
    df6 = df5.merge(vital, on = ['PATID', 'ENCOUNTERID'], how = 'left')
    df7 = df6.merge(labnum, on = ['PATID', 'ENCOUNTERID'], how = 'left')
    data = df7.merge(labcat, on = ['PATID', 'ENCOUNTERID'], how = 'left')
    
    return data

def process_feature_data(file_paths, impute = False, lapsed_day = 1):
    data = merge_feature_data(file_paths, lapsed_day)
    data.replace('NaN', np.nan, inplace=True)
    
    bool_col = [col for col in data.columns if col.startswith('RX_') 
                                            or col.startswith('PX_') 
                                            or col.startswith('DX_')
                                            or col.startswith('LABCAT_')
                                            or col.startswith('PREADM_CKD_FLAG')
                                            or col.startswith('SMOKING_')]
    data[bool_col] = data[bool_col].fillna(False)
    
    for col in bool_col:
        data[col] = data[col].astype(bool)
    
    fl_col = [col for col in data.columns 
                           if col.startswith("LAB_")
                           or col.startswith("SCR_")
                           or col.startswith("SYSTOLIC_") 
                           or col.startswith("DIASTOLIC_") 
                           or col.startswith("BMI") 
                           or col.startswith("WT") 
                           or col.startswith("HT")]
    
    if impute:
        for col in fl_col:
            data[col].fillna(data[col].mean(), inplace=True)
                  
    for col in fl_col:
        data[col] = data[col].astype(float)
        
    site_name = file_paths[0].split('/')[-2]
    if site_name == 'UTHSCSA':
            data['LAB_5902-2'] = data['LAB_5902-2'] * 10

    loincmap3 = pd.read_csv(os.path.join(file_paths[2], 'map/GroupLoincTerms.csv'))
    mmc = loincmap3[loincmap3['Category']=='Mass-Molar conversion'][['GroupId']]    
    contains_group_id = data.columns[data.columns.to_series().apply(lambda col: any(mmc['GroupId'].apply(lambda x: x in col)))]

    data.drop(columns=contains_group_id).to_pickle(file_paths[1] + 'data_'+ 'd'+str(lapsed_day) + '.pkl')
    print(f"Finish processing feature data from {site_name} for day {lapsed_day}.", flush = True)
    return data

In [None]:
base_path = './'
site_list = ['KUMC', 'UMHC', 'UNMC', 'UTHSCSA']
event_type = 'PROGNOSIS'
trim = 0.05
impute = False

logging.basicConfig(filename= base_path+'log/processing_errors.log', level=logging.ERROR)

for site_name in site_list:
    try:
        file_paths = get_data_file_path(base_path, site_name)
        process_onset_and_recovery(file_paths, aggfunc_7d = 'last', aggfunc_1y = 'mean', keep_ckd = False)
        process_demo(file_paths)
        process_dx(file_paths, trim, event_type)
        unify_lab_units(file_paths, event_type)
        manual_update_lab_units(file_paths)

        for lapsed_day in range(1,8):
                process_vital(file_paths, event_type, lapsed_day)
                process_px(file_paths, trim, event_type, lapsed_day)
                process_amed(file_paths, trim, event_type, lapsed_day)
                process_pmed(file_paths, trim, event_type, lapsed_day)
                process_lab_scr(file_paths, event_type, lapsed_day)
                process_numeric_lab(file_paths, trim, lapsed_day)
                process_categorical_lab(file_paths, trim, lapsed_day)
                process_feature_data(file_paths, impute, lapsed_day)
    except Exception as e:
        logging.error(f"Error processing site {site_name}: {e}")
        print(f"Error processing site {site_name}. Check log for details.")

#### Step 4: Prepare Multi-state Transition Data

In [None]:
def generate_ft_list(aki_subgrp):
    if aki_subgrp == 'aki1':
        fts_dict = {
                 'SCR_BASELINE' : 'Baseline SCr (mg/dL)', 
                 'SCR_FD' : 'Most recent SCr (slope, mg/(dL*day))',
                 'SCR_MEAN' : 'Most recent SCr (level, mg/dL)',
                 'SYSTOLIC_FD': 'Systolic pressure (slope, mmHg/day)',
                 'SYSTOLIC_MEAN': 'Systolic pressure (level, mmHg)',
                 'DIASTOLIC_FD': 'Diastolic pressure (slope, mmHg/day)', 
                 'DIASTOLIC_MEAN': 'Diastolic pressure (level, mmHg)', 
                  'BMI': 'BMI', 
                  'DX_09_401': 'Essential Hypertension',
                  'ONSET_SINCE_ADMIT': 'Days since admission',
                  'LAB_LG13614-9': 'Anion Gap (mEq/L)',
                  'LAB_LG5465-2': 'Albumin (g/dL)',
                  'LAB_LG5665-7': 'ALP (IU/L)',
                  'LAB_LG6033-7': 'AST (IU/L)',
                  'LAB_LG6199-6': 'Bilirubin (mg/dL)',
                  'LAB_LG12080-4': 'BNP (pg/mL)',
                  'LAB_LG7247-2':  'Calcium (mg/dL)',
                  'LAB_LG4454-7':  'CO2 (mEq/L)',
                  'LAB_LG7967-5':  'Glucose (mg/dL)',   
                  'LAB_LG6039-4': 'Lactate (mmol/L)',
                  'LAB_736-9': 'Lymphocytes/(100*Leukocytes)',
                  'LAB_LG10990-6': 'Potassium (mEq/L)',
                 'LAB_5902-2':  'PT (s)'
                 } 

        imp_ft_lst = list(fts_dict.keys())
    else:
        fts_dict = {
                     'SCR_BASELINE' : 'Baseline SCr (mg/dL)', 
                     'SCR_FD' : 'Most recent SCr (slope, mg/(dL*day))',
                     'SCR_MEAN' : 'Most recent SCr (level, mg/dL)',
                     'SYSTOLIC_FD': 'Systolic pressure (slope, mmHg/day)',
                     'SYSTOLIC_MEAN': 'Systolic pressure (level, mmHg)',
                     'DIASTOLIC_FD': 'Diastolic pressure (slope, mmHg/day)',  
                     'DIASTOLIC_MEAN': 'Diastolic pressure (level, mmHg)',
                      'AGE': 'Age',
                      'ONSET_SINCE_ADMIT': 'Days since admission',
                      'LAB_LG5465-2': 'Albumin (g/dL)',
                      'LAB_LG5665-7': 'ALP (IU/L)',
                      'LAB_LG6033-7': 'AST (IU/L)',
                      'LAB_LG6199-6': 'Bilirubin (mg/dL)',
                      'LAB_LG12080-4': 'BNP (pg/mL)',
                      'LAB_LG7247-2':  'Calcium (mg/dL)',
                      'LAB_LG4454-7':  'CO2 (mEq/L)',
                      'LAB_4544-3': 'Hematocrit (%)',  
                      'LAB_LG10990-6': 'Potassium (mEq/L)',
                      'LAB_LG32850-6': 'RBC (10^6 cells/µL)'
                     }  

        imp_ft_lst = list(fts_dict.keys())
    return imp_ft_lst

In [None]:
def impute_MICE(base_path, site_name, ft_list):
    file_paths = get_data_file_path(base_path, site_name)

    for t in range(1, 8):
        df = pd.read_pickle(file_paths[1] + 'data_'+ 'd'+str(t) + '.pkl')

        # Select only the relevant columns for imputation
        cols_to_impute = [col for col in ft_list if col in df.columns]

        df_subset = df[cols_to_impute]  

        imputer = IterativeImputer(max_iter=50, 
                                   random_state=42)

        df_imputed = pd.DataFrame(imputer.fit_transform(df_subset), columns=cols_to_impute)

        df_imputed['ENCOUNTERID'] = df['ENCOUNTERID'].values

        # Ensure all variables in imp_ft_lst are present in data, filling missing ones with NaN
        for var in ft_list:
            if var not in df.columns:
                df_imputed[var] = np.nan

        df_imputed.to_pickle(file_paths[1] + 'data_mice_'+ 'd'+str(t) + '.pkl')

In [None]:
def get_transition_data(base_path, site_name, imp_ft_lst, aki_subgrp): 
    file_paths = get_data_file_path(base_path, site_name)
    onset0 = pd.read_pickle(file_paths[1]+'onset.pkl')
    demo = pd.read_pickle(file_paths[0] + 'AKI_DEMO'+'.pkl') 
    demo_deduplicated = demo[['PATID', 'ENCOUNTERID', 'DEATH_DATE']].drop_duplicates()

    demo_cleaned = (demo_deduplicated
                    .groupby(['PATID'], as_index=False)
                    .agg({'DEATH_DATE': lambda x: x.max() if x.notna().any() else pd.NaT}))


    onset = onset0.merge(demo_cleaned[['PATID',  'DEATH_DATE']], 
                         on = 'PATID', 
                         how = 'left')

    onset['DEATH_SINCE_ONSET'] = (onset['DEATH_DATE'] - onset['ONSET_DATE']).dt.days
    onset['DEATH_SINCE_ONSET'] = onset['DEATH_SINCE_ONSET'].fillna(1000000)
    onset['DEATH_SINCE_ONSET'] = onset['DEATH_SINCE_ONSET'].astype(int)
    onset.loc[onset['DEATH_SINCE_ONSET'] < 0, 'DEATH_SINCE_ONSET'] = 1000000
    
    if aki_subgrp == 'aki1':
        onset = onset[onset['AKI_INIT_STG'] == 1]
    elif aki_subgrp == 'aki23':
        onset = onset[onset['AKI_INIT_STG'] >1]
    
    df_scr = load_and_filter_scr(onset, file_paths)
    df_scr0, df_admit = load_onset_data(file_paths)
    df_rrt = get_rrt(df_admit, file_paths)
    df_scr_filtered = df_scr[df_scr['DAYS_SINCE_ONSET'] >= 0]
    df_scr_filtered = df_scr_filtered.merge(df_rrt[['PATID', 'ENCOUNTERID', 'RRT_ONSET_DATE']], 
                                            on =['PATID', 'ENCOUNTERID'], 
                                            how = 'left')
    df_scr_filtered['RRT_SINCE_ONSET'] = (df_scr_filtered['RRT_ONSET_DATE'] - df_scr_filtered['ONSET_DATE']).dt.days

    df_scr_filtered['AKI_STG_TMP'] = 1
    mask_noaki = (df_scr_filtered['RESULT_NUM'] < 1.5 * df_scr_filtered['SCR_BASELINE']) & (df_scr_filtered['RESULT_NUM'] < (0.3 + df_scr_filtered['SCR_REFERENCE']))
    mask_aki2 = (df_scr_filtered['RESULT_NUM'] >= 2.0 * df_scr_filtered['SCR_BASELINE']) & (df_scr_filtered['RESULT_NUM'] < 3.0 * df_scr_filtered['SCR_BASELINE'])
    mask_aki3 = (
                 (df_scr_filtered['RESULT_NUM']>=3.0 * df_scr_filtered['SCR_BASELINE']) | (df_scr_filtered['RESULT_NUM']>=4)  # scr criteria
                 | ((df_scr_filtered['RRT_SINCE_ONSET'] <= df_scr_filtered['DAYS_SINCE_ONSET']) & (df_scr_filtered['RRT_SINCE_ONSET'] >= 0)) # rrt criteria
                )
    df_scr_filtered.loc[mask_noaki,'AKI_STG_TMP'] = 0
    df_scr_filtered.loc[mask_aki2,'AKI_STG_TMP'] = 2
    df_scr_filtered.loc[mask_aki3,'AKI_STG_TMP'] = 3
    
    df_scr_filtered['AKI_STATUS_CHANGE'] = np.nan
    df_scr_filtered.loc[df_scr_filtered['AKI_STG_TMP'] == 0,'AKI_STATUS_CHANGE'] = 1
    df_scr_filtered.loc[df_scr_filtered['AKI_STG_TMP'] == 1,'AKI_STATUS_CHANGE'] = 2
    df_scr_filtered.loc[df_scr_filtered['AKI_STG_TMP'] > 1,'AKI_STATUS_CHANGE'] = 3

    df_status = df_scr_filtered[(df_scr_filtered['DAYS_SINCE_ONSET'] <=7)][['PATID','ENCOUNTERID','AKI_INIT_STG', 'AKI_STG_TMP', 'AKI_STATUS_CHANGE', 'DAYS_SINCE_ONSET']]

    onset_filtered = onset[(onset['DISCHARGE_SINCE_ONSET'] >= 0)]
    
    df_msm = df_status[df_status['ENCOUNTERID'].isin(onset_filtered['ENCOUNTERID'])].merge(
                             onset_filtered[['PATID','ENCOUNTERID','DISCHARGE_SINCE_ONSET', 'DEATH_SINCE_ONSET']], 
                             on = ['PATID', 'ENCOUNTERID'], 
                             how = 'outer')

    mask_death_discharge = (
                            (df_msm['DISCHARGE_SINCE_ONSET'] < df_msm['DAYS_SINCE_ONSET']) | 
                            (df_msm['DEATH_SINCE_ONSET'] < df_msm['DAYS_SINCE_ONSET'])
                           ) 
    
    df_msm = df_msm[~mask_death_discharge]

    mask_discharge_same = (
                        (df_msm['DISCHARGE_SINCE_ONSET'] == df_msm['DAYS_SINCE_ONSET']) | 
                        (df_msm['DAYS_SINCE_ONSET'].isna() & (df_msm['DISCHARGE_SINCE_ONSET'] <= 7))
                      ) 
    df_msm.loc[mask_discharge_same, 'AKI_STATUS_CHANGE'] = 0
    
    mask_death_same = (
                        (df_msm['DEATH_SINCE_ONSET'] == df_msm['DAYS_SINCE_ONSET']) | 
                        (df_msm['DAYS_SINCE_ONSET'].isna() & (df_msm['DEATH_SINCE_ONSET'] <= 7))
                      ) 
    df_msm.loc[mask_death_same, 'AKI_STATUS_CHANGE'] = 4    

    df_msm['AKI_STATUS_CHANGE'] = df_msm['AKI_STATUS_CHANGE'].fillna(2)
    
    discharge_list = onset[(onset['DISCHARGE_SINCE_ONSET'] <=7) & 
                           (onset['DISCHARGE_SINCE_ONSET'] < onset['DEATH_SINCE_ONSET'])].ENCOUNTERID.unique()
    
    mask_discharge  =     (
                            (~df_msm.ENCOUNTERID.isin(df_msm[df_msm['AKI_STATUS_CHANGE'].isin([0, 4])].ENCOUNTERID.unique())) & 
                            (df_msm.ENCOUNTERID.isin(discharge_list)) 
                           )
    
    df_disc =  df_msm[mask_discharge].groupby('ENCOUNTERID').DISCHARGE_SINCE_ONSET.last().reset_index()
    df_disc['AKI_STATUS_CHANGE'] = 0
    df_disc['DAYS_SINCE_ONSET'] = df_disc['DISCHARGE_SINCE_ONSET'].copy()

    death_list = onset[(onset['DEATH_SINCE_ONSET'] <=7) & 
                       (onset['DEATH_SINCE_ONSET'] <= onset['DISCHARGE_SINCE_ONSET'])].ENCOUNTERID.unique()
    
    mask_death   =      (
                         (~df_msm.ENCOUNTERID.isin(df_msm[df_msm['AKI_STATUS_CHANGE'] == 4].ENCOUNTERID.unique())) & 
                         (df_msm.ENCOUNTERID.isin(death_list)) 
                        )
    
    df_death =  df_msm[mask_death].groupby('ENCOUNTERID').DEATH_SINCE_ONSET.last().reset_index()
    df_death['AKI_STATUS_CHANGE'] = 4
    df_death['DAYS_SINCE_ONSET'] = df_death['DEATH_SINCE_ONSET'].copy()
    
    df_final = pd.concat([df_msm[['ENCOUNTERID', 'DAYS_SINCE_ONSET', 'AKI_STATUS_CHANGE']], 
                          df_disc[['ENCOUNTERID', 'DAYS_SINCE_ONSET', 'AKI_STATUS_CHANGE']],
                          df_death[['ENCOUNTERID', 'DAYS_SINCE_ONSET', 'AKI_STATUS_CHANGE']],
                         ], axis=0)

   
    df_final['DAYS_SINCE_ONSET'] = df_final['DAYS_SINCE_ONSET'].astype('int')
    df_final['AKI_STATUS_CHANGE'] = df_final['AKI_STATUS_CHANGE'].astype('int')

    transitions = pd.DataFrame({
        'TRANS':      [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
        'FROM_STATUS': [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
        'TO_STATUS':   [0, 2, 3, 4, 0, 1, 3, 4, 0, 1, 2, 4]
    })

    df0 = df_final[['ENCOUNTERID', 'DAYS_SINCE_ONSET', 'AKI_STATUS_CHANGE']]
    df0.sort_values(['ENCOUNTERID', 'DAYS_SINCE_ONSET'], inplace = True)

    df0['tstart'] = df0.groupby('ENCOUNTERID')['DAYS_SINCE_ONSET'].shift(0)
    df0['tstop'] = df0.groupby('ENCOUNTERID')['DAYS_SINCE_ONSET'].shift(-1)
    df0['FROM_STATUS'] = df0['AKI_STATUS_CHANGE']
    df0['TO_STATUS'] = df0.groupby('ENCOUNTERID')['AKI_STATUS_CHANGE'].shift(-1)

    df1 = df0[~(df0.tstart == 7)]
    df2 = df1[~(((df1.FROM_STATUS == 0) & (df1.tstop.isna())) | ((df1.FROM_STATUS == 4) & (df1.tstop.isna())))]

    df2['tstop'] = df2['tstop'].fillna(1000)
    df2['TO_STATUS'] = df2['TO_STATUS'].fillna(1000)

    df2['transition_start'] = (
        (df2['FROM_STATUS'] != df2['FROM_STATUS'].shift()) | 
        (df2['ENCOUNTERID'] != df2['ENCOUNTERID'].shift())
    )

    df2['transition_group'] = df2['transition_start'].cumsum()

    df3 = df2.groupby(['ENCOUNTERID', 'transition_group']).agg(
            tstart=('tstart', 'min'),                 
            tstop=('tstop', 'max'),                   
            FROM_STATUS=('FROM_STATUS', 'first'),     
            TO_STATUS=('TO_STATUS', 'last')           
        ).reset_index()

    df3.sort_values(['ENCOUNTERID', 'tstart'], inplace=True)
    df3.loc[(df3['FROM_STATUS'] == df3['TO_STATUS']), 'tstop'] = np.nan
    df3.loc[(df3['FROM_STATUS'] == df3['TO_STATUS']), 'TO_STATUS'] = np.nan
    df3['TO_STATUS'] = df3['TO_STATUS'].replace(1000, np.nan)
    df3['tstop'] = df3['tstop'].replace(1000, np.nan)
    df_trans = df3[['ENCOUNTERID', 'FROM_STATUS', 'TO_STATUS', 'tstart', 'tstop']]

    df_censored = df_trans[df_trans.tstop.isna()].drop('TO_STATUS', axis = 1)
    expanded_censored = df_censored[['ENCOUNTERID']].drop_duplicates().merge(transitions, how='cross')
    df_censored_expanded = df_censored.merge(expanded_censored, on=['ENCOUNTERID', 'FROM_STATUS'], how='right')
    df_censored_expanded = df_censored_expanded.dropna(subset=['tstart'])

    df_censored_expanded['EVENT_STATUS'] = 0
    df_censored_expanded['tstop'] = 7

    df_unc = df_trans[~df_trans.tstop.isna()].drop_duplicates()
    df_unc['tstop'] = df_unc['tstop'].astype(int)
    df_unc['TO_STATUS']= df_unc['TO_STATUS'].astype(int)
    df_unc['EVENT_STATUS'] = False
    expanded_df = df_unc[['ENCOUNTERID']].drop_duplicates().merge(transitions, how='cross')

    expanded_df2 = expanded_df.merge(df_unc, 
                                    on=['ENCOUNTERID', 'FROM_STATUS'], 
                                    how='left', 
                                    suffixes=('', '_orig'))
    expanded_df2['EVENT_STATUS'] = expanded_df2['EVENT_STATUS'].fillna(False)
    expanded_df2.loc[expanded_df2['TO_STATUS_orig'] == expanded_df2['TO_STATUS'], 'EVENT_STATUS'] = True
    df_unc_expanded = expanded_df2.drop(columns=['TO_STATUS_orig']).dropna(subset = 'tstart').sort_values(['ENCOUNTERID', 'tstart', 'FROM_STATUS'])
    
    cols_to_keep = ['ENCOUNTERID', 'FROM_STATUS', 'TO_STATUS', 'TRANS', 'tstart', 'tstop', 'EVENT_STATUS']
    data_surv =pd.concat([df_unc_expanded[cols_to_keep], df_censored_expanded[cols_to_keep]]).sort_values(by = ['ENCOUNTERID', 'tstart'])    
    var_lst = ['ENCOUNTERID'] + imp_ft_lst 

    df_all = pd.DataFrame()

    for t in range(0, 8):
        if t < 7: 
            features = pd.read_pickle(file_paths[1] + 'data_mice_'+ 'd'+str(t+1) + '.pkl')
        else: 
            features = pd.read_pickle(file_paths[1] + 'data_mice_'+ 'd'+str(7) + '.pkl')
            
        for var in imp_ft_lst:
            if var not in features.columns:
                features[var] = np.nan

        features = features[var_lst]
        df_day = data_surv[data_surv['tstart'] == t].merge(features, 
                                                          on = 'ENCOUNTERID', 
                                                          how = 'left')
        df_all = pd.concat([df_all, df_day], axis = 0)

    df_all['site'] = site_name
    df_all['ENCOUNTERID'] = df_all['ENCOUNTERID'].astype(str)
    
    return df_all

In [None]:
# Runtime: Prepare data for multi-state transition analysis
aki_subgrps = ['aki1', 'aki23']
ft_aki1  = generate_ft_list(aki_subgrps[0])
ft_aki23 = generate_ft_list(aki_subgrps[1])
ft_all = sorted(list(set(ft_aki1) | set(ft_aki23)))

for site_name in site_list:
    impute_MICE(base_path, site_name, ft_all)

for aki_subgrp in aki_subgrps:
    data_all = pd.DataFrame()
    ft_lst = generate_ft_list(aki_subgrp)
    for site_name in site_list: 
        data_site = get_transition_data(base_path, site_name, ft_lst, aki_subgrp)
        data_all = pd.concat([data_all, data_site], axis = 0)
    data_all.to_parquet(base_path + 'data_msm_'+ aki_subgrp + '.parquet')

#### Prepare data for the Sankey Plot

In [None]:
def get_sankey_data(base_path, site_name, aki_subgrp):
    filepath_lst = get_data_file_path(base_path, site_name)
    onset0 = pd.read_pickle(filepath_lst[1]+'onset.pkl')
    demo = pd.read_pickle(filepath_lst[0] + 'AKI_DEMO'+'.pkl')

    demo_deduplicated = demo[['PATID', 'ENCOUNTERID', 'DEATH_DATE']].drop_duplicates()
    demo_cleaned = (demo_deduplicated
                    .groupby(['PATID'], as_index=False)
                    .agg({'DEATH_DATE': lambda x: x.max() if x.notna().any() else pd.NaT}))
    onset = onset0.merge(demo_cleaned[['PATID',  'DEATH_DATE']],
                         on = 'PATID',
                         how = 'left')

    onset['DEATH_SINCE_ONSET'] = (onset['DEATH_DATE'] - onset['ONSET_DATE']).dt.days
    onset['DEATH_SINCE_ONSET'] = onset['DEATH_SINCE_ONSET'].fillna(1000000)
    onset['DEATH_SINCE_ONSET'] = onset['DEATH_SINCE_ONSET'].astype(int)
    onset.loc[onset['DEATH_SINCE_ONSET'] < 0, 'DEATH_SINCE_ONSET'] = 1000000

    if aki_subgrp == 'aki1':
        onset = onset[onset['AKI_INIT_STG'] == 1]
    elif aki_subgrp == 'aki23':
        onset = onset[onset['AKI_INIT_STG'] >1]

    df_scr = load_and_filter_scr(onset, filepath_lst)
    df_scr0, df_admit = load_onset_data(filepath_lst)
    df_rrt = get_rrt(df_admit, filepath_lst)
    df_scr_filtered = df_scr[df_scr['DAYS_SINCE_ONSET'] >= 0]
    df_scr_filtered = df_scr_filtered.merge(df_rrt[['PATID', 'ENCOUNTERID', 'RRT_ONSET_DATE']],
                                            on =['PATID', 'ENCOUNTERID'],
                                            how = 'left')
    df_scr_filtered['RRT_SINCE_ONSET'] = (df_scr_filtered['RRT_ONSET_DATE'] - df_scr_filtered['ONSET_DATE']).dt.days

    df_sankey = onset[['PATID','ENCOUNTERID','DISCHARGE_SINCE_ONSET', 'DEATH_SINCE_ONSET']].copy()

    for t in range(0, 8):
        df_t = df_scr_filtered[df_scr_filtered['DAYS_SINCE_ONSET'] <= t].sort_values(by = ['PATID', 'ENCOUNTERID', 'DAYS_SINCE_ONSET']).groupby(['PATID', 'ENCOUNTERID']).last().reset_index()

        # generate daily aki stage indicator (KDIGO definition)
        status_label= 'State_d' + str(t)
        df_t[status_label] = 1
        mask_noaki = (df_t['RESULT_NUM'] < 1.5 * df_t['SCR_BASELINE']) & (df_t['RESULT_NUM'] < (0.3 + df_t['SCR_REFERENCE']))
        mask_aki2 = (df_t['RESULT_NUM'] >= 2.0 * df_t['SCR_BASELINE']) & (df_t['RESULT_NUM'] < 3.0 * df_t['SCR_BASELINE'])
        mask_aki3 = (
                (df_t['RESULT_NUM']>=3.0 * df_t['SCR_BASELINE']) | (df_t['RESULT_NUM']>=4)  # scr criteria
                | ((df_t['RRT_SINCE_ONSET'] <= df_t['DAYS_SINCE_ONSET']) & (df_t['RRT_SINCE_ONSET'] >= 0)) # rrt criteria
        )
        df_t.loc[mask_noaki,status_label] = 0
        df_t.loc[mask_aki2,status_label] = 2
        df_t.loc[mask_aki3,status_label] = 3

        df_sankey = df_sankey.merge(df_t[['PATID','ENCOUNTERID', status_label]], on = ['PATID','ENCOUNTERID'], how = 'left')
        df_sankey.loc[df_sankey['DISCHARGE_SINCE_ONSET'] <= t, status_label] = 4
        df_sankey.loc[df_sankey['DEATH_SINCE_ONSET'] <= t, status_label] = 5
        df_sankey[status_label] = df_sankey[status_label].fillna(1)

    df_sankey['site'] = site_name
    return df_sankey

In [None]:
# Build the Sankey data frame across all sites
data_all = pd.DataFrame()
for site_name in site_list:
    data_site = get_sankey_data(base_path, site_name, 'all')
    data_all = pd.concat([data_all, data_site], axis = 0)
data_all.to_pickle(base_path + 'data_sankey' + '.pkl')