In [62]:
# ===============================
# 0. Imports & hypothèses de base
# ===============================
import numpy as np
import pandas as pd

from lifelines import CoxPHFitter
from sklearn.model_selection import KFold, StratifiedKFold, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.base import clone

from sksurv.util import Surv
from sksurv.linear_model import CoxnetSurvivalAnalysis
from sksurv.metrics import (
    concordance_index_censored,
    integrated_brier_score,
    cumulative_dynamic_auc,
)

import matplotlib.pyplot as plt


In [63]:
df = pd.read_csv("../../data/train_enhanced.csv")
df = df.drop(columns=["ID"])

time_col = "OS_YEARS"
event_col = "OS_STATUS"

cytogenetics_features = [
    # statut global / qualité
    'is_cyto_missing_or_failed',
    'is_normal_karyotype',
    'is_abnormal_karyotype',

    # complexité / volume d’anomalies
    'has_any_abnormality',
    'n_events',
    'n_chromosomes_altered',
    'n_monosomies_total',
    'n_trisomies_total',
    'n_structural_events_total',

    # anomalies spécifiques défavorables / favorables
    'has_minus5_or_del5q',
    'has_minus7_or_del7q',
    'has_plus8',
    'has_t_8_21',
    'has_inv16_or_t_16_16',
    'has_t_15_17',
    'has_inv3_or_t3_3',
    'has_t_6_9',
    'has_t_9_22',
    'has_abn17p',

    # MK / complexe
    'is_monosomal_karyotype',
    'is_complex_karyotype',

    # résumé type ELN-like
    'eln_like_flag_adverse_cyto',
    'eln_like_flag_intermediate_cyto',
    'eln_like_risk_cyto',

    # ploidie
    'baseline_chr_count',
    'is_hypodiploid',
    'is_hyperdiploid',

    # clonalité
    'total_metaphases',
    'max_clone_size',
    'max_adverse_clone_size',

    # proportions clonales
    'prop_any_abnormal',
    'prop_adverse_5_7',
    'prop_plus8',

    # NEW
    'n_autosomal_monosomies',
    'n_autosomal_trisomies',
    'worst_clone_events',
    'worst_clone_is_adverse',
]

all_features = [c for c in df.columns if c not in [time_col, event_col]]
clinical_mutation_features = sorted(list(set(all_features) - set(cytogenetics_features)))


## Cox univariate analysis

In [64]:
from lifelines import CoxPHFitter
from lifelines.exceptions import ConvergenceError
from lifelines.statistics import logrank_test
import numpy as np
import pandas as pd

