In [1]:
# import libraries
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import time
import sys
import json
import pickle
import json

from tqdm import tqdm
import itertools

from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder, LabelEncoder
from sklearn import metrics
from sklearn.preprocessing import MinMaxScaler

from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, precision_recall_curve, roc_curve, auc
from sklearn.metrics import make_scorer, roc_auc_score

from torchsurv.loss import cox
from lifelines.utils import concordance_index

sys.path.append('./../src/')
from utils import *
from utils_CPHMLP import *

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset


In [70]:
# Dataset
torch.manual_seed(0)

data_df = pd.read_csv('./../Data/1000_features_survival_3classes.csv',index_col=0).drop(['index', 'y'],axis=1)
data_df_event_time = data_df[['event', 'time']]


data_df = pd.get_dummies(data_df.drop(['event', 'time'], axis=1),dtype='int')
scaler = MinMaxScaler()
data_df = pd.DataFrame(scaler.fit_transform(data_df), columns=data_df.columns)
data_df['event'] = [int(e) for e in data_df_event_time['event']]
data_df['time'] = data_df_event_time['time']

data_df = data_df.fillna(data_df.mean())

train_ci_ls = []
valid_ci_ls = []
test_ci_ls = []
elapsed_time_ls = []
epochs_ls = []

for seed in seeds:
    print("*******************")
    print(seed)
    data_train, data_tmp = train_test_split(data_df, test_size=test_size, random_state=seed)
    data_val, data_test = train_test_split(data_tmp, test_size=0.5, random_state=seed)
    
    X_train = torch.tensor(data_train.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_train = torch.tensor(data_train['event'].to_numpy(), dtype=torch.long)
    t_train = torch.tensor(data_train['time'].to_numpy(), dtype=torch.long)
    train_dataset = TensorDataset(X_train, e_train, t_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    X_val = torch.tensor(data_val.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_val = torch.tensor(data_val['event'].to_numpy(), dtype=torch.long)
    t_val = torch.tensor(data_val['time'].to_numpy(), dtype=torch.long)
    val_dataset = TensorDataset(X_val, e_val, t_val)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    
    X_test = torch.tensor(data_test.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_test = torch.tensor(data_test['event'].to_numpy(), dtype=torch.long)
    t_test = torch.tensor(data_test['time'].to_numpy(), dtype=torch.long)
    test_dataset = TensorDataset(X_test, e_test, t_test)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
    
    # parameters
    input_size = X_train.shape[1]  # Number of RNA expression features
    
    model = CPHMLP(input_size, hidden_size)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Train the model
    start = time.time()
    epoch = train(model, optimizer, train_loader, val_loader, max_epochs, l2_reg)
    end = time.time()
    elapsed_time = end-start
    elapsed_time_ls = elapsed_time_ls + [elapsed_time]
    print("Time elapsed: ", elapsed_time)
    print("Number of epoch: ", epoch)
    epochs_ls = epochs_ls + [epoch]
    
    model.eval()
    
    ## Training
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(train_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    
    print("train CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    train_ci_ls = train_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]
    
    ## Validation
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(val_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    print("valid CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    valid_ci_ls = valid_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]
    ## test
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(test_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    sorted_indices = np.argsort(times_ls).tolist()
    
    print("test CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    test_ci_ls = test_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]

*******************
999
Epoch [1/250], Loss: 3.4621
Validation Loss: 3.0474
Epoch [2/250], Loss: 3.4501
Validation Loss: 3.6065
EarlyStopping counter: 1 out of 50
Epoch [3/250], Loss: 3.1960
Validation Loss: 3.4176
EarlyStopping counter: 2 out of 50
Epoch [4/250], Loss: 3.2900
Validation Loss: 3.0134
Validation loss improved to 3.013421
Epoch [5/250], Loss: 3.6347
Validation Loss: 3.7795
EarlyStopping counter: 1 out of 50
Epoch [6/250], Loss: 3.3368
Validation Loss: 3.3128
EarlyStopping counter: 2 out of 50
Epoch [7/250], Loss: 3.3850
Validation Loss: 3.3232
EarlyStopping counter: 3 out of 50
Epoch [8/250], Loss: 3.5599
Validation Loss: 3.6003
EarlyStopping counter: 4 out of 50
Epoch [9/250], Loss: 3.2225
Validation Loss: 3.5644
EarlyStopping counter: 5 out of 50
Epoch [10/250], Loss: 3.2940
Validation Loss: 2.8833
Validation loss improved to 2.883288
Epoch [11/250], Loss: 3.4136
Validation Loss: 3.5437
EarlyStopping counter: 1 out of 50
Epoch [12/250], Loss: 3.4811
Validation Loss: 3.



Epoch [16/250], Loss: 2.9611
Validation Loss: 3.2061
EarlyStopping counter: 6 out of 50
Epoch [17/250], Loss: 3.3358
Validation Loss: 2.9232
EarlyStopping counter: 7 out of 50
Epoch [18/250], Loss: 3.3634
Validation Loss: 3.7039
EarlyStopping counter: 8 out of 50
Epoch [19/250], Loss: 3.5515
Validation Loss: 2.9290
EarlyStopping counter: 9 out of 50
Epoch [20/250], Loss: 3.4221
Validation Loss: 2.8739
Validation loss improved to 2.873859
Epoch [21/250], Loss: 3.5274
Validation Loss: 3.6518
EarlyStopping counter: 1 out of 50
Epoch [22/250], Loss: 3.1202
Validation Loss: 2.8966
EarlyStopping counter: 2 out of 50
Epoch [23/250], Loss: 3.1395
Validation Loss: 3.3598
EarlyStopping counter: 3 out of 50
Epoch [24/250], Loss: 3.1289
Validation Loss: 2.9673
EarlyStopping counter: 4 out of 50
Epoch [25/250], Loss: 3.2847
Validation Loss: 2.8146
Validation loss improved to 2.814599
Epoch [26/250], Loss: 2.9958
Validation Loss: 2.8114
Validation loss improved to 2.811363
Epoch [27/250], Loss: 3.12



Epoch [1/250], Loss: 3.1999
Validation Loss: 3.3099
Epoch [2/250], Loss: 3.3037
Validation Loss: 3.1796
Validation loss improved to 3.179571
Epoch [3/250], Loss: 3.3354
Validation Loss: 3.1692
Validation loss improved to 3.169164
Epoch [4/250], Loss: 3.0745
Validation Loss: 3.6569
EarlyStopping counter: 1 out of 50
Epoch [5/250], Loss: 3.5158
Validation Loss: 3.7839
EarlyStopping counter: 2 out of 50
Epoch [6/250], Loss: 3.1917
Validation Loss: 3.8678
EarlyStopping counter: 3 out of 50
Epoch [7/250], Loss: 3.2140
Validation Loss: 3.7774
EarlyStopping counter: 4 out of 50
Epoch [8/250], Loss: 3.4429
Validation Loss: 3.3648
EarlyStopping counter: 5 out of 50
Epoch [9/250], Loss: 3.1859
Validation Loss: 3.2792
EarlyStopping counter: 6 out of 50
Epoch [10/250], Loss: 3.0641
Validation Loss: 3.1558
Validation loss improved to 3.155788
Epoch [11/250], Loss: 3.1200
Validation Loss: 3.6894
EarlyStopping counter: 1 out of 50
Epoch [12/250], Loss: 3.1657
Validation Loss: 3.4416
EarlyStopping cou



Epoch [15/250], Loss: 2.9700
Validation Loss: 3.3027
EarlyStopping counter: 14 out of 50
Epoch [16/250], Loss: 3.1913
Validation Loss: 3.6996
EarlyStopping counter: 15 out of 50
Epoch [17/250], Loss: 3.0467
Validation Loss: 3.1808
EarlyStopping counter: 16 out of 50
Epoch [18/250], Loss: 3.3430
Validation Loss: 3.3192
EarlyStopping counter: 17 out of 50
Epoch [19/250], Loss: 2.9979
Validation Loss: 3.3418
EarlyStopping counter: 18 out of 50
Epoch [20/250], Loss: 3.1653
Validation Loss: 3.5630
EarlyStopping counter: 19 out of 50
Epoch [21/250], Loss: 3.1362
Validation Loss: 3.2828
EarlyStopping counter: 20 out of 50
Epoch [22/250], Loss: 3.0526
Validation Loss: 3.4995
EarlyStopping counter: 21 out of 50
Epoch [23/250], Loss: 3.0026
Validation Loss: 3.2051
EarlyStopping counter: 22 out of 50
Epoch [24/250], Loss: 2.8409
Validation Loss: 3.4044
EarlyStopping counter: 23 out of 50
Epoch [25/250], Loss: 2.9152
Validation Loss: 3.4517
EarlyStopping counter: 24 out of 50
Epoch [26/250], Loss:



Epoch [9/250], Loss: 2.8403
Validation Loss: 3.8167
EarlyStopping counter: 6 out of 50
Epoch [10/250], Loss: 3.3208
Validation Loss: 3.2167
EarlyStopping counter: 7 out of 50
Epoch [11/250], Loss: 3.3795
Validation Loss: 3.2816
EarlyStopping counter: 8 out of 50
Epoch [12/250], Loss: 2.9875
Validation Loss: 4.2285
EarlyStopping counter: 9 out of 50
Epoch [13/250], Loss: 3.2777
Validation Loss: 3.1585
EarlyStopping counter: 10 out of 50
Epoch [14/250], Loss: 3.3930
Validation Loss: 3.4554
EarlyStopping counter: 11 out of 50
Epoch [15/250], Loss: 3.0303
Validation Loss: 3.4343
EarlyStopping counter: 12 out of 50
Epoch [16/250], Loss: 3.0529
Validation Loss: 3.6994
EarlyStopping counter: 13 out of 50
Epoch [17/250], Loss: 3.1513
Validation Loss: 3.3895
EarlyStopping counter: 14 out of 50
Epoch [18/250], Loss: 2.8466
Validation Loss: 3.2815
EarlyStopping counter: 15 out of 50
Epoch [19/250], Loss: 3.1622
Validation Loss: 3.7555
EarlyStopping counter: 16 out of 50
Epoch [20/250], Loss: 2.91



Epoch [33/250], Loss: 2.6360
Validation Loss: 3.2869
EarlyStopping counter: 19 out of 50
Epoch [34/250], Loss: 3.2724
Validation Loss: 3.7391
EarlyStopping counter: 20 out of 50
Epoch [35/250], Loss: 3.1677
Validation Loss: 2.9642
EarlyStopping counter: 21 out of 50
Epoch [36/250], Loss: 3.0298
Validation Loss: 3.2671
EarlyStopping counter: 22 out of 50
Epoch [37/250], Loss: 3.1422
Validation Loss: 3.6576
EarlyStopping counter: 23 out of 50
Epoch [38/250], Loss: 2.8093
Validation Loss: 2.8741
EarlyStopping counter: 24 out of 50
Epoch [39/250], Loss: 2.8491
Validation Loss: 2.8713
EarlyStopping counter: 25 out of 50
Epoch [40/250], Loss: 3.1566
Validation Loss: 2.9072
EarlyStopping counter: 26 out of 50
Epoch [41/250], Loss: 2.8524
Validation Loss: 3.1731
EarlyStopping counter: 27 out of 50
Epoch [42/250], Loss: 2.8151
Validation Loss: 3.3662
EarlyStopping counter: 28 out of 50
Epoch [43/250], Loss: 2.9422
Validation Loss: 2.8332
Validation loss improved to 2.833169
Epoch [44/250], Loss



Epoch [3/250], Loss: 3.4337
Validation Loss: 3.0244
Validation loss improved to 3.024375
Epoch [4/250], Loss: 3.2706
Validation Loss: 2.8635
Validation loss improved to 2.863529
Epoch [5/250], Loss: 3.4234
Validation Loss: 3.6673
EarlyStopping counter: 1 out of 50
Epoch [6/250], Loss: 3.6313
Validation Loss: 3.6102
EarlyStopping counter: 2 out of 50
Epoch [7/250], Loss: 3.4793
Validation Loss: 4.2888
EarlyStopping counter: 3 out of 50
Epoch [8/250], Loss: 3.4158
Validation Loss: 3.6305
EarlyStopping counter: 4 out of 50
Epoch [9/250], Loss: 3.3765
Validation Loss: 2.9973
EarlyStopping counter: 5 out of 50
Epoch [10/250], Loss: 3.1771
Validation Loss: 3.5037
EarlyStopping counter: 6 out of 50
Epoch [11/250], Loss: 3.3443
Validation Loss: 3.6088
EarlyStopping counter: 7 out of 50
Epoch [12/250], Loss: 2.9632
Validation Loss: 3.4276
EarlyStopping counter: 8 out of 50
Epoch [13/250], Loss: 3.1523
Validation Loss: 4.3696
EarlyStopping counter: 9 out of 50
Epoch [14/250], Loss: 3.5562
Valida



Epoch [3/250], Loss: 3.1054
Validation Loss: 3.3156
EarlyStopping counter: 1 out of 50
Epoch [4/250], Loss: 3.2832
Validation Loss: 3.2019
EarlyStopping counter: 2 out of 50
Epoch [5/250], Loss: 3.3262
Validation Loss: 2.8692
Validation loss improved to 2.869197
Epoch [6/250], Loss: 3.3080
Validation Loss: 3.0039
EarlyStopping counter: 1 out of 50
Epoch [7/250], Loss: 3.5569
Validation Loss: 3.2023
EarlyStopping counter: 2 out of 50
Epoch [8/250], Loss: 3.3156
Validation Loss: 2.9578
EarlyStopping counter: 3 out of 50
Epoch [9/250], Loss: 3.3298
Validation Loss: 3.3368
EarlyStopping counter: 4 out of 50
Epoch [10/250], Loss: 3.3436
Validation Loss: 3.1838
EarlyStopping counter: 5 out of 50
Epoch [11/250], Loss: 3.1407
Validation Loss: 3.0792
EarlyStopping counter: 6 out of 50
Epoch [12/250], Loss: 3.6012
Validation Loss: 3.6575
EarlyStopping counter: 7 out of 50
Epoch [13/250], Loss: 3.1794
Validation Loss: 3.2715
EarlyStopping counter: 8 out of 50
Epoch [14/250], Loss: 3.2458
Validati



Epoch [3/250], Loss: 3.1389
Validation Loss: 3.2412
Validation loss improved to 3.241196
Epoch [4/250], Loss: 3.6313
Validation Loss: 2.9597
Validation loss improved to 2.959708
Epoch [5/250], Loss: 3.3736
Validation Loss: 4.2791
EarlyStopping counter: 1 out of 50
Epoch [6/250], Loss: 3.3711
Validation Loss: 3.3699
EarlyStopping counter: 2 out of 50
Epoch [7/250], Loss: 3.4020
Validation Loss: 3.4843
EarlyStopping counter: 3 out of 50
Epoch [8/250], Loss: 3.1364
Validation Loss: 4.1410
EarlyStopping counter: 4 out of 50
Epoch [9/250], Loss: 3.4068
Validation Loss: 3.2884
EarlyStopping counter: 5 out of 50
Epoch [10/250], Loss: 3.3083
Validation Loss: 3.5906
EarlyStopping counter: 6 out of 50
Epoch [11/250], Loss: 3.4766
Validation Loss: 3.5112
EarlyStopping counter: 7 out of 50
Epoch [12/250], Loss: 3.5211
Validation Loss: 3.5155
EarlyStopping counter: 8 out of 50
Epoch [13/250], Loss: 3.1606
Validation Loss: 3.2867
EarlyStopping counter: 9 out of 50
Epoch [14/250], Loss: 3.1984
Valida



Epoch [40/250], Loss: 2.8792
Validation Loss: 2.9297
EarlyStopping counter: 2 out of 50
Epoch [41/250], Loss: 2.8480
Validation Loss: 2.9060
EarlyStopping counter: 3 out of 50
Epoch [42/250], Loss: 2.9077
Validation Loss: 3.5989
EarlyStopping counter: 4 out of 50
Epoch [43/250], Loss: 2.6967
Validation Loss: 2.9605
EarlyStopping counter: 5 out of 50
Epoch [44/250], Loss: 2.8756
Validation Loss: 2.9807
EarlyStopping counter: 6 out of 50
Epoch [45/250], Loss: 2.7290
Validation Loss: 3.3726
EarlyStopping counter: 7 out of 50
Epoch [46/250], Loss: 2.6586
Validation Loss: 3.9450
EarlyStopping counter: 8 out of 50
Epoch [47/250], Loss: 2.9701
Validation Loss: 3.6905
EarlyStopping counter: 9 out of 50
Epoch [48/250], Loss: 2.6695
Validation Loss: 3.1602
EarlyStopping counter: 10 out of 50
Epoch [49/250], Loss: 2.7054
Validation Loss: 4.1474
EarlyStopping counter: 11 out of 50
Epoch [50/250], Loss: 2.8836
Validation Loss: 3.3960
EarlyStopping counter: 12 out of 50
Epoch [51/250], Loss: 2.9581




Epoch [17/250], Loss: 2.9748
Validation Loss: 3.0940
EarlyStopping counter: 3 out of 50
Epoch [18/250], Loss: 3.1081
Validation Loss: 3.0382
EarlyStopping counter: 4 out of 50
Epoch [19/250], Loss: 3.1348
Validation Loss: 2.9398
Validation loss improved to 2.939787
Epoch [20/250], Loss: 3.1354
Validation Loss: 2.9412
EarlyStopping counter: 1 out of 50
Epoch [21/250], Loss: 2.8937
Validation Loss: 2.8848
Validation loss improved to 2.884792
Epoch [22/250], Loss: 3.2959
Validation Loss: 3.0695
EarlyStopping counter: 1 out of 50
Epoch [23/250], Loss: 3.0443
Validation Loss: 3.5378
EarlyStopping counter: 2 out of 50
Epoch [24/250], Loss: 2.8094
Validation Loss: 3.4505
EarlyStopping counter: 3 out of 50
Epoch [25/250], Loss: 3.0893
Validation Loss: 3.2513
EarlyStopping counter: 4 out of 50
Epoch [26/250], Loss: 3.1621
Validation Loss: 2.9630
EarlyStopping counter: 5 out of 50
Epoch [27/250], Loss: 2.9240
Validation Loss: 2.8552
Validation loss improved to 2.855227
Epoch [28/250], Loss: 2.79

In [71]:
print(model)

CPHMLP(
  (fc1): Linear(in_features=1056, out_features=64, bias=True)
  (bc1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): SELU()
  (fc2): Linear(in_features=64, out_features=64, bias=True)
  (bc2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): SELU()
  (fc3): Linear(in_features=64, out_features=1, bias=True)
)


In [72]:
# scaled and one hot
print("Train: ",np.mean(train_ci_ls), "Valid: ",np.mean(valid_ci_ls), 
      "Test: ",np.mean(test_ci_ls), "Elapsed time: ", np.mean(elapsed_time_ls),
      "epochs: ", np.mean(epochs_ls))

print("\nTrain: ", train_ci_ls)
print("\nvalid: ", valid_ci_ls)
print("\nTest: ", test_ci_ls)
print("\nElapsed time: ", elapsed_time_ls)
print("\nEpochs: ", epochs_ls)

Train:  0.919216835578251 Valid:  0.7545780608780869 Test:  0.7188251843954966 Elapsed time:  52.940882730484006 epochs:  144.9

Train:  [0.9111653967700033, 0.9385626698691613, 0.9225754377250714, 0.900326163811783, 0.9390952259815144, 0.9230141999124936, 0.9003164893052104, 0.8893006599904228, 0.944768472906404, 0.9230436395104463]

valid:  [0.7560975609756098, 0.7308031774051191, 0.7978989494747374, 0.7757889778615167, 0.7737991266375546, 0.8075033497096918, 0.7694128787878788, 0.635395874916833, 0.7158089939731108, 0.783271719038817]

Test:  [0.7555081734186212, 0.8008714596949891, 0.6813852813852814, 0.6407112201103617, 0.6609294320137694, 0.7615933412604042, 0.6948810963321241, 0.7450185748058088, 0.7512285012285013, 0.696124763705104]

Elapsed time:  [50.661893367767334, 67.56466341018677, 43.881601095199585, 33.866379499435425, 65.96821093559265, 54.37343406677246, 35.562692165374756, 36.36761474609375, 85.19462847709656, 55.9677095413208]

Epochs:  [140, 185, 119, 91, 179, 150