In [1]:
# import libraries
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import time
import re
import sys
import json
import pickle
import json
from ast import literal_eval

from tqdm import tqdm
import itertools

from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder, LabelEncoder
from sklearn import metrics
from sklearn.preprocessing import MinMaxScaler

from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, precision_recall_curve, roc_curve, auc
from sklearn.metrics import make_scorer, roc_auc_score

from torchsurv.loss import cox
from lifelines.utils import concordance_index

sys.path.append('./../src/')
from utils import *
from utils_RuleMLP import *

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset


# Best configs run

In [2]:
# split data into train and test sets
seeds = [999, 7, 42, 1995, 1303, 2405, 1996, 200, 0, 777]
test_size = 0.3
batch_size = 64

# hidden_size = 128  # Number of neurons in the hidden layers
l2_reg = 0.001
lr = 1e-3
max_epochs = 250



In [None]:
# Dataset
torch.manual_seed(0)

data_df = pd.read_csv('./../Data/1000_features_survival_3classes.csv',index_col=0).drop(['index', 'y'],axis=1)
data_df_event_time = data_df[['event', 'time']]


data_df = pd.get_dummies(data_df.drop(['event', 'time'], axis=1),dtype='int')
scaler = MinMaxScaler()
data_df = pd.DataFrame(scaler.fit_transform(data_df), columns=data_df.columns)
data_df['event'] = [int(e) for e in data_df_event_time['event']]
data_df['time'] = data_df_event_time['time']

data_df = data_df.fillna(data_df.mean())

train_ci_ls = []
valid_ci_ls = []
test_ci_ls = []
epoch_ls = []
elapsed_time_ls = []
nconcepts_ls = []

