# Importation des librairies

In [19]:
import numpy as np
import pandas as pd
import re
import matplotlib.pyplot as plt
import os.path

from sklearn.tree import plot_tree
from sksurv.ensemble import RandomSurvivalForest
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.metrics import concordance_index_censored , concordance_index_ipcw
from sklearn.impute import SimpleImputer
from sksurv.util import Surv
from lifelines.utils import concordance_index
from sklearn.preprocessing import StandardScaler


# Modèle MTLR

In [49]:
import numpy as np
import pandas as pd
import torch

from sklearn.model_selection import KFold

from torchmtlr import (
    MTLR,
    mtlr_neg_log_likelihood,
    mtlr_survival,
    mtlr_risk,
)
from torchmtlr.utils import make_time_bins, encode_survival

import optuna
from tqdm.auto import tqdm

# ====================
# 0. Device
# ====================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device utilisé :", device)

# ====================
# 1. Chargement des données
# ====================
df_train = pd.read_csv("../data/train_pivot3.csv")
df_val   = pd.read_csv("../data/eval_pivot3.csv")

# Renommer pour coller à la notation du notebook MTLR
df_train = df_train.rename(columns={
    "OS_STATUS": "event",
    "OS_YEARS":  "time"
})

# Enlever les lignes avec time/event manquants
df_train = df_train.dropna(subset=["time", "event"])

# Temps >= 0
df_train["time"] = df_train["time"].clip(lower=0)

# Enlever ID des features (si présent)
for col in ["ID"]:
    if col in df_train.columns:
        df_train = df_train.drop(columns=[col])

# Définition des colonnes de features
feature_cols = [c for c in df_train.columns if c not in ["time", "event"]]
print("Nombre de covariables :", len(feature_cols))

# ====================
# 2. Time bins
# ====================
time_bins = make_time_bins(
    df_train["time"].values,
    event=df_train["event"].values
)
num_time_bins = len(time_bins)
in_features = len(feature_cols)

print("Nombre de time bins   :", num_time_bins)

# ====================
# 3. Helpers de normalisation
# ====================
def fit_normalizer(df, feature_cols):
    """
    Calcule médianes (pour imputation), moyennes, std pour standardisation.
    Remplace inf/-inf par NaN avant.
    """
    X = df[feature_cols].replace([np.inf, -np.inf], np.nan)
    med = X.median()
    X = X.fillna(med)
    mean = X.mean()
    std = X.std()
    std[std == 0] = 1.0  # évite division par 0
    return med, mean, std

def apply_normalizer(df, feature_cols, med, mean, std):
    X = df[feature_cols].replace([np.inf, -np.inf], np.nan)
    X = X.fillna(med)
    X = (X - mean) / std
    df_norm = df.copy()
    df_norm[feature_cols] = X
    return df_norm

# ====================
# 4. Reset des paramètres du modèle
# ====================
def reset_parameters(model):
    for param in model.parameters():
        if param.dim() > 1:  # poids (matrices)
            torch.nn.init.xavier_uniform_(param)
        else:  # biais (vecteurs)
            torch.nn.init.zeros_(param)

def make_model():
    model = MTLR(
        in_features=in_features,
        num_time_bins=num_time_bins
    ).to(device)
    reset_parameters(model)
    return model

# ====================
# 5. Fonction d'entraînement MTLR (un dataset)
# ====================
def train_mtlr(
    model,
    df,
    time_bins,
    C1=1.0,
    num_epochs=100,
    lr=1e-3,
    batch_size=512,
    device=device,
    verbose=False
):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Labels survival
    y = encode_survival(
        df["time"].values,
        df["event"].values,
        time_bins
    ).to(device)

    # Features
    X = torch.tensor(
        df[feature_cols].values,
        dtype=torch.float32,
        device=device
    )

    n = df.shape[0]
    n_batches = int(np.ceil(n / batch_size))

    model.train()

    # barre de progression sur les epochs
    if verbose:
        epoch_iter = tqdm(range(num_epochs), desc="Entraînement MTLR", leave=False)
    else:
        epoch_iter = range(num_epochs)

    for epoch in epoch_iter:
        indices = np.random.permutation(n)
        epoch_loss = 0.0

        for b in range(n_batches):
            batch_idx = indices[b*batch_size:(b+1)*batch_size]
            xb = X[batch_idx]
            yb = y[batch_idx]

            logits = model(xb)
            loss = mtlr_neg_log_likelihood(
                logits,
                yb,
                model,
                C1=C1
            )

            if not torch.isfinite(loss):
                # on log et on arrête l'entraînement pour ce modèle
                if verbose:
                    print("⚠️ Loss non finie (NaN/inf) à l'epoch", epoch)
                return model

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        epoch_loss /= n_batches

        if verbose:
            epoch_iter.set_postfix(loss=f"{epoch_loss:.4f}")

    return model

