In [151]:
%load_ext autoreload
%autoreload 2
import os
import pickle as pkl
from functools import partial
from os.path import join as oj
import warnings

import numpy as np
import pandas as pd
from tqdm import tqdm

pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 50)
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn import metrics, model_selection

import imodels
from imodels.util import data_util
from imodels.discretization import discretizer, simple

import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rcParams['figure.dpi'] = 250

# change working directory to project root
while os.getcwd().split('/')[-1] != 'imodels-experiments':
    os.chdir('..')

import viz
import validate
# from local_models.stable import StableLinearClassifier as stbl_local
# from experiments.util import get_comparison_result

np.random.seed(0)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [152]:
class TransferTree:

    def __init__(self, num_groups: int, *args, **kwargs):
        self.trees = [
            DecisionTreeClassifier(*args, **kwargs) for _ in range(num_groups)
        ]

    def fit(self, X, y, sample_weights):
        for tree, weight in zip(self.trees, sample_weights):
            tree.fit(X, y, weight)
        return self

    def predict(self, X, subgroups):
        preds = np.zeros(X.shape[0])
        for tree, subgroup in zip(self.trees, subgroups):
            preds[subgroup] = tree.predict(X[subgroup])
        return preds
    
    def predict_proba(self, X, subgroups):
        preds_proba = np.zeros((X.shape[0], 2))
        for tree, subgroup in zip(self.trees, subgroups):
            preds_proba[subgroup] = tree.predict_proba(X[subgroup])
        return preds_proba


def print_results(model, X_test, y_test, test_subgroups):
    if isinstance(model, TransferTree):
        pred_proba_args = (X_test, test_subgroups)
        tree_0 = model.trees[0]
        tree_1 = model.trees[1]
    else:
        pred_proba_args = (X_test,)
        tree_0, tree_1 = model, model

    spec_scorer = validate.make_best_spec_high_sens_scorer(0.95)
    print('spec: ', spec_scorer(y_test, model.predict_proba(*pred_proba_args)[:, 1]))
    print('APC: ', metrics.average_precision_score(y_test, model.predict_proba(*pred_proba_args)[:, 1]))
    print('AUC: ', metrics.roc_auc_score(y_test, model.predict_proba(*pred_proba_args)[:, 1]))

    y_test_0, y_test_1 = y_test[~test_subgroups[1]], y_test[test_subgroups[1]]
    X_test_0, X_test_1 = X_test[~test_subgroups[1]], X_test[test_subgroups[1]]

    y_score_0 = tree_0.predict_proba(X_test_0)[:, 1]
    print('group 0 spec: ', spec_scorer(y_test_0, y_score_0))
    print('group 0 APC: ', metrics.average_precision_score(y_test_0, y_score_0))
    print('group 0 AUC: ', metrics.roc_auc_score(y_test_0, y_score_0))

    y_score_1 = tree_1.predict_proba(X_test_1)[:, 1]
    print('group 1 spec: ', spec_scorer(y_test_1, y_score_1))
    print('group 1 APC: ', metrics.average_precision_score(y_test_1, y_score_1))
    print('group 1 AUC: ', metrics.roc_auc_score(y_test_1, y_score_1))

In [178]:
X, y, feature_names = data_util.get_clean_dataset('csi_with_meta_keys.csv', data_source='imodels')
X_df = pd.DataFrame(X, columns=feature_names)

In [182]:
X_df['SITE'].value_counts()

8.0     351
16.0    329
15.0    285
7.0     267
11.0    240
10.0    230
17.0    216
13.0    197
3.0     194
4.0     179
1.0     147
6.0     147
9.0     136
2.0     121
14.0    121
12.0     90
5.0      63
Name: SITE, dtype: int64

In [154]:
cutoff = 4
max = X_df['AgeInYears'].max()

In [155]:
is_group_1 = X_df['AgeInYears'] > cutoff