for seed in seeds:
    print("*******************")
    print(seed)
    rules_df = pd.read_csv('./../results/RuleKit/rules_class1_2_seed'+str(seed)+'.csv',index_col=0)
    conditions_ls_ls_ = rules_df.reset_index(drop=True)['conditions'].drop_duplicates().to_list()
    
    clinical_feature_ls = [feature for feature in data_df.columns.to_list() if 'EN' not in feature]
    
    conditions_ls_ls = []
    for condition_ls in conditions_ls_ls_:
        _ls = []
        for condition in literal_eval(condition_ls):
            if 'EN' in condition:
                _ls = _ls + [condition]
            else:
                _ls = _ls + [clinical_feature for clinical_feature in clinical_feature_ls if condition in clinical_feature]
                    
        conditions_ls_ls = conditions_ls_ls + [_ls] 
    
    feature_idx_dict = {feature:idx for idx,feature in enumerate(data_df.drop(['event', 'time'], axis=1).columns.to_list())}
    feature_groups = [[feature_idx_dict[condition] for condition in condition_ls] for condition_ls in conditions_ls_ls]
    nconcepts_ls = nconcepts_ls + [len(feature_groups)]

    ##########3
    data_train, data_tmp = train_test_split(data_df, test_size=test_size, random_state=seed)
    data_val, data_test = train_test_split(data_tmp, test_size=0.5, random_state=seed)
    
    X_train = torch.tensor(data_train.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_train = torch.tensor(data_train['event'].to_numpy(), dtype=torch.long)
    t_train = torch.tensor(data_train['time'].to_numpy(), dtype=torch.long)
    train_dataset = TensorDataset(X_train, e_train, t_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    X_val = torch.tensor(data_val.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_val = torch.tensor(data_val['event'].to_numpy(), dtype=torch.long)
    t_val = torch.tensor(data_val['time'].to_numpy(), dtype=torch.long)
    val_dataset = TensorDataset(X_val, e_val, t_val)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    
    X_test = torch.tensor(data_test.drop(['event', 'time'], axis=1).to_numpy(), dtype=torch.float32)
    e_test = torch.tensor(data_test['event'].to_numpy(), dtype=torch.long)
    t_test = torch.tensor(data_test['time'].to_numpy(), dtype=torch.long)
    test_dataset = TensorDataset(X_test, e_test, t_test)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
    # parameters
    input_size = X_train.shape[1]  # Number of RNA expression features
    
    model = RuleMLP(input_size, feature_groups)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Train the model
    start = time.time()
    epoch = train(model, optimizer, train_loader, val_loader, max_epochs, l2_reg, patience=50)
    end = time.time()
    elapsed_time = end-start
    
    print("Time elapsed: ", elapsed_time)
    print("Number of epoch: ", epoch)
    epoch_ls = epoch_ls + [epoch]
    elapsed_time_ls = elapsed_time_ls + [elapsed_time]
    
    model.eval()
    
    ## Training
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(train_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    
    print("train CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    train_ci_ls = train_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]
    
    ## Validation
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(val_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    print("valid CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                      events_ls))
    valid_ci_ls = valid_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]
    ## test
    events_ls = []
    times_ls = []
    predicted_ls = []
    
    for i, (inputs, events, times) in enumerate(test_loader):
        predicted_ls = predicted_ls + model(inputs).reshape(-1).tolist()
        events_ls = events_ls + events.tolist()
        times_ls = times_ls + times.tolist()
    
    sorted_indices = np.argsort(times_ls).tolist()
    
    print("test CI lifelines: ", concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls))
    test_ci_ls = test_ci_ls + [concordance_index(times_ls, 
                                                     [-i for i in predicted_ls],  
                                                     events_ls)]

    torch.save(model, './../models/RuleMLP/catrulekit_1hot_seed'+str(seed)+'.pt')
    dict_to_save = {'feature_groups_idx': feature_groups,
                'weights': model.fc2.weight[0].detach().numpy(),
                'feature_groups':[[condition for condition in condition_ls] for condition_ls in conditions_ls_ls]}

    with open('./../models/RuleMLP/catrulekit_1hot_concept_weights_seed'+str(seed)+'.pkl','wb') as f:
        pickle.dump(dict_to_save,f)
    f.close()

In [4]:
print(model)

RuleMLP(
  (intermitten_layers): ModuleList(
    (0-1): 2 x Linear(in_features=20, out_features=20, bias=True)
    (2): Linear(in_features=21, out_features=21, bias=True)
    (3): Linear(in_features=20, out_features=20, bias=True)
    (4): Linear(in_features=21, out_features=21, bias=True)
    (5): Linear(in_features=23, out_features=23, bias=True)
    (6): Linear(in_features=19, out_features=19, bias=True)
    (7-9): 3 x Linear(in_features=21, out_features=21, bias=True)
    (10-11): 2 x Linear(in_features=22, out_features=22, bias=True)
    (12): Linear(in_features=23, out_features=23, bias=True)
    (13): Linear(in_features=21, out_features=21, bias=True)
    (14): Linear(in_features=20, out_features=20, bias=True)
    (15): Linear(in_features=21, out_features=21, bias=True)
    (16-17): 2 x Linear(in_features=19, out_features=19, bias=True)
    (18): Linear(in_features=21, out_features=21, bias=True)
    (19): Linear(in_features=23, out_features=23, bias=True)
    (20): Linear(in_f

In [5]:
print("Train:",np.mean(train_ci_ls), "Valid:",np.mean(valid_ci_ls), "Test:",np.mean(test_ci_ls),
      "Epochs:",np.mean(epoch_ls), "Elapsed time", np.mean(elapsed_time_ls), "nconcepts", np.mean(nconcepts_ls))

print("\nTrain: ", train_ci_ls) 
print("\nValid: ", valid_ci_ls)
print("\nTest: ", test_ci_ls)
print("\nEpoch: ", epoch_ls)
print("\nElapsed time: ", elapsed_time_ls)
print("\nnconcepts: ", nconcepts_ls)

Train: 0.9253888645443332 Valid: 0.7245054932932629 Test: 0.6823901922974325 Epochs: 114.9 Elapsed time 114.6811257839203 nconcepts 67.9

Train:  [0.9217123247833421, 0.9389419126477466, 0.9223477792955006, 0.938849987455238, 0.9386712456542017, 0.9084960820969731, 0.9239788440706047, 0.9223626407944869, 0.9196256157635468, 0.9189022128816912]

Valid:  [0.6517615176151762, 0.6796116504854369, 0.7603801900950475, 0.7951012717852096, 0.7122270742358079, 0.7615006699419383, 0.7258522727272727, 0.7218895542248835, 0.7264719517848864, 0.7102587800369686]

Test:  [0.7604832977967306, 0.6836601307189543, 0.680952380952381, 0.7096873083997548, 0.6794320137693631, 0.7164090368608799, 0.6440951229343007, 0.6936845660249915, 0.6652334152334153, 0.5902646502835539]

Epoch:  [106, 149, 166, 134, 83, 102, 110, 62, 119, 118]

Elapsed time:  [95.34932351112366, 170.12869024276733, 160.9313440322876, 122.75166344642639, 78.18098211288452, 94.35079145431519, 102.73288178443909, 59.227784633636475, 123.1