def run_univariate_cox(
    df: pd.DataFrame,
    features: list,
    time_col: str,
    event_col: str,
    min_events: int = 10,
    penalizer: float = 0.1,
) -> pd.DataFrame:
    """
    Pour chaque feature dans `features`, fit un Cox PH univarié.
    Retourne HR, IC 95 %, p-value, C-index.
    - utilise un penalizer (ridge) pour stabiliser les fit
    - fallback log-rank pour les variables binaires qui ne convergent pas
    """
    results = []

    for feat in features:
        sub = df[[time_col, event_col, feat]].copy().dropna()

        # nb d'événements total sur ce sous-échantillon
        n_events = int(sub[event_col].sum())
        if n_events < min_events:
            # trop peu d'événements pour quelque chose de stable
            continue

        # détecter si la feature est binaire (0/1 ou deux valeurs distinctes)
        unique_vals = sub[feat].unique()
        is_binary = (len(unique_vals) == 2)

        # essayer d'abord un CoxPH avec pénalisation
        try:
            cph = CoxPHFitter(penalizer=penalizer)
            cph.fit(sub, duration_col=time_col, event_col=event_col)

            summary = cph.summary.loc[feat]

            # lifelines donne directement les HR et IC exponentiés
            hr = summary["exp(coef)"]
            ci_lower = summary["exp(coef) lower 95%"]
            ci_upper = summary["exp(coef) upper 95%"]
            p_value = summary["p"]

            c_index = cph.concordance_index_

            results.append(
                {
                    "feature": feat,
                    "HR": float(hr),
                    "CI_lower_95": float(ci_lower),
                    "CI_upper_95": float(ci_upper),
                    "p_value": float(p_value),
                    "c_index": float(c_index),
                    "n": len(sub),
                    "n_events": n_events,
                    "method": "cox",
                }
            )
            continue  # on passe à la feature suivante si tout s'est bien passé

        except ConvergenceError as e:
            print(f"[WARN] Cox convergence failed for {feat}: {e}")
            # on tente un fallback seulement pour les variables binaires
            if not is_binary:
                continue
        except Exception as e:
            print(f"[WARN] Could not fit Cox for {feat}: {e}")
            if not is_binary:
                continue

        # ================================
        # Fallback : test de log-rank binaire
        # ================================
        # On suppose ici que les valeurs sont 0 / 1 (sinon on mappe)
        vals = sorted(unique_vals)
        v0, v1 = vals[0], vals[1]

        grp0 = sub[sub[feat] == v0]
        grp1 = sub[sub[feat] == v1]

        # il faut des événements dans *les deux* groupes pour que ce soit informatif
        if grp0[event_col].sum() == 0 or grp1[event_col].sum() == 0:
            # séparation complète : HR non estimable proprement
            # mais tu peux quand même garder l'info : direction suggérée mais instable
            results.append(
                {
                    "feature": feat,
                    "HR": np.nan,
                    "CI_lower_95": np.nan,
                    "CI_upper_95": np.nan,
                    "p_value": np.nan,
                    "c_index": np.nan,
                    "n": len(sub),
                    "n_events": n_events,
                    "method": "failed_separation",
                }
            )
            continue

        # log-rank test
        lr = logrank_test(
            grp0[time_col],
            grp1[time_col],
            event_observed_A=grp0[event_col],
            event_observed_B=grp1[event_col],
        )

        # HR approximatif = ratio des taux d'événements par unité de temps
        rate0 = grp0[event_col].sum() / grp0[time_col].sum()
        rate1 = grp1[event_col].sum() / grp1[time_col].sum()
        hr_approx = rate1 / rate0 if rate0 > 0 else np.nan

        # C-index simple basé sur le score binaire (0/1)
        # -> plus 1 = plus à risque
        # on utilise la définition classique : fraction de paires concordantes
        # lifelines a une fonction utilitaire, mais on peut faire simple :
        from lifelines.utils import concordance_index

        c_index_bin = concordance_index(
            sub[time_col].values,
            -sub[feat].values,  # plus grand score = plus faible survie
            sub[event_col].values,
        )

        results.append(
            {
                "feature": feat,
                "HR": float(hr_approx),
                "CI_lower_95": np.nan,  # pas trivial en fallback → on laisse NaN
                "CI_upper_95": np.nan,
                "p_value": float(lr.p_value),
                "c_index": float(c_index_bin),
                "n": len(sub),
                "n_events": n_events,
                "method": "logrank_fallback",
            }
        )

    res_df = pd.DataFrame(results)
    if not res_df.empty:
        # trier par p-value, en mettant les NaN à la fin
        res_df = res_df.sort_values("p_value", na_position="last")
    return res_df


univ_results = run_univariate_cox(
    df=df,
    features=cytogenetics_features,
    time_col=time_col,
    event_col=event_col,
    min_events=10,
    penalizer=0.1,   # tu peux jouer avec 0.01 / 0.1 / 1
)

print(univ_results.head(10))



  return (X - mean) / std

  return (X - mean) / std


[WARN] Cox convergence failed for has_t_8_21: Convergence halted due to matrix inversion problems. Suspicion is high collinearity. Please see the following tips in the lifelines documentation: https://lifelines.readthedocs.io/en/latest/Examples.html#problems-with-convergence-in-the-cox-proportional-hazard-modelMatrix is singular.
[WARN] Cox convergence failed for has_inv16_or_t_16_16: Convergence halted due to matrix inversion problems. Suspicion is high collinearity. Please see the following tips in the lifelines documentation: https://lifelines.readthedocs.io/en/latest/Examples.html#problems-with-convergence-in-the-cox-proportional-hazard-modelMatrix is singular.



  return (X - mean) / std


[WARN] Cox convergence failed for has_t_9_22: Convergence halted due to matrix inversion problems. Suspicion is high collinearity. Please see the following tips in the lifelines documentation: https://lifelines.readthedocs.io/en/latest/Examples.html#problems-with-convergence-in-the-cox-proportional-hazard-modelMatrix is singular.
                       feature        HR  CI_lower_95  CI_upper_95  \
