In [1]:
# import libraries
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import time
import sys
import json
import pickle
import json
from ast import literal_eval

from tqdm import tqdm
import itertools

from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder, LabelEncoder
from sklearn import metrics
from sklearn.preprocessing import MinMaxScaler

from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, precision_recall_curve, roc_curve, auc
from sklearn.metrics import make_scorer, roc_auc_score

from torchsurv.loss import cox
from lifelines.utils import concordance_index

sys.path.append('./../src/')
from utils import *
from utils_XGBMLP import *

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset


In [2]:
# split data into train and test sets
seeds = [999, 7, 42, 1995, 1303, 2405, 1996, 200, 0, 777]
test_size = 0.3
batch_size = 64

l2_reg = 0.001
lr = 1e-3
max_epochs = 250
ntopfeatures = 5

In [3]:
# Dataset
torch.manual_seed(0)

data_df = pd.read_csv('./../Data/1000_features_survival_3classes.csv',index_col=0).drop(['index', 'y'],axis=1)
data_df_event_time = data_df[['event', 'time']]


data_df = pd.get_dummies(data_df.drop(['event', 'time'], axis=1),dtype='int')
scaler = MinMaxScaler()
data_df = pd.DataFrame(scaler.fit_transform(data_df), columns=data_df.columns)
data_df['event'] = [int(e) for e in data_df_event_time['event']]
data_df['time'] = data_df_event_time['time']

data_df = data_df.fillna(data_df.mean())


train_ci_ls = []
valid_ci_ls = []
test_ci_ls = []
epoch_ls = []
elapsed_time_ls = []
nconcepts_ls = []

