# Imports

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

pd.set_option('display.max_columns', None)
sns.set_style('whitegrid')
sns.set_palette('muted')

import optuna

from sklearn.base import clone

from sklearn.model_selection import KFold
from lifelines.utils import concordance_index
import xgboost as xgb

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATA_PATH = '../data/equity-post-HCT-survival-predictions/'
RANDOM_STATE = 54321

# Data

In [3]:
X = pd.read_pickle(DATA_PATH + 'train_test_split/X_25-12-2024.pkl')
y = pd.read_pickle(DATA_PATH + 'train_test_split/y_25-12-2024.pkl')
efs_time = pd.read_pickle(DATA_PATH + 'train_test_split/efs_time_25-12-2024.pkl')
race_group = pd.read_pickle(DATA_PATH + 'train_test_split/race_group_25-12-2024.pkl')

In [4]:
for col in X.columns:
    X[col] = X[col].ffill().bfill()

In [5]:
X_cph = X[[
    'hla_match_c_high',
    'hla_high_res_8',
    'hla_low_res_6',
    # 'hla_high_res_6',
    'hla_high_res_10',
    # 'hla_match_dqb1_high',
    'hla_nmdp_6',
    'hla_match_c_low',
    'hla_match_drb1_low',
    'hla_match_dqb1_low',
    'year_hct',
    'hla_match_a_high',
    'donor_age',
    'hla_match_b_low',
    'age_at_hct',
    # 'hla_match_a_low',
    'hla_match_b_high',
    'comorbidity_score',
    'karnofsky_score',
    # 'hla_low_res_8',
    # 'hla_match_drb1_high',
    # 'hla_low_res_10',
    'dri_score',
    'psych_disturb',
    'cyto_score',
    'diabetes',
    'tbi_status',
    'arrhythmia',
    'graft_type',
    'vent_hist',
    'renal_issue',
    'pulm_severe',
    'cmv_status',
    'tce_imm_match',
    'rituximab',
    'prod_type',
    'cyto_score_detail',
    'conditioning_intensity',
    'obesity',
    'mrd_hct',
    'in_vivo_tcd',
    'tce_match',
    'hepatic_severe',
    'prior_tumor',
    'peptic_ulcer',
    'gvhd_proph',
    'rheum_issue',
    'sex_match',
    'hepatic_mild',
    'tce_div_match',
    'donor_related',
    'melphalan_dose',
    'cardiac',
    'pulm_moderate',
    'prim_disease_hct_AI',
    'prim_disease_hct_ALL',
    'prim_disease_hct_AML',
    'prim_disease_hct_CML',
    'prim_disease_hct_HD',
    'prim_disease_hct_HIS',
    'prim_disease_hct_IEA',
    'prim_disease_hct_IIS',
    'prim_disease_hct_IMD',
    'prim_disease_hct_IPA',
    'prim_disease_hct_MDS',
    'prim_disease_hct_MPN',
    'prim_disease_hct_NHL',
    'prim_disease_hct_Other acute leukemia',
    'prim_disease_hct_Other leukemia',
    'prim_disease_hct_PCD',
    'prim_disease_hct_SAA',
    # 'prim_disease_hct_Solid tumor',
    'ethnicity_Hispanic or Latino',
    'ethnicity_Non-resident of the U.S.',
    'ethnicity_Not Hispanic or Latino',
    'race_group_American Indian or Alaska Native',
    'race_group_Asian',
    'race_group_Black or African-American',
    'race_group_More than one race',
    'race_group_Native Hawaiian or other Pacific Islander',
    # 'race_group_White',
]]

In [6]:
dtrain = xgb.DMatrix(X_cph, label=efs_time)
dtrain.set_weight(y)

## Metric

In [7]:
def cross_validate_xgboost_survival(params, X, efs, efs_time, cv=5):
    kf = KFold(n_splits=cv, shuffle=True, random_state=RANDOM_STATE)
    c_index_scores = []
    
    for train_index, test_index in kf.split(X):
        X_train, X_test = X.iloc[train_index], X.iloc[test_index]
        efs_time_train, efs_time_test = efs_time.iloc[train_index], efs_time.iloc[test_index]
        efs_train, efs_test = efs.iloc[train_index], efs.iloc[test_index]
        
        dtrain = xgb.DMatrix(X_train, label=efs_time_train)
        dtrain.set_weight(efs_train)
        dtest = xgb.DMatrix(X_test, label=efs_time_test)
        dtest.set_weight(efs_test)
        
        model = xgb.train(params, dtrain)
        
        pred_risks = model.predict(dtest)
        
        c_index = concordance_index(
            efs_time_test,
            -pred_risks,
            event_observed=efs_test
        )
        
        c_index_scores.append(c_index)
    
    return np.mean(c_index_scores)

