In [1]:
# import libraries
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import time
import sys
import json
import pickle
import json

from tqdm import tqdm
import itertools

from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder, LabelEncoder
from sklearn import metrics
from sklearn.preprocessing import MinMaxScaler

from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, precision_recall_curve, roc_curve, auc
from sklearn.metrics import make_scorer, roc_auc_score

from torchsurv.loss import cox
from lifelines.utils import concordance_index

sys.path.append('./../src/')
from utils import *
from utils_deepsurv import *

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset


In [26]:
# split data into train and test sets
seeds = [999, 7, 42, 1995, 1303, 2405, 1996, 200, 0, 777]
test_size = 0.3
batch_size = 64

hidden_size = 128  # Number of neurons in the hidden layers
l2_reg = 0.01
lr = 1e-5
max_epochs = 250
dropout = 0.2


In [None]:
# Dataset
torch.manual_seed(0)

data_df = pd.read_csv('./../Data/1000_features_survival_3classes.csv',index_col=0).drop(['index', 'y'],axis=1)
data_df_event_time = data_df[['event', 'time']]


data_df = pd.get_dummies(data_df.drop(['event', 'time'], axis=1),dtype='int')
scaler = MinMaxScaler()
data_df = pd.DataFrame(scaler.fit_transform(data_df), columns=data_df.columns)
data_df['event'] = [int(e) for e in data_df_event_time['event']]
data_df['time'] = data_df_event_time['time']

data_df = data_df.fillna(data_df.mean())

train_ci_ls = []
valid_ci_ls = []
test_ci_ls = []
elapsed_time_ls = []
epochs_ls = []

for seed in seeds:
    print("*******************")
    print(seed)
    data_train, data_tmp = train_test_split(data_df, test_size=test_size, random_state=seed)
    data_val, data_test = train_test_split(data_tmp, test_size=0.5, random_state=seed)
    
    X_train = torch.tensor(data_train.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_train = torch.tensor(data_train['event'].to_numpy(), dtype=torch.long)
    t_train = torch.tensor(data_train['time'].to_numpy(), dtype=torch.long)
    train_dataset = TensorDataset(X_train, e_train, t_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    X_val = torch.tensor(data_val.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_val = torch.tensor(data_val['event'].to_numpy(), dtype=torch.long)
    t_val = torch.tensor(data_val['time'].to_numpy(), dtype=torch.long)
    val_dataset = TensorDataset(X_val, e_val, t_val)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    
    X_test = torch.tensor(data_test.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_test = torch.tensor(data_test['event'].to_numpy(), dtype=torch.long)
    t_test = torch.tensor(data_test['time'].to_numpy(), dtype=torch.long)
    test_dataset = TensorDataset(X_test, e_test, t_test)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
    
    # parameters
    input_size = X_train.shape[1]  # Number of RNA expression features
    
    model = DeepSurv(input_size, hidden_size ,dropout=dropout)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Train the model
    start = time.time()
    epoch = train(model, optimizer, train_loader, val_loader, max_epochs, l2_reg)
    end = time.time()
    elapsed_time = end-start
    elapsed_time_ls = elapsed_time_ls + [elapsed_time]
    print("Time elapsed: ", elapsed_time)
    print("Number of epoch: ", epoch)
    epochs_ls = epochs_ls + [epoch]
    
    model.eval()
    
    ## Training
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(train_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    
    print("train CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    train_ci_ls = train_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]
    
    ## Validation
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(val_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    print("valid CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    valid_ci_ls = valid_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]
    ## test
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(test_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    sorted_indices = np.argsort(times_ls).tolist()
    
    print("test CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    test_ci_ls = test_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]

In [28]:
print(model)

DeepSurv(
  (fc1): Linear(in_features=1056, out_features=128, bias=True)
  (bc1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (selu1): SELU()
  (droupout1): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (bc2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (selu2): SELU()
  (droupout2): Dropout(p=0.2, inplace=False)
  (fc3): Linear(in_features=128, out_features=1, bias=True)
)


In [29]:
# scaled and one hot
print("Train: ",np.mean(train_ci_ls), "Valid: ",np.mean(valid_ci_ls), 
      "Test: ",np.mean(test_ci_ls), "Elapsed time: ", np.mean(elapsed_time_ls),
      "epochs: ", np.mean(epochs_ls))

print("\nTrain: ", train_ci_ls)
print("\nvalid: ", valid_ci_ls)
print("\nTest: ", test_ci_ls)
print("\nElapsed time: ", elapsed_time_ls)
print("\nEpochs: ", epochs_ls)

Train:  0.9241653255662987 Valid:  0.7587159051928758 Test:  0.7157762352560988 Elapsed time:  28.747127532958984 epochs:  167.4

Train:  [0.9476531146396789, 0.9164191054084234, 0.9156008112918581, 0.9170449102479301, 0.9125116594590011, 0.9224971162642696, 0.9315406019669067, 0.9367283629322729, 0.9440394088669951, 0.8976181645856512]

valid:  [0.7317073170731707, 0.7312444836716682, 0.8039019509754878, 0.7498822421102214, 0.7917030567685589, 0.7539079946404645, 0.7694128787878788, 0.7618097139055223, 0.7324988409828466, 0.761090573012939]

Test:  [0.7590618336886994, 0.7686274509803922, 0.7471861471861472, 0.6633966891477621, 0.628657487091222, 0.7633769322235434, 0.6569931479242241, 0.746031746031746, 0.7604422604422605, 0.6639886578449905]

Elapsed time:  [42.032065629959106, 21.597612619400024, 23.52042269706726, 20.63997197151184, 23.593926191329956, 27.201050758361816, 33.16026282310486, 32.10476064682007, 41.66835808753967, 21.95284390449524]

Epochs:  [249, 125, 137, 119, 137