In [1]:
# import libraries
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import time
import sys
import json
import pickle
import json
from ast import literal_eval

from tqdm import tqdm
import itertools

from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder, LabelEncoder
from sklearn import metrics
from sklearn.preprocessing import MinMaxScaler

from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, precision_recall_curve, roc_curve, auc
from sklearn.metrics import make_scorer, roc_auc_score

from torchsurv.loss import cox
from lifelines.utils import concordance_index

sys.path.append('./../src/')
from utils import *
from utils_XGBMLP import *

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset


In [2]:
# split data into train and test sets
seeds = [999, 7, 42, 1995, 1303, 2405, 1996, 200, 0, 777]
test_size = 0.3
batch_size = 64

l2_reg = 0.001
lr = 1e-3
max_epochs = 500


In [None]:
# Dataset
torch.manual_seed(0)

data_df = pd.read_csv('./../Data/1000_features_survival_3classes.csv',index_col=0).drop(['index', 'y'],axis=1)
data_df_event_time = data_df[['event', 'time']]


data_df = pd.get_dummies(data_df.drop(['event', 'time'], axis=1),dtype='int')
scaler = MinMaxScaler()
data_df = pd.DataFrame(scaler.fit_transform(data_df), columns=data_df.columns)
data_df['event'] = [int(e) for e in data_df_event_time['event']]
data_df['time'] = data_df_event_time['time']

data_df = data_df.fillna(data_df.mean())


train_ci_ls = []
valid_ci_ls = []
test_ci_ls = []
epoch_ls = []
elapsed_time_ls = []
nconcepts_ls = []

for seed in seeds:
    print("*******************")
    print(seed)
    rules_df = pd.read_csv('./../results/XGBoost/rules_seed'+str(seed)+'_pruning_depth_5.csv',index_col=0)
    
    data_train, data_tmp = train_test_split(data_df, test_size=test_size, random_state=seed)
    data_val, data_test = train_test_split(data_tmp, test_size=0.5, random_state=seed)
    
    X_train = torch.tensor(data_train.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_train = torch.tensor(data_train['event'].to_numpy(), dtype=torch.long)
    t_train = torch.tensor(data_train['time'].to_numpy(), dtype=torch.long)
    train_dataset = TensorDataset(X_train, e_train, t_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    X_val = torch.tensor(data_val.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_val = torch.tensor(data_val['event'].to_numpy(), dtype=torch.long)
    t_val = torch.tensor(data_val['time'].to_numpy(), dtype=torch.long)
    val_dataset = TensorDataset(X_val, e_val, t_val)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    
    X_test = torch.tensor(data_test.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_test = torch.tensor(data_test['event'].to_numpy(), dtype=torch.long)
    t_test = torch.tensor(data_test['time'].to_numpy(), dtype=torch.long)
    test_dataset = TensorDataset(X_test, e_test, t_test)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
    # create dict of feature index so groups can be passed to MLP
    feature_groups = [[int(condition[1:]) for condition in literal_eval(conditions)] for conditions in rules_df['conditions']]
    feature_group_ls = []
    for feature in feature_groups:
        if list(set(feature)) not in feature_group_ls:
            feature_group_ls = feature_group_ls + [list(set(feature))]

    
    print(len(feature_group_ls))

    nconcepts_ls = nconcepts_ls + [len(feature_group_ls)]
    
    # parameters
    input_size = X_train.shape[1]  # Number of RNA expression features
    
    model = XGBMLP(input_size, feature_group_ls)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Train the model, 
    start = time.time()
    print("Starting Training")
    epoch = train(model, optimizer, train_loader, val_loader, max_epochs, l2_reg)
    end = time.time()
    elapsed_time = end-start
    
    print("Time elapsed: ", elapsed_time)
    print("Number of epoch: ", epoch)
    epoch_ls = epoch_ls + [epoch]
    elapsed_time_ls = elapsed_time_ls + [elapsed_time]
    
    model.eval()
    
    ## Training
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(train_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    
    print("train CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    train_ci_ls = train_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]
    
    ## Validation
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(val_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    print("valid CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    valid_ci_ls = valid_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]
    ## test
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(test_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    sorted_indices = np.argsort(times_ls).tolist()
    
    print("test CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    test_ci_ls = test_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]

    torch.save(model, './../models/XGBMLP/XGBMLP_seed'+str(seed)+'.pt')
    dict_to_save = {'feature_groups_idx': feature_group_ls,
                'weights': model.fc2.weight[0].detach().numpy(),
                'feature_groups':[[data_train.columns[feature] for feature in feature_group] for feature_group in feature_group_ls]}

    with open('./../models/XGBMLP/concept_weights_seed'+str(seed)+'.pkl','wb') as f:
        pickle.dump(dict_to_save,f)
    f.close()

In [4]:
print("Train:",np.mean(train_ci_ls), "Valid:",np.mean(valid_ci_ls), "Test:",np.mean(test_ci_ls),
      "Epochs:",np.mean(epoch_ls), "Elapsed time:", np.mean(elapsed_time_ls), "nconcepts:",np.mean(nconcepts_ls))

print("\nTrain: ",train_ci_ls)
print("\nValid: ",valid_ci_ls)
print("\nTest: ",test_ci_ls)
print("\nEpoch: ",epoch_ls)
print("\nTime: ",elapsed_time_ls)
print("\nnconcepts: ",nconcepts_ls)

Train: 0.9000068678593971 Valid: 0.6878435659875906 Test: 0.6879466165347004 Epochs: 184.2 Elapsed time: 1529.5388562202454 nconcepts: 656.1

Train:  [0.8835766494115822, 0.8899563870804627, 0.8883438884059771, 0.9246630020755879, 0.9128932417535827, 0.9108229585139811, 0.9033539369995115, 0.9137848472861277, 0.8953300492610837, 0.8773437178060741]

Valid:  [0.5934959349593496, 0.5834068843777581, 0.7593796898449224, 0.6905322656617994, 0.7786026200873363, 0.6721750781598929, 0.6827651515151515, 0.6254158349966733, 0.7315716272600834, 0.761090573012939]

Test:  [0.697228144989339, 0.6684095860566449, 0.645021645021645, 0.7121397915389331, 0.6437177280550774, 0.8317479191438764, 0.7230955259975816, 0.7443431273218507, 0.6953316953316954, 0.5184310018903592]

Epoch:  [228, 194, 136, 133, 227, 196, 145, 166, 273, 144]

Time:  [1445.7396335601807, 1548.035896062851, 886.264410495758, 1223.123381137848, 1725.9595630168915, 1709.6599242687225, 1532.4133903980255, 1629.3712539672852, 2419.709