<a href="https://colab.research.google.com/github/JosselinPerret/ENSxQRT-Data-Challenge/blob/main/QRT_CHALLENGE_FINAL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<p align="center">
  <img src="https://upload.wikimedia.org/wikipedia/fr/8/86/Logo_CentraleSup%C3%A9lec.svg" alt="Logo 1" width="250"/>
  <img src="https://www.qube-rt.com/img/qrt.svg" alt="Logo 2" width="400" style="margin: 20px;"/>
</p>

# Data Challenge : Leukemia Risk Prediction


*GOAL OF THE CHALLENGE and WHY IT IS IMPORTANT:*

The goal of the challenge is to **predict disease risk for patients with blood cancer**, in the context of specific subtypes of adult myeloid leukemias.

The risk is measured through the **overall survival** of patients, i.e. the duration of survival from the diagnosis of the blood cancer to the time of death or last follow-up.

Estimating the prognosis of patients is critical for an optimal clinical management.
For exemple, patients with low risk-disease will be offered supportive care to improve blood counts and quality of life, while patients with high-risk disease will be considered for hematopoietic stem cell transplantion.

The performance metric used in the challenge is the **IPCW-C-Index**.

*THE DATASETS*

The **training set is made of 3,323 patients**.

The **test set is made of 1,193 patients**.

For each patient, you have acces to CLINICAL data and MOLECULAR data.

The details of the data are as follows:

- OUTCOME:
  * OS_YEARS = Overall survival time in years
  * OS_STATUS = 1 (death) , 0 (alive at the last follow-up)

- CLINICAL DATA, with one line per patient:
  
  * ID = unique identifier per patient
  * CENTER = clinical center
  * BM_BLAST = Bone marrow blasts in % (blasts are abnormal blood cells)
  * WBC = White Blood Cell count in Giga/L
  * ANC = Absolute Neutrophil count in Giga/L
  * MONOCYTES = Monocyte count in Giga/L
  * HB = Hemoglobin in g/dL
  * PLT = Platelets coutn in Giga/L
  * CYTOGENETICS = A description of the karyotype observed in the blood cells of the patients, measured by a cytogeneticist. Cytogenetics is the science of chromosomes. A karyotype is performed from the blood tumoral cells. The convention for notation is ISCN (https://en.wikipedia.org/wiki/International_System_for_Human_Cytogenomic_Nomenclature). Cytogenetic notation are: https://en.wikipedia.org/wiki/Cytogenetic_notation. Note that a karyotype can be normal or abnornal. The notation 46,XX denotes a normal karyotype in females (23 pairs of chromosomes including 2 chromosomes X) and 46,XY in males (23 pairs of chromosomes inclusing 1 chromosme X and 1 chromsome Y). A common abnormality in the blood cancerous cells might be for exemple a loss of chromosome 7 (monosomy 7, or -7), which is typically asssociated with higher risk disease

- GENE MOLECULAR DATA, with one line per patient per somatic mutation. Mutations are detected from the sequencing of the blood tumoral cells.
We call somatic (= acquired) mutations the mutations that are found in the tumoral cells but not in other cells of the body.

  * ID = unique identifier per patient
  * CHR START END = position of the mutation on the human genome
  * REF ALT = reference and alternate (=mutant) nucleotide
  * GENE = the affected gene
  * PROTEIN_CHANGE = the consequence of the mutation on the protei that is expressed by a given gene
  * EFFECT = a broad categorization of the mutation consequences on a given gene.
  * VAF = Variant Allele Fraction = it represents the **proportion** of cells with the deleterious mutations.

In [1]:
pip install scikit-survival

Collecting scikit-survival
  Downloading scikit_survival-0.25.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (7.2 kB)
Collecting ecos (from scikit-survival)
  Downloading ecos-2.0.14-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.0 kB)
Collecting osqp<1.0.0,>=0.6.3 (from scikit-survival)
  Downloading osqp-0.6.7.post3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.9 kB)
