# Imports

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

pd.set_option('display.max_columns', None)
sns.set_style('whitegrid')
sns.set_palette('muted')

from datetime import datetime

from tqdm import tqdm

from sklearn.base import clone

from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline

from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from xgboost import XGBRegressor
from lightgbm import LGBMRegressor

In [2]:
DATA_PATH = '../data/equity-post-HCT-survival-predictions/'
RANDOM_STATE = 54321

## Data

In [3]:
sample_df = pd.read_csv(DATA_PATH + 'sample_submission.csv')
test_df = pd.read_csv(DATA_PATH + 'test.csv')

In [4]:
X = pd.read_pickle(DATA_PATH + 'train_test_split/X_25-12-2024.pkl')
y = pd.read_pickle(DATA_PATH + 'train_test_split/y_25-12-2024.pkl')
efs_time = pd.read_pickle(DATA_PATH + 'train_test_split/efs_time_25-12-2024.pkl')
race_group = pd.read_pickle(DATA_PATH + 'train_test_split/race_group_25-12-2024.pkl')

In [5]:
print(X.shape, y.shape, efs_time.shape, race_group.shape)

(28800, 81) (28800,) (28800,) (28800,)


# Modeling

## Metric

In [6]:
import pandas as pd
import pandas.api.types
import numpy as np
from lifelines.utils import concordance_index

class ParticipantVisibleError(Exception):
    pass


def score(solution: pd.DataFrame, submission: pd.DataFrame, row_id_column_name: str) -> float:
    """
    >>> import pandas as pd
    >>> row_id_column_name = "id"
    >>> y_pred = {'prediction': {0: 1.0, 1: 0.0, 2: 1.0}}
    >>> y_pred = pd.DataFrame(y_pred)
    >>> y_pred.insert(0, row_id_column_name, range(len(y_pred)))
    >>> y_true = { 'efs': {0: 1.0, 1: 0.0, 2: 0.0}, 'efs_time': {0: 25.1234,1: 250.1234,2: 2500.1234}, 'race_group': {0: 'race_group_1', 1: 'race_group_1', 2: 'race_group_1'}}
    >>> y_true = pd.DataFrame(y_true)
    >>> y_true.insert(0, row_id_column_name, range(len(y_true)))
    >>> score(y_true.copy(), y_pred.copy(), row_id_column_name)
    0.75
    """
    
    del solution[row_id_column_name]
    del submission[row_id_column_name]
    
    event_label = 'efs'
    interval_label = 'efs_time'
    prediction_label = 'prediction'
    for col in submission.columns:
        if not pandas.api.types.is_numeric_dtype(submission[col]):
            raise ParticipantVisibleError(f'Submission column {col} must be a number')
    # Merging solution and submission dfs on ID
    merged_df = pd.concat([solution, submission], axis=1)
    merged_df.reset_index(inplace=True)
    merged_df_race_dict = dict(merged_df.groupby(['race_group']).groups)
    metric_list = []
    for race in merged_df_race_dict.keys():
        # Retrieving values from y_test based on index
        indices = sorted(merged_df_race_dict[race])
        merged_df_race = merged_df.iloc[indices]
        # Calculate the concordance index
        c_index_race = concordance_index(
                        merged_df_race[interval_label],
                        -merged_df_race[prediction_label],
                        merged_df_race[event_label])
        metric_list.append(c_index_race)
    return float(np.mean(metric_list)-np.sqrt(np.var(metric_list)))

In [7]:
def score_wrapper(y_true, y_pred, efs_time, race_group):
    y_pred_dict = {
        'prediction': {i: pred for i, pred in enumerate(y_pred)}
    }        
    y_pred_dict = pd.DataFrame(y_pred_dict)
    y_pred_dict.insert(0, 'id', range(len(y_pred_dict)))
    
    y_true_dict = {
        'efs': {i: y for i, y in enumerate(y_true.values)},
        'efs_time': {i: t for i, t in enumerate(efs_time.values)},
        'race_group': {i: r for i, r in enumerate(race_group.values)},
    }
    y_true_dict = pd.DataFrame(y_true_dict)
    y_true_dict.insert(0, 'id', range(len(y_true_dict)))
    
    return score(y_true_dict.copy(), y_pred_dict.copy(), 'id')