32          worst_clone_events  1.166563     1.146919     1.186542   
5        n_chromosomes_altered  1.217920     1.191417     1.245012   
4                     n_events  1.132965     1.117225     1.148927   
30      n_autosomal_monosomies  1.286198     1.244992     1.328768   
17        is_complex_karyotype  3.147919     2.712936     3.652646   
8    n_structural_events_total  1.179770     1.153909     1.206210   
16      is_monosomal_karyotype  3.269894     2.787992     3.835093   
6           n_monosomies_total  1.269493     1.227778     1.312624   
10         has_minus7_or_del7q  2.6987

In [65]:
# ===============================
# 2. Préparation données survie
# ===============================

y = Surv.from_dataframe(event=event_col, time=time_col, data=df)

X_clinical = df[clinical_mutation_features].copy()
X_clinical_plus_cyto = df[clinical_mutation_features + cytogenetics_features].copy()

In [None]:
from sklearn.base import clone
from sklearn.model_selection import StratifiedKFold, KFold, GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sksurv.linear_model import CoxnetSurvivalAnalysis
from sksurv.metrics import (
    concordance_index_censored,
    integrated_brier_score,
    cumulative_dynamic_auc,
)
import numpy as np
import pandas as pd


def make_cindex_scorer(event_field, time_field):
    """
    Scorer sklearn pour C-index, compatible avec scikit-survival.
    event_field et time_field viennent de y.dtype.names (ex: 'OS_STATUS', 'OS_YEARS').
    """
    def cindex_scorer(estimator, X, y):
        est = clone(estimator)
        est.fit(X, y)
        risk_scores = est.predict(X)
        event = y[event_field].astype(bool)
        time = y[time_field]
        cindex = concordance_index_censored(event, time, risk_scores)[0]
        return cindex

    return cindex_scorer


def nested_cv_survival(
    X: pd.DataFrame,
    y,
    times_auc=(1.0, 2.0, 3.0),   # OS_YEARS en années
    n_splits_outer=5,
    n_splits_inner=3,
    random_state=42,
):
    """
    Nested CV pour CoxnetSurvivalAnalysis, avec:
    - standardisation
    - chemin d'alphas "safe"
    - fit_baseline_model=True pour pouvoir appeler predict_survival_function.
    """

    # noms de champs dans y, par ex. ('OS_STATUS', 'OS_YEARS')
    event_field, time_field = y.dtype.names

    # chemin d'alphas "safe" : de 0.1 à 100, suffisamment régularisé
    safe_alphas = np.logspace(-1, 2, 40)  # 0.1 ... 100

    # pipeline : standardisation + Coxnet
    pipe = Pipeline(
        [
            ("scaler", StandardScaler(with_mean=True, with_std=True)),
            ("coxnet", CoxnetSurvivalAnalysis(
                alphas=safe_alphas,
                l1_ratio=0.5,
                fit_baseline_model=True,   # IMPORTANT pour predict_survival_function
            )),
        ]
    )

    param_grid = {
        "coxnet__l1_ratio": [0.1, 0.5, 0.9],
    }

    # Stratification sur l'indicateur d'événement réel (OS_STATUS)
    event_indicator = y[event_field].astype(int)

    outer_cv = StratifiedKFold(
        n_splits=n_splits_outer,
        shuffle=True,
        random_state=random_state,
    )

    fold_results = {
        "c_index": [],
        "ibs": [],
        "auc_per_time": [],
        "times_auc": np.array(times_auc, dtype=float),
    }

    scorer = make_cindex_scorer(event_field, time_field)

    for fold_idx, (train_idx, test_idx) in enumerate(
        outer_cv.split(X, event_indicator)
    ):
        print(f"Outer fold {fold_idx + 1}/{n_splits_outer}")

        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        inner_cv = KFold(
            n_splits=n_splits_inner,
            shuffle=True,
            random_state=random_state,
        )

        gs = GridSearchCV(
            estimator=pipe,
            param_grid=param_grid,
            scoring=scorer,
            cv=inner_cv,
            n_jobs=-1,
            # pendant le debug tu peux mettre error_score="raise"
        )
        gs.fit(X_train, y_train)

        best_model = gs.best_estimator_
        best_model.fit(X_train, y_train)

        # --- C-index sur le fold test ---
        risk_scores_test = best_model.predict(X_test)
        cindex_test = concordance_index_censored(
            y_test[event_field].astype(bool),
            y_test[time_field],
            risk_scores_test,
        )[0]

         # --- IBS ---
        times_grid = np.linspace(
            np.percentile(y_train[time_field], 5),
            np.percentile(y_train[time_field], 95),
            50,
        )

        surv_funcs_test = best_model.predict_survival_function(X_test)
        pred_surv_test = np.vstack(
            [fn(times_grid) for fn in surv_funcs_test]
        )

        # contrainte sksurv : les temps du test doivent être <= max des temps du train
        max_time_train = y_train[time_field].max()

        # on coupe la grille de temps pour l'IBS
        times_grid_ib = times_grid[times_grid <= max_time_train]

        # on filtre les patients de test avec un temps > max_time_train
        mask_ib = y_test[time_field] <= max_time_train
        y_test_ib = y_test[mask_ib]
        pred_surv_test_ib = pred_surv_test[mask_ib, :]

        # si trop peu de points, on met NaN
        if (len(times_grid_ib) < 2) or (len(y_test_ib) == 0):
            ibs = np.nan
        else:
            ibs = integrated_brier_score(
                y_train,
                y_test_ib,
                pred_surv_test_ib,
                times_grid_ib,
            )

        # --- AUC(t) ---
        _, auc_values = cumulative_dynamic_auc(
            y_train,
            y_test,
            risk_scores_test,
            np.array(times_auc, dtype=float),
        )

        fold_results["c_index"].append(cindex_test)
        fold_results["ibs"].append(ibs)
        fold_results["auc_per_time"].append(auc_values)

    return fold_results