for seed in seeds:
    print("*******************")
    print(seed)
    with open('./../models/XGBMLP/concept_weights_seed'+str(seed)+'.pkl','rb') as f:
        dict_to_save = pickle.load(f)
    f.close()
    
    concepts_weights_df = pd.DataFrame(dict_to_save)
    concepts_weights_df['abs_weights'] = [np.abs(i) for i in concepts_weights_df['weights']]
    concepts_weights_df['concepts'] = ['concept'+str(i) for i in range(len(concepts_weights_df))]
    concepts_weights_df = concepts_weights_df.sort_values('abs_weights',ascending=False)
    
    data_train, data_tmp = train_test_split(data_df, test_size=test_size, random_state=seed)
    data_val, data_test = train_test_split(data_tmp, test_size=0.5, random_state=seed)
    
    X_train = torch.tensor(data_train.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_train = torch.tensor(data_train['event'].to_numpy(), dtype=torch.long)
    t_train = torch.tensor(data_train['time'].to_numpy(), dtype=torch.long)
    train_dataset = TensorDataset(X_train, e_train, t_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    X_val = torch.tensor(data_val.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_val = torch.tensor(data_val['event'].to_numpy(), dtype=torch.long)
    t_val = torch.tensor(data_val['time'].to_numpy(), dtype=torch.long)
    val_dataset = TensorDataset(X_val, e_val, t_val)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    
    X_test = torch.tensor(data_test.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_test = torch.tensor(data_test['event'].to_numpy(), dtype=torch.long)
    t_test = torch.tensor(data_test['time'].to_numpy(), dtype=torch.long)
    test_dataset = TensorDataset(X_test, e_test, t_test)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
    # create dict of feature index so groups can be passed to MLP
    feature_group_ls = concepts_weights_df['feature_groups_idx'].to_list()[:ntopfeatures]
    nconcepts_ls = nconcepts_ls + [len(feature_group_ls)]
    
    # parameters
    input_size = X_train.shape[1]  # Number of RNA expression features
    
    model = XGBMLP(input_size, feature_group_ls)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Train the model, 
    start = time.time()
    print("Starting Training")
    epoch = train(model, optimizer, train_loader, val_loader, max_epochs, l2_reg)
    end = time.time()
    elapsed_time = end-start
    
    print("Time elapsed: ", elapsed_time)
    print("Number of epoch: ", epoch)
    epoch_ls = epoch_ls + [epoch]
    elapsed_time_ls = elapsed_time_ls + [elapsed_time]
    
    model.eval()
    
    ## Training
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(train_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    
    print("train CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    train_ci_ls = train_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]
    
    ## Validation
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(val_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    print("valid CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    valid_ci_ls = valid_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]
    ## test
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(test_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    sorted_indices = np.argsort(times_ls).tolist()
    
    print("test CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    test_ci_ls = test_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]

    torch.save(model, './../models/XGBMLP/ntopfeatures/XGBMLP_top'+str(ntopfeatures)+'_seed'+str(seed)+'.pt')
    
    dict_to_save = {'feature_groups_idx': feature_group_ls,
                    'weights': model.fc2.weight[0].detach().numpy(),
                    'feature_groups':[[data_train.columns[feature] for feature in feature_group] for feature_group in feature_group_ls]}

    with open('./../models/XGBMLP/ntopfeatures/concept_weights_top'+str(ntopfeatures)+'_seed'+str(seed)+'.pkl','wb') as f:
        pickle.dump(dict_to_save,f)
    f.close()

*******************
999
Starting Training
Epoch [1/250], Loss: 3.1878
Validation Loss: 3.2888
Epoch [2/250], Loss: 3.3668
Validation Loss: 3.4135
EarlyStopping counter: 1 out of 50
Epoch [3/250], Loss: 3.2278
Validation Loss: 2.9001
Validation loss improved to 2.900075
Epoch [4/250], Loss: 3.1038
Validation Loss: 2.9444
EarlyStopping counter: 1 out of 50
Epoch [5/250], Loss: 3.2670
Validation Loss: 2.8557
Validation loss improved to 2.855653
Epoch [6/250], Loss: 3.2053
Validation Loss: 3.4051
EarlyStopping counter: 1 out of 50
Epoch [7/250], Loss: 3.0103
Validation Loss: 3.2279
EarlyStopping counter: 2 out of 50




Epoch [8/250], Loss: 2.9432
Validation Loss: 3.0929
EarlyStopping counter: 3 out of 50
Epoch [9/250], Loss: 3.3235
Validation Loss: 3.2662
EarlyStopping counter: 4 out of 50
Epoch [10/250], Loss: 2.9674
Validation Loss: 2.8473
Validation loss improved to 2.847273
Epoch [11/250], Loss: 2.9881
Validation Loss: 3.2528
EarlyStopping counter: 1 out of 50
Epoch [12/250], Loss: 3.0736
Validation Loss: 2.7229
Validation loss improved to 2.722892
Epoch [13/250], Loss: 3.0902
Validation Loss: 2.8010
EarlyStopping counter: 1 out of 50
Epoch [14/250], Loss: 2.8465
Validation Loss: 2.7329
EarlyStopping counter: 2 out of 50
Epoch [15/250], Loss: 2.9676
Validation Loss: 3.1224
EarlyStopping counter: 3 out of 50
Epoch [16/250], Loss: 3.1253
Validation Loss: 3.1252
EarlyStopping counter: 4 out of 50
Epoch [17/250], Loss: 3.2234
Validation Loss: 3.1106
EarlyStopping counter: 5 out of 50
Epoch [18/250], Loss: 3.0562
Validation Loss: 2.7361
EarlyStopping counter: 6 out of 50
Epoch [19/250], Loss: 3.0364
V



Epoch [11/250], Loss: 3.2224
Validation Loss: 3.4495
EarlyStopping counter: 2 out of 50
Epoch [12/250], Loss: 2.9359
Validation Loss: 3.0775
Validation loss improved to 3.077488
Epoch [13/250], Loss: 3.3320
Validation Loss: 2.9880
Validation loss improved to 2.987986
Epoch [14/250], Loss: 3.1730
Validation Loss: 3.1039
EarlyStopping counter: 1 out of 50
Epoch [15/250], Loss: 3.2163
Validation Loss: 3.8879
EarlyStopping counter: 2 out of 50
Epoch [16/250], Loss: 3.1148
Validation Loss: 3.1126
EarlyStopping counter: 3 out of 50
Epoch [17/250], Loss: 3.0976
Validation Loss: 3.3574
EarlyStopping counter: 4 out of 50
Epoch [18/250], Loss: 3.2191
Validation Loss: 3.1406
EarlyStopping counter: 5 out of 50
Epoch [19/250], Loss: 2.9807
Validation Loss: 3.2429
EarlyStopping counter: 6 out of 50
Epoch [20/250], Loss: 3.4737
Validation Loss: 3.5109
EarlyStopping counter: 7 out of 50
Epoch [21/250], Loss: 2.9286
Validation Loss: 3.0417
EarlyStopping counter: 8 out of 50
Epoch [22/250], Loss: 2.8551



Epoch [8/250], Loss: 2.9062
Validation Loss: 3.4486
EarlyStopping counter: 5 out of 50
Epoch [9/250], Loss: 3.0315
Validation Loss: 3.6522
EarlyStopping counter: 6 out of 50
Epoch [10/250], Loss: 3.4762
Validation Loss: 3.2346
Validation loss improved to 3.234642
Epoch [11/250], Loss: 3.3136
Validation Loss: 3.6822
EarlyStopping counter: 1 out of 50
Epoch [12/250], Loss: 3.1408
Validation Loss: 3.6635
EarlyStopping counter: 2 out of 50
Epoch [13/250], Loss: 3.1672
Validation Loss: 3.6791
EarlyStopping counter: 3 out of 50
Epoch [14/250], Loss: 3.2297
Validation Loss: 3.6608
EarlyStopping counter: 4 out of 50
Epoch [15/250], Loss: 3.1383
Validation Loss: 3.1291
Validation loss improved to 3.129067
Epoch [16/250], Loss: 3.3396
Validation Loss: 3.2018
EarlyStopping counter: 1 out of 50
Epoch [17/250], Loss: 3.5199
Validation Loss: 3.3736
EarlyStopping counter: 2 out of 50
Epoch [18/250], Loss: 3.0897
Validation Loss: 3.5446
EarlyStopping counter: 3 out of 50
Epoch [19/250], Loss: 3.1901
V



Validation Loss: 3.2897
EarlyStopping counter: 9 out of 50
Epoch [12/250], Loss: 2.9545
Validation Loss: 3.3671
EarlyStopping counter: 10 out of 50
Epoch [13/250], Loss: 3.0154
Validation Loss: 3.2248
EarlyStopping counter: 11 out of 50
Epoch [14/250], Loss: 3.2945
Validation Loss: 3.6304
EarlyStopping counter: 12 out of 50
Epoch [15/250], Loss: 3.3106
Validation Loss: 3.3071
EarlyStopping counter: 13 out of 50
Epoch [16/250], Loss: 3.2704
Validation Loss: 3.4483
EarlyStopping counter: 14 out of 50
Epoch [17/250], Loss: 3.2388
Validation Loss: 3.6261
EarlyStopping counter: 15 out of 50
Epoch [18/250], Loss: 3.2437
Validation Loss: 3.4918
EarlyStopping counter: 16 out of 50
Epoch [19/250], Loss: 3.0212
Validation Loss: 3.3738
EarlyStopping counter: 17 out of 50
Epoch [20/250], Loss: 2.9695
Validation Loss: 3.6761
EarlyStopping counter: 18 out of 50
Epoch [21/250], Loss: 3.0412
Validation Loss: 3.5374
EarlyStopping counter: 19 out of 50
Epoch [22/250], Loss: 3.2451
Validation Loss: 3.158



Epoch [4/250], Loss: 3.2264
Validation Loss: 3.1850
EarlyStopping counter: 2 out of 50
Epoch [5/250], Loss: 3.3838
Validation Loss: 3.4961
EarlyStopping counter: 3 out of 50
Epoch [6/250], Loss: 3.4914
Validation Loss: 3.7626
EarlyStopping counter: 4 out of 50
Epoch [7/250], Loss: 3.5051
Validation Loss: 3.4264
EarlyStopping counter: 5 out of 50
Epoch [8/250], Loss: 3.4481
Validation Loss: 3.2900
EarlyStopping counter: 6 out of 50
Epoch [9/250], Loss: 3.3013
Validation Loss: 3.0944
EarlyStopping counter: 7 out of 50
Epoch [10/250], Loss: 3.1756
Validation Loss: 3.2030
EarlyStopping counter: 8 out of 50
Epoch [11/250], Loss: 3.6024
Validation Loss: 3.4974
EarlyStopping counter: 9 out of 50
Epoch [12/250], Loss: 3.2276
Validation Loss: 3.2825
EarlyStopping counter: 10 out of 50
Epoch [13/250], Loss: 3.2416
Validation Loss: 3.2052
EarlyStopping counter: 11 out of 50
Epoch [14/250], Loss: 3.2809
Validation Loss: 2.8493
EarlyStopping counter: 12 out of 50
Epoch [15/250], Loss: 3.2214
Valida



Epoch [8/250], Loss: 2.9054
Validation Loss: 4.0065
EarlyStopping counter: 6 out of 50
Epoch [9/250], Loss: 3.2018
Validation Loss: 4.4325
EarlyStopping counter: 7 out of 50
Epoch [10/250], Loss: 3.2070
Validation Loss: 4.2467
EarlyStopping counter: 8 out of 50
Epoch [11/250], Loss: 3.2652
Validation Loss: 3.9538
EarlyStopping counter: 9 out of 50
Epoch [12/250], Loss: 3.0332
Validation Loss: 4.0507
EarlyStopping counter: 10 out of 50
Epoch [13/250], Loss: 3.0868
Validation Loss: 4.0703
EarlyStopping counter: 11 out of 50
Epoch [14/250], Loss: 3.3319
Validation Loss: 3.4474
EarlyStopping counter: 12 out of 50
Epoch [15/250], Loss: 2.8490
Validation Loss: 3.5286
EarlyStopping counter: 13 out of 50
Epoch [16/250], Loss: 3.2429
Validation Loss: 3.4014
EarlyStopping counter: 14 out of 50
Epoch [17/250], Loss: 3.0273
Validation Loss: 3.5957
EarlyStopping counter: 15 out of 50
Epoch [18/250], Loss: 3.3821
Validation Loss: 3.1931
EarlyStopping counter: 16 out of 50
Epoch [19/250], Loss: 3.165



Epoch [4/250], Loss: 3.2127
Validation Loss: 3.0905
EarlyStopping counter: 3 out of 50
Epoch [5/250], Loss: 3.1195
Validation Loss: 2.9267
EarlyStopping counter: 4 out of 50
Epoch [6/250], Loss: 3.0624
Validation Loss: 2.9930
EarlyStopping counter: 5 out of 50
Epoch [7/250], Loss: 3.1027
Validation Loss: 3.8794
EarlyStopping counter: 6 out of 50
Epoch [8/250], Loss: 3.2243
Validation Loss: 3.2536
EarlyStopping counter: 7 out of 50
Epoch [9/250], Loss: 3.0398
Validation Loss: 2.8109
Validation loss improved to 2.810885
Epoch [10/250], Loss: 3.1686
Validation Loss: 3.2297
EarlyStopping counter: 1 out of 50
Epoch [11/250], Loss: 3.0965
Validation Loss: 2.8634
EarlyStopping counter: 2 out of 50
Epoch [12/250], Loss: 3.2055
Validation Loss: 2.9536
EarlyStopping counter: 3 out of 50
Epoch [13/250], Loss: 3.1801
Validation Loss: 2.7357
Validation loss improved to 2.735699
Epoch [14/250], Loss: 2.9544
Validation Loss: 2.9453
EarlyStopping counter: 1 out of 50
Epoch [15/250], Loss: 3.3479
Valid



Epoch [14/250], Loss: 3.3301
Validation Loss: 2.7170
EarlyStopping counter: 4 out of 50
Epoch [15/250], Loss: 2.9706
Validation Loss: 2.8952
EarlyStopping counter: 5 out of 50
Epoch [16/250], Loss: 3.1138
Validation Loss: 3.1255
EarlyStopping counter: 6 out of 50
Epoch [17/250], Loss: 3.1806
Validation Loss: 3.0325
EarlyStopping counter: 7 out of 50
Epoch [18/250], Loss: 3.1144
Validation Loss: 2.8803
EarlyStopping counter: 8 out of 50
Epoch [19/250], Loss: 3.0702
Validation Loss: 2.5695
Validation loss improved to 2.569535
Epoch [20/250], Loss: 3.3480
Validation Loss: 3.0780
EarlyStopping counter: 1 out of 50
Epoch [21/250], Loss: 3.1389
Validation Loss: 2.6645
EarlyStopping counter: 2 out of 50
Epoch [22/250], Loss: 3.0816
Validation Loss: 3.3649
EarlyStopping counter: 3 out of 50
Epoch [23/250], Loss: 3.0280
Validation Loss: 3.5242
EarlyStopping counter: 4 out of 50
Epoch [24/250], Loss: 3.2174
Validation Loss: 3.2704
EarlyStopping counter: 5 out of 50
Epoch [25/250], Loss: 2.8706
V



Epoch [16/250], Loss: 2.8261
Validation Loss: 3.3541
EarlyStopping counter: 14 out of 50
Epoch [17/250], Loss: 2.9951
Validation Loss: 3.0078
EarlyStopping counter: 15 out of 50
Epoch [18/250], Loss: 2.9273
Validation Loss: 3.4884
EarlyStopping counter: 16 out of 50
Epoch [19/250], Loss: 3.0245
Validation Loss: 3.7010
EarlyStopping counter: 17 out of 50
Epoch [20/250], Loss: 3.0107
Validation Loss: 3.1831
EarlyStopping counter: 18 out of 50
Epoch [21/250], Loss: 3.1192
Validation Loss: 3.4876
EarlyStopping counter: 19 out of 50
Epoch [22/250], Loss: 3.2180
Validation Loss: 2.8644
Validation loss improved to 2.864392
Epoch [23/250], Loss: 2.8424
Validation Loss: 2.9642
EarlyStopping counter: 1 out of 50
Epoch [24/250], Loss: 3.0831
Validation Loss: 2.8014
Validation loss improved to 2.801360
Epoch [25/250], Loss: 3.1203
Validation Loss: 3.2663
EarlyStopping counter: 1 out of 50
Epoch [26/250], Loss: 2.8460
Validation Loss: 3.6605
EarlyStopping counter: 2 out of 50
Epoch [27/250], Loss: 



Validation Loss: 3.4185
EarlyStopping counter: 5 out of 50
Epoch [21/250], Loss: 3.0227
Validation Loss: 2.8345
EarlyStopping counter: 6 out of 50
Epoch [22/250], Loss: 2.8597
Validation Loss: 2.8459
EarlyStopping counter: 7 out of 50
Epoch [23/250], Loss: 3.0316
Validation Loss: 2.7351
Validation loss improved to 2.735065
Epoch [24/250], Loss: 2.8565
Validation Loss: 2.7763
EarlyStopping counter: 1 out of 50
Epoch [25/250], Loss: 2.7227
Validation Loss: 3.4390
EarlyStopping counter: 2 out of 50
Epoch [26/250], Loss: 3.0206
Validation Loss: 2.8223
EarlyStopping counter: 3 out of 50
Epoch [27/250], Loss: 3.0659
Validation Loss: 2.7784
EarlyStopping counter: 4 out of 50
Epoch [28/250], Loss: 3.0456
Validation Loss: 2.8154
EarlyStopping counter: 5 out of 50
Epoch [29/250], Loss: 3.0177
Validation Loss: 2.8117
EarlyStopping counter: 6 out of 50
Epoch [30/250], Loss: 2.7272
Validation Loss: 2.8003
EarlyStopping counter: 7 out of 50
Epoch [31/250], Loss: 2.8514
Validation Loss: 3.5465
EarlyS

In [4]:
print(model)


XGBMLP(
  (intermitten_layers): ModuleList(
    (0-4): 5 x Linear(in_features=5, out_features=5, bias=True)
  )
  (bc1): BatchNorm1d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (group_layers): ModuleList(
    (0-4): 5 x Linear(in_features=5, out_features=1, bias=True)
  )
  (bc2): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
  (fc2): Linear(in_features=5, out_features=1, bias=True)
)


In [5]:
print("Train:",np.mean(train_ci_ls), "Valid:",np.mean(valid_ci_ls), "Test:",np.mean(test_ci_ls),
      "Epochs:",np.mean(epoch_ls), "Elapsed time:", np.mean(elapsed_time_ls), "nconcepts:",np.mean(nconcepts_ls))

print("\nTrain: ",train_ci_ls)
print("\nValid: ",valid_ci_ls)
print("\nTest: ",test_ci_ls)
print("\nEpoch: ",epoch_ls)
print("\nTime: ",elapsed_time_ls)
print("\nnconcepts: ",nconcepts_ls)

Train: 0.7244263046585604 Valid: 0.6583443028966569 Test: 0.6394639951728377 Epochs: 118.9 Elapsed time: 22.709285163879393 nconcepts: 5.0

Train:  [0.7724461505651525, 0.7395444872848324, 0.7590545966306552, 0.6939534247200237, 0.6592894089714237, 0.718805934529255, 0.7157968520996623, 0.7319439528637754, 0.7257733990147783, 0.7276548399060453]

Valid:  [0.6517615176151762, 0.6129744042365401, 0.727863931965983, 0.6726330664154498, 0.6563318777292576, 0.6185797230906654, 0.7121212121212122, 0.614105123087159, 0.5920259619842374, 0.7250462107208873]

Test:  [0.6368159203980099, 0.7816993464052288, 0.5974025974025974, 0.6204782342121398, 0.7276247848537005, 0.7413793103448276, 0.5912938331318017, 0.6710570753123944, 0.507985257985258, 0.5189035916824196]

Epoch:  [84, 92, 186, 170, 70, 109, 204, 76, 126, 72]

Time:  [15.695184707641602, 17.50842833518982, 34.923418283462524, 34.511085987091064, 13.336848020553589, 20.643424034118652, 38.509018421173096, 14.375809669494629, 23.9325766563