In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, ShuffleSplit
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, precision_recall_curve, roc_curve, auc
from sklearn.metrics import make_scorer, roc_auc_score
from sklearn.metrics import matthews_corrcoef
from sklearn.preprocessing import LabelEncoder

from sklearn.preprocessing import LabelBinarizer
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import OrdinalEncoder, LabelEncoder

import matplotlib.pyplot as plt 
import seaborn as sns

import time 
import sys 
import os
from tqdm import tqdm
import itertools
import json
import pickle
import re

import xgboost as xgb
from xgboost import plot_tree
import lifelines

sys.path.append('./../src/')
from utils import *
from utils_xgboost import *

# Pruning

In [2]:
# dataset
data_df = pd.read_csv('./../Data/breast_cancer/1000_features_survival_3classes.csv',
                      index_col=0).drop(['index'],axis=1)


data_df_event_time = data_df[['event', 'time']]


data_df = pd.get_dummies(data_df.drop(['event', 'time'], axis=1),dtype=int)
scaler = MinMaxScaler()
data_df = pd.DataFrame(scaler.fit_transform(data_df), columns=data_df.columns)
data_df['event'] = [int(e) for e in data_df_event_time['event']]
data_df['time'] = data_df_event_time['time']

data_df = data_df.fillna(data_df.mean())

val_ls = []
test_ls = []
elapsed_time_ls = []

# Split data into training and validation sets
test_size=0.3
max_depth = 10

params = {'verbosity': 0,
              'objective': 'survival:aft',
              'eval_metric': 'aft-nloglik',
              'tree_method': 'hist',
              'learning_rate': 0.01,
              'aft_loss_distribution': 'logistic',
              'aft_loss_distribution_scale': 1.2,
              'max_depth': max_depth,
              'lambda': 0.01,
              'alpha': 0.1}

num_boost_round = 500

seeds = [999, 7, 42, 1995, 1303, 2405, 1996, 200, 0, 777]
pruning_result_df = []