Collecting qdldl (from osqp<1.0.0,>=0.6.3->scikit-survival)
  Downloading qdldl-0.1.7.post5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.7 kB)
Downloading scikit_survival-0.25.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (3.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.9/3.9 MB[0m [31m54.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading osqp-0.6.7.post3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64

In [2]:
import pandas as pd
import numpy as np
from sksurv.util import Surv
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sentence_transformers import SentenceTransformer
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
from sklearn.model_selection import train_test_split
from sksurv.ensemble import RandomSurvivalForest
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.metrics import concordance_index_censored , concordance_index_ipcw
from sklearn.impute import SimpleImputer
import xgboost as xgb
from sklearn.metrics import accuracy_score

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
# Clinical Data
clin_tr = pd.read_csv("/content/drive/My Drive/Colab Notebooks/qrt data challenge/X_train/clinical_train.csv")
clin_eval = pd.read_csv("/content/drive/My Drive/Colab Notebooks/qrt data challenge/X_test/clinical_test.csv")

# Molecular Data
mol_tr = pd.read_csv("/content/drive/My Drive/Colab Notebooks/qrt data challenge/X_train/molecular_train.csv")
mol_eval = pd.read_csv("/content/drive/My Drive/Colab Notebooks/qrt data challenge/X_test/molecular_test.csv")

y_tr = pd.read_csv("/content/drive/My Drive/Colab Notebooks/qrt data challenge/target_train.csv")
# y_eval = pd.read_csv("/content/drive/My Drive/target_test.csv")

## Engineer cytogenetic features


In [5]:
# Define a list of common cytogenetic abnormalities and patterns based on previous analysis
common_abnormalities = ['Normal', '+8', '-7', 'del(5q)', 'del(7q)', '-Y', 'complex', '>3abnormalities']

# Create binary columns for these abnormalities in clin_tr and clin_eval
for abn in common_abnormalities:
    # Use string contains with case=False and na=False to handle missing values and case variations
    # Also handle variations in notation like '+8', ' +8', '+ 8' by searching for patterns
    if abn == 'Normal':
        clin_tr[f'CYTO_{abn.replace("+", "plus_").replace("-", "minus_").replace(">", "greater_than_").replace("(", "").replace(")", "").replace(",", "").replace(" ", "_")}'] = clin_tr['CYTOGENETICS'].str.contains(abn, case=False, na=False).astype(int)
        clin_eval[f'CYTO_{abn.replace("+", "plus_").replace("-", "minus_").replace(">", "greater_than_").replace("(", "").replace(")", "").replace(",", "").replace(" ", "_")}'] = clin_eval['CYTOGENETICS'].str.contains(abn, case=False, na=False).astype(int)
    else:
        clin_tr[f'CYTO_{abn.replace("+", "plus_").replace("-", "minus_").replace(">", "greater_than_").replace("(", "").replace(")", "").replace(",", "").replace(" ", "_")}'] = clin_tr['CYTOGENETICS'].str.contains(abn.replace("(", "\(").replace(")", "\)").replace("+", "\+"), case=False, na=False).astype(int)
        clin_eval[f'CYTO_{abn.replace("+", "plus_").replace("-", "minus_").replace(">", "greater_than_").replace("(", "").replace(")", "").replace(",", "").replace(" ", "_")}'] = clin_eval['CYTOGENETICS'].str.contains(abn.replace("(", "\(").replace(")", "\)").replace("+", "\+"), case=False, na=False).astype(int)


# Drop the original CYTOGENETICS column
clin_tr = clin_tr.drop('CYTOGENETICS', axis=1)
clin_eval = clin_eval.drop('CYTOGENETICS', axis=1)

# Align columns after creating new features
train_cols = clin_tr.columns.tolist()
eval_cols = clin_eval.columns.tolist()

missing_in_eval = list(set(train_cols) - set(eval_cols))
for col in missing_in_eval:
    clin_eval[col] = 0

missing_in_train = list(set(eval_cols) - set(train_cols))
for col in missing_in_train:
    clin_tr[col] = 0

clin_eval = clin_eval[train_cols]

# Display the head of the updated dataframes
display(clin_tr.head())
display(clin_eval.head())

  clin_tr[f'CYTO_{abn.replace("+", "plus_").replace("-", "minus_").replace(">", "greater_than_").replace("(", "").replace(")", "").replace(",", "").replace(" ", "_")}'] = clin_tr['CYTOGENETICS'].str.contains(abn.replace("(", "\(").replace(")", "\)").replace("+", "\+"), case=False, na=False).astype(int)
  clin_tr[f'CYTO_{abn.replace("+", "plus_").replace("-", "minus_").replace(">", "greater_than_").replace("(", "").replace(")", "").replace(",", "").replace(" ", "_")}'] = clin_tr['CYTOGENETICS'].str.contains(abn.replace("(", "\(").replace(")", "\)").replace("+", "\+"), case=False, na=False).astype(int)
  clin_tr[f'CYTO_{abn.replace("+", "plus_").replace("-", "minus_").replace(">", "greater_than_").replace("(", "").replace(")", "").replace(",", "").replace(" ", "_")}'] = clin_tr['CYTOGENETICS'].str.contains(abn.replace("(", "\(").replace(")", "\)").replace("+", "\+"), case=False, na=False).astype(int)
  clin_eval[f'CYTO_{abn.replace("+", "plus_").replace("-", "minus_").replace(">", "great

Unnamed: 0,ID,CENTER,BM_BLAST,WBC,ANC,MONOCYTES,HB,PLT,CYTO_Normal,CYTO_plus_8,CYTO_minus_7,CYTO_del5q,CYTO_del7q,CYTO_minus_Y,CYTO_complex,CYTO_greater_than_3abnormalities
0,P132697,MSK,14.0,2.8,0.2,0.7,7.6,119.0,0,0,0,0,0,0,0,0
1,P132698,MSK,1.0,7.4,2.4,0.1,11.6,42.0,0,0,0,0,0,0,0,0
2,P116889,MSK,15.0,3.7,2.1,0.1,14.2,81.0,0,0,0,0,0,0,0,0
3,P132699,MSK,1.0,3.9,1.9,0.1,8.9,77.0,0,0,0,0,0,0,0,0
4,P132700,MSK,6.0,128.0,9.7,0.9,11.1,195.0,0,0,0,0,0,0,0,0


Unnamed: 0,ID,CENTER,BM_BLAST,WBC,ANC,MONOCYTES,HB,PLT,CYTO_Normal,CYTO_plus_8,CYTO_minus_7,CYTO_del5q,CYTO_del7q,CYTO_minus_Y,CYTO_complex,CYTO_greater_than_3abnormalities
0,KYW1,KYW,68.0,3.45,0.5865,,7.6,48.0,0,0,0,0,0,0,0,0
1,KYW2,KYW,35.0,3.18,1.2402,,10.0,32.0,0,0,0,0,0,0,0,0
2,KYW3,KYW,,12.4,8.68,,12.3,25.0,0,1,0,0,0,0,0,0
3,KYW4,KYW,61.0,5.55,2.0535,,8.0,44.0,1,0,0,0,0,0,0,0
4,KYW5,KYW,2.0,1.21,0.7381,,8.6,27.0,0,0,1,0,0,0,0,0


## Refine feature combination

In [6]:
# Step 1: Identify the top N most frequently mutated genes
N = 50 # Choose a reasonable number for top genes
top_genes = mol_tr['GENE'].value_counts().head(N).index.tolist()

# Step 2: Create binary columns for the top N genes
for gene in top_genes:
    mol_tr[f'MUT_{gene}'] = mol_tr['GENE'].apply(lambda x: 1 if x == gene else 0)
    mol_eval[f'MUT_{gene}'] = mol_eval['GENE'].apply(lambda x: 1 if x == gene else 0)

# Step 3: Aggregate these gene presence/absence features by patient ID
mol_tr_gene_agg = mol_tr.groupby('ID')[[f'MUT_{gene}' for gene in top_genes]].sum().reset_index()
mol_eval_gene_agg = mol_eval.groupby('ID')[[f'MUT_{gene}' for gene in top_genes]].sum().reset_index()

# Step 1: Identify the unique values in the 'EFFECT' column
effect_types = pd.concat([mol_tr['EFFECT'], mol_eval['EFFECT']]).dropna().unique()

# Step 2: Create binary columns for each unique effect type
for effect in effect_types:
    mol_tr[f'EFFECT_{effect}'] = mol_tr['EFFECT'].apply(lambda x: 1 if x == effect else 0)
    mol_eval[f'EFFECT_{effect}'] = mol_eval['EFFECT'].apply(lambda x: 1 if x == effect else 0)

# Step 3: Aggregate these effect-specific features by patient ID
mol_tr_effect_agg = mol_tr.groupby('ID')[[f'EFFECT_{effect}' for effect in effect_types]].sum().reset_index()
mol_eval_effect_agg = mol_eval.groupby('ID')[[f'EFFECT_{effect}' for effect in effect_types]].sum().reset_index()

# Calculate mean, max, and standard deviation of VAF for each patient
mean_vaf_tr = mol_tr.groupby('ID')['VAF'].mean().reset_index(name='mean_vaf_per_patient')
max_vaf_tr = mol_tr.groupby('ID')['VAF'].max().reset_index(name='max_vaf_per_patient')
std_vaf_tr = mol_tr.groupby('ID')['VAF'].std().reset_index(name='std_vaf_per_patient')

mean_vaf_eval = mol_eval.groupby('ID')['VAF'].mean().reset_index(name='mean_vaf_per_patient')
max_vaf_eval = mol_eval.groupby('ID')['VAF'].max().reset_index(name='max_vaf_per_patient')
std_vaf_eval = mol_eval.groupby('ID')['VAF'].std().reset_index(name='std_vaf_per_patient')

# Merge the preprocessed clinical and engineered molecular features for training
train_combined_updated = pd.merge(clin_tr, mol_tr_gene_agg, on='ID', how='left')
train_combined_updated = pd.merge(train_combined_updated, mol_tr_effect_agg, on='ID', how='left')
train_combined_updated = pd.merge(train_combined_updated, mean_vaf_tr, on='ID', how='left')
train_combined_updated = pd.merge(train_combined_updated, max_vaf_tr, on='ID', how='left')
train_combined_updated = pd.merge(train_combined_updated, std_vaf_tr, on='ID', how='left')

# Merge the preprocessed clinical and engineered molecular features for evaluation
eval_combined_updated = pd.merge(clin_eval, mol_eval_gene_agg, on='ID', how='left')
eval_combined_updated = pd.merge(eval_combined_updated, mol_eval_effect_agg, on='ID', how='left')
eval_combined_updated = pd.merge(eval_combined_updated, mean_vaf_eval, on='ID', how='left')
eval_combined_updated = pd.merge(eval_combined_updated, max_vaf_eval, on='ID', how='left')
eval_combined_updated = pd.merge(eval_combined_updated, std_vaf_eval, on='ID', how='left')

train_combined_updated.fillna(0, inplace=True)
eval_combined_updated.fillna(0, inplace=True)

display(train_combined_updated.head())
print("\nShape of the updated combined training data:", train_combined_updated.shape)
display(eval_combined_updated.head())
print("\nShape of the updated combined evaluation data:", eval_combined_updated.shape)

Unnamed: 0,ID,CENTER,BM_BLAST,WBC,ANC,MONOCYTES,HB,PLT,CYTO_Normal,CYTO_plus_8,...,EFFECT_3_prime_UTR_variant,EFFECT_stop_lost,EFFECT_inframe_variant,EFFECT_synonymous_codon,EFFECT_stop_retained_variant,EFFECT_ITD,EFFECT_PTD,mean_vaf_per_patient,max_vaf_per_patient,std_vaf_per_patient
0,P132697,MSK,14.0,2.8,0.2,0.7,7.6,119.0,0,0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.251578,0.422,0.147784
1,P132698,MSK,1.0,7.4,2.4,0.1,11.6,42.0,0,0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.272867,0.2825,0.008568
2,P116889,MSK,15.0,3.7,2.1,0.1,14.2,81.0,0,0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.039333,0.048,0.007506
3,P132699,MSK,1.0,3.9,1.9,0.1,8.9,77.0,0,0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.209227,0.477,0.136217
4,P132700,MSK,6.0,128.0,9.7,0.9,11.1,195.0,0,0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.4721,0.4721,0.0



Shape of the updated combined training data: (3323, 85)


Unnamed: 0,ID,CENTER,BM_BLAST,WBC,ANC,MONOCYTES,HB,PLT,CYTO_Normal,CYTO_plus_8,...,EFFECT_3_prime_UTR_variant,EFFECT_stop_lost,EFFECT_inframe_variant,EFFECT_synonymous_codon,EFFECT_stop_retained_variant,EFFECT_ITD,EFFECT_PTD,mean_vaf_per_patient,max_vaf_per_patient,std_vaf_per_patient
0,KYW1,KYW,68.0,3.45,0.5865,0.0,7.6,48.0,0,0,...,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.246225,0.384,0.165531
1,KYW2,KYW,35.0,3.18,1.2402,0.0,10.0,32.0,0,0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.281,0.713,0.374188
2,KYW3,KYW,0.0,12.4,8.68,0.0,12.3,25.0,0,1,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.162667,0.327,0.148816
3,KYW4,KYW,61.0,5.55,2.0535,0.0,8.0,44.0,1,0,...,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.351367,0.428,0.129284
4,KYW5,KYW,2.0,1.21,0.7381,0.0,8.6,27.0,0,0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.263667,0.407,0.200143



Shape of the updated combined evaluation data: (1193, 85)


## Train and evaluate model


In [7]:
random_state = 42

# Merge the target variable y_tr with train_combined_updated to ensure OS_STATUS and OS_YEARS are present
train_combined_updated = pd.merge(train_combined_updated, y_tr, on='ID', how='left', suffixes=('', '_y'))

# Drop duplicate OS_YEARS and OS_STATUS columns if they exist after merge
train_combined_updated.drop(columns=['OS_YEARS_y', 'OS_STATUS_y'], inplace=True, errors='ignore')

# Clean the training data by dropping rows with missing target values
train_merged_cleaned_updated = train_combined_updated.dropna(subset=['OS_STATUS', 'OS_YEARS'])

# Define the training features and target
X_updated = train_merged_cleaned_updated.drop(['ID', 'OS_YEARS', 'OS_STATUS'], axis=1)
y_updated = Surv.from_dataframe("OS_STATUS", "OS_YEARS", train_merged_cleaned_updated)

# Apply one-hot encoding to the 'CENTER' column on the full training feature set
X_updated = pd.get_dummies(X_updated, columns=['CENTER'], prefix='CENTER', dummy_na=False)

# Split data into training and validation sets
X_train_updated, X_val_updated, y_train_updated, y_val_updated = train_test_split(X_updated, y_updated, test_size=0.2, random_state=random_state)


# Align columns after one-hot encoding
train_cols_updated = X_train_updated.columns.tolist()
val_cols_updated = X_val_updated.columns.tolist()

missing_in_val_updated = list(set(train_cols_updated) - set(val_cols_updated))
for col in missing_in_val_updated:
    X_val_updated[col] = 0

missing_in_train_updated = list(set(val_cols_updated) - set(train_cols_updated))
for col in missing_in_train_updated:
    X_train_updated[col] = 0

X_val_updated = X_val_updated[train_cols_updated]

# --- build Cox labels for XGBoost: +time if event, -time if censored ---
t_tr = y_train_updated["OS_YEARS"].astype(float)
e_tr = y_train_updated["OS_STATUS"].astype(bool)
cox_label_tr = np.where(e_tr, t_tr, -t_tr)

t_va = y_val_updated["OS_YEARS"].astype(float)
e_va = y_val_updated["OS_STATUS"].astype(bool)
cox_label_va = np.where(e_va, t_va, -t_va)

In [8]:
# --- build Cox labels for XGBoost: +time if event, -time if censored ---
t_tr = y_train_updated["OS_YEARS"].astype(float)
e_tr = y_train_updated["OS_STATUS"].astype(bool)
cox_label_tr = np.where(e_tr, t_tr, -t_tr)

t_va = y_val_updated["OS_YEARS"].astype(float)
e_va = y_val_updated["OS_STATUS"].astype(bool)
cox_label_va = np.where(e_va, t_va, -t_va)

In [9]:
# --- DMatrix ---
dtrain = xgb.DMatrix(X_train_updated, label=cox_label_tr)
dvalid = xgb.DMatrix(X_val_updated,   label=cox_label_va)

In [10]:
params = {
    "objective": "survival:cox",
    "eval_metric": "cox-nloglik",
    "tree_method": "hist",
    "max_depth": 3, # Reduced for faster training
    "eta": 0.1, # Increased learning rate for faster training
    "subsample": 0.8,
    "colsample_bytree": 0.8,
    "lambda": 1.0,
    "alpha": 0.0,
    "seed": random_state,
}

In [11]:
bst = xgb.train(
    params,
    dtrain,
    num_boost_round=1000, # Reduced number of boosting rounds for faster training
    evals=[(dtrain, "train"), (dvalid, "valid")],
    early_stopping_rounds=50, # Reduced early stopping rounds for faster training
    verbose_eval=False,
)

In [12]:
risk_scores_val = bst.predict(dvalid)  # higher = higher hazard (worse prognosis)

In [13]:
eval_times_updated = np.unique(t_va[e_va])  # distinct observed event times in validation
if len(eval_times_updated) == 0:
    c_index_updated = np.nan
else:
    tau = eval_times_updated[-1]
    c_index_updated, _, _, _, _ = concordance_index_ipcw(
        y_train_updated, y_val_updated, risk_scores_val, tau=tau
    )

print(f"IPCW-C-Index (XGBoost Cox) on the updated validation set: {c_index_updated:.4f}")

IPCW-C-Index (XGBoost Cox) on the updated validation set: 0.6995


## Generate submission file


In [14]:
param_grid = {
    'n_estimators': [100, 200],
    'max_depth': [None, 20],
    'min_samples_split': [5, 20],
    'min_samples_leaf': [5, 15]
}

In [15]:
from sklearn.model_selection import GridSearchCV, KFold
from sklearn.metrics import make_scorer
from sksurv.metrics import concordance_index_censored

def c_index_scorer(estimator, X, y):
    return concordance_index_censored(y["OS_STATUS"], y["OS_YEARS"], estimator.predict(X))[0]

kf = KFold(n_splits=5, shuffle=True, random_state=random_state)

rsf = RandomSurvivalForest(random_state=random_state)

grid_search = GridSearchCV(estimator=rsf, param_grid=param_grid, cv=kf, scoring=make_scorer(c_index_scorer, greater_is_better=True), n_jobs=-1)

grid_search.fit(X_train_updated, y_train_updated)

print("Best hyperparameters found:", grid_search.best_params_)



Best hyperparameters found: {'max_depth': None, 'min_samples_leaf': 5, 'min_samples_split': 5, 'n_estimators': 100}


In [19]:
# Merge the target variable y_tr with train_combined_updated to ensure OS_STATUS and OS_YEARS are present
train_combined_updated = pd.merge(train_combined_updated, y_tr, on='ID', how='left', suffixes=('', '_y'))

# Drop duplicate OS_YEARS and OS_STATUS columns if they exist after merge
train_combined_updated.drop(columns=['OS_YEARS_y', 'OS_STATUS_y'], inplace=True, errors='ignore')

# Clean the training data by dropping rows with missing target values
train_merged_cleaned_updated = train_combined_updated.dropna(subset=['OS_STATUS', 'OS_YEARS'])

# Define the training features and target
X_updated = train_merged_cleaned_updated.drop(['ID', 'OS_YEARS', 'OS_STATUS'], axis=1)
y_updated = Surv.from_dataframe("OS_STATUS", "OS_YEARS", train_merged_cleaned_updated)

# Apply one-hot encoding to the 'CENTER' column on the full training feature set
X_updated = pd.get_dummies(X_updated, columns=['CENTER'], prefix='CENTER', dummy_na=False)

# Apply one-hot encoding to the 'CENTER' column on the evaluation feature set
X_eval_updated = pd.get_dummies(eval_combined_updated.drop('ID', axis=1), columns=['CENTER'], prefix='CENTER', dummy_na=False)

# Align the columns of the evaluation features with the training features (X_updated)
train_cols_updated = X_updated.columns.tolist()
eval_cols_updated = X_eval_updated.columns.tolist()

# Add missing columns in eval_combined_updated and set to 0
missing_in_eval_updated = list(set(train_cols_updated) - set(eval_cols_updated))
for col in missing_in_eval_updated:
    X_eval_updated[col] = 0

# Ensure the order of columns in X_eval_updated matches X_updated
X_eval_updated = X_eval_updated[train_cols_updated]

# Define the best hyperparameters found from the grid search
best_params = {'max_depth': None, 'min_samples_leaf': 5, 'min_samples_split': 5, 'n_estimators': 100}

# Retrain the model with the correctly encoded and aligned training data and best hyperparameters
rsf_updated = RandomSurvivalForest(
    n_estimators=best_params['n_estimators'],
    min_samples_split=best_params['min_samples_split'],
    min_samples_leaf=best_params['min_samples_leaf'],
    max_depth=best_params['max_depth'],
    n_jobs=-1,
    random_state=random_state
)
rsf_updated.fit(X_updated, y_updated)

# Predict the risk scores using the retrained model
predictions_updated = rsf_updated.predict(X_eval_updated)

# Create a pandas DataFrame for the submission file
submission_df_updated = pd.DataFrame({'ID': clin_eval['ID'], 'risk_score': predictions_updated})

# Save the submission DataFrame to a CSV file
submission_df_updated.to_csv("submission_improved_cytogenetics.csv", index=False)

display(submission_df_updated.head())

Unnamed: 0,ID,risk_score
0,KYW1,1048.591981
1,KYW2,1078.031844
2,KYW3,680.346974
3,KYW4,1075.803663
4,KYW5,953.334841


In [20]:
# Calculate C-index for Random Survival Forest on the validation set
risk_scores_rsf_val = rsf_updated.predict(X_val_updated)

# Ensure event times and statuses are aligned with the validation data
y_val_rsf = y_val_updated

eval_times_rsf_val = np.unique(y_val_rsf["OS_YEARS"][y_val_rsf["OS_STATUS"]])
if len(eval_times_rsf_val) == 0:
    c_index_rsf_val = np.nan
else:
    tau_rsf = eval_times_rsf_val[-1]
    c_index_rsf_val, _, _, _, _ = concordance_index_ipcw(
        y_train_updated, y_val_rsf, risk_scores_rsf_val, tau=tau_rsf
    )

print(f"IPCW-C-Index (Random Survival Forest) on the validation set: {c_index_rsf_val:.4f}")
print(f"IPCW-C-Index (XGBoost Cox) on the validation set: {c_index_updated:.4f}")

# Compare the C-indices
if c_index_rsf_val > c_index_updated:
    print("\nRandom Survival Forest performed better on the validation set.")
elif c_index_updated > c_index_rsf_val:
    print("\nXGBoost Cox performed better on the validation set.")
else:
    print("\nBoth models performed equally on the validation set.")

IPCW-C-Index (Random Survival Forest) on the validation set: 0.8006
IPCW-C-Index (XGBoost Cox) on the validation set: 0.6995

Random Survival Forest performed better on the validation set.
