In [7]:
import numpy as np
import pandas as pd
import re
import matplotlib.pyplot as plt
import os.path

from sklearn.tree import plot_tree
from sksurv.ensemble import RandomSurvivalForest
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.metrics import concordance_index_censored , concordance_index_ipcw
from sklearn.impute import SimpleImputer
from sksurv.util import Surv
from lifelines.utils import concordance_index


# For preprocessing
from sklearn.preprocessing import StandardScaler


import torch # For building the networks 
import torchtuples as tt # Some useful functions

from pycox.datasets import metabric
from pycox.models import DeepHitSingle
from pycox.evaluation import EvalSurv


In [8]:
np.random.seed(1234)
_ = torch.manual_seed(123)

In [12]:
df_train_scaled = pd.read_csv('data/df_train_scaled.csv')
df_eval_scaled = pd.read_csv('data/df_eval_scaled.csv')



# DeepHit


In [41]:
import numpy as np
import pandas as pd
import torch
import torchtuples as tt
from pycox.models import DeepHitSingle
from pycox.evaluation import EvalSurv
from sklearn.preprocessing import StandardScaler

# Préparation des données
features = df_train_scaled.drop(columns=['OS_YEARS', 'OS_STATUS'])
features_encoded = pd.get_dummies(features, drop_first=True)
X = features_encoded.values.astype('float32')
scaler = StandardScaler()
X = scaler.fit_transform(X)

durations = df_train_scaled['OS_YEARS'].values
events = df_train_scaled['OS_STATUS'].values.astype('bool')

# Discrétisation du temps et transformation des labels
num_durations = 50
labtrans = DeepHitSingle.label_transform(num_durations)
y_train = labtrans.fit_transform(durations, events)

# Réseau
in_features = X.shape[1]
num_nodes = [64, 32]
out_features = labtrans.out_features
net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm=True, dropout=0.1, activation=torch.nn.ReLU)

# Modèle DeepHit
model = DeepHitSingle(net, tt.optim.Adam, alpha=0.2, sigma=0.1, duration_index=labtrans.cuts)
model.optimizer.set_lr(0.01)

# Early stopping
epochs = 100
callbacks = [tt.callbacks.EarlyStopping()]

# Entraînement
log = model.fit(
    X, (y_train[0], y_train[1]),
    batch_size=128,
    epochs=epochs,
    callbacks=callbacks,
    verbose=True,
    val_data=None
)

# Prédiction de survie sur le train
surv = model.predict_surv_df(X)

# Évaluation Antolini
ev = EvalSurv(surv, durations, events, censor_surv='km')
c_index = ev.concordance_td('antolini')
print(f"C-index (Antolini): {c_index:.4f}")

# Évaluation IPCW C-index
from sksurv.util import Surv
from sksurv.metrics import concordance_index_ipcw

y_struct = Surv.from_arrays(event=events, time=durations)
median_pred = []
for surv_curve in surv.values.T:
    below_half = np.where(surv_curve <= 0.5)[0]
    if below_half.size > 0:
        median_pred.append(surv.index[below_half[0]])
    else:
        median_pred.append(surv.index[-1])
median_pred = np.array(median_pred)

result = concordance_index_ipcw(
    y_struct, y_struct, -median_pred
)
cindex_ipcw = result[0] if isinstance(result, (tuple, list, np.ndarray)) else result.cindex
print(f"IPCW C-index: {cindex_ipcw:.4f}")

0:	[0s / 0s],		train_loss: 0.5874
1:	[0s / 0s],		train_loss: 0.5185
2:	[0s / 0s],		train_loss: 0.4812
3:	[0s / 0s],		train_loss: 0.4521
4:	[0s / 0s],		train_loss: 0.4431
5:	[0s / 0s],		train_loss: 0.4201
6:	[0s / 0s],		train_loss: 0.4127
7:	[0s / 0s],		train_loss: 0.4062
8:	[0s / 0s],		train_loss: 0.3975
9:	[0s / 1s],		train_loss: 0.3849
C-index (Antolini): 0.4351
IPCW C-index: 0.5652


# N-MTLR 

In [None]:
df_test = df_train_scaled.sample(frac=0.2)
df_train_scaled = df_train_scaled.drop(df_test.index)
df_val = df_train_scaled.sample(frac=0.2)
df_train = df_train_scaled.drop(df_val.index)

In [None]:
num_durations = 10
labtrans = MTLR.label_transform(num_durations)
get_target = lambda df: (df['duration'].values, df['event'].values)
y_train = labtrans.fit_transform(*get_target(df_train))
y_val = labtrans.transform(*get_target(df_val))

train = (x_train, y_train)
val = (x_val, y_val)

# We don't need to transform the test labels
durations_test, events_test = get_target(df_test)