In [67]:
# ===============================
# 4. Lancer les évaluations
# ===============================

import warnings
warnings.filterwarnings(
    "ignore",
    category=DeprecationWarning,
    message="`row_stack` alias is deprecated"
)

# OS_YEARS est en années -> 1, 2, 3 ans
times_auc = (1.0, 2.0, 3.0)

print("=== Nested CV: modèle SANS cytogénétique ===")
results_no_cyto = nested_cv_survival(
    X=X_clinical,
    y=y,
    times_auc=times_auc,
    n_splits_outer=5,
    n_splits_inner=3,
    random_state=42,
)

print("\n=== Nested CV: modèle AVEC cytogénétique ===")
results_with_cyto = nested_cv_survival(
    X=X_clinical_plus_cyto,
    y=y,
    times_auc=times_auc,
    n_splits_outer=5,
    n_splits_inner=3,
    random_state=42,
)


=== Nested CV: modèle SANS cytogénétique ===
Outer fold 1/5
Outer fold 2/5
Outer fold 3/5
Outer fold 4/5
Outer fold 5/5

=== Nested CV: modèle AVEC cytogénétique ===
Outer fold 1/5
Outer fold 2/5
Outer fold 3/5
Outer fold 4/5
Outer fold 5/5


In [70]:
# ===============================
# 5. Résumé des performances + 6. Deltas
# ===============================

def summarize_nested_results(results, label="model"):
    cidx = np.array(results["c_index"])
    ibs = np.array(results["ibs"])
    auc_list = results["auc_per_time"]  # liste d'objets (arrays / scalaires)
    times_all = np.array(results["times_auc"], dtype=float)

    # C-index et IBS (en gérant les NaN potentiels pour IBS)
    summary = {
        "label": label,
        "c_index_mean": float(np.nanmean(cidx)),
        "c_index_std": float(np.nanstd(cidx)),
        "ibs_mean": float(np.nanmean(ibs)),
        "ibs_std": float(np.nanstd(ibs)),
    }

    # Si on n'a pas d'AUC du tout, on s'arrête là
    if len(auc_list) == 0:
        return pd.Series(summary)

    # Normaliser chaque entrée en array 1D au minimum
    auc_arrays = []
    for a in auc_list:
        arr = np.atleast_1d(np.array(a, dtype=float))  # scalaires -> shape (1,)
        auc_arrays.append(arr)

    # Longueur minimale commune des vecteurs AUC(t)
    min_len = min(arr.shape[0] for arr in auc_arrays)

    # S'il n'y a aucune dimension exploitable, on retourne juste C-index/IBS
    if min_len == 0:
        return pd.Series(summary)

    # Tronquer chaque vecteur à min_len
    auc_per_time = np.vstack([arr[:min_len] for arr in auc_arrays])

    # Tronquer aussi les temps pour garder la cohérence
    times_used = times_all[:min_len]

    # AUC(t) : moyenne et std par temps (sur les temps effectivement dispo)
    for i, t in enumerate(times_used):
        summary[f"auc_{t:.2f}_mean"] = float(np.nanmean(auc_per_time[:, i]))
        summary[f"auc_{t:.2f}_std"] = float(np.nanstd(auc_per_time[:, i]))

    return pd.Series(summary)


