In [1]:
# import libraries
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import time
import sys
import json
import pickle
import json
from ast import literal_eval

from tqdm import tqdm
import itertools

from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder, LabelEncoder
from sklearn import metrics
from sklearn.preprocessing import MinMaxScaler

from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, precision_recall_curve, roc_curve, auc
from sklearn.metrics import make_scorer, roc_auc_score

from torchsurv.loss import cox
from lifelines.utils import concordance_index

sys.path.append('./../src/')
from utils import *
from utils_XGBMLP import *

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset


In [2]:
# split data into train and test sets
seeds = [999, 7, 42, 1995, 1303, 2405, 1996, 200, 0, 777]
test_size = 0.3
batch_size = 64

l2_reg = 0.001
lr = 1e-3
max_epochs = 250
ntopfeatures = 5

In [None]:
# Dataset
torch.manual_seed(0)

data_df = pd.read_csv('./../Data/1000_features_survival_3classes.csv',index_col=0).drop(['index', 'y'],axis=1)
data_df_event_time = data_df[['event', 'time']]


data_df = pd.get_dummies(data_df.drop(['event', 'time'], axis=1),dtype='int')
scaler = MinMaxScaler()
data_df = pd.DataFrame(scaler.fit_transform(data_df), columns=data_df.columns)
data_df['event'] = [int(e) for e in data_df_event_time['event']]
data_df['time'] = data_df_event_time['time']

data_df = data_df.fillna(data_df.mean())


train_ci_ls = []
valid_ci_ls = []
test_ci_ls = []
epoch_ls = []
elapsed_time_ls = []
nconcepts_ls = []

for seed in seeds:
    print("*******************")
    print(seed)
    with open('./../models/XGBMLP/concept_weights_seed'+str(seed)+'.pkl','rb') as f:
        dict_to_save = pickle.load(f)
    f.close()
    
    concepts_weights_df = pd.DataFrame(dict_to_save)
    concepts_weights_df['abs_weights'] = [np.abs(i) for i in concepts_weights_df['weights']]
    concepts_weights_df['concepts'] = ['concept'+str(i) for i in range(len(concepts_weights_df))]
    concepts_weights_df = concepts_weights_df.sort_values('abs_weights',ascending=False)
    
    data_train, data_tmp = train_test_split(data_df, test_size=test_size, random_state=seed)
    data_val, data_test = train_test_split(data_tmp, test_size=0.5, random_state=seed)
    
    X_train = torch.tensor(data_train.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_train = torch.tensor(data_train['event'].to_numpy(), dtype=torch.long)
    t_train = torch.tensor(data_train['time'].to_numpy(), dtype=torch.long)
    train_dataset = TensorDataset(X_train, e_train, t_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    X_val = torch.tensor(data_val.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_val = torch.tensor(data_val['event'].to_numpy(), dtype=torch.long)
    t_val = torch.tensor(data_val['time'].to_numpy(), dtype=torch.long)
    val_dataset = TensorDataset(X_val, e_val, t_val)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    
    X_test = torch.tensor(data_test.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_test = torch.tensor(data_test['event'].to_numpy(), dtype=torch.long)
    t_test = torch.tensor(data_test['time'].to_numpy(), dtype=torch.long)
    test_dataset = TensorDataset(X_test, e_test, t_test)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
    # create dict of feature index so groups can be passed to MLP
    feature_group_ls = concepts_weights_df['feature_groups_idx'].to_list()[:ntopfeatures]
    nconcepts_ls = nconcepts_ls + [len(feature_group_ls)]
    
    # parameters
    input_size = X_train.shape[1]  # Number of RNA expression features
    
    model = XGBMLP(input_size, feature_group_ls)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Train the model, 
    start = time.time()
    print("Starting Training")
    epoch = train(model, optimizer, train_loader, val_loader, max_epochs, l2_reg)
    end = time.time()
    elapsed_time = end-start
    
    print("Time elapsed: ", elapsed_time)
    print("Number of epoch: ", epoch)
    epoch_ls = epoch_ls + [epoch]
    elapsed_time_ls = elapsed_time_ls + [elapsed_time]
    
    model.eval()
    
    ## Training
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(train_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    
    print("train CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    train_ci_ls = train_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]
    
    ## Validation
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(val_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    print("valid CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    valid_ci_ls = valid_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]
    ## test
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(test_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    sorted_indices = np.argsort(times_ls).tolist()
    
    print("test CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    test_ci_ls = test_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]

    torch.save(model, './../models/XGBMLP/ntopfeatures/XGBMLP_top'+str(ntopfeatures)+'_seed'+str(seed)+'.pt')
    
    dict_to_save = {'feature_groups_idx': feature_group_ls,
                    'weights': model.fc2.weight[0].detach().numpy(),
                    'feature_groups':[[data_train.columns[feature] for feature in feature_group] for feature_group in feature_group_ls]}

    with open('./../models/XGBMLP/ntopfeatures/concept_weights_top'+str(ntopfeatures)+'_seed'+str(seed)+'.pkl','wb') as f:
        pickle.dump(dict_to_save,f)
    f.close()

In [4]:
print(model)


XGBMLP(
  (intermitten_layers): ModuleList(
    (0-4): 5 x Linear(in_features=5, out_features=5, bias=True)
  )
  (bc1): BatchNorm1d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (group_layers): ModuleList(
    (0-4): 5 x Linear(in_features=5, out_features=1, bias=True)
  )
  (bc2): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
  (fc2): Linear(in_features=5, out_features=1, bias=True)
)


In [5]:
print("Train:",np.mean(train_ci_ls), "Valid:",np.mean(valid_ci_ls), "Test:",np.mean(test_ci_ls),
      "Epochs:",np.mean(epoch_ls), "Elapsed time:", np.mean(elapsed_time_ls), "nconcepts:",np.mean(nconcepts_ls))

print("\nTrain: ",train_ci_ls)
print("\nValid: ",valid_ci_ls)
print("\nTest: ",test_ci_ls)
print("\nEpoch: ",epoch_ls)
print("\nTime: ",elapsed_time_ls)
print("\nnconcepts: ",nconcepts_ls)

Train: 0.7244263046585604 Valid: 0.6583443028966569 Test: 0.6394639951728377 Epochs: 118.9 Elapsed time: 22.709285163879393 nconcepts: 5.0

Train:  [0.7724461505651525, 0.7395444872848324, 0.7590545966306552, 0.6939534247200237, 0.6592894089714237, 0.718805934529255, 0.7157968520996623, 0.7319439528637754, 0.7257733990147783, 0.7276548399060453]

Valid:  [0.6517615176151762, 0.6129744042365401, 0.727863931965983, 0.6726330664154498, 0.6563318777292576, 0.6185797230906654, 0.7121212121212122, 0.614105123087159, 0.5920259619842374, 0.7250462107208873]

Test:  [0.6368159203980099, 0.7816993464052288, 0.5974025974025974, 0.6204782342121398, 0.7276247848537005, 0.7413793103448276, 0.5912938331318017, 0.6710570753123944, 0.507985257985258, 0.5189035916824196]

Epoch:  [84, 92, 186, 170, 70, 109, 204, 76, 126, 72]

Time:  [15.695184707641602, 17.50842833518982, 34.923418283462524, 34.511085987091064, 13.336848020553589, 20.643424034118652, 38.509018421173096, 14.375809669494629, 23.9325766563