In [1]:
import os
import time
import pickle
import logging
import warnings
import catboost as cb
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold, train_test_split, RandomizedSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
warnings.filterwarnings("ignore")

#### Utility functions: Load & prepare data

In [2]:
def load_discrete_survival_data_for_site(base_path, site_name, pred_pt, pred_task, fs_type, aki_subgrp):
    dat_lst = []
    pred_end = 7

    filepath = os.path.join(base_path, site_name, 'processed_data') + '/'
    outcome0 = pd.read_pickle(filepath + 'outcome.pkl')
    demo = pd.read_pickle(base_path + site_name + '/AKI_DEMO' + '.pkl')
    demo_deduplicated = demo[['PATID', 'ENCOUNTERID', 'DEATH_DATE']].drop_duplicates()
    demo_cleaned = (demo_deduplicated
                    .groupby(['PATID'], as_index=False)
                    .agg({'DEATH_DATE': lambda x: x.max() if x.notna().any() else pd.NaT}))

    outcome = outcome0.merge(demo_cleaned[['PATID', 'DEATH_DATE']],
                             on='PATID',
                             how='left')
    outcome['DEATH_SINCE_ONSET'] = (outcome['DEATH_DATE'] - outcome['ONSET_DATE']).dt.days
    outcome['DEATH_SINCE_ONSET'] = outcome['DEATH_SINCE_ONSET'].fillna(1000000)
    outcome['DEATH_SINCE_ONSET'] = outcome['DEATH_SINCE_ONSET'].astype(int)
    outcome.loc[outcome['DEATH_SINCE_ONSET'] < 0, 'DEATH_SINCE_ONSET'] = 1000000
    outcome['DEATH_DISCHARGE_SINCE_ONSET'] = outcome[['DISCHARGE_SINCE_ONSET', 'DEATH_SINCE_ONSET']].min(axis=1)

    outcome['AKI2_SINCE_ONSET'] = outcome['AKI2_SINCE_ADMIT'] - outcome['ONSET_SINCE_ADMIT']
    outcome['AKI3_SINCE_ONSET'] = outcome['AKI3_SINCE_ADMIT'] - outcome['ONSET_SINCE_ADMIT']

    if pred_pt < 7:
        tw = list(range(pred_pt + 1, pred_end + 1)) if pred_task != 'rvsl' else list(range(pred_pt + 2, pred_end + 1))
    else:
        raise ValueError(f"Prediction point pred_pt should be less than {pred_end}.")

    for t in tw:
        if pred_task == 'rvsl':
            tmp_out = \
            outcome[(~((outcome['RVRT_SINCE_ONSET'] + 1) < t)) & (outcome['DEATH_DISCHARGE_SINCE_ONSET'] >= t)][
                ['PATID', 'ENCOUNTERID', 'RVRT_SINCE_ONSET']]
            tmp_out['AKI_RVRT'] = ((tmp_out['RVRT_SINCE_ONSET'] + 1) <= t)
            tmp_out.drop('RVRT_SINCE_ONSET', axis=1, inplace=True)
        elif pred_task == 'stgup':
            mask_stg1up = ((outcome['AKI_INIT_STG'] == 1) &
                           ((outcome['AKI2_SINCE_ONSET'] >= t) | np.isnan(outcome['AKI2_SINCE_ONSET'])) &
                           ((outcome['AKI3_SINCE_ONSET'] >= t) | np.isnan(outcome['AKI3_SINCE_ONSET'])))

            mask_stg2up = ((outcome['AKI_INIT_STG'] == 2) &
                           ((outcome['AKI3_SINCE_ONSET'] >= t) | np.isnan(outcome['AKI3_SINCE_ONSET'])))

            tmp_out = outcome[(mask_stg1up | mask_stg2up) & (outcome['DEATH_DISCHARGE_SINCE_ONSET'] >= t)][
                ['PATID', 'ENCOUNTERID',
                 'AKI_INIT_STG',
                 'AKI2_SINCE_ONSET',
                 'AKI3_SINCE_ONSET']]
            tmp_out['AKI_STGUP'] = (((tmp_out['AKI_INIT_STG'] == 1) & (
                        (tmp_out['AKI2_SINCE_ONSET'] <= t) | (tmp_out['AKI3_SINCE_ONSET'] <= t))) |
                                    ((tmp_out['AKI_INIT_STG'] == 2) & (tmp_out['AKI3_SINCE_ONSET'] <= t)))

            tmp_out.drop(['AKI_INIT_STG', 'AKI2_SINCE_ONSET', 'AKI3_SINCE_ONSET'], axis=1, inplace=True)
        else:
            raise ValueError(f"Unknown pred_task: {pred_task}")
        
        for s in tw:
            tmp_out['POD_' + str(s)] = False

        tmp_out['POD_' + str(t)] = True
        tmp_out['ID_PAT_ENC'] = tmp_out['PATID'].astype(str) + '_' + tmp_out['ENCOUNTERID'].astype(str)
        tmp_out['ID_POD'] = tmp_out['ID_PAT_ENC'] + '_' + str(t)

        # load features for day t
        t_dat = t - pred_pt if pred_task != 'rvsl' else (t - 1 - pred_pt)
        data = pd.read_pickle(filepath + 'data_' + 'd' + str(t_dat) + '.pkl')
        cols = ['ADMIT_DATE', 'DISCHARGE_DATE', 'ONSET_DATE', 'AKI1_SINCE_ADMIT', 'AKI2_SINCE_ADMIT',
                'AKI3_SINCE_ADMIT',
                'DISCHARGE_SINCE_ONSET', 'SCR_ONSET', 'SCR_REFERENCE', 'AKI1_7D', 'AKI1_2D', 'FLAG', 'AKI_STAGE',
                'SCR_RANGE',
                'SYSTOLIC_RANGE', 'WT', 'PREADM_CKD_STAGE', 'DIASTOLIC_RANGE', 'LAB_LG50024-5', 'LAB_LG6657-3']

        cols_to_drop = [var for var in data.columns if var in cols]
        data = data.drop(cols_to_drop, axis=1)

        # merge outcome and features
        dat_t = tmp_out.merge(data, on=['PATID', 'ENCOUNTERID'], how='left')

        if aki_subgrp == 'aki1':
            dat_t = dat_t[dat_t['AKI_INIT_STG'] == 1].drop('AKI_INIT_STG', axis=1)
        elif (aki_subgrp == 'aki23') & (pred_task == 'rvsl'):
            dat_t = dat_t[dat_t['AKI_INIT_STG'] > 1].drop('AKI_INIT_STG', axis=1)
        elif (aki_subgrp == 'aki23') & (pred_task == 'stgup'):
            dat_t = dat_t[dat_t['AKI_INIT_STG'] == 2].drop('AKI_INIT_STG', axis=1)
        dat_lst.append(dat_t)

    all_columns = set(dat_lst[0].columns)
    column_dtypes = {}

    for col in dat_lst[-1].columns:
        column_dtypes[col] = dat_lst[-1][col].dtype

    for df in dat_lst[0:-1]:
        all_columns.update(df.columns)
        for col in df.columns:
            if col not in column_dtypes:
                column_dtypes[col] = df[col].dtype

    all_columns = sorted(list(all_columns))

    aligned_dat_lst = []
    for df in dat_lst:
        aligned_df = df.reindex(columns=all_columns)
        aligned_dat_lst.append(aligned_df)

    concatenated_df = pd.concat(aligned_dat_lst, axis=0, ignore_index=True)

    for col in concatenated_df.columns:
        concatenated_df[col] = concatenated_df[col].astype(column_dtypes[col])

    bool_columns = concatenated_df.select_dtypes(include=['bool']).columns
    concatenated_df[bool_columns] = concatenated_df[bool_columns].fillna(False)

    if fs_type == 'rm_scr_bun':
        scr_bun_labs = ['2160-0', '38483-4', '14682-9', '21232-4', '35203-9', '44784-7', '59826-8',
                        '16188-5', '16189-3', '59826-8', '35591-7', '50380-5', '50381-3', '35592-5',
                        '44784-7', '11041-1', '51620-3', '72271-0', '11042-9', '51619-5', '35203-9', '14682-9',
                        '12966-8', '12965-0', '6299-2', '59570-2', '12964-3', '49071-4', '72270-2',
                        '11065-0', '3094-0', '35234-4', '14937-7',
                        '48642-3', '48643-1',  # eGFR
                        '3097-3', '44734-2',  # scr bun ratio
                        '12967-6', '13506-1', '20624-3', '2890-2', '33914-3', '34366-5', '62238-1', '88293-6',
                        '88294-4',  # additional keys from athena
                        'LG12083-8', 'LG1314-6', 'LG34710-0', 'LG34791-0', 'LG34808-2', 'LG35227-4', 'LG35814-9',
                        'LG49763-2',  # loinc group id
                        'LG49764-0', 'LG49776-4', 'LG50019-5', 'LG50024-5', 'LG50025-2', 'LG50986-5', 'LG6657-3',
                        'LG7133-4']

        scr_bun_cols = ['LAB_' + var for var in scr_bun_labs] + ['SCR_BASELINE', 'SCR_MEAN', 'SCR_FD', 'SCR_REFERENCE',
                                                                 'SCR_RANGE', 'SCR_ONSET']

        rm_cols = [var for var in concatenated_df.columns if var in scr_bun_cols]
        concatenated_df.drop(rm_cols, axis=1, inplace=True)

    return concatenated_df

