In [1]:
# import libraries
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import time
import sys
import json
import pickle
import json
from ast import literal_eval

from tqdm import tqdm
import itertools

from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder, LabelEncoder
from sklearn import metrics
from sklearn.preprocessing import MinMaxScaler

from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, precision_recall_curve, roc_curve, auc
from sklearn.metrics import make_scorer, roc_auc_score

from torchsurv.loss import cox
from lifelines.utils import concordance_index

sys.path.append('./../src/')
from utils import *
from utils_XGBMLP import *

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset


In [2]:
# split data into train and test sets
seeds = [999, 7, 42, 1995, 1303, 2405, 1996, 200, 0, 777]
test_size = 0.3
batch_size = 64

l2_reg = 0.001
lr = 1e-3
max_epochs = 500


In [3]:
# Dataset
torch.manual_seed(0)

data_df = pd.read_csv('./../Data/1000_features_survival_3classes.csv',index_col=0).drop(['index', 'y'],axis=1)
data_df_event_time = data_df[['event', 'time']]


data_df = pd.get_dummies(data_df.drop(['event', 'time'], axis=1),dtype='int')
scaler = MinMaxScaler()
data_df = pd.DataFrame(scaler.fit_transform(data_df), columns=data_df.columns)
data_df['event'] = [int(e) for e in data_df_event_time['event']]
data_df['time'] = data_df_event_time['time']

data_df = data_df.fillna(data_df.mean())


train_ci_ls = []
valid_ci_ls = []
test_ci_ls = []
epoch_ls = []
elapsed_time_ls = []
nconcepts_ls = []

