# Imports

In [2]:
from src.utilities import predict_and_save, split_data, get_method_name, score_method
from src.preprocess import process_missing_values, main_preprocess, create_entity
from sklearn.model_selection import train_test_split
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.ensemble import GradientBoostingSurvivalAnalysis, RandomSurvivalForest
import lightgbm as lgb
import pandas as pd

import warnings
import logging

# Régler le logger de Featuretools au niveau ERROR
logging.getLogger('featuretools.entityset').setLevel(logging.ERROR)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore",message=".*Ill-conditioned matrix.*")

# Settings

In [3]:
GLOBAL = {
    "cox": {"run": True, "save":False, "shap": False},
    "xgb": {"run": True, "save":True, "shap": False},
    "lgbm": {"run": True, "save":False, "shap": False},
    "rsf": {"run": True, "save":False, "shap": False}
}

PARAMS = {
    "size": 0.7,
    "clinical": ["CYTOGENETICS"],#["CYTOGENETICS"], # Possible: ["CYTOGENETICS", "HB/PLT", "logMONOCYTES", "logWBC", "logANC"] ["BM_BLAST+WBC", "BM_BLAST/HB", "HB*PLT", "HB/num_trisomies"]
    "molecular": [],#["END-START"], # Possible: ["GENE", "EFFECT", "ALT", "REF", "END-START"]
    "merge": [], # Possible: ["featuretools", "gpt"]
    "additional": [
        #['cadd', 'phred'],
        # ['cadd', 'rawscore'],
        # # ['cadd', 'consequence'],
        # # ['cadd', 'bstatistic'],
        # # ['cadd', 'gerp', 'n'],
        # ['cadd', 'phast_cons', 'mammalian'],
        # ['cadd', 'phylop', 'mammalian'],
        # ['snpeff', 'putative_impact'],
        # # ['snpeff', 'rank'],
        # # ['snpeff', 'total'],
         #['cadd', 'exon'],
        # # ['cadd', 'cds', 'rel_cds_pos']
        ],
    "xgb": {
        'max_depth': 2,
        'learning_rate': 0.05,
        'n_estimators': 450,
        'subsample': 0.55,
        'max_features': 'sqrt',
        'random_state': 26
    },
    "lgbm": {
        'max_depth': 2,
        'learning_rate': 0.05,
        'verbose': 0
    },
    "rsf": {
    'n_estimators':300,  # Nombre d'arbres dans la forêt
    'max_depth':2,
    #'min_samples_split':60,  # Nombre minimum d'échantillons requis pour splitter un nœud
    #'min_samples_leaf':40,  # Nombre minimum d'échantillons par feuille
    'max_features':None,  # Sélection aléatoire des features
    'n_jobs':-1,  # Utilisation de tous les cœurs disponibles
    }
}

##############################################
# Define the methods used for training
##############################################

size_method = get_method_name("size", PARAMS)
clinical_method = get_method_name("clinical", PARAMS)
molecular_method = get_method_name("molecular", PARAMS)
merge_method = get_method_name("merge", PARAMS)

# Preprocess, Handling missing values, Train/Test split

In [4]:
data = create_entity(PARAMS)
data = main_preprocess(data, PARAMS)
X, X_eval, y = split_data(data)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=(1 - PARAMS['size']), random_state=42)
X_train, X_test, X_eval = process_missing_values(X_train, X_test, X_eval, method="impute", strategy="median")

# EDA

In [None]:
from src.report import EDAReport

df_analyze = pd.concat([pd.DataFrame(X_train, columns=X.columns), pd.DataFrame(y_train, columns=["event", "time"])],axis=1)
bool_cols = df_analyze.select_dtypes(include=['bool']).columns
df_analyze[bool_cols] = df_analyze[bool_cols].astype(int)

ede = EDAReport(df_analyze, target_variables=["event", "time"])
ede.generate_report()
ede.display()

# Models

## Fit a CoxPH model