In [None]:
def catRandomSearch(data, outcome_label, validation_type, test_size,  n_folds = 10, n_iters = 50, rnd_split_seed = 42):
    # a global seed for reproducibility
    np.random.seed(1234)

    if validation_type == 'rnd':
        enc_tmp = data['ID_PAT_ENC'].unique()
        enc_tr_tmp, enc_ts = train_test_split(enc_tmp, test_size=test_size, random_state=rnd_split_seed)
        enc_tr, enc_val = train_test_split(enc_tr_tmp, test_size=test_size, random_state=123)
        data_train = data[data['ID_PAT_ENC'].isin(enc_tr)].reset_index(drop=True)
        data_val = data[data['ID_PAT_ENC'].isin(enc_val)].reset_index(drop=True)
        data_test = data[data['ID_PAT_ENC'].isin(enc_ts)].reset_index(drop=True)
    elif validation_type == 'tmpr':
        data_before_covid = data[data['BCCOVID'] == 1]
        data_after_covid = data[data['BCCOVID'] == 0]
        enc_before = data_before_covid['ID_PAT_ENC'].unique()
        enc_tr, enc_val = train_test_split(enc_before, test_size=test_size, random_state=rnd_split_seed)
        data_train = data_before_covid[data_before_covid['ID_PAT_ENC'].isin(enc_tr)].reset_index(drop=True)
        data_val = data_before_covid[data_before_covid['ID_PAT_ENC'].isin(enc_val)].reset_index(drop=True)
        data_test = data_after_covid.reset_index(drop=True)
    else:
        raise ValueError(f"Unknown validation type: {validation_type}")
    
    x_train = data_train.drop(['ID_POD', 'ID_PAT_ENC', 'PATID', 'ENCOUNTERID', outcome_label], axis=1)
    y_train = data_train[outcome_label]

    x_test = data_test.drop(['ID_POD', 'ID_PAT_ENC', 'PATID', 'ENCOUNTERID', outcome_label], axis=1)
    y_test = data_test[outcome_label]

    x_val = data_val.drop(['ID_POD', 'ID_PAT_ENC', 'PATID', 'ENCOUNTERID', outcome_label], axis=1)
    y_val = data_val[outcome_label]

    np.random.shuffle(enc_tr)
    kf = KFold(n_splits=n_folds, shuffle=True, random_state=42)

    folds = []
    for train_idx, val_idx in kf.split(enc_tr):
        train_enc_ids = enc_tr[train_idx]
        test_enc_ids = enc_tr[val_idx]

        # Get indices of rows for these encounters
        train_indices = data_train.index[data_train['ID_PAT_ENC'].isin(train_enc_ids)].tolist()
        val_indices = data_train.index[data_train['ID_PAT_ENC'].isin(test_enc_ids)].tolist()

        folds.append([train_indices, val_indices])

    labelcount = y_train.value_counts()
    cat_features = list(x_train.select_dtypes(include=['bool']).columns)
    cvmodel = cb.CatBoostClassifier(
        scale_pos_weight=labelcount[0] / labelcount[1],
        objective='Logloss',
        eval_metric='AUC:hints=skip_train~false',
        verbose=False,
        early_stopping_rounds=500,
        cat_features=cat_features,
        random_seed=42
    )

    params = {
        'subsample': [0.5, 0.8, 1.0],
        'colsample_bylevel': [0.5, 0.8, 1.0],
        'max_depth': [6, 8, 12],
        'learning_rate': [0.01, 0.05, 0.1],
        'n_estimators': [100, 200, 400],
        'l2_leaf_reg': [3, 5],
        'random_strength': [1, 3, 5],
        'bagging_temperature': [0.1, 0.5, 1.0]
    }

    randomized_search_result = cvmodel.randomized_search(
        params,
        X=x_train,
        y=y_train,
        cv=folds,
        n_iter=n_iters
    )

    bestmodel = cb.CatBoostClassifier(
        scale_pos_weight=labelcount[0] / labelcount[1],
        objective='Logloss',
        eval_metric='AUC:hints=skip_train~false',
        verbose=250, # type: ignore[arg-type]
        early_stopping_rounds=500,
        cat_features=cat_features,
        random_seed=42,
        **randomized_search_result['params']
    )

    bestmodel.fit(x_train, y_train,
                  cat_features=cat_features,
                  eval_set=[(x_train, y_train), (x_val, y_val)],
                  early_stopping_rounds=500
                  )

    y_pred = bestmodel.predict_proba(x_test)[:, 1]

    return {
        'best_model': bestmodel,
        'data_test': data_test,
        'data_val': data_val,
        'y_pred': y_pred
    }

