In [1]:
# import libraries
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import time
import sys
import json
import pickle
import json

from tqdm import tqdm
import itertools

from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder, LabelEncoder
from sklearn import metrics
from sklearn.preprocessing import MinMaxScaler

from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, precision_recall_curve, roc_curve, auc
from sklearn.metrics import make_scorer, roc_auc_score

from torchsurv.loss import cox
from lifelines.utils import concordance_index

sys.path.append('./../src/')
from utils import *
from utils_CPHMLP import *

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset


In [None]:
# Dataset
torch.manual_seed(0)

data_df = pd.read_csv('./../Data/1000_features_survival_3classes.csv',index_col=0).drop(['index', 'y'],axis=1)
data_df_event_time = data_df[['event', 'time']]


data_df = pd.get_dummies(data_df.drop(['event', 'time'], axis=1),dtype='int')
scaler = MinMaxScaler()
data_df = pd.DataFrame(scaler.fit_transform(data_df), columns=data_df.columns)
data_df['event'] = [int(e) for e in data_df_event_time['event']]
data_df['time'] = data_df_event_time['time']

data_df = data_df.fillna(data_df.mean())

train_ci_ls = []
valid_ci_ls = []
test_ci_ls = []
elapsed_time_ls = []
epochs_ls = []

for seed in seeds:
    print("*******************")
    print(seed)
    data_train, data_tmp = train_test_split(data_df, test_size=test_size, random_state=seed)
    data_val, data_test = train_test_split(data_tmp, test_size=0.5, random_state=seed)
    
    X_train = torch.tensor(data_train.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_train = torch.tensor(data_train['event'].to_numpy(), dtype=torch.long)
    t_train = torch.tensor(data_train['time'].to_numpy(), dtype=torch.long)
    train_dataset = TensorDataset(X_train, e_train, t_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    X_val = torch.tensor(data_val.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_val = torch.tensor(data_val['event'].to_numpy(), dtype=torch.long)
    t_val = torch.tensor(data_val['time'].to_numpy(), dtype=torch.long)
    val_dataset = TensorDataset(X_val, e_val, t_val)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    
    X_test = torch.tensor(data_test.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_test = torch.tensor(data_test['event'].to_numpy(), dtype=torch.long)
    t_test = torch.tensor(data_test['time'].to_numpy(), dtype=torch.long)
    test_dataset = TensorDataset(X_test, e_test, t_test)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
    
    # parameters
    input_size = X_train.shape[1]  # Number of RNA expression features
    
    model = CPHMLP(input_size, hidden_size)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Train the model
    start = time.time()
    epoch = train(model, optimizer, train_loader, val_loader, max_epochs, l2_reg)
    end = time.time()
    elapsed_time = end-start
    elapsed_time_ls = elapsed_time_ls + [elapsed_time]
    print("Time elapsed: ", elapsed_time)
    print("Number of epoch: ", epoch)
    epochs_ls = epochs_ls + [epoch]
    
    model.eval()
    
    ## Training
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(train_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    
    print("train CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    train_ci_ls = train_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]
    
    ## Validation
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(val_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    print("valid CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    valid_ci_ls = valid_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]
    ## test
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(test_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    sorted_indices = np.argsort(times_ls).tolist()
    
    print("test CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    test_ci_ls = test_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]

In [71]:
print(model)

CPHMLP(
  (fc1): Linear(in_features=1056, out_features=64, bias=True)
  (bc1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): SELU()
  (fc2): Linear(in_features=64, out_features=64, bias=True)
  (bc2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): SELU()
  (fc3): Linear(in_features=64, out_features=1, bias=True)
)


In [72]:
# scaled and one hot
print("Train: ",np.mean(train_ci_ls), "Valid: ",np.mean(valid_ci_ls), 
      "Test: ",np.mean(test_ci_ls), "Elapsed time: ", np.mean(elapsed_time_ls),
      "epochs: ", np.mean(epochs_ls))

print("\nTrain: ", train_ci_ls)
print("\nvalid: ", valid_ci_ls)
print("\nTest: ", test_ci_ls)
print("\nElapsed time: ", elapsed_time_ls)
print("\nEpochs: ", epochs_ls)

Train:  0.919216835578251 Valid:  0.7545780608780869 Test:  0.7188251843954966 Elapsed time:  52.940882730484006 epochs:  144.9

Train:  [0.9111653967700033, 0.9385626698691613, 0.9225754377250714, 0.900326163811783, 0.9390952259815144, 0.9230141999124936, 0.9003164893052104, 0.8893006599904228, 0.944768472906404, 0.9230436395104463]

valid:  [0.7560975609756098, 0.7308031774051191, 0.7978989494747374, 0.7757889778615167, 0.7737991266375546, 0.8075033497096918, 0.7694128787878788, 0.635395874916833, 0.7158089939731108, 0.783271719038817]

Test:  [0.7555081734186212, 0.8008714596949891, 0.6813852813852814, 0.6407112201103617, 0.6609294320137694, 0.7615933412604042, 0.6948810963321241, 0.7450185748058088, 0.7512285012285013, 0.696124763705104]

Elapsed time:  [50.661893367767334, 67.56466341018677, 43.881601095199585, 33.866379499435425, 65.96821093559265, 54.37343406677246, 35.562692165374756, 36.36761474609375, 85.19462847709656, 55.9677095413208]

Epochs:  [140, 185, 119, 91, 179, 150