# ====================
# 6. C-index de Harrell
# ====================
def concordance_index(event_times, event_observed, predicted_risk):
    """
    C-index de Harrell.
    event_times : array-like (n,)
    event_observed : array-like (n,) 1=event, 0=censuré
    predicted_risk : array-like (n,) (plus grand = plus à risque)
    """
    t = np.asarray(event_times, dtype=float)
    e = np.asarray(event_observed, dtype=int)
    r = np.asarray(predicted_risk, dtype=float)

    n = len(t)
    assert len(e) == n and len(r) == n

    num = 0.0
    den = 0.0

    for i in range(n):
        for j in range(i + 1, n):
            # On regarde si (i, j) forme une paire comparable
            if t[i] == t[j]:
                continue

            # L'évènement le plus tôt
            if t[i] < t[j]:
                if e[i] == 0:
                    continue  # censure avant l'autre
                ti, tj = i, j
            else:
                if e[j] == 0:
                    continue
                ti, tj = j, i

            den += 1.0

            # concordance : plus de risque pour celui qui a l'évènement le plus tôt
            if r[ti] > r[tj]:
                num += 1.0
            elif r[ti] == r[tj]:
                num += 0.5

    return num / den if den > 0 else np.nan

# ====================
# 7. Cross-validation générique (utile pour Optuna)
# ====================
def crossval_cindex(
    data,
    n_splits,
    C1,
    num_epochs,
    lr,
    batch_size,
    device=device,
    verbose=False
):
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    fold_cindex = []

    # liste des splits pour tqdm
    splits = list(kf.split(data))

    if verbose:
        fold_iter = tqdm(splits, desc="Cross-validation", total=n_splits)
    else:
        fold_iter = splits

    for fold, (train_idx, val_idx) in enumerate(fold_iter, start=1):
        train_fold = data.iloc[train_idx].reset_index(drop=True)
        val_fold   = data.iloc[val_idx].reset_index(drop=True)

        # Normalisation apprise sur le train du fold
        med, mean, std = fit_normalizer(train_fold, feature_cols)
        train_fold_norm = apply_normalizer(train_fold, feature_cols, med, mean, std)
        val_fold_norm   = apply_normalizer(val_fold, feature_cols, med, mean, std)

        # Modèle
        model = make_model()

        # Entraînement
        model = train_mtlr(
            model,
            train_fold_norm,
            time_bins,
            C1=C1,
            num_epochs=num_epochs,
            lr=lr,
            batch_size=batch_size,
            device=device,
            verbose=False
        )

        # Évaluation C-index
        x_val = torch.tensor(
            val_fold_norm[feature_cols].values,
            dtype=torch.float32,
            device=device
        )

        with torch.no_grad():
            logits_val = model(x_val)
            risk_val = mtlr_risk(logits_val).cpu().numpy().reshape(-1)

        times_val  = val_fold_norm["time"].values
        events_val = val_fold_norm["event"].values

        cidx = concordance_index(times_val, events_val, risk_val)
        fold_cindex.append(cidx)

        if verbose:
            fold_iter.set_postfix(c_index=f"{cidx:.4f}")

    mean_c = float(np.mean(fold_cindex))
    std_c  = float(np.std(fold_cindex))
    if verbose:
        print(f"--> C-index moyen = {mean_c:.4f} ± {std_c:.4f}")

    return mean_c, std_c

# ====================
# 8. Optuna : optimisation hyperparamètres
# ====================
def objective(trial):
    # Espace de recherche
    C1 = trial.suggest_float("C1", 1e-3, 10.0, log=True)
    lr = trial.suggest_float("lr", 1e-4, 5e-3, log=True)
    batch_size = trial.suggest_categorical("batch_size", [128, 256, 512, 1024])
    # Tu peux aussi tuner le nombre d'epochs, mais c'est plus coûteux :
    num_epochs = trial.suggest_int("num_epochs", 50, 150, step=25)

    # On fait typiquement un 3 à 5-fold CV
    mean_c, std_c = crossval_cindex(
        data=df_train,
        n_splits=5,
        C1=C1,
        num_epochs=num_epochs,
        lr=lr,
        batch_size=batch_size,
        device=device,
        verbose=False
    )

    # Tu peux logguer les infos si tu veux
    trial.set_user_attr("std_cindex", std_c)

    # Optuna va MAXIMISER ce retour si on le précise
    return mean_c

