In [1]:
import numpy as np
import pandas as pd
from copy import deepcopy

from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn import metrics
from sksurv import metrics as skmetrics

import torch
import torch.nn as nn

from ictsurf.dataset import (
    get_metabric_dataset_onehot,
    get_support2_dataset_onehot,
    get_gaussian_dataset,
    get_synthetic_dataset_compet,
    get_loader
)
from ictsurf.preprocessing import cut_continuous_time,  CTCutEqualSpacing
from ictsurf.eval import *
from ictsurf.utils import *
from ictsurf.loss import nll_continuous_time_multi_loss_trapezoid
from ictsurf.model import MLPTimeEncode
from ictsurf.train_utils import test_step
from ictsurf import ICTSurF, ICTSurFMulti

## Load data

In [6]:
random_state = 1234
np.random.seed(random_state)
_ = torch.manual_seed(random_state)

features, durations, events = get_support2_dataset_onehot()

features, features_val, durations, durations_val, events, events_val = train_test_split(
            features, durations, events, test_size=0.15, random_state = random_state, stratify = events)

features_train, features_test, durations_train, durations_test, events_train, events_test = train_test_split(
            features, durations, events, test_size=0.15, random_state = random_state, stratify = events)

# remove the samples with durations greater than the maximum duration in the training set
# because c-index from library sksurv cannot handle this case
while durations_train.max()<=durations_test.max():
    test_index_max = durations_test.argmax()
    durations_test = deepcopy(np.delete(durations_test, test_index_max))
    features_test = deepcopy(np.delete(features_test, test_index_max, axis = 0))
    events_test = deepcopy(np.delete(events_test, test_index_max))
    
while durations_train.max()<=durations_val.max():
    test_index_max = durations_val.argmax()
    durations_val = deepcopy(np.delete(durations_val, test_index_max))
    features_val = deepcopy(np.delete(features_val, test_index_max, axis = 0))
    events_val = deepcopy(np.delete(events_val, test_index_max))

## Normalization

In [7]:
mean_time = np.mean(durations_train)
durations_train = durations_train/mean_time
durations_val = durations_val/mean_time
durations_test = durations_test/mean_time

scaler =  StandardScaler()

features_train = scaler.fit_transform(features_train)
features_val = scaler.transform(features_val)
features_test = scaler.transform(features_test)

## Training

In [8]:
# add 1 for time feature
in_features = features_train.shape[1]+1
num_nodes = [64]
num_nodes_res = [64]
time_dim = 16
batch_norm = True
dropout = 0.0
lr = 0.0002
activation = nn.ReLU
output_risk = 1
batch_size = 256
epochs = 10000
n_discrete_time = 50
patience = 10
device = 'cpu'

# defined network
net = MLPTimeEncode(
in_features, num_nodes, num_nodes_res, time_dim=time_dim, batch_norm= batch_norm,
dropout=dropout, activation=activation, output_risk = output_risk).float()

# # or defined your own network
# class CustomNet(nn.Module):
#     def __init__(self, in_features, output_risk = 1):
#         super().__init__()
#         self.output_risk = output_risk
#         self.linear = nn.Linear(in_features, 1)
    
#     def forward(self, input):
        
#         # the input in shape of (batch, time_step, in_features)
#         # the last features is the time feature
#         time_step = input.shape[1]
#         input = input.view(-1, input.shape[-1])
#         out = self.linear(input)

#         if self.output_risk == 1:

#             return out.view(-1, time_step)
#         return out.reshape(-1, time_step, 1)
# net = CustomNet(in_features)

model = ICTSurF(net).to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

model.fit(optimizer, features_train, durations_train, events_train,
            features_val, durations_val, events_val,
    n_discrete_time = n_discrete_time, patience = patience, device = device,
    batch_size=batch_size, epochs=epochs, shuffle=True)

epoch 5 val_loss: 0.5890185779173982 train_loss: 0.6010431903096392
epoch 6 val_loss: 0.555299671404275 train_loss: 0.5621553739853233
epoch 7 val_loss: 0.5302227970668731 train_loss: 0.5339830898330713
epoch 8 val_loss: 0.5080168172993881 train_loss: 0.5044601295172714
epoch 9 val_loss: 0.4859957410820004 train_loss: 0.481226466366328
epoch 10 val_loss: 0.4683984379812615 train_loss: 0.45970528679939804
epoch 11 val_loss: 0.4523657756157259 train_loss: 0.43712887861902894
epoch 12 val_loss: 0.4387495931680033 train_loss: 0.423800971262455
epoch 13 val_loss: 0.42761736930733213 train_loss: 0.40689455899073684
epoch 14 val_loss: 0.416024483478575 train_loss: 0.39290452062150516
epoch 15 val_loss: 0.40664274824950375 train_loss: 0.38289311330796805
epoch 16 val_loss: 0.3985128894679264 train_loss: 0.3690381046117367
epoch 17 val_loss: 0.3880805564715657 train_loss: 0.35997805692680374
epoch 18 val_loss: 0.3803891124037417 train_loss: 0.3509233159741795
epoch 19 val_loss: 0.37485780277258

0.2883737283459465

## Evaluate using our method

In [9]:
tmp_df = model.evaluate(
            features_test, durations_test, events_test, 
            quantile_evals = [0.25, 0.5, 0.75],
            interpolation = True, device = device)
tmp_df

Unnamed: 0,model,timepoint,c_index,brier
0,ICTSurf_loss.pth,0.25,0.833771,0.115548
0,ICTSurf_loss.pth,0.5,0.783476,0.158197
0,ICTSurf_loss.pth,0.75,0.749687,0.181603


## Custom Evaluation

In [24]:

# select specific time of interest
eval_time = np.quantile(durations_test[events_test == 1], 0.25)

time_of_interests = np.array([eval_time]*len(features_test))
fake_events = np.array([1]*len(features_test))

# create dataloader for evaluation using data processor that already fitted from model
test_loader = get_loader(features_test, time_of_interests, fake_events, model.processor ,batch_size=256, fit_y=False)

preds = test_step(model, test_loader, device = device)

# get hazard
hazard = model.pred_to_hazard(preds)

# get survival probability
# to get survival function we need to integrate the hazard function
# to integrate, we need discretized time from dataloader
# the discretization time can be access from
# test_loader.dataset.extended_data['continuous_times']
surv = model.pred_to_surv(preds, test_loader)

# then using survival probability for your evaluation