In [None]:
def logisticRandomSearch(data, outcome_label, validation_type,  test_size, n_folds = 10, n_iters = 50, rnd_split_seed = 42):
    # a global seed for reproducibility
    np.random.seed(1234)

    if validation_type == 'rnd':
        enc_tmp = data['ID_PAT_ENC'].unique()
        enc_tr_tmp, enc_ts = train_test_split(enc_tmp, test_size=test_size, random_state=rnd_split_seed)
        enc_tr, enc_val = train_test_split(enc_tr_tmp, test_size=test_size, random_state=123)
        data_train = data[data['ID_PAT_ENC'].isin(enc_tr)].reset_index(drop=True)
        data_val = data[data['ID_PAT_ENC'].isin(enc_val)].reset_index(drop=True)
        data_test = data[data['ID_PAT_ENC'].isin(enc_ts)].reset_index(drop=True)
    elif validation_type == 'tmpr':
        data_before_covid = data[data['BCCOVID'] == 1]
        data_after_covid = data[data['BCCOVID'] == 0]
        enc_before = data_before_covid['ID_PAT_ENC'].unique()
        enc_tr, enc_val = train_test_split(enc_before, test_size=test_size, random_state=rnd_split_seed)
        data_train = data_before_covid[data_before_covid['ID_PAT_ENC'].isin(enc_tr)].reset_index(drop=True)
        data_val = data_before_covid[data_before_covid['ID_PAT_ENC'].isin(enc_val)].reset_index(drop=True)
        data_test = data_after_covid.reset_index(drop=True)
    else:
        raise ValueError(f"Unknown validation type: {validation_type}")
    
    data_test_raw = data_test.copy()
    x_train = data_train.drop(['ID_POD', 'ID_PAT_ENC', 'PATID', 'ENCOUNTERID', 'BCCOVID', outcome_label], axis=1)
    y_train = data_train[outcome_label]

    x_test = data_test.drop(['ID_POD', 'ID_PAT_ENC', 'PATID', 'ENCOUNTERID', 'BCCOVID', outcome_label], axis=1)
    y_test = data_test[outcome_label]

    preprocessor = Pipeline(steps=[
        ('imputer', SimpleImputer(strategy='mean')),
        ('scaler', StandardScaler())
    ])

    x_train_processed = pd.DataFrame(preprocessor.fit_transform(x_train), columns=x_train.columns)

    model = LogisticRegression(penalty='l1',
                               solver='liblinear',
                               max_iter=1000)

    param_distributions = {
        'C': np.logspace(-3, 1, 20)
    }

    np.random.shuffle(enc_tr)
    kf = KFold(n_splits=n_folds, shuffle=True)

    folds = []
    for train_idx, val_idx in kf.split(enc_tr):
        train_enc_ids = enc_tr[train_idx]
        test_enc_ids = enc_tr[val_idx]

        train_indices = data_train.index[data_train['ID_PAT_ENC'].isin(train_enc_ids)].tolist()
        val_indices = data_train.index[data_train['ID_PAT_ENC'].isin(test_enc_ids)].tolist()

        folds.append([train_indices, val_indices])

    randomized_search = RandomizedSearchCV(
        model,
        param_distributions,
        n_iter=n_iters,
        cv=kf,
        random_state=rnd_split_seed,
        scoring='roc_auc',
        verbose=1,
        n_jobs=-1
    )
    randomized_search.fit(x_train_processed, y_train)

    bestmodel = randomized_search.best_estimator_

    x_test_processed = pd.DataFrame(preprocessor.transform(x_test), columns=x_test.columns)
    y_pred = randomized_search.predict_proba(x_test_processed)[:, 1]

    data_test_imputed_scaled = x_test_processed
    data_test_imputed_scaled[outcome_label] = y_test.reset_index(drop=True)

    return {
        'best_model': bestmodel,
        'data_test_raw': data_test_raw,
        'data_test': data_test_imputed_scaled,
        'data_val': data_val,
        'y_pred': y_pred
    }