for seed in seeds:
    print("*******************")
    print(seed)
    rules_df = pd.read_csv('./../results/XGBoost/rules_seed'+str(seed)+'_pruning_depth_5.csv',index_col=0)
    
    data_train, data_tmp = train_test_split(data_df, test_size=test_size, random_state=seed)
    data_val, data_test = train_test_split(data_tmp, test_size=0.5, random_state=seed)
    
    X_train = torch.tensor(data_train.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_train = torch.tensor(data_train['event'].to_numpy(), dtype=torch.long)
    t_train = torch.tensor(data_train['time'].to_numpy(), dtype=torch.long)
    train_dataset = TensorDataset(X_train, e_train, t_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    X_val = torch.tensor(data_val.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_val = torch.tensor(data_val['event'].to_numpy(), dtype=torch.long)
    t_val = torch.tensor(data_val['time'].to_numpy(), dtype=torch.long)
    val_dataset = TensorDataset(X_val, e_val, t_val)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    
    X_test = torch.tensor(data_test.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_test = torch.tensor(data_test['event'].to_numpy(), dtype=torch.long)
    t_test = torch.tensor(data_test['time'].to_numpy(), dtype=torch.long)
    test_dataset = TensorDataset(X_test, e_test, t_test)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
    # create dict of feature index so groups can be passed to MLP
    feature_groups = [[int(condition[1:]) for condition in literal_eval(conditions)] for conditions in rules_df['conditions']]
    feature_group_ls = []
    for feature in feature_groups:
        if list(set(feature)) not in feature_group_ls:
            feature_group_ls = feature_group_ls + [list(set(feature))]

    
    print(len(feature_group_ls))

    nconcepts_ls = nconcepts_ls + [len(feature_group_ls)]
    
    # parameters
    input_size = X_train.shape[1]  # Number of RNA expression features
    
    model = XGBMLP(input_size, feature_group_ls)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Train the model, 
    start = time.time()
    print("Starting Training")
    epoch = train(model, optimizer, train_loader, val_loader, max_epochs, l2_reg)
    end = time.time()
    elapsed_time = end-start
    
    print("Time elapsed: ", elapsed_time)
    print("Number of epoch: ", epoch)
    epoch_ls = epoch_ls + [epoch]
    elapsed_time_ls = elapsed_time_ls + [elapsed_time]
    
    model.eval()
    
    ## Training
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(train_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    
    print("train CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    train_ci_ls = train_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]
    
    ## Validation
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(val_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    print("valid CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    valid_ci_ls = valid_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]
    ## test
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(test_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    sorted_indices = np.argsort(times_ls).tolist()
    
    print("test CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    test_ci_ls = test_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]

    torch.save(model, './../models/XGBMLP/XGBMLP_seed'+str(seed)+'.pt')
    dict_to_save = {'feature_groups_idx': feature_group_ls,
                'weights': model.fc2.weight[0].detach().numpy(),
                'feature_groups':[[data_train.columns[feature] for feature in feature_group] for feature_group in feature_group_ls]}

    with open('./../models/XGBMLP/concept_weights_seed'+str(seed)+'.pkl','wb') as f:
        pickle.dump(dict_to_save,f)
    f.close()

*******************
999
552
Starting Training
Epoch [1/500], Loss: 6.7802
Validation Loss: 6.7239
Epoch [2/500], Loss: 6.1388
Validation Loss: 6.1853
Validation loss improved to 6.185280
Epoch [3/500], Loss: 6.4144
Validation Loss: 6.1579
Validation loss improved to 6.157859
Epoch [4/500], Loss: 6.0962
Validation Loss: 6.3477
EarlyStopping counter: 1 out of 50
Epoch [5/500], Loss: 6.1680
Validation Loss: 6.8256
EarlyStopping counter: 2 out of 50
Epoch [6/500], Loss: 5.9345
Validation Loss: 5.8793
Validation loss improved to 5.879303
Epoch [7/500], Loss: 5.6269
Validation Loss: 5.6436
Validation loss improved to 5.643602
Epoch [8/500], Loss: 5.6174
Validation Loss: 5.5381
Validation loss improved to 5.538050
Epoch [9/500], Loss: 5.7856
Validation Loss: 6.1839
EarlyStopping counter: 1 out of 50
Epoch [10/500], Loss: 5.6843
Validation Loss: 6.5749
EarlyStopping counter: 2 out of 50
Epoch [11/500], Loss: 5.4780
Validation Loss: 5.5108
Validation loss improved to 5.510816
Epoch [12/500], Lo



Epoch [14/500], Loss: 5.2056
Validation Loss: 5.3864
EarlyStopping counter: 1 out of 50
Epoch [15/500], Loss: 5.3043
Validation Loss: 6.0510
EarlyStopping counter: 2 out of 50
Epoch [16/500], Loss: 5.4437
Validation Loss: 5.2093
Validation loss improved to 5.209267
Epoch [17/500], Loss: 5.5011
Validation Loss: 5.4859
EarlyStopping counter: 1 out of 50
Epoch [18/500], Loss: 5.7115
Validation Loss: 5.1033
Validation loss improved to 5.103319
Epoch [19/500], Loss: 5.1411
Validation Loss: 5.0803
Validation loss improved to 5.080275
Epoch [20/500], Loss: 5.0028
Validation Loss: 5.7228
EarlyStopping counter: 1 out of 50
Epoch [21/500], Loss: 4.8583
Validation Loss: 5.0278
Validation loss improved to 5.027834
Epoch [22/500], Loss: 4.9019
Validation Loss: 4.8005
Validation loss improved to 4.800531
Epoch [23/500], Loss: 4.7477
Validation Loss: 5.0996
EarlyStopping counter: 1 out of 50
Epoch [24/500], Loss: 5.0758
Validation Loss: 6.0325
EarlyStopping counter: 2 out of 50
Epoch [25/500], Loss: 



Epoch [9/500], Loss: 6.2645
Validation Loss: 6.9343
EarlyStopping counter: 4 out of 50
Epoch [10/500], Loss: 6.1880
Validation Loss: 6.8588
EarlyStopping counter: 5 out of 50
Epoch [11/500], Loss: 6.1837
Validation Loss: 6.9576
EarlyStopping counter: 6 out of 50
Epoch [12/500], Loss: 5.9699
Validation Loss: 6.2436
Validation loss improved to 6.243619
Epoch [13/500], Loss: 6.1110
Validation Loss: 6.7083
EarlyStopping counter: 1 out of 50
Epoch [14/500], Loss: 6.0332
Validation Loss: 6.6829
EarlyStopping counter: 2 out of 50
Epoch [15/500], Loss: 5.7764
Validation Loss: 6.1881
Validation loss improved to 6.188141
Epoch [16/500], Loss: 5.7117
Validation Loss: 6.2541
EarlyStopping counter: 1 out of 50
Epoch [17/500], Loss: 5.8446
Validation Loss: 6.0253
Validation loss improved to 6.025305
Epoch [18/500], Loss: 5.4859
Validation Loss: 6.1298
EarlyStopping counter: 1 out of 50
Epoch [19/500], Loss: 5.3222
Validation Loss: 6.4240
EarlyStopping counter: 2 out of 50
Epoch [20/500], Loss: 6.142



Epoch [30/500], Loss: 4.2457
Validation Loss: 4.8182
EarlyStopping counter: 5 out of 50
Epoch [31/500], Loss: 4.9399
Validation Loss: 5.2627
EarlyStopping counter: 6 out of 50
Epoch [32/500], Loss: 4.7983
Validation Loss: 4.7485
Validation loss improved to 4.748529
Epoch [33/500], Loss: 4.6855
Validation Loss: 5.0034
EarlyStopping counter: 1 out of 50
Epoch [34/500], Loss: 4.4635
Validation Loss: 4.7176
Validation loss improved to 4.717603
Epoch [35/500], Loss: 4.7031
Validation Loss: 4.8353
EarlyStopping counter: 1 out of 50
Epoch [36/500], Loss: 4.3372
Validation Loss: 4.8987
EarlyStopping counter: 2 out of 50
Epoch [37/500], Loss: 4.1017
Validation Loss: 4.7974
EarlyStopping counter: 3 out of 50
Epoch [38/500], Loss: 4.0887
Validation Loss: 4.8481
EarlyStopping counter: 4 out of 50
Epoch [39/500], Loss: 4.0712
Validation Loss: 5.5565
EarlyStopping counter: 5 out of 50
Epoch [40/500], Loss: 4.5619
Validation Loss: 4.7728
EarlyStopping counter: 6 out of 50
Epoch [41/500], Loss: 4.2664



Epoch [3/500], Loss: 7.2817
Validation Loss: 7.9189
EarlyStopping counter: 2 out of 50
Epoch [4/500], Loss: 7.1348
Validation Loss: 7.2835
Validation loss improved to 7.283546
Epoch [5/500], Loss: 6.9023
Validation Loss: 7.3881
EarlyStopping counter: 1 out of 50
Epoch [6/500], Loss: 6.4304
Validation Loss: 6.5548
Validation loss improved to 6.554784
Epoch [7/500], Loss: 6.3848
Validation Loss: 7.8577
EarlyStopping counter: 1 out of 50
Epoch [8/500], Loss: 6.2687
Validation Loss: 7.7026
EarlyStopping counter: 2 out of 50
Epoch [9/500], Loss: 6.9028
Validation Loss: 7.4370
EarlyStopping counter: 3 out of 50
Epoch [10/500], Loss: 6.2396
Validation Loss: 7.3221
EarlyStopping counter: 4 out of 50
Epoch [11/500], Loss: 6.4553
Validation Loss: 6.8431
EarlyStopping counter: 5 out of 50
Epoch [12/500], Loss: 6.0844
Validation Loss: 8.1595
EarlyStopping counter: 6 out of 50
Epoch [13/500], Loss: 6.1267
Validation Loss: 6.8943
EarlyStopping counter: 7 out of 50
Epoch [14/500], Loss: 5.9930
Valida



Epoch [9/500], Loss: 5.7804
Validation Loss: 6.4575
EarlyStopping counter: 4 out of 50
Epoch [10/500], Loss: 5.5822
Validation Loss: 5.7984
Validation loss improved to 5.798364
Epoch [11/500], Loss: 5.6551
Validation Loss: 6.1430
EarlyStopping counter: 1 out of 50
Epoch [12/500], Loss: 5.4490
Validation Loss: 6.5222
EarlyStopping counter: 2 out of 50
Epoch [13/500], Loss: 5.6629
Validation Loss: 6.8206
EarlyStopping counter: 3 out of 50
Epoch [14/500], Loss: 5.8418
Validation Loss: 5.6629
Validation loss improved to 5.662866
Epoch [15/500], Loss: 5.3024
Validation Loss: 5.8099
EarlyStopping counter: 1 out of 50
Epoch [16/500], Loss: 5.1869
Validation Loss: 5.6110
Validation loss improved to 5.611028
Epoch [17/500], Loss: 5.4512
Validation Loss: 6.3830
EarlyStopping counter: 1 out of 50
Epoch [18/500], Loss: 5.1364
Validation Loss: 6.4151
EarlyStopping counter: 2 out of 50
Epoch [19/500], Loss: 5.3182
Validation Loss: 5.5685
Validation loss improved to 5.568455
Epoch [20/500], Loss: 5.4



Epoch [13/500], Loss: 5.4301
Validation Loss: 5.9467
Validation loss improved to 5.946710
Epoch [14/500], Loss: 5.2762
Validation Loss: 6.9465
EarlyStopping counter: 1 out of 50
Epoch [15/500], Loss: 5.3830
Validation Loss: 6.2845
EarlyStopping counter: 2 out of 50
Epoch [16/500], Loss: 5.8519
Validation Loss: 6.5387
EarlyStopping counter: 3 out of 50
Epoch [17/500], Loss: 5.5435
Validation Loss: 6.6206
EarlyStopping counter: 4 out of 50
Epoch [18/500], Loss: 5.4614
Validation Loss: 6.1808
EarlyStopping counter: 5 out of 50
Epoch [19/500], Loss: 5.5011
Validation Loss: 6.0098
EarlyStopping counter: 6 out of 50
Epoch [20/500], Loss: 5.1130
Validation Loss: 6.3261
EarlyStopping counter: 7 out of 50
Epoch [21/500], Loss: 5.1261
Validation Loss: 6.5401
EarlyStopping counter: 8 out of 50
Epoch [22/500], Loss: 4.9188
Validation Loss: 5.8948
Validation loss improved to 5.894810
Epoch [23/500], Loss: 5.3724
Validation Loss: 6.3087
EarlyStopping counter: 1 out of 50
Epoch [24/500], Loss: 5.4014



Epoch [50/500], Loss: 4.3675
Validation Loss: 5.4591
EarlyStopping counter: 2 out of 50
Epoch [51/500], Loss: 4.4487
Validation Loss: 5.0610
Validation loss improved to 5.060998
Epoch [52/500], Loss: 4.2362
Validation Loss: 5.2077
EarlyStopping counter: 1 out of 50
Epoch [53/500], Loss: 4.3546
Validation Loss: 5.8030
EarlyStopping counter: 2 out of 50
Epoch [54/500], Loss: 4.4421
Validation Loss: 4.5974
Validation loss improved to 4.597355
Epoch [55/500], Loss: 4.0782
Validation Loss: 5.0959
EarlyStopping counter: 1 out of 50
Epoch [56/500], Loss: 4.3500
Validation Loss: 5.1073
EarlyStopping counter: 2 out of 50
Epoch [57/500], Loss: 4.4521
Validation Loss: 5.0722
EarlyStopping counter: 3 out of 50
Epoch [58/500], Loss: 4.3890
Validation Loss: 4.9894
EarlyStopping counter: 4 out of 50
Epoch [59/500], Loss: 4.2265
Validation Loss: 5.0556
EarlyStopping counter: 5 out of 50
Epoch [60/500], Loss: 3.9276
Validation Loss: 4.7398
EarlyStopping counter: 6 out of 50
Epoch [61/500], Loss: 4.2408



Epoch [11/500], Loss: 6.1032
Validation Loss: 7.4373
EarlyStopping counter: 1 out of 50
Epoch [12/500], Loss: 6.2194
Validation Loss: 7.2567
EarlyStopping counter: 2 out of 50
Epoch [13/500], Loss: 6.1795
Validation Loss: 6.8526
EarlyStopping counter: 3 out of 50
Epoch [14/500], Loss: 6.3266
Validation Loss: 6.5774
EarlyStopping counter: 4 out of 50
Epoch [15/500], Loss: 5.8266
Validation Loss: 5.9697
Validation loss improved to 5.969720
Epoch [16/500], Loss: 5.8485
Validation Loss: 6.6823
EarlyStopping counter: 1 out of 50
Epoch [17/500], Loss: 6.1832
Validation Loss: 6.7999
EarlyStopping counter: 2 out of 50
Epoch [18/500], Loss: 5.7737
Validation Loss: 5.8718
Validation loss improved to 5.871764
Epoch [19/500], Loss: 5.9387
Validation Loss: 6.6788
EarlyStopping counter: 1 out of 50
Epoch [20/500], Loss: 5.5662
Validation Loss: 6.8193
EarlyStopping counter: 2 out of 50
Epoch [21/500], Loss: 5.9530
Validation Loss: 5.5347
Validation loss improved to 5.534741
Epoch [22/500], Loss: 5.70



Epoch [20/500], Loss: 5.3384
Validation Loss: 5.5261
Validation loss improved to 5.526057
Epoch [21/500], Loss: 5.2795
Validation Loss: 6.1581
EarlyStopping counter: 1 out of 50
Epoch [22/500], Loss: 5.6815
Validation Loss: 5.5390
EarlyStopping counter: 2 out of 50
Epoch [23/500], Loss: 5.1662
Validation Loss: 6.4181
EarlyStopping counter: 3 out of 50
Epoch [24/500], Loss: 5.2440
Validation Loss: 5.8391
EarlyStopping counter: 4 out of 50
Epoch [25/500], Loss: 5.0549
Validation Loss: 6.3781
EarlyStopping counter: 5 out of 50
Epoch [26/500], Loss: 5.2053
Validation Loss: 6.1907
EarlyStopping counter: 6 out of 50
Epoch [27/500], Loss: 5.0844
Validation Loss: 6.8138
EarlyStopping counter: 7 out of 50
Epoch [28/500], Loss: 5.1721
Validation Loss: 6.7305
EarlyStopping counter: 8 out of 50
Epoch [29/500], Loss: 4.9055
Validation Loss: 6.6476
EarlyStopping counter: 9 out of 50
Epoch [30/500], Loss: 4.8653
Validation Loss: 6.7971
EarlyStopping counter: 10 out of 50
Epoch [31/500], Loss: 4.5985




Epoch [2/500], Loss: 7.0592
Validation Loss: 7.2300
EarlyStopping counter: 1 out of 50
Epoch [3/500], Loss: 6.8744
Validation Loss: 6.5767
Validation loss improved to 6.576748
Epoch [4/500], Loss: 6.5470
Validation Loss: 6.4929
Validation loss improved to 6.492934
Epoch [5/500], Loss: 6.4576
Validation Loss: 6.2261
Validation loss improved to 6.226100
Epoch [6/500], Loss: 6.7032
Validation Loss: 6.7409
EarlyStopping counter: 1 out of 50
Epoch [7/500], Loss: 6.1852
Validation Loss: 6.4955
EarlyStopping counter: 2 out of 50
Epoch [8/500], Loss: 6.0871
Validation Loss: 6.2460
EarlyStopping counter: 3 out of 50
Epoch [9/500], Loss: 6.0407
Validation Loss: 6.8067
EarlyStopping counter: 4 out of 50
Epoch [10/500], Loss: 6.0233
Validation Loss: 7.1663
EarlyStopping counter: 5 out of 50
Epoch [11/500], Loss: 5.9410
Validation Loss: 6.2926
EarlyStopping counter: 6 out of 50
Epoch [12/500], Loss: 5.9240
Validation Loss: 6.3769
EarlyStopping counter: 7 out of 50
Epoch [13/500], Loss: 5.7604
Valid

In [4]:
print("Train:",np.mean(train_ci_ls), "Valid:",np.mean(valid_ci_ls), "Test:",np.mean(test_ci_ls),
      "Epochs:",np.mean(epoch_ls), "Elapsed time:", np.mean(elapsed_time_ls), "nconcepts:",np.mean(nconcepts_ls))

print("\nTrain: ",train_ci_ls)
print("\nValid: ",valid_ci_ls)
print("\nTest: ",test_ci_ls)
print("\nEpoch: ",epoch_ls)
print("\nTime: ",elapsed_time_ls)
print("\nnconcepts: ",nconcepts_ls)

Train: 0.9000068678593971 Valid: 0.6878435659875906 Test: 0.6879466165347004 Epochs: 184.2 Elapsed time: 1529.5388562202454 nconcepts: 656.1

Train:  [0.8835766494115822, 0.8899563870804627, 0.8883438884059771, 0.9246630020755879, 0.9128932417535827, 0.9108229585139811, 0.9033539369995115, 0.9137848472861277, 0.8953300492610837, 0.8773437178060741]

Valid:  [0.5934959349593496, 0.5834068843777581, 0.7593796898449224, 0.6905322656617994, 0.7786026200873363, 0.6721750781598929, 0.6827651515151515, 0.6254158349966733, 0.7315716272600834, 0.761090573012939]

Test:  [0.697228144989339, 0.6684095860566449, 0.645021645021645, 0.7121397915389331, 0.6437177280550774, 0.8317479191438764, 0.7230955259975816, 0.7443431273218507, 0.6953316953316954, 0.5184310018903592]

Epoch:  [228, 194, 136, 133, 227, 196, 145, 166, 273, 144]

Time:  [1445.7396335601807, 1548.035896062851, 886.264410495758, 1223.123381137848, 1725.9595630168915, 1709.6599242687225, 1532.4133903980255, 1629.3712539672852, 2419.709