In [156]:
p_group_1 = 0.5 / cutoff * X_df['AgeInYears']
p_group_1[is_group_1] = 0.5 / (max - cutoff) * (X_df.loc[is_group_1, 'AgeInYears'] - cutoff) + 0.5

In [157]:
X_df_clean = X_df.drop(columns=['SITE', 'AgeInYears'])

In [158]:
X, feature_names = X_df_clean.values, X_df_clean.columns.values
X_train, X_test, y_train, y_test, is_group_1_train, is_group_1_test, p_group_1_train, p_group_1_test = (
    model_selection.train_test_split(X, y, is_group_1, p_group_1, random_state=2))

test_subgroups = [~is_group_1_test, is_group_1_test]
dtree_args = {'max_leaf_nodes': 8, 'class_weight': {0: 1, 1: 6}}

In [159]:
def all_stats_curve(y_test, preds_proba, plot=False, thresholds=None):
    '''preds_proba should be 1d
    '''
    if thresholds is None:
        thresholds = sorted(np.unique(preds_proba))
    all_stats = {
        s: [] for s in ['sens', 'spec', 'ppv', 'npv', 'lr+', 'lr-', 'f1']
    }
    for threshold in tqdm(thresholds):
        preds = preds_proba > threshold
        tn, fp, fn, tp = metrics.confusion_matrix(y_test, preds).ravel()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            sens = tp / (tp + fn)
            spec = tn / (tn + fp)
            all_stats['sens'].append(sens)
            all_stats['spec'].append(spec)
            all_stats['ppv'].append(tp / (tp + fp))
            all_stats['npv'].append(tn / (tn + fn))
            all_stats['lr+'].append(sens / (1 - spec))
            all_stats['lr-'].append((1 - sens) / spec)
            all_stats['f1'].append(tp / (tp + 0.5 * (fp + fn)))

    if plot:
        plt.plot(all_stats['sens'], all_stats['spec'], '.-')
        plt.xlabel('sensitivity')
        plt.ylabel('specificity')
        plt.grid()
    return all_stats, thresholds

In [160]:
def predict_and_save(model, model_name='decision_tree'):
    '''Plots cv and returns cv, saves all stats
    '''
    results = {'model': model}
    for x, y, suffix in [[X_test, y_test, '_test']]:
        stats, threshes = all_stats_curve(y, model.predict_proba(x, test_subgroups)[:, 1],
                                                     plot=suffix == '_test')
        for stat in stats.keys():
            results[stat + suffix] = stats[stat]
        results['threshes' + suffix] = threshes
    # pkl.dump(results, open(oj(MODELS_DIR, model_name + '.pkl'), 'wb'))
    return stats, threshes

### plain cart

In [161]:
cart = DecisionTreeClassifier(**dtree_args)
cart.fit(X_train, y_train)
print_results(cart, X_test, y_test, test_subgroups)

spec:  0
APC:  0.37684220439698934
AUC:  0.7866985465299189
group 0 spec:  0.5533980582524272
group 0 APC:  0.5424594870886431
group 0 AUC:  0.864288729421697
group 1 spec:  0
group 1 APC:  0.3573469744460341
group 1 AUC:  0.7700014788524105


In [None]:
predict_and_save(cart)
plt.show()

### two trees, no transfer

In [166]:
subcart_no_transfer = TransferTree(2, **dtree_args)
subcart_no_transfer.fit(
    X_train, y_train, [(~is_group_1_train).astype(int), is_group_1_train.astype(int)])

print_results(subcart_no_transfer, X_test, y_test, test_subgroups)

spec:  0.09117221418234443
APC:  0.35900999628774966
AUC:  0.76155645042891
group 0 spec:  0
group 0 APC:  0.3690793603478702
group 0 AUC:  0.7847192908400169
group 1 spec:  0
group 1 APC:  0.3555925261645406
group 1 AUC:  0.7629177758059745


In [113]:
# plot_tree(subcart_no_transfer.trees[0], feature_names=feature_names)

In [114]:
# plot_tree(subcart_no_transfer.trees[1], feature_names=feature_names)

