# Imports

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

pd.set_option('display.max_columns', None)
sns.set_style('whitegrid')
sns.set_palette('muted')

import optuna

from sklearn.model_selection import KFold

from lifelines import CoxPHFitter
from lifelines.utils import concordance_index

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATA_PATH = '../data/equity-post-HCT-survival-predictions/'
RANDOM_STATE = 54321

## Data

In [3]:
X = pd.read_pickle(DATA_PATH + 'train_test_split/X_25-12-2024.pkl')
y = pd.read_pickle(DATA_PATH + 'train_test_split/y_25-12-2024.pkl')
efs_time = pd.read_pickle(DATA_PATH + 'train_test_split/efs_time_25-12-2024.pkl')
race_group = pd.read_pickle(DATA_PATH + 'train_test_split/race_group_25-12-2024.pkl')

In [4]:
for col in X.columns:
    X[col] = X[col].ffill().bfill()
    

In [5]:
X_cph = pd.concat([X[[
    'hla_match_c_high',
    'hla_high_res_8',
    'hla_low_res_6',
    # 'hla_high_res_6',
    'hla_high_res_10',
    # 'hla_match_dqb1_high',
    'hla_nmdp_6',
    'hla_match_c_low',
    'hla_match_drb1_low',
    'hla_match_dqb1_low',
    'year_hct',
    'hla_match_a_high',
    'donor_age',
    'hla_match_b_low',
    'age_at_hct',
    # 'hla_match_a_low',
    'hla_match_b_high',
    'comorbidity_score',
    'karnofsky_score',
    # 'hla_low_res_8',
    # 'hla_match_drb1_high',
    # 'hla_low_res_10',
    'dri_score',
    'psych_disturb',
    'cyto_score',
    'diabetes',
    'tbi_status',
    'arrhythmia',
    'graft_type',
    'vent_hist',
    'renal_issue',
    'pulm_severe',
    'cmv_status',
    'tce_imm_match',
    'rituximab',
    'prod_type',
    'cyto_score_detail',
    'conditioning_intensity',
    'obesity',
    'mrd_hct',
    'in_vivo_tcd',
    'tce_match',
    'hepatic_severe',
    'prior_tumor',
    'peptic_ulcer',
    'gvhd_proph',
    'rheum_issue',
    'sex_match',
    'hepatic_mild',
    'tce_div_match',
    'donor_related',
    'melphalan_dose',
    'cardiac',
    'pulm_moderate',
    'prim_disease_hct_AI',
    'prim_disease_hct_ALL',
    'prim_disease_hct_AML',
    'prim_disease_hct_CML',
    'prim_disease_hct_HD',
    'prim_disease_hct_HIS',
    'prim_disease_hct_IEA',
    'prim_disease_hct_IIS',
    'prim_disease_hct_IMD',
    'prim_disease_hct_IPA',
    'prim_disease_hct_MDS',
    'prim_disease_hct_MPN',
    'prim_disease_hct_NHL',
    'prim_disease_hct_Other acute leukemia',
    'prim_disease_hct_Other leukemia',
    'prim_disease_hct_PCD',
    'prim_disease_hct_SAA',
    # 'prim_disease_hct_Solid tumor',
    'ethnicity_Hispanic or Latino',
    'ethnicity_Non-resident of the U.S.',
    'ethnicity_Not Hispanic or Latino',
    'race_group_American Indian or Alaska Native',
    'race_group_Asian',
    'race_group_Black or African-American',
    'race_group_More than one race',
    'race_group_Native Hawaiian or other Pacific Islander',
    # 'race_group_White',
]], y, efs_time], axis=1)

## Metric

In [6]:
from tqdm import tqdm

def cross_validate_cox(model, X, duration_col, event_col, cv=5):
    kf = KFold(n_splits=cv, shuffle=True, random_state=RANDOM_STATE)
    c_index_scores = []
    
    for train_index, test_index in kf.split(X):
        train_data = X.iloc[train_index]
        test_data = X.iloc[test_index]
        
        model.fit(train_data, duration_col=duration_col, event_col=event_col)
        predictions = model.predict_partial_hazard(test_data)
        
        c_index = concordance_index(
            test_data[duration_col],
            -predictions,
            event_observed=test_data[event_col]
        )
        
        c_index_scores.append(c_index)
    
    return np.mean(c_index_scores)

# Search

In [7]:
def objective(trial):
    penalizer = trial.suggest_float('penalizer', 1e-4, 1e-1, log=True)
    l1_ratio = trial.suggest_float('l1_ratio', 0.0, 1.0)
    
    cph = CoxPHFitter(
        penalizer=penalizer,
        l1_ratio=l1_ratio,
    )

    c_index_scores = cross_validate_cox(cph, X_cph, duration_col='efs_time', event_col='efs', cv=5)
    
    return round(c_index_scores, 4)

In [8]:
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=100)

[I 2025-01-12 17:59:57,342] A new study created in memory with name: no-name-2f2e5d40-9420-42a9-b4ed-bb61f5a4b6ef
[I 2025-01-12 18:01:57,259] Trial 0 finished with value: 0.6402 and parameters: {'penalizer': 0.00017551940918715197, 'l1_ratio': 0.14555558029464366}. Best is trial 0 with value: 0.6402.
[I 2025-01-12 18:04:22,322] Trial 1 finished with value: 0.6402 and parameters: {'penalizer': 0.0003589180564101607, 'l1_ratio': 0.7656504486072426}. Best is trial 0 with value: 0.6402.
[I 2025-01-12 18:07:06,616] Trial 2 finished with value: 0.6402 and parameters: {'penalizer': 0.0008930933580368345, 'l1_ratio': 0.6545004379039904}. Best is trial 0 with value: 0.6402.
[I 2025-01-12 18:09:35,536] Trial 3 finished with value: 0.6402 and parameters: {'penalizer': 0.0007540054172331479, 'l1_ratio': 0.45404601296196234}. Best is trial 0 with value: 0.6402.
[I 2025-01-12 18:11:53,083] Trial 4 finished with value: 0.6402 and parameters: {'penalizer': 0.0010968863297445871, 'l1_ratio': 0.14219790

KeyboardInterrupt: 

In [9]:
print("Mejores hiperparámetros:", study.best_params)
print("Mejor índice de concordancia:", study.best_value)

Mejores hiperparámetros: {'penalizer': 0.0007403068172116651, 'l1_ratio': 0.9820846109393846}
Mejor índice de concordancia: 0.6398