# 5) Résumés
summary_no_cyto = summarize_nested_results(results_no_cyto, label="no_cyto")
summary_with_cyto = summarize_nested_results(results_with_cyto, label="with_cyto")

comparison_df = pd.concat([summary_no_cyto, summary_with_cyto], axis=1)
print("\n=== Résumé performances (nested CV) ===")
print(comparison_df)


# 6) Δ C-index / Δ IBS / Δ AUC
delta = pd.Series(
    {
        "delta_c_index": summary_with_cyto["c_index_mean"] - summary_no_cyto["c_index_mean"],
        "delta_ibs": summary_with_cyto["ibs_mean"] - summary_no_cyto["ibs_mean"],
    }
)

# Colonnes AUC_*_mean communes aux deux modèles
auc_cols_no   = {c for c in summary_no_cyto.index   if c.startswith("auc_") and c.endswith("_mean")}
auc_cols_with = {c for c in summary_with_cyto.index if c.startswith("auc_") and c.endswith("_mean")}
common_auc_cols = sorted(auc_cols_no & auc_cols_with)

for col in common_auc_cols:
    # ex: col = "auc_1.00_mean" -> nom de delta = "delta_auc_1.00"
    time_label = col.replace("auc_", "").replace("_mean", "")
    delta_name = f"delta_auc_{time_label}"
    delta[delta_name] = summary_with_cyto[col] - summary_no_cyto[col]

print("\n=== Deltas (with_cyto - no_cyto) ===")
print(delta)



=== Résumé performances (nested CV) ===
                      0          1
label           no_cyto  with_cyto
c_index_mean   0.741184   0.742249
c_index_std    0.004154   0.004272
ibs_mean       0.160939   0.160592
ibs_std        0.009092   0.009285
auc_1.00_mean  0.794707    0.79562
auc_1.00_std    0.01011    0.01037

=== Deltas (with_cyto - no_cyto) ===
delta_c_index     0.001066
delta_ibs        -0.000347
delta_auc_1.00    0.000913
dtype: float64


In [74]:
# ===============================
# 7. Calibration plot (Plotly)
# ===============================

import numpy as np
from sksurv.nonparametric import kaplan_meier_estimator
import plotly.graph_objects as go

calib_time = 2.0  # OS_YEARS --> années
event_field, time_field = y.dtype.names

# Probabilité de survie prédite
surv_funcs_all = best_final.predict_survival_function(X_clinical_plus_cyto)
surv_at_calib = np.array([fn(calib_time) for fn in surv_funcs_all])

# Binning en déciles
n_bins = 10
quantiles = np.quantile(surv_at_calib, np.linspace(0, 1, n_bins + 1))
bin_ids = np.digitize(surv_at_calib, quantiles[1:-1], right=True)

bin_pred_surv = []
bin_obs_surv = []

for b in range(n_bins):
    mask = bin_ids == b
    if mask.sum() < 5:
        continue

    # survie prédite moyenne
    bin_pred_surv.append(surv_at_calib[mask].mean())

    # survie observée (Kaplan–Meier)
    t, s = kaplan_meier_estimator(
        y[event_field][mask],
        y[time_field][mask],
    )

    obs = s[t <= calib_time][-1] if np.any(t <= calib_time) else 1.0
    bin_obs_surv.append(obs)

bin_pred_surv = np.array(bin_pred_surv)
bin_obs_surv = np.array(bin_obs_surv)

# === Plotly ===
fig = go.Figure()

# Points calibration
fig.add_trace(go.Scatter(
    x=bin_pred_surv,
    y=bin_obs_surv,
    mode="markers+lines",
    name="Calibration",
    marker=dict(size=8),
))

# Ligne parfaite
fig.add_trace(go.Scatter(
    x=[0, 1],
    y=[0, 1],
    mode="lines",
    name="Perfect calibration",
    line=dict(dash="dash")
))

fig.update_layout(
    width=650,
    height=600,
    title=f"Calibration curve at t = {calib_time:.1f} years",
    xaxis_title="Predicted survival probability",
    yaxis_title="Observed survival (KM)",
)

fig.show()