#### Runtime: Train predictive models

In [None]:
base_path = "./"
model_path = os.path.join(base_path, "model")
os.makedirs(model_path, exist_ok=True)

site_labels = ['Site1', 'Site2', 'Site3', 'Site4']
site_mapping = {
    "... MASKED_FOR_ANONYMITY": "... MASKED_FOR_ANONYMITY",
    # ...
}

test_size = 0.15
prediction_tasks = ["rvsl", "stgup"]
model_types = ["cb", "rlr"]
validation_types = ["rnd", "tmpr"]
aki_subgroups = ["aki1", "aki23"]
fs_types = ["no_fs", "rm_scr_bun"]

for model_type in model_types:
    for val_type in validation_types:
        for aki_subgrp in aki_subgroups:
            for fs_type in fs_types:
                result_dict = {}

                for pred_task in prediction_tasks:
                    if pred_task == "rvsl":
                        outcome_label = "AKI_RVRT"
                    elif pred_task == "stgup":
                        outcome_label = "AKI_STGUP"
                    else:
                        raise ValueError(f"Unknown pred_task: {pred_task}")

                    result_dict[pred_task] = {}

                    for site_label in site_labels:
                        try:
                            start_time = time.time()

                            # Get model data
                            model_data = load_discrete_survival_data_for_site(
                                base_path=base_path,
                                site_name=site_mapping.get(site_label, site_label),
                                pred_pt=0,
                                pred_task=pred_task,
                                fs_type=fs_type,
                                aki_subgrp=aki_subgrp
                            )

                            # Model fitting
                            if model_type == "cb":
                                site_result = catRandomSearch(
                                    model_data,
                                    outcome_label,
                                    val_type,
                                    test_size
                                )
                            elif model_type == "rlr":
                                site_result = logisticRandomSearch(
                                    model_data,
                                    outcome_label,
                                    val_type,
                                    test_size
                                )
                            else:
                                raise ValueError(f"Unsupported model_type: {model_type}")

                            elapsed_time = time.time() - start_time
                            site_result["elapsed_time"] = elapsed_time

                            result_dict[pred_task][site_label] = site_result

                            # Create output filename
                            if val_type == "rnd":
                                file_name = f"{model_type}_{aki_subgrp}_{pred_task}_1d_{fs_type}.pkl"
                            elif val_type == "tmpr": 
                                file_name = f"{model_type}_tmpr_{aki_subgrp}_{pred_task}_1d_{fs_type}.pkl"
                            else:
                                raise ValueError(f"Unknown validation type: {val_type}")

                            file_path = os.path.join(model_path, file_name)
                            with open(file_path, "wb") as file:
                                pickle.dump(result_dict[pred_task], file)

                            print(
                                f"Completed: {model_type}, {val_type}, {aki_subgrp}, {pred_task}, {fs_type} for site: {site_label}")

                        except Exception as e:
                            err_msg = (f"Error running model={model_type}, val={val_type}, fs={fs_type}, "
                                       f"site={site_label}, task={pred_task}: {str(e)}")
                            logging.error(err_msg)
                            print(err_msg)