def cross_validate(model, X, y, cv=10, scale=False):
    cv_scores = []
    
    for i in tqdm(range(cv)):
        test_idxs = list(range(int((len(X)*(i)/cv)), int((len(X)*(i+1)/cv))))
        
        X_train = X.drop(index=test_idxs)
        y_train = y.drop(index=test_idxs)
        
        X_test = X.iloc[test_idxs]
        y_test = y.iloc[test_idxs]
        
        if scale:
            scaler = StandardScaler()
            X_train = scaler.fit_transform(X_train)
            X_test = scaler.transform(X_test)
        
        model_copy = clone(model)
        
        model_copy.fit(X_train, y_train)
        y_pred = model_copy.predict(X_test)
        
        cv_scores.append(score_wrapper(
            y_test, 
            y_pred, 
            efs_time.iloc[test_idxs], 
            race_group.iloc[test_idxs]
        ))
    
    return np.mean(cv_scores)

## Scikit learn model

Best CVs:
- numeric: 0.5901
- numeric + dri_score: 0.6023
- numeric + dri_score + psych_disturb: 0.6033
- numeric + dri_score + psych_disturb + cyto_score: 0.6067
- numeric + dri_score + psych_disturb + cyto_score + diabetes: 0.6069
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status: 0.6075
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia: 0.6074
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type: 0.6104
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist: 0.6109
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue: 0.6104
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe: 0.6122
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct: 0.6151
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status: 0.6162
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match: 0.6172
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab: 0.6167
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab + prod_type: 0.6177
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab + prod_type + cyto_score_detail: 0.6194
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab + prod_type + cyto_score_detail + conditioning_intensity: 0.6240
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab + prod_type + cyto_score_detail + conditioning_intensity + ethnicity: 0.6230
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab + prod_type + cyto_score_detail + conditioning_intensity + ethnicity + obesity: 0.6229
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab + prod_type + cyto_score_detail + conditioning_intensity + ethnicity + obesity + mrd_hct: 0.6237
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab + prod_type + cyto_score_detail + conditioning_intensity + ethnicity + obesity + mrd_hct + in_vivo_tcd: 0.6225
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab + prod_type + cyto_score_detail + conditioning_intensity + ethnicity + obesity + mrd_hct + in_vivo_tcd + tce_match: 0.6223
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab + prod_type + cyto_score_detail + conditioning_intensity + ethnicity + obesity + mrd_hct + in_vivo_tcd + tce_match + hepatic_severe: 0.6228
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab + prod_type + cyto_score_detail + conditioning_intensity + ethnicity + obesity + mrd_hct + in_vivo_tcd + tce_match + hepatic_severe + prior_tumor: 0.6238
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab + prod_type + cyto_score_detail + conditioning_intensity + ethnicity + obesity + mrd_hct + in_vivo_tcd + tce_match + hepatic_severe + prior_tumor + peptic_ulcer: 0.6237
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab + prod_type + cyto_score_detail + conditioning_intensity + ethnicity + obesity + mrd_hct + in_vivo_tcd + tce_match + hepatic_severe + prior_tumor + peptic_ulcer + gvhd_proph: 0.6226
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab + prod_type + cyto_score_detail + conditioning_intensity + ethnicity + obesity + mrd_hct + in_vivo_tcd + tce_match + hepatic_severe + prior_tumor + peptic_ulcer + gvhd_proph + rheum_issue: 0.6222
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab + prod_type + cyto_score_detail + conditioning_intensity + ethnicity + obesity + mrd_hct + in_vivo_tcd + tce_match + hepatic_severe + prior_tumor + peptic_ulcer + gvhd_proph + rheum_issue + sex_match: 0.6266
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab + prod_type + cyto_score_detail + conditioning_intensity + ethnicity + obesity + mrd_hct + in_vivo_tcd + tce_match + hepatic_severe + prior_tumor + peptic_ulcer + gvhd_proph + rheum_issue + sex_match + hepatic_mild: 0.6269
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab + prod_type + cyto_score_detail + conditioning_intensity + ethnicity + obesity + mrd_hct + in_vivo_tcd + tce_match + hepatic_severe + prior_tumor + peptic_ulcer + gvhd_proph + rheum_issue + sex_match + hepatic_mild + tce_div_match: 0.6271
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab + prod_type + cyto_score_detail + conditioning_intensity + ethnicity + obesity + mrd_hct + in_vivo_tcd + tce_match + hepatic_severe + prior_tumor + peptic_ulcer + gvhd_proph + rheum_issue + sex_match + hepatic_mild + tce_div_match + donor_related: 0.6275
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab + prod_type + cyto_score_detail + conditioning_intensity + ethnicity + obesity + mrd_hct + in_vivo_tcd + tce_match + hepatic_severe + prior_tumor + peptic_ulcer + gvhd_proph + rheum_issue + sex_match + hepatic_mild + tce_div_match + donor_related + melphalan_dose: 0.6269
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab + prod_type + cyto_score_detail + conditioning_intensity + ethnicity + obesity + mrd_hct + in_vivo_tcd + tce_match + hepatic_severe + prior_tumor + peptic_ulcer + gvhd_proph + rheum_issue + sex_match + hepatic_mild + tce_div_match + donor_related + melphalan_dose + cardiac: 0.6287
- numeric + dri_score + psych_disturb + cyto_score + diabetes + tbi_status + arrhythmia + graft_type + vent_hist + renal_issue + pulm_severe + prim_disease_hct + cmv_status + tce_imm_match + rituximab + prod_type + cyto_score_detail + conditioning_intensity + ethnicity + obesity + mrd_hct + in_vivo_tcd + tce_match + hepatic_severe + prior_tumor + peptic_ulcer + gvhd_proph + rheum_issue + sex_match + hepatic_mild + tce_div_match + donor_related + melphalan_dose + cardiac + pulm_moderate: 0.6280