# Création de l’étude
study = optuna.create_study(direction="maximize", study_name="mtlr_survival")
print("=== Début optimisation Optuna ===")
study.optimize(objective, n_trials=20, show_progress_bar=True)  # barre Optuna
print("=== Fin optimisation Optuna ===")

print("Meilleurs hyperparamètres Optuna :")
print(study.best_params)
print("Meilleur C-index moyen :", study.best_value)

best_params = study.best_params
best_C1        = best_params["C1"]
best_lr        = best_params["lr"]
best_batchsize = best_params["batch_size"]
best_epochs    = best_params["num_epochs"]

# ====================
# 9. Entraînement final sur tout df_train avec les meilleurs hyperparamètres
# ====================
print("\n=== Entraînement final avec hyperparamètres Optuna ===")
print(f"C1 = {best_C1:.4g}, lr = {best_lr:.4g}, "
      f"batch_size = {best_batchsize}, num_epochs = {best_epochs}")

# Normalisation sur tout df_train
med_full, mean_full, std_full = fit_normalizer(df_train, feature_cols)
df_train_norm = apply_normalizer(df_train, feature_cols, med_full, mean_full, std_full)

mtlr_model = make_model()
mtlr_model = train_mtlr(
    mtlr_model,
    df_train_norm,
    time_bins,
    C1=best_C1,
    num_epochs=best_epochs,
    lr=best_lr,
    batch_size=best_batchsize,
    device=device,
    verbose=True  # pour afficher la barre sur l'entraînement final
)

norm_params = {"med": med_full, "mean": mean_full, "std": std_full}

print("Entraînement final terminé.")

# ====================
# 10. Prédictions sur df_val avec le modèle optimisé
# ====================

# Enlever ID de df_val (si présent)
df_val_feats = df_val.copy()
for col in ["ID"]:
    if col in df_val_feats.columns:
        df_val_feats = df_val_feats.drop(columns=[col])

# Garder exactement les mêmes features que dans df_train, dans le même ordre
df_val_feats = df_val_feats[feature_cols]

# Appliquer le normalizer appris sur tout df_train
df_val_norm = apply_normalizer(
    df_val_feats,
    feature_cols,
    norm_params["med"],
    norm_params["mean"],
    norm_params["std"]
)

x_val_tensor = torch.tensor(
    df_val_norm[feature_cols].values,
    dtype=torch.float32,
    device=device
)

with torch.no_grad():
    logits_val = mtlr_model(x_val_tensor)
    surv_val = mtlr_survival(logits_val).cpu().numpy()         # (n_samples, num_time_bins)
    risk_val = mtlr_risk(logits_val).cpu().numpy().reshape(-1) # (n_samples,)

print("Shape des prédictions de survie sur df_val :", surv_val.shape)
print("Exemple de risk scores (5 premiers) :", risk_val[:5])


[I 2025-12-10 00:36:35,682] A new study created in memory with name: mtlr_survival_improved


Device utilisé : cpu
Nombre de covariables : 164
Nombre de time bins (bords) : 40
Dimension logits/target     : 41
=== Début optimisation Optuna ===


  0%|          | 0/20 [00:00<?, ?it/s]

[I 2025-12-10 00:40:00,674] Trial 0 finished with value: 0.5 and parameters: {'n_layers': 3, 'hidden_size': 256, 'dropout': 0.2683595842828394, 'C1': 1.0924356736442253, 'C2': 6.496271281628262e-06, 'lr': 0.0062552696262983265, 'batch_size': 64, 'num_epochs': 175}. Best is trial 0 with value: 0.5.
[I 2025-12-10 00:40:33,685] Trial 1 finished with value: 0.5 and parameters: {'n_layers': 2, 'hidden_size': 32, 'dropout': 0.3550257041979083, 'C1': 0.0006957609770187876, 'C2': 1.9716367787104583e-06, 'lr': 0.0037439784696579304, 'batch_size': 64, 'num_epochs': 50}. Best is trial 0 with value: 0.5.
[I 2025-12-10 00:41:25,000] Trial 2 finished with value: 0.5 and parameters: {'n_layers': 3, 'hidden_size': 128, 'dropout': 0.3407811116676015, 'C1': 0.044233442046629906, 'C2': 0.0014399673781958634, 'lr': 3.181901418364637e-05, 'batch_size': 256, 'num_epochs': 150}. Best is trial 0 with value: 0.5.
[W 2025-12-10 00:41:53,299] Trial 3 failed with parameters: {'n_layers': 3, 'hidden_size': 128, 'd

KeyboardInterrupt: 