In [1]:
# import libraries
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import time
import sys
import json
import pickle
import json

from tqdm import tqdm
import itertools

from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder, LabelEncoder
from sklearn import metrics
from sklearn.preprocessing import MinMaxScaler

from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, precision_recall_curve, roc_curve, auc
from sklearn.metrics import make_scorer, roc_auc_score

from torchsurv.loss import cox
from lifelines.utils import concordance_index

sys.path.append('./../src/')
from utils import *
from utils_deepsurv import *

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset


In [26]:
# split data into train and test sets
seeds = [999, 7, 42, 1995, 1303, 2405, 1996, 200, 0, 777]
test_size = 0.3
batch_size = 64

hidden_size = 128  # Number of neurons in the hidden layers
l2_reg = 0.01
lr = 1e-5
max_epochs = 250
dropout = 0.2


In [27]:
# Dataset
torch.manual_seed(0)

data_df = pd.read_csv('./../Data/1000_features_survival_3classes.csv',index_col=0).drop(['index', 'y'],axis=1)
data_df_event_time = data_df[['event', 'time']]


data_df = pd.get_dummies(data_df.drop(['event', 'time'], axis=1),dtype='int')
scaler = MinMaxScaler()
data_df = pd.DataFrame(scaler.fit_transform(data_df), columns=data_df.columns)
data_df['event'] = [int(e) for e in data_df_event_time['event']]
data_df['time'] = data_df_event_time['time']

data_df = data_df.fillna(data_df.mean())

train_ci_ls = []
valid_ci_ls = []
test_ci_ls = []
elapsed_time_ls = []
epochs_ls = []