### linear transfer

In [167]:
transfer = TransferTree(2, **dtree_args)
transfer.fit(X_train, y_train, [1 - p_group_1_train, p_group_1_train])

print_results(transfer, X_test, y_test, test_subgroups)

spec:  0.0824891461649783
APC:  0.3656475390967175
AUC:  0.7846850814824136
group 0 spec:  0.6310679611650486
group 0 APC:  0.4305700735578784
group 0 AUC:  0.8330519206416209
group 1 spec:  0
group 1 APC:  0.3573469744460341
group 1 AUC:  0.7700014788524105


### linear one-way transfer (higher -> lower)

In [168]:
transfer = TransferTree(2, **dtree_args)
transfer.fit(X_train, y_train, [1 - p_group_1_train, is_group_1_train.astype(int)])

print_results(transfer, X_test, y_test, test_subgroups)

spec:  0.0824891461649783
APC:  0.36441896738686214
AUC:  0.7794941169068144
group 0 spec:  0.6310679611650486
group 0 APC:  0.4305700735578784
group 0 AUC:  0.8330519206416209
group 1 spec:  0
group 1 APC:  0.3555925261645406
group 1 AUC:  0.7629177758059745


### sigmoidal transfer

In [169]:
p_group_1_sig = 1 / (1 + np.exp(-1 * (X_df['AgeInYears'] - 4)))
p_group_1_sig_train, p_group_1_sig_test = model_selection.train_test_split(p_group_1_sig, random_state=0)

In [170]:
transfer = TransferTree(2, **dtree_args)
transfer.fit(X_train, y_train, [1 - p_group_1_sig_train, p_group_1_sig_train])

print_results(transfer, X_test, y_test, test_subgroups)

spec:  0.34008683068017365
APC:  0.35969449385571667
AUC:  0.7766888986765661
group 0 spec:  0.07766990291262135
group 0 APC:  0.39347242810765287
group 0 AUC:  0.787674124102997
group 1 spec:  0
group 1 APC:  0.3573469744460341
group 1 AUC:  0.7700014788524105


### sigmoidal one-way transfer (higher -> lower)\

In [171]:
transfer = TransferTree(2, **dtree_args)
transfer.fit(X_train, y_train, [1 - p_group_1_sig_train, is_group_1_train.astype(int)])

print_results(transfer, X_test, y_test, test_subgroups)

spec:  0.34008683068017365
APC:  0.36555831346500395
AUC:  0.7726934289729231
group 0 spec:  0.07766990291262135
group 0 APC:  0.39347242810765287
group 0 AUC:  0.787674124102997
group 1 spec:  0
group 1 APC:  0.3555925261645406
group 1 AUC:  0.7629177758059745


### step transfer

In [176]:
p_group_1_sig_train = is_group_1_train.astype(int) * 0.6
p_group_1_sig_train[p_group_1_sig_train == 0] = 0.4

In [177]:
transfer = TransferTree(2, **dtree_args)
transfer.fit(X_train, y_train, [1 - p_group_1_sig_train, p_group_1_sig_train])

print_results(transfer, X_test, y_test, test_subgroups)

spec:  0.0824891461649783
APC:  0.3852164600642098
AUC:  0.7870655844292035
group 0 spec:  0.5533980582524272
group 0 APC:  0.5424594870886431
group 0 AUC:  0.864288729421697
group 1 spec:  0
group 1 APC:  0.3573469744460341
group 1 AUC:  0.7700014788524105


In [75]:
# predict_and_save(transfer)
# plt.show()

In [22]:
# fpr, tpr, thres = metrics.roc_curve(y_test, subcart_no_transfer.predict_proba(X_test, test_subgroups)[:, 1])

# plt.plot(fpr, tpr)

In [None]:
# remove meta keys, split into train and test and fit. 

# try regular cart on the whole thing, subgroup cart w/o transfer, and subgroup cart with transfer

# try linear vs sigmoidal transfer function

# site split