# Search

In [8]:
efs = pd.Series(y, name='efs')

def objective(trial):
    params = {
        'objective': 'survival:cox',
        'eval_metric': 'cox-nloglik',
        'eta': trial.suggest_float('eta', 0.0001, 0.3),
        'max_depth': trial.suggest_int('max_depth', 3, 100),
        'subsample': trial.suggest_float('subsample', 0.5, 1.0),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
        # 'lambda_': trial.suggest_float('lambda', 0.0, 1.0)
    }
    
    cv_index = cross_validate_xgboost_survival(
        params, X, efs, efs_time, cv=5
    )
    
    return round(cv_index, 4)

In [9]:
params = {'eta': 0.05884790339411519, 'max_depth': 50, 'subsample': 0.9909449377442469, 'colsample_bytree': 0.513728660762146}

In [10]:
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=100)

[I 2025-01-15 21:26:43,087] A new study created in memory with name: no-name-8330b77b-b4da-4632-874a-3faa66cc9f61
[I 2025-01-15 21:26:43,634] Trial 0 finished with value: 0.5464 and parameters: {'eta': 0.06242337880404521, 'max_depth': 86, 'subsample': 0.5604799580596477, 'colsample_bytree': 0.6075622665574724}. Best is trial 0 with value: 0.5464.
[I 2025-01-15 21:26:44,082] Trial 1 finished with value: 0.5417 and parameters: {'eta': 0.18913644588960707, 'max_depth': 10, 'subsample': 0.5874985401239307, 'colsample_bytree': 0.9063893598494692}. Best is trial 0 with value: 0.5464.
[I 2025-01-15 21:26:44,591] Trial 2 finished with value: 0.5463 and parameters: {'eta': 0.15469880571717626, 'max_depth': 75, 'subsample': 0.9680625452788817, 'colsample_bytree': 0.5693293398867161}. Best is trial 0 with value: 0.5464.
[I 2025-01-15 21:26:45,187] Trial 3 finished with value: 0.5487 and parameters: {'eta': 0.049635608447403, 'max_depth': 68, 'subsample': 0.8630783639245649, 'colsample_bytree': 0

In [11]:
print("Mejores hiperparámetros:", study.best_params)
print("Mejor índice de concordancia:", study.best_value)

Mejores hiperparámetros: {'eta': 0.0460479714782899, 'max_depth': 12, 'subsample': 0.9434730553514691, 'colsample_bytree': 0.501892394882468}
Mejor índice de concordancia: 0.5505


In [12]:
dtrain = xgb.DMatrix(X, label=efs_time)
dtrain.set_weight(efs)

model = xgb.train(study.best_params, dtrain)

pred_risks = model.predict(dtrain)

c_index = concordance_index(
    efs_time,
    pred_risks,
    event_observed=efs
)

print(c_index)

0.6143717834380109


In [13]:
from scipy.stats import spearmanr, pearsonr

# Correlación de Spearman y Pearson
spearman_corr, _ = spearmanr(pred_risks, efs_time)
pearson_corr, _ = pearsonr(pred_risks, efs_time)

print("Spearman Correlation:", spearman_corr)
print("Pearson Correlation:", pearson_corr)

Spearman Correlation: 0.31909882362237957
Pearson Correlation: 0.15675500542645296


In [14]:
from scipy.stats import spearmanr, pearsonr

# Correlación de Spearman y Pearson
spearman_corr, _ = spearmanr(-pred_risks, efs_time)
pearson_corr, _ = pearsonr(-pred_risks, efs_time)

print("Spearman Correlation:", spearman_corr)
print("Pearson Correlation:", pearson_corr)

Spearman Correlation: -0.31909882362237957
Pearson Correlation: -0.15675500542645296