for seed in seeds:
    print("*******************")
    print(seed)
    data_train, data_tmp = train_test_split(data_df, test_size=test_size, random_state=seed)
    data_val, data_test = train_test_split(data_tmp, test_size=0.5, random_state=seed)
    
    X_train = torch.tensor(data_train.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_train = torch.tensor(data_train['event'].to_numpy(), dtype=torch.long)
    t_train = torch.tensor(data_train['time'].to_numpy(), dtype=torch.long)
    train_dataset = TensorDataset(X_train, e_train, t_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    X_val = torch.tensor(data_val.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_val = torch.tensor(data_val['event'].to_numpy(), dtype=torch.long)
    t_val = torch.tensor(data_val['time'].to_numpy(), dtype=torch.long)
    val_dataset = TensorDataset(X_val, e_val, t_val)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    
    X_test = torch.tensor(data_test.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_test = torch.tensor(data_test['event'].to_numpy(), dtype=torch.long)
    t_test = torch.tensor(data_test['time'].to_numpy(), dtype=torch.long)
    test_dataset = TensorDataset(X_test, e_test, t_test)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
    
    # parameters
    input_size = X_train.shape[1]  # Number of RNA expression features
    
    model = DeepSurv(input_size, hidden_size ,dropout=dropout)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Train the model
    start = time.time()
    epoch = train(model, optimizer, train_loader, val_loader, max_epochs, l2_reg)
    end = time.time()
    elapsed_time = end-start
    elapsed_time_ls = elapsed_time_ls + [elapsed_time]
    print("Time elapsed: ", elapsed_time)
    print("Number of epoch: ", epoch)
    epochs_ls = epochs_ls + [epoch]
    
    model.eval()
    
    ## Training
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(train_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    
    print("train CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    train_ci_ls = train_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]
    
    ## Validation
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(val_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    print("valid CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    valid_ci_ls = valid_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]
    ## test
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(test_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    sorted_indices = np.argsort(times_ls).tolist()
    
    print("test CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    test_ci_ls = test_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]

*******************
999
Epoch [1/250], Loss: 7.0359
Validation Loss: 6.8087
Epoch [2/250], Loss: 6.7032
Validation Loss: 7.0165
EarlyStopping counter: 1 out of 50
Epoch [3/250], Loss: 6.8092
Validation Loss: 6.4735
Validation loss improved to 6.473457
Epoch [4/250], Loss: 6.9154
Validation Loss: 6.6786
EarlyStopping counter: 1 out of 50
Epoch [5/250], Loss: 6.9416
Validation Loss: 6.7735
EarlyStopping counter: 2 out of 50
Epoch [6/250], Loss: 6.8372
Validation Loss: 7.2108
EarlyStopping counter: 3 out of 50
Epoch [7/250], Loss: 6.6089
Validation Loss: 6.5251
EarlyStopping counter: 4 out of 50
Epoch [8/250], Loss: 6.6796
Validation Loss: 6.4650
Validation loss improved to 6.465045
Epoch [9/250], Loss: 6.7855
Validation Loss: 6.4572
Validation loss improved to 6.457167
Epoch [10/250], Loss: 6.6022
Validation Loss: 6.9076
EarlyStopping counter: 1 out of 50
Epoch [11/250], Loss: 6.7785
Validation Loss: 6.4444
Validation loss improved to 6.444410
Epoch [12/250], Loss: 6.6785
Validation Loss



Epoch [94/250], Loss: 6.2545
Validation Loss: 6.0871
EarlyStopping counter: 5 out of 50
Epoch [95/250], Loss: 5.6012
Validation Loss: 6.5672
EarlyStopping counter: 6 out of 50
Epoch [96/250], Loss: 5.7222
Validation Loss: 5.9623
EarlyStopping counter: 7 out of 50
Epoch [97/250], Loss: 5.6866
Validation Loss: 5.9606
EarlyStopping counter: 8 out of 50
Epoch [98/250], Loss: 5.6530
Validation Loss: 6.4896
EarlyStopping counter: 9 out of 50
Epoch [99/250], Loss: 5.6833
Validation Loss: 6.4280
EarlyStopping counter: 10 out of 50
Epoch [100/250], Loss: 5.9572
Validation Loss: 5.9476
EarlyStopping counter: 11 out of 50
Epoch [101/250], Loss: 5.8144
Validation Loss: 6.4658
EarlyStopping counter: 12 out of 50
Epoch [102/250], Loss: 5.9797
Validation Loss: 6.4899
EarlyStopping counter: 13 out of 50
Epoch [103/250], Loss: 5.9739
Validation Loss: 7.0418
EarlyStopping counter: 14 out of 50
Epoch [104/250], Loss: 5.7011
Validation Loss: 6.2767
EarlyStopping counter: 15 out of 50
Epoch [105/250], Loss



Epoch [20/250], Loss: 6.5592
Validation Loss: 6.2491
Validation loss improved to 6.249078
Epoch [21/250], Loss: 6.4620
Validation Loss: 6.2039
Validation loss improved to 6.203926
Epoch [22/250], Loss: 6.3541
Validation Loss: 6.3400
EarlyStopping counter: 1 out of 50
Epoch [23/250], Loss: 6.3590
Validation Loss: 6.1511
Validation loss improved to 6.151050
Epoch [24/250], Loss: 6.2246
Validation Loss: 6.3100
EarlyStopping counter: 1 out of 50
Epoch [25/250], Loss: 6.4112
Validation Loss: 6.2425
EarlyStopping counter: 2 out of 50
Epoch [26/250], Loss: 6.2818
Validation Loss: 6.9276
EarlyStopping counter: 3 out of 50
Epoch [27/250], Loss: 6.3615
Validation Loss: 6.8943
EarlyStopping counter: 4 out of 50
Epoch [28/250], Loss: 6.1492
Validation Loss: 6.1832
EarlyStopping counter: 5 out of 50
Epoch [29/250], Loss: 5.8286
Validation Loss: 6.2890
EarlyStopping counter: 6 out of 50
Epoch [30/250], Loss: 6.0466
Validation Loss: 6.2980
EarlyStopping counter: 7 out of 50
Epoch [31/250], Loss: 5.98



Epoch [9/250], Loss: 6.2431
Validation Loss: 6.4919
EarlyStopping counter: 8 out of 50
Epoch [10/250], Loss: 6.3617
Validation Loss: 6.6374
EarlyStopping counter: 9 out of 50
Epoch [11/250], Loss: 6.3707
Validation Loss: 6.7423
EarlyStopping counter: 10 out of 50
Epoch [12/250], Loss: 6.3360
Validation Loss: 7.1840
EarlyStopping counter: 11 out of 50
Epoch [13/250], Loss: 6.1588
Validation Loss: 6.3547
Validation loss improved to 6.354661
Epoch [14/250], Loss: 6.6277
Validation Loss: 6.2666
Validation loss improved to 6.266647
Epoch [15/250], Loss: 6.3874
Validation Loss: 7.0617
EarlyStopping counter: 1 out of 50
Epoch [16/250], Loss: 6.4425
Validation Loss: 6.7472
EarlyStopping counter: 2 out of 50
Epoch [17/250], Loss: 6.7321
Validation Loss: 6.8508
EarlyStopping counter: 3 out of 50
Epoch [18/250], Loss: 6.2210
Validation Loss: 6.3944
EarlyStopping counter: 4 out of 50
Epoch [19/250], Loss: 6.1137
Validation Loss: 6.5046
EarlyStopping counter: 5 out of 50
Epoch [20/250], Loss: 6.365



Epoch [3/250], Loss: 6.6373
Validation Loss: 6.4524
Validation loss improved to 6.452419
Epoch [4/250], Loss: 6.5565
Validation Loss: 6.6843
EarlyStopping counter: 1 out of 50
Epoch [5/250], Loss: 6.6460
Validation Loss: 7.2517
EarlyStopping counter: 2 out of 50
Epoch [6/250], Loss: 6.2883
Validation Loss: 6.8082
EarlyStopping counter: 3 out of 50
Epoch [7/250], Loss: 6.2616
Validation Loss: 7.3723
EarlyStopping counter: 4 out of 50
Epoch [8/250], Loss: 6.7048
Validation Loss: 7.4535
EarlyStopping counter: 5 out of 50
Epoch [9/250], Loss: 6.4405
Validation Loss: 6.5232
EarlyStopping counter: 6 out of 50
Epoch [10/250], Loss: 6.4353
Validation Loss: 6.5977
EarlyStopping counter: 7 out of 50
Epoch [11/250], Loss: 6.4240
Validation Loss: 7.2653
EarlyStopping counter: 8 out of 50
Epoch [12/250], Loss: 6.7909
Validation Loss: 6.9398
EarlyStopping counter: 9 out of 50
Epoch [13/250], Loss: 6.5624
Validation Loss: 7.1491
EarlyStopping counter: 10 out of 50
Epoch [14/250], Loss: 6.2020
Validat



Epoch [40/250], Loss: 6.3133
Validation Loss: 6.3314
EarlyStopping counter: 4 out of 50
Epoch [41/250], Loss: 6.2657
Validation Loss: 6.7210
EarlyStopping counter: 5 out of 50
Epoch [42/250], Loss: 6.1319
Validation Loss: 6.6827
EarlyStopping counter: 6 out of 50
Epoch [43/250], Loss: 5.9868
Validation Loss: 6.8728
EarlyStopping counter: 7 out of 50
Epoch [44/250], Loss: 6.7420
Validation Loss: 6.3060
EarlyStopping counter: 8 out of 50
Epoch [45/250], Loss: 6.2519
Validation Loss: 6.0647
EarlyStopping counter: 9 out of 50
Epoch [46/250], Loss: 5.8987
Validation Loss: 6.4821
EarlyStopping counter: 10 out of 50
Epoch [47/250], Loss: 6.1825
Validation Loss: 6.1741
EarlyStopping counter: 11 out of 50
Epoch [48/250], Loss: 5.9121
Validation Loss: 6.7429
EarlyStopping counter: 12 out of 50
Epoch [49/250], Loss: 6.0497
Validation Loss: 6.1194
EarlyStopping counter: 13 out of 50
Epoch [50/250], Loss: 5.7233
Validation Loss: 6.2466
EarlyStopping counter: 14 out of 50
Epoch [51/250], Loss: 6.026



Epoch [19/250], Loss: 6.3633
Validation Loss: 6.6012
EarlyStopping counter: 2 out of 50
Epoch [20/250], Loss: 6.2463
Validation Loss: 6.5559
EarlyStopping counter: 3 out of 50
Epoch [21/250], Loss: 6.3090
Validation Loss: 6.4947
EarlyStopping counter: 4 out of 50
Epoch [22/250], Loss: 6.0294
Validation Loss: 6.5083
EarlyStopping counter: 5 out of 50
Epoch [23/250], Loss: 6.3623
Validation Loss: 6.6817
EarlyStopping counter: 6 out of 50
Epoch [24/250], Loss: 6.4036
Validation Loss: 6.1397
Validation loss improved to 6.139725
Epoch [25/250], Loss: 6.4719
Validation Loss: 6.5343
EarlyStopping counter: 1 out of 50
Epoch [26/250], Loss: 6.1268
Validation Loss: 6.7366
EarlyStopping counter: 2 out of 50
Epoch [27/250], Loss: 6.2665
Validation Loss: 6.2116
EarlyStopping counter: 3 out of 50
Epoch [28/250], Loss: 6.4421
Validation Loss: 6.8607
EarlyStopping counter: 4 out of 50
Epoch [29/250], Loss: 6.3304
Validation Loss: 6.2715
EarlyStopping counter: 5 out of 50
Epoch [30/250], Loss: 6.3835
V



Validation Loss: 6.6422
EarlyStopping counter: 1 out of 50
Epoch [11/250], Loss: 6.5600
Validation Loss: 6.7134
EarlyStopping counter: 2 out of 50
Epoch [12/250], Loss: 6.6893
Validation Loss: 6.6253
EarlyStopping counter: 3 out of 50
Epoch [13/250], Loss: 6.6357
Validation Loss: 6.6981
EarlyStopping counter: 4 out of 50
Epoch [14/250], Loss: 6.6808
Validation Loss: 6.5829
EarlyStopping counter: 5 out of 50
Epoch [15/250], Loss: 6.4928
Validation Loss: 6.4887
EarlyStopping counter: 6 out of 50
Epoch [16/250], Loss: 6.6480
Validation Loss: 6.3822
Validation loss improved to 6.382240
Epoch [17/250], Loss: 6.3973
Validation Loss: 6.7823
EarlyStopping counter: 1 out of 50
Epoch [18/250], Loss: 6.3441
Validation Loss: 6.3654
Validation loss improved to 6.365428
Epoch [19/250], Loss: 6.5617
Validation Loss: 6.7622
EarlyStopping counter: 1 out of 50
Epoch [20/250], Loss: 6.5879
Validation Loss: 6.6235
EarlyStopping counter: 2 out of 50
Epoch [21/250], Loss: 6.3796
Validation Loss: 6.6889
Earl



Epoch [3/250], Loss: 6.5631
Validation Loss: 6.0946
Validation loss improved to 6.094580
Epoch [4/250], Loss: 6.6892
Validation Loss: 6.4010
EarlyStopping counter: 1 out of 50
Epoch [5/250], Loss: 6.5228
Validation Loss: 6.2183
EarlyStopping counter: 2 out of 50
Epoch [6/250], Loss: 6.5923
Validation Loss: 6.3924
EarlyStopping counter: 3 out of 50
Epoch [7/250], Loss: 6.7344
Validation Loss: 7.0312
EarlyStopping counter: 4 out of 50
Epoch [8/250], Loss: 6.5835
Validation Loss: 7.2737
EarlyStopping counter: 5 out of 50
Epoch [9/250], Loss: 6.6700
Validation Loss: 6.7675
EarlyStopping counter: 6 out of 50
Epoch [10/250], Loss: 6.4869
Validation Loss: 6.3262
EarlyStopping counter: 7 out of 50
Epoch [11/250], Loss: 6.5336
Validation Loss: 6.2666
EarlyStopping counter: 8 out of 50
Epoch [12/250], Loss: 6.6821
Validation Loss: 6.8024
EarlyStopping counter: 9 out of 50
Epoch [13/250], Loss: 6.7386
Validation Loss: 7.2842
EarlyStopping counter: 10 out of 50
Epoch [14/250], Loss: 6.3809
Validat



Epoch [3/250], Loss: 6.6505
Validation Loss: 6.6949
Validation loss improved to 6.694914
Epoch [4/250], Loss: 7.0315
Validation Loss: 6.9029
EarlyStopping counter: 1 out of 50
Epoch [5/250], Loss: 6.7342
Validation Loss: 6.7169
EarlyStopping counter: 2 out of 50
Epoch [6/250], Loss: 6.5384
Validation Loss: 6.5143
Validation loss improved to 6.514259
Epoch [7/250], Loss: 6.4860
Validation Loss: 6.4211
Validation loss improved to 6.421089
Epoch [8/250], Loss: 6.7936
Validation Loss: 6.6439
EarlyStopping counter: 1 out of 50
Epoch [9/250], Loss: 6.5553
Validation Loss: 6.7577
EarlyStopping counter: 2 out of 50
Epoch [10/250], Loss: 6.3233
Validation Loss: 6.8052
EarlyStopping counter: 3 out of 50
Epoch [11/250], Loss: 6.8743
Validation Loss: 6.3094
Validation loss improved to 6.309425
Epoch [12/250], Loss: 6.3933
Validation Loss: 6.6877
EarlyStopping counter: 1 out of 50
Epoch [13/250], Loss: 6.3730
Validation Loss: 7.0889
EarlyStopping counter: 2 out of 50
Epoch [14/250], Loss: 6.3799
Va



Validation Loss: 6.2488
EarlyStopping counter: 2 out of 50
Epoch [13/250], Loss: 6.4130
Validation Loss: 6.6778
EarlyStopping counter: 3 out of 50
Epoch [14/250], Loss: 6.2181
Validation Loss: 6.1438
Validation loss improved to 6.143814
Epoch [15/250], Loss: 6.4737
Validation Loss: 6.1323
Validation loss improved to 6.132263
Epoch [16/250], Loss: 6.3078
Validation Loss: 6.1877
EarlyStopping counter: 1 out of 50
Epoch [17/250], Loss: 6.6262
Validation Loss: 6.1929
EarlyStopping counter: 2 out of 50
Epoch [18/250], Loss: 6.2942
Validation Loss: 7.3901
EarlyStopping counter: 3 out of 50
Epoch [19/250], Loss: 6.4067
Validation Loss: 6.7266
EarlyStopping counter: 4 out of 50
Epoch [20/250], Loss: 6.3074
Validation Loss: 6.4682
EarlyStopping counter: 5 out of 50
Epoch [21/250], Loss: 6.4498
Validation Loss: 7.1504
EarlyStopping counter: 6 out of 50
Epoch [22/250], Loss: 6.3345
Validation Loss: 6.3450
EarlyStopping counter: 7 out of 50
Epoch [23/250], Loss: 6.2287
Validation Loss: 6.1687
Earl

In [28]:
print(model)

DeepSurv(
  (fc1): Linear(in_features=1056, out_features=128, bias=True)
  (bc1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (selu1): SELU()
  (droupout1): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (bc2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (selu2): SELU()
  (droupout2): Dropout(p=0.2, inplace=False)
  (fc3): Linear(in_features=128, out_features=1, bias=True)
)


In [29]:
# scaled and one hot
print("Train: ",np.mean(train_ci_ls), "Valid: ",np.mean(valid_ci_ls), 
      "Test: ",np.mean(test_ci_ls), "Elapsed time: ", np.mean(elapsed_time_ls),
      "epochs: ", np.mean(epochs_ls))

print("\nTrain: ", train_ci_ls)
print("\nvalid: ", valid_ci_ls)
print("\nTest: ", test_ci_ls)
print("\nElapsed time: ", elapsed_time_ls)
print("\nEpochs: ", epochs_ls)

Train:  0.9241653255662987 Valid:  0.7587159051928758 Test:  0.7157762352560988 Elapsed time:  28.747127532958984 epochs:  167.4

Train:  [0.9476531146396789, 0.9164191054084234, 0.9156008112918581, 0.9170449102479301, 0.9125116594590011, 0.9224971162642696, 0.9315406019669067, 0.9367283629322729, 0.9440394088669951, 0.8976181645856512]

valid:  [0.7317073170731707, 0.7312444836716682, 0.8039019509754878, 0.7498822421102214, 0.7917030567685589, 0.7539079946404645, 0.7694128787878788, 0.7618097139055223, 0.7324988409828466, 0.761090573012939]

Test:  [0.7590618336886994, 0.7686274509803922, 0.7471861471861472, 0.6633966891477621, 0.628657487091222, 0.7633769322235434, 0.6569931479242241, 0.746031746031746, 0.7604422604422605, 0.6639886578449905]

Elapsed time:  [42.032065629959106, 21.597612619400024, 23.52042269706726, 20.63997197151184, 23.593926191329956, 27.201050758361816, 33.16026282310486, 32.10476064682007, 41.66835808753967, 21.95284390449524]

Epochs:  [249, 125, 137, 119, 137