for seed in tqdm(seeds):
    data_train, data_tmp = train_test_split(data_df, test_size=test_size, random_state=seed)
    data_val, data_test = train_test_split(data_tmp, test_size=test_size, random_state=seed)
    
    X_train = data_train.drop(['event', 'time','y'], axis=1)
    y_lower_train = data_train['time']
    y_upper_train = np.array([t if e else np.inf for t,e in zip(data_train['time'], data_train['event'])])
    dtrain = xgb.DMatrix(X_train.values)
    dtrain.set_float_info('label_lower_bound', y_lower_train)
    dtrain.set_float_info('label_upper_bound', y_upper_train)
    
    X_val = data_val.drop(['event', 'time','y'], axis=1)
    y_lower_val = data_val['time']
    y_upper_val = np.array([t if e else np.inf for t,e in zip(data_val['time'], data_val['event'])])
    dvalid = xgb.DMatrix(X_val.values)
    dvalid.set_float_info('label_lower_bound', y_lower_val)
    dvalid.set_float_info('label_upper_bound', y_upper_val)
    
    X_test = data_test.drop(['event', 'time','y'], axis=1)
    y_lower_test = data_test['time']
    y_upper_test = np.array([t if e else np.inf for t,e in zip(data_test['time'], data_test['event'])])
    dtest = xgb.DMatrix(X_test.values)
    dtest.set_float_info('label_lower_bound', y_lower_test)
    dtest.set_float_info('label_upper_bound', y_upper_test)
    
    
    bst = xgb.train(params, dtrain, num_boost_round=num_boost_round,
                    evals=[(dtrain, 'train'), (dvalid, 'valid')],
                    early_stopping_rounds=50, verbose_eval=False)
    
    
    xbgboost_rules = [] 
    ntree = 0
    max_depth_rule = []
    
    for tree_str in bst.get_dump():
        if_then_rules = parse_tree(tree_str)
        ntree = ntree+1
        xbgboost_rules.extend(if_then_rules)
        
    print("Number of trees: ", ntree)
    print("Number of rules: ", len(xbgboost_rules))

    pruning_depths = [i+1 for i in range(max_depth)]
    pruning_depths = [5]
    ntree_ls = []
    nxbgboost_rules_ls = []
    c_index_val = []
    c_index_test = []

    
    for pruning_depth in pruning_depths:
        rules_df = {'rules':[], 'seed':[], 'pruning_depth':[]}
        pruned = xgb.train(
                {"process_type": "update", "updater": "prune", "max_depth": pruning_depth},
                dtrain,
                num_boost_round=len(bst.get_dump()),
                xgb_model=bst,
                evals=[(dtrain, 'train'), (dvalid, 'valid')],
                verbose_eval=False
            )
        
        xbgboost_rules = [] 
        tree_idx_ls = []
        ntree = 0
        for tree_str in pruned.get_dump():
            if_then_rules = parse_tree(tree_str)
            tree_idx_ls = tree_idx_ls + [ntree]*len(if_then_rules)
            ntree = ntree+1
            xbgboost_rules.extend(if_then_rules)
            
            
        ntree_ls = ntree_ls + [ntree]
        nxbgboost_rules_ls = nxbgboost_rules_ls + [len(xbgboost_rules)]

        rules_df['pruning_depth'] = rules_df['pruning_depth'] + [pruning_depth]*len(xbgboost_rules)
        rules_df['seed'] = rules_df['seed'] + [seed]*len(xbgboost_rules)
        rules_df['rules'] = rules_df['rules'] + xbgboost_rules
        rules_df = pd.DataFrame(rules_df)
        rules_df['conditions'] = [[condition[1:].split('=')[0]
                                    for condition in re.sub('>','',re.sub('<','=', rule)).split()
                                      if '=' in condition] 
                                    for rule in rules_df['rules']]
        rules_df['nconditions'] = [len([condition[1:].split('=')[0]
                                        for condition in re.sub('>','',re.sub('<','=', rule)).split()
                                          if '=' in condition])
                                        for rule in rules_df['rules']]
        rules_df['tree_idx'] = tree_idx_ls
        rules_df.to_csv('./../results/XGBoost/rules_seed'+str(seed)+'_pruning_depth_'+str(pruning_depth)+'.csv')
        
        # Run prediction on the validation set
        df = pd.DataFrame({'Label (lower bound)': y_lower_val,
                           'Label (upper bound)': y_upper_val,
                           'Predicted label': pruned.predict(dvalid)})
        
        c_index_val = c_index_val + [lifelines.utils.concordance_index(event_times = data_val['time'], 
                                                                          predicted_scores = df['Predicted label'], 
                                                                          event_observed = data_val['event'])]
    
    
        # Run prediction on the test set
        df = pd.DataFrame({'Label (lower bound)': y_lower_test,
                           'Label (upper bound)': y_upper_test,
                           'Predicted label': pruned.predict(dtest)})
        
        c_index_test = c_index_test + [lifelines.utils.concordance_index(event_times = data_test['time'], 
                                                                          predicted_scores = df['Predicted label'], 
                                                                          event_observed = data_test['event'])]
        
    pruning_result_df = pruning_result_df + [pd.DataFrame({'pruning_depths':pruning_depths,
                                                        'ntree': ntree_ls, 
                                                        'xbgboost_rules':nxbgboost_rules_ls,
                                                        'c_index_val': c_index_val,
                                                        'c_index_test': c_index_test,
                                                        'seed': [seed]*len(pruning_depths)})]


pruning_result_df = pd.concat(pruning_result_df)
pruning_result_df.to_csv('./../results/XGBoost/survival_tree_prunning.csv')

  0%|                                                                                                                                                                                | 0/10 [00:00<?, ?it/s]

Number of trees:  353
Number of rules:  12243


 10%|████████████████▊                                                                                                                                                       | 1/10 [01:20<12:00, 80.06s/it]

Number of trees:  360
Number of rules:  12580


 20%|█████████████████████████████████▌                                                                                                                                      | 2/10 [02:41<10:48, 81.12s/it]

Number of trees:  381
Number of rules:  13457


 30%|██████████████████████████████████████████████████▍                                                                                                                     | 3/10 [04:11<09:55, 85.06s/it]

Number of trees:  347
Number of rules:  11913


 40%|███████████████████████████████████████████████████████████████████▏                                                                                                    | 4/10 [05:30<08:16, 82.77s/it]

Number of trees:  332
Number of rules:  12165


 50%|████████████████████████████████████████████████████████████████████████████████████                                                                                    | 5/10 [06:52<06:52, 82.50s/it]

Number of trees:  352
Number of rules:  12096


 60%|████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                   | 6/10 [08:11<05:24, 81.05s/it]

Number of trees:  387
Number of rules:  13737


 70%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                  | 7/10 [09:41<04:12, 84.19s/it]

Number of trees:  385
Number of rules:  13222


 80%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                 | 8/10 [11:06<02:48, 84.31s/it]

Number of trees:  387
Number of rules:  12800


 90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                | 9/10 [12:28<01:23, 83.61s/it]

Number of trees:  363
Number of rules:  13142


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [13:53<00:00, 83.40s/it]
