In [1]:
import numpy as np
import pandas as pd
from copy import deepcopy

from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn import metrics
from sksurv import metrics as skmetrics

import torch
import torch.nn as nn

from ictsurf.dataset import (
    get_metabric_dataset_onehot,
    get_support2_dataset_onehot,
    get_gaussian_dataset,
    get_synthetic_dataset_compet,
    get_loader
)
from ictsurf.preprocessing import cut_continuous_time,  CTCutEqualSpacing
from ictsurf.eval import *
from ictsurf.utils import *
from ictsurf.loss import nll_continuous_time_multi_loss_trapezoid
from ictsurf.model import MLPTimeEncode
from ictsurf.train_utils import test_step
from ictsurf import ICTSurF, ICTSurFMulti

In [2]:
random_state = 1234
np.random.seed(random_state)
_ = torch.manual_seed(random_state)

features, durations, events, true_duration, true_events, risk1_durations, risk2_durations = get_synthetic_dataset_compet()

(
    features, features_val, 
    durations, durations_val, 
    events, events_val, 
    true_duration, true_duration_val,  
    true_events, true_events_val,
    risk1_durations, risk1_durations_val,
    risk2_durations, risk2_durations_val
) = train_test_split(
        features, durations, events, true_duration, true_events, risk1_durations, risk2_durations,
        test_size=0.15, random_state = random_state, stratify = events)
(
    features_train, features_test, 
    durations_train, durations_test, 
    events_train, events_test, 
    true_duration_train, true_duration_test,  
    true_events_train, true_events_test,
    risk1_durations_train, risk1_durations_test,
    risk2_durations_train, risk2_durations_test
) = train_test_split(
        features, durations, events, true_duration, true_events, risk1_durations, risk2_durations,
        test_size=0.15, random_state = random_state, stratify = events)

In [3]:
mean_time = np.mean(durations_train)
durations_train = durations_train/mean_time
durations_val = durations_val/mean_time
durations_test = durations_test/mean_time

true_duration_test = true_duration_test/mean_time
risk1_durations_test = risk1_durations_test/mean_time
risk2_durations_test = risk2_durations_test/mean_time

scaler =  StandardScaler()
features_train = scaler.fit_transform(features_train)
features_val = scaler.transform(features_val)
features_test = scaler.transform(features_test)

In [4]:
# add 1 for time feature
in_features = features_train.shape[1]+1
num_nodes = [64]
num_nodes_res = [64]
time_dim = 16
batch_norm = True
dropout = 0.0
lr = 0.0002
activation = nn.ReLU
output_risk = 2
batch_size = 256
epochs = 10000
n_discrete_time = 50
patience = 10
device = 'cpu'

# defined network
net = MLPTimeEncode(
in_features, num_nodes, num_nodes_res, time_dim=time_dim, batch_norm= batch_norm,
dropout=dropout, activation=activation, output_risk = output_risk).float()

model = ICTSurFMulti(net).to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

model.fit(optimizer, features_train, durations_train, events_train,
            features_val, durations_val, events_val,
    n_discrete_time = n_discrete_time, patience = patience, device = device,
    batch_size=batch_size, epochs=epochs, shuffle=True)

epoch 0 val_loss: 1.3102756737087449 train_loss: 2.2931591445179142
epoch 1 val_loss: 1.0283979083556376 train_loss: 1.186303976828263
epoch 2 val_loss: 1.0208575996201985 train_loss: 1.0389776042981058
epoch 3 val_loss: 0.9740348705756805 train_loss: 0.9855602054985892
epoch 4 val_loss: 0.9564489199805402 train_loss: 0.9719841271240279
epoch 5 val_loss: 0.940814183694074 train_loss: 0.9599332992614383
epoch 6 val_loss: 0.9255820336922354 train_loss: 0.9380448581702131
epoch 7 val_loss: 0.9150425334251916 train_loss: 0.9243039664029737
epoch 8 val_loss: 0.8943305007624682 train_loss: 0.9004748818580458
epoch 9 val_loss: 0.8875684042669245 train_loss: 0.8849469881604531
epoch 10 val_loss: 0.861579078241132 train_loss: 0.8655559096170004
epoch 11 val_loss: 0.8443714793316354 train_loss: 0.8449284095985881
epoch 12 val_loss: 0.8260100898676207 train_loss: 0.8245356187809092
epoch 13 val_loss: 0.8084564382099144 train_loss: 0.8067963843835951
epoch 14 val_loss: 0.7971647696381313 train_los

0.7064238025244549

In [5]:

# select specific time of interest
eval_time = np.quantile(durations_test[events_test == 1], 0.25)

time_of_interests = np.array([eval_time]*len(features_test))
fake_events = np.array([1]*len(features_test))

# create dataloader for evaluation using data processor that already fitted from model
test_loader = get_loader(features_test, time_of_interests, fake_events, model.processor ,batch_size=256, fit_y=False)

preds = test_step(model, test_loader, device = device)

# get hazard
hazard = model.pred_to_hazard(preds)

# get survival probability
# to get survival function we need to integrate the hazard function
# to integrate, we need discretized time from dataloader
# the discretization time can be access from
# test_loader.dataset.extended_data['continuous_times']
surv = model.pred_to_surv(preds, test_loader)

for risk in [1,2]:
    event_risk = np.ones_like(true_events_test)
    event_risk[true_events_test != risk] = 0
    event_risk = event_risk.astype(bool)
    y_true = np.ones_like(true_events_test)
    if risk == 1:
        y_true[risk1_durations_test > eval_time] = 0
    elif risk == 2:
        y_true[risk2_durations_test > eval_time] = 0
        
    surv_at_risk = surv[:, :, risk-1]
    surv_at_time = surv_at_risk[:, -1]
    c = skmetrics.concordance_index_censored(event_risk, true_duration_test, 1-surv_at_time)[0]
    brier = metrics.brier_score_loss(y_true, 1-surv_at_time)
    print(c, brier)

0.7206997184291418 0.1143647230445533
0.7700174706552777 0.14487319313290456