In [8]:
X = X.dropna()
y = y[X.index].copy()

X.reset_index(drop=True, inplace=True)
y.reset_index(drop=True, inplace=True)

In [9]:
print(X.shape, y.shape)

(19280, 81) (19280,)


- Linear Regression

In [10]:
# linreg = LinearRegression()
# cv_score = cross_validate(linreg, X, y, cv=5, scale=True)
# print(f'CV Score: {cv_score:.4f}')

cv: 0.5777

- Random Forest

In [11]:
# rfreg = RandomForestRegressor(random_state=RANDOM_STATE)
# cv_score = cross_validate(rfreg, X, y, cv=5)
# print(f'CV Score: {cv_score:.4f}')

cv: 0.5789

- XGBoost

In [12]:
# xgbreg = XGBRegressor(random_state=RANDOM_STATE)
# cv_score = cross_validate(xgbreg, X, y, cv=5)
# print(f'CV Score: {cv_score:.4f}')

0.5759

- LightGBM

In [13]:
# lgbmreg = LGBMRegressor(verbose=-1, random_state=RANDOM_STATE)
# cv_score = cross_validate(lgbmreg, X, y, cv=5)
# print(f'CV Score: {cv_score:.4f}')

0.5862

- Voting Regressor

In [14]:
from sklearn.base import BaseEstimator, RegressorMixin

class VotingRegressor(BaseEstimator, RegressorMixin):
    def __init__(self, estimators, weights=None):
        self.estimators = estimators
        self.weights = weights if weights else np.ones(len(estimators)) / len(estimators)
    
    def fit(self, X, y):
        self.models_ = []
        for name, model in self.estimators:
            model.fit(X, y)
            self.models_.append((name, model))
        return self
    
    def predict(self, X):
        predictions = np.array([model.predict(X) for _, model in self.models_])
        return np.average(predictions, axis=0, weights=self.weights)

estimators = [
    ('linreg', Pipeline([
        # ('scaler', StandardScaler()), 
        ('linreg', LinearRegression(**{'fit_intercept': False}))
    ])),
    ('rfreg', RandomForestRegressor(**{
        'n_estimators': 214,
        'max_depth': 22,
        'min_samples_split': 7,
        'min_samples_leaf': 1,
        'max_features': 'sqrt',
        'bootstrap': True,
        'random_state': RANDOM_STATE,
    })),
    ('xgbreg', XGBRegressor(**{
        'n_estimators': 1951,
        'max_depth': 2,
        'learning_rate': 0.026214891441095647,
        'min_child_weight': 2,
        'subsample': 0.5161925118818808,
        'colsample_bytree': 0.8799771355893139,
        'gamma': 0.6079832570503964,
        'reg_alpha': 0.0005737571071016235,
        'reg_lambda': 0.0277732573625255,
        'random_state': RANDOM_STATE
    })),
    ('lgbmreg', LGBMRegressor(**{
        'n_estimators': 976,
        'max_depth': 2,
        'learning_rate': 0.07642404569065746,
        'num_leaves': 113,
        'min_child_samples': 74,
        'subsample': 0.7080421704267752,
        'colsample_bytree': 0.6623639577905086,
        'reg_alpha': 0.007141897055432138,
        'reg_lambda': 0.013485715695541321,
        'random_state': RANDOM_STATE,
        'verbose': -1,
    }))
]

In [15]:
voting = VotingRegressor(estimators=estimators, weights=[1, 1, 1, 1])
cv_score = cross_validate(voting, X, y, cv=5)
print(f'CV Score: {cv_score:.4f}')

100%|██████████| 5/5 [00:19<00:00,  3.88s/it]

CV Score: 0.5915