In [6]:
if GLOBAL["cox"]["run"]:
    # Initialize and train the Cox Proportional Hazards model
    cox = CoxPHSurvivalAnalysis()    
    cox.fit(X_train, y_train)
    cox_score_method = score_method(cox, X_train, X_test, y_train, y_test)

    # Predict and save the results
    if GLOBAL["cox"]["save"]:
        predict_and_save(X_eval, cox, method=f"{size_method}-{cox_score_method}-{clinical_method}-{molecular_method}-{merge_method}")

    if GLOBAL["cox"]["shap"]:
        from src.report import ShapReport

        if not isinstance(X_train, pd.DataFrame):
            X_train = pd.DataFrame(X_train, columns=X.columns)

        report_cox = ShapReport(model=cox, X_train=X_train, predict_function=cox.predict)
        report_cox.generate_report(output_html=f"report/shap/shap_{cox.__class__.__name__}_{size_method}-{cox_score_method}-{clinical_method}-{molecular_method}-{merge_method}.html")


CoxPHSurvivalAnalysis Model Concordance Index IPCW on train: 0.684
CoxPHSurvivalAnalysis Model Concordance Index IPCW on test: 0.671


## Fit a LightGBM model

In [6]:
# X_train_lgb = X_train  # Features for training
# y_train_transformed = y_train['time']

# # Create LightGBM dataset
# train_dataset = lgb.Dataset(X_train_lgb, label=y_train_transformed)

# # Train the LightGBM model
# model = lgb.train(params=PARAMS['lgbm'], train_set=train_dataset)

# # Evaluate the model using Concordance Index IPCW
# train_ci_ipcw = concordance_index_ipcw(y_train, y_train, -model.predict(X_train), tau=7)[0]
# test_ci_ipcw = concordance_index_ipcw(y_train, y_test, -model.predict(X_test), tau=7)[0]
# print(f"LightGBM Survival Model Concordance Index IPCW on train: {train_ci_ipcw:.3f}")
# print(f"LightGBM Survival Model Concordance Index IPCW on test: {test_ci_ipcw:.3f}")
# lightgbm_score_method = f"score_{train_ci_ipcw:.3f}_{test_ci_ipcw:.3f}"

# # Predict and save the results
# if GLOBAL["save_lgbm"]:
#     predict_and_save(X_eval, model, method=f"{size_method}-{lightgbm_score_method}-{clinical_method}-{molecular_method}-{merge_method}-{PARAMS['xgb']['max_depth']}_lr{PARAMS['xgb']['learning_rate']}")


## Fit a Gradient Boosting Survival Analysis Model

In [7]:
if GLOBAL["xgb"]["run"]:
    xgb_params_method = "_".join([(str(key) + "=" + str(PARAMS['xgb'][key])) for key in PARAMS['xgb'].keys()])

    xgb = GradientBoostingSurvivalAnalysis(**PARAMS['xgb'])
    xgb.fit(X_train, y_train)
    xgboost_score_method = score_method(xgb, X_train, X_test, y_train, y_test)

    if GLOBAL["xgb"]["save"]:
        predict_and_save(X_eval, xgb, method=f"{size_method}-{xgboost_score_method}--{molecular_method}-{merge_method}-{xgb_params_method}")

    if GLOBAL["xgb"]["shap"]:
        from src.report import ShapReport

        if not isinstance(X_train, pd.DataFrame):
            X_train = pd.DataFrame(X_train, columns=X.columns)

        report_xgb = ShapReport(model=xgb, X_train=X_train, predict_function=xgb.predict)
        report_xgb.generate_report(output_html=f"report/shap/shap_{xgb.__class__.__name__}_{size_method}-{xgboost_score_method}--{molecular_method}-{merge_method}-{xgb_params_method}.html")


GradientBoostingSurvivalAnalysis Model Concordance Index IPCW on train: 0.723
GradientBoostingSurvivalAnalysis Model Concordance Index IPCW on test: 0.680


## Fit a Random Survival Forest model

In [8]:

if GLOBAL["rsf"]["run"]:

    rsf_params_method = "_".join([(str(key) + "=" + str(PARAMS['rsf'][key])) for key in PARAMS['rsf'].keys()])

    rsf = RandomSurvivalForest(**PARAMS["rsf"], random_state=42)
    rsf.fit(X_train, y_train)
    rsf_score_method = score_method(rsf, X_train, X_test, y_train, y_test)

    if GLOBAL["rsf"]["save"]:
        predict_and_save(X_eval, rsf, method=f"{size_method}-{rsf_score_method}-{clinical_method}-{molecular_method}-{merge_method}-{rsf_params_method}")

    if GLOBAL["rsf"]["shap"]:
        from src.report import ShapReport

        if not isinstance(X_train, pd.DataFrame):
            X_train = pd.DataFrame(X_train, columns=X.columns)

        report_rsf = ShapReport(model=rsf, X_train=X_train, predict_function=rsf.predict)
        report_rsf.generate_report(output_html=f"report/shap/shap_{rsf.__class__.__name__}_{size_method}-{rsf_score_method}-{clinical_method}-{molecular_method}-{merge_method}-{rsf_params_method}.html")


RandomSurvivalForest Model Concordance Index IPCW on train: 0.668
RandomSurvivalForest Model Concordance Index IPCW on test: 0.661


## Shap Report (Gradient Boosting as an Example)

In [None]:
from src.report import ShapReport

shap = ShapReport(model=xgb, X_train=pd.DataFrame(data=X_train, columns=X.columns), predict_function=xgb.predict)
shap.generate_report(output_file=f"report/shap/shap_{xgb.__class__.__name__}_{size_method}-{xgboost_score_method}--{molecular_method}-{merge_method}-{xgb_params_method}.html")
shap.display()

# Deep

## Custom second preprocess

In [11]:
from src.deep import convert_float32, convert_survival_data, score_method_deep

X_train_deep, X_test_deep, X_eval_deep = convert_float32(X_train, X_test, X_eval)

y_train_deep = convert_survival_data(y_train)
y_test_deep = convert_survival_data(y_test)

## Parameters

In [12]:
params = {
    'num_nodes': [128, 128, 23],  # Augmentation de la capacité du modèle
    'out_features': 1,
    'batch_norm': True,
    'dropout': 0.0,              # Vous pouvez tester avec ou sans dropout
    'output_bias': False,
    'in_features': X_train_deep.shape[1]
}
params_CoxHP = {
    "batch_size": 256,
    "lr": 0.0005,              # Taux d'apprentissage légèrement réduit
    "epochs": 512,             # Augmentation du nombre d'époques
    "verbose": False
}

## Fit a CoxHP

In [None]:
from sklearn.preprocessing import StandardScaler
import torch
import torchtuples as tt
from pycox.models import CoxPH
import numpy as np

np.random.seed(42)
_ = torch.manual_seed(4)

params['in_features'] = X_train_deep.shape[1]

net = tt.practical.MLPVanilla(**params)
model = CoxPH(net, tt.optim.Adam)

model.optimizer.set_lr(params_CoxHP['lr'])

batch_size = params_CoxHP['batch_size']

lrfinder = model.lr_finder(X_train_deep, y_train_deep, batch_size, tolerance=10)
_ = lrfinder.plot()

epochs = params_CoxHP['epochs']
callbacks = [tt.callbacks.EarlyStopping()]
verbose = params_CoxHP['verbose']

log = model.fit(X_train_deep, y_train_deep, batch_size, epochs, callbacks, verbose,
                val_data=(X_test_deep, y_test_deep), val_batch_size=batch_size)

_ = log.plot()

print("LogLikehood: ", model.partial_log_likelihood(*(X_test_deep, y_test_deep)).mean())

score_method_deep(model, X_train_deep, X_test_deep, y_train, y_test, reverse=False)

## Make a prediction

In [None]:
from src.deep import predict_and_save_deep

predict_and_save_deep(X_eval_deep, model, method="deepCoxHP")