In [42]:
from tqdm import tqdm
from sklearn.model_selection import StratifiedShuffleSplit
import os
os.environ['PYTHONWARNINGS'] = "ignore"
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import sklearn
import matplotlib.pyplot as plt
import math
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import KBinsDiscretizer
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score,confusion_matrix
from sklearn.preprocessing import OneHotEncoder
import pickle
import xgboost as xgb
from shap import GPUTreeExplainer
from matplotlib.ticker import MaxNLocator,MultipleLocator
from scipy.stats import kendalltau
from scipy.special import expit

import mlresearch
mlresearch.utils.set_matplotlib_style()
from mlresearch.utils import set_matplotlib_style
set_matplotlib_style(font_size=27)

In [43]:
import numpy as np
print(np.__version__) #1.26.4
# print(shap.__version__) #0.46.1.dev86
print(sklearn.__version__) #1.6.0
print(xgb.__version__) #1.7.6

2.0.2
1.6.1
2.1.4


## Preprocessing

In [44]:
# Use White alone & African American only 
FEAT_CNT = 8
STATE = 'VA'
FOLDS = 5
seeds = [0,21,42,63,84]

In [45]:
categorical_cols =['Occupation', 'Marriage','Place of Birth','Sex', 'Race']

with open(file=f'dataset/ACS_Income_{STATE}.pickle', mode='rb') as f:
    df=pickle.load(f)
df = df[(df['Race']=='White alone') | (df['Race']== 'Black or African American alone')]
df.reset_index(drop=True, inplace=True)
columns = df.columns
with pd.option_context('future.no_silent_downcasting', True):
    df.replace([' <=50K',' >50K'],
                 [0,1], inplace = True)
    df['Sex'].replace( {'Female':0.0},inplace = True)
    df['Sex'].replace({'Male':1.0}, inplace = True)
X = df.iloc[:, 0:FEAT_CNT]
Y = df.iloc[:, FEAT_CNT]

category_col =['Occupation', 'Marriage','Place of Birth', 'Race']
X = pd.get_dummies(X, columns=category_col, drop_first=True)
for c in X.columns:
    X[c] = X[c].astype(float)

## Utils

In [72]:
def get_bins(strategy, bin_size, seed):
    ## compute bin boundary
    np.random.seed(seed)
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=seed)
    for train_val_idx, test_idx in splitter.split(X, Y):
        X_train_val, X_test = X.iloc[train_val_idx], X.iloc[test_idx]
        Y_train_val, Y_test = Y.iloc[train_val_idx], Y.iloc[test_idx]
    
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.25, random_state=seed)
    for train_idx, val_idx in splitter.split(X_train_val, Y_train_val):
        X_train, X_val = X_train_val.iloc[train_idx], X_train_val.iloc[val_idx]
        Y_train, Y_val = Y_train_val.iloc[train_idx], Y_train_val.iloc[val_idx] 
    kd = KBinsDiscretizer(n_bins=bin_size, encode='ordinal', strategy=strategy)

    kd.fit(X_train)
    bin_boundaries = kd.bin_edges_[0]

    return bin_boundaries
def assign_age(age,bin_edges):
    # Assign age to a median of the bin_edges
    for idx in range(len(bin_edges)-1):
        if age == bin_edges[-1]:
            median = (bin_edges[-1] + bin_edges[-2])/2
        elif bin_edges[idx] <= age and age < bin_edges[idx+1]:
            median = (bin_edges[idx] + bin_edges[idx+1])/2
    return median

In [73]:
def compute_shap(X_train,X_test,Y_train,Y_test, seed):

    print('**********START**********')
    # Train model on new data
    param_grid = {
        'classifier__n_estimators': [50, 100, 200],  # Number of boosting rounds
        'classifier__max_depth': [3, 5, 7,9,11],          # Maximum tree depth
        'classifier__learning_rate': [0.01, 0.1, 0.2],  # Step size shrinkage 
        'classifier__colsample_bytree': [0.8, 1.0],  # Subsample ratio of columns for each tree
        'classifier__gamma': [0, 0.1, 0.2],          # Minimum loss reduction for a split
    }
    model = xgb.XGBClassifier(random_state=seed)
    grid_search = GridSearchCV(
        model, 
        param_grid,              # 3-fold cross-validation
        scoring='f1',   # Evaluation metric
        n_jobs=-1,            # Use all processors
        verbose=1             # Print progress
    )

    grid_search.fit(X_train, Y_train)
        
    # Extract the best model
    model = grid_search.best_estimator_

    explainer = GPUTreeExplainer(model,X_train,feature_perturbation = 'interventional')
    shap_values = explainer(X_test)
    pred = best_model.predict(X_test)
    return shap_values, pred
    
def get_tfs(shap_vals,Y_true, pred):
    # Compute indices of errors
    TP_i = np.where((Y_true == 1.0) & (pred == 1.0))[0]  # True Positives
    FP_i = np.where((Y_true == 0.0) & (pred == 1.0))[0]  # False Positives
    TN_i = np.where((Y_true == 0.0) & (pred == 0.0))[0]  # True Negatives
    FN_i = np.where((Y_true == 1.0) & (pred == 0.0))[0]  # False Negatives
    return TP_i,FP_i,TN_i,FN_i

def get_ranks(shap_vals):

    # avg_shap = np.mean(np.abs(shap_vals), axis=0)
    # Compute rankings row-wise
    sorted_indices = np.argsort(-np.abs(shap_vals), axis=1)  # Indices of absolute values in descending order
    rank = np.empty_like(sorted_indices)             # Create an empty array of the same shape
    rows, cols = shap_vals.shape                            # Get the shape of the array
    rank[np.arange(rows)[:, None], sorted_indices] = np.arange(1, cols + 1)  # Assign ranks row-wise

    target_rank = rank[:,0]

    return rank,target_rank

def compute_fidelity(pred, sv, base):

    sv_sums = expit(np.sum(sv, axis=1)+base)
    binary_predictions = (sv_sums > 0.5).astype(float)
    fidelity = np.mean(binary_predictions == pred)
    match_idx = np.where(binary_predictions == pred)[0]
        
    return fidelity,match_idx

## Train the model with plain Age

In [83]:
# base_models = list()
base_shap_vals = list()
base_preds = list()
base_accs = list() 
base_f1s = list()
base_ranks = list()
base_age_ranks = list()

base_tp_idx = list()
base_fp_idx = list()
base_tn_idx = list()
base_fn_idx = list()

base_tp_age_ranks = list()
base_fp_age_ranks = list()
base_tn_age_ranks = list()
base_fn_age_ranks = list()

base_firsts = list()
base_percentages = list()
for seed in tqdm(seeds):
    np.random.seed(seed)
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=seed)
    for train_val_idx, test_idx in splitter.split(X, Y):
        X_train_val, X_test = X.iloc[train_val_idx], X.iloc[test_idx]
        Y_train_val, Y_test = Y.iloc[train_val_idx], Y.iloc[test_idx]
    
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.25, random_state=seed)
    for train_idx, val_idx in splitter.split(X_train_val, Y_train_val):
        X_train, X_val = X_train_val.iloc[train_idx], X_train_val.iloc[val_idx]
        Y_train, Y_val = Y_train_val.iloc[train_idx], Y_train_val.iloc[val_idx]
    param_grid = {
        'classifier__n_estimators': [50, 100, 200],  # Number of boosting rounds
        'classifier__max_depth': [3, 5, 7,9,11],          # Maximum tree depth
        'classifier__learning_rate': [0.01, 0.1, 0.2],  # Step size shrinkage 
        'classifier__colsample_bytree': [0.8, 1.0],  # Subsample ratio of columns for each tree
        'classifier__gamma': [0, 0.1, 0.2],          # Minimum loss reduction for a split
    }
    model = xgb.XGBClassifier(random_state=seed)
    grid_search = GridSearchCV(
        model, 
        param_grid,              # 3-fold cross-validation
        scoring='f1',   # Evaluation metric
        n_jobs=-1,            # Use all processors
        verbose=1             # Print progress
    )

    grid_search.fit(X_train, Y_train)
    
    # Extract the best model
    best_model = grid_search.best_estimator_
    explainer = GPUTreeExplainer(best_model,X_train, feature_perturbation='interventional') 
    shap_values = explainer(X_test)
    
    sv = shap_values.values
    base_rank,age_rank= get_ranks(sv)
    base_ranks.append(base_rank)
    
    pred = best_model.predict(X_test)
    # base_models.append(best_model)
    base_shap_vals.append(shap_values)
    base_preds.append(pred)
    base_accs.append(accuracy_score(Y_test,pred)*100)
    base_f1s.append(f1_score(Y_test,pred)*100)
    base_ranks.append(base_rank)
    base_age_ranks.append(age_rank)

    # Errors
    tp,fp,tn,fn = get_tfs(sv, Y_test,pred)
    tp_rank, tp_age_rank = get_ranks(sv[tp])
    fp_rank, fp_age_rank = get_ranks(sv[fp])
    tn_rank, tn_age_rank = get_ranks(sv[tn])
    fn_rank, fn_age_rank = get_ranks(sv[fn])
    base_tp_idx.append(tp)
    base_fp_idx.append(fp)
    base_tn_idx.append(tn)
    base_fn_idx.append(fn)

    base_tp_age_ranks.append(tp_age_rank)
    base_fp_age_ranks.append(fp_age_rank)
    base_tn_age_ranks.append(tn_age_rank)
    base_fn_age_ranks.append(fn_age_rank)
    
    # First

    ## Indices of first
    first = [int(j) for j,v in enumerate(age_rank) if v == 1]
    base_firsts.append(first)
    base_percentages.append(len(first)/len(X_test) )

print(f'Overall average acc: {sum(base_accs)/len(base_accs):.2f} average f1s : {sum(base_f1s)/len(base_f1s):.2f}')


  0%|                                                       | 0/5 [00:00<?, ?it/s]

Fitting 5 folds for each of 270 candidates, totalling 1350 fits


 20%|█████████▍                                     | 1/5 [00:15<01:03, 15.89s/it]

Fitting 5 folds for each of 270 candidates, totalling 1350 fits


 40%|██████████████████▊                            | 2/5 [00:26<00:38, 12.82s/it]

Fitting 5 folds for each of 270 candidates, totalling 1350 fits


 60%|████████████████████████████▏                  | 3/5 [00:37<00:23, 11.89s/it]

Fitting 5 folds for each of 270 candidates, totalling 1350 fits


 80%|█████████████████████████████████████▌         | 4/5 [00:47<00:11, 11.23s/it]

Fitting 5 folds for each of 270 candidates, totalling 1350 fits


100%|███████████████████████████████████████████████| 5/5 [00:57<00:00, 11.53s/it]

Overall average acc: 79.50 average f1s : 76.63





In [49]:
len(first)

2407

In [84]:
import pickle
from pathlib import Path
path = './results'
if not os.path.exists(path):
   # Create a new directory because it does not exist
   os.makedirs(path)

# # save
with open(path + '/Sens_Income_base_shap_vals_cv.pickle', 'wb') as f:
    pickle.dump(base_shap_vals, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_base_accs_cv.pickle', 'wb') as f:
    pickle.dump(base_accs, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_base_f1s_cv.pickle', 'wb') as f:
    pickle.dump(base_f1s, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_base_age_ranks_cv.pickle', 'wb') as f:
    pickle.dump(base_age_ranks, f, pickle.HIGHEST_PROTOCOL)


with open(path + '/Sens_Income_base_tp_idx_cv.pickle', 'wb') as f:
    pickle.dump(base_tp_idx, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_base_fp_idx_cv.pickle', 'wb') as f:
    pickle.dump(base_fp_idx, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_base_tn_idx_cv.pickle', 'wb') as f:
    pickle.dump(base_tn_idx, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_base_fn_idx_cv.pickle', 'wb') as f:
    pickle.dump(base_fn_idx, f, pickle.HIGHEST_PROTOCOL)

    
with open(path + '/Sens_Income_base_tp_age_ranks_cv.pickle', 'wb') as f:
    pickle.dump(base_tp_age_ranks, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_base_fp_age_ranks_cv.pickle', 'wb') as f:
    pickle.dump(base_fp_age_ranks, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_base_tn_age_ranks_cv.pickle', 'wb') as f:
    pickle.dump(base_tn_age_ranks, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_base_fn_age_ranks_cv.pickle', 'wb') as f:
    pickle.dump(base_fn_age_ranks, f, pickle.HIGHEST_PROTOCOL)

with open(path + '/Sens_Income_base_firsts_cv.pickle', 'wb') as f:
    pickle.dump(base_firsts, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_base_percentages_cv.pickle', 'wb') as f:
    pickle.dump(base_percentages, f, pickle.HIGHEST_PROTOCOL)

In [51]:
# def get_first_rank_shift(base_ranks, rank_difs):
#     first_rank_difs = list()
#     first_indices = list()
#     for idx,base_rank in enumerate(base_ranks):
#         age_rank = base_rank[:,0]
#         first = [int(j) for j in range(len(age_rank)) if int(age_rank[j]) == 1]
#         first_rank_dif = [int(rank_difs[idx][k]) for k in first]
#         first_rank_difs.append(first_rank_dif)
#         first_indices.append(first)
#     return first_rank_difs, first_indices



# def compute_average_counts(l,decrease, etc=False):
#     if not etc:
#         s = sum([1 for r in l if r ==decrease])
#     else:
#         s = sum([1 for r in l if r <= decrease])
#     percent = s/len(l)
#     return percent
    
# def down_percent(rank_shifts):
#     k_downs = list()
#     for shifts in rank_shifts:
#         downs = list()
#         for decrease in range(-1,-5,-1):
#             down_percent = compute_average_counts(shifts,decrease, decrease == -4)
#             downs.append(down_percent)
#         k_downs.append(downs)
#         # print(f'number of shifts {len(shifts)}')
    
#     return np.mean(k_downs,axis = 0),np.std(k_downs,axis=0)

# def get_bucket_idx(age, bin_edges):
#     for idx in range(len(bin_edges)-1):
#         if age == bin_edges[-1]:
#             bin_idx = len(bin_edges) - 1
#         elif bin_edges[idx] <= age and age < bin_edges[idx+1]:
#             bin_idx = idx
#             break
#     return bin_idx
        

# def get_all_bucket_idx(X, edges):
#     X2 = X.copy()
#     X2['Age'] = X2['Age'].apply(lambda age: assign_age(age, edges)) 
 
#     np.random.seed(0)
#     _, X_test, _, _ = train_test_split(X2, Y, test_size=0.3, random_state=0)
#     bucket_indices = list()
#     for age in X_test['Age']:
#         cur_bucket_idx = get_bucket_idx(age, edges)
#         bucket_indices.append(cur_bucket_idx)
#     return bucket_indices
        

# Equi Depth

In [53]:
ed_fids = list()
ed_preds = list()
ed_ranks = list()
ed_age_ranks = list()
ed_shap_vals = list()
ed_rank_difs = list()

ed_tp_age_ranks = list()
ed_fp_age_ranks = list()
ed_tn_age_ranks = list()
ed_fn_age_ranks = list()

ed_firsts = list()
ed_percentages = list()
ed_first_rank_difs = list()

for bucket in range(2,21):
    b_fids = list()
    b_preds = list()
    b_ranks = list()
    b_age_ranks = list()
    b_shap_vals = list()
    b_rank_difs = list()

    b_tp_age_ranks = list()
    b_fp_age_ranks = list()
    b_tn_age_ranks = list()
    b_fn_age_ranks = list()

    b_firsts = list()
    b_percentages = list()
    b_first_rank_difs = list()
    for i, seed in enumerate(seeds):
        
        X2 = X.copy()
        bucket_edge = get_bins('quantile',bucket,seed)
        X2['Age'] = X2['Age'].apply(lambda age: assign_age(age, bucket_edge)) 
        np.random.seed(seed)
        splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=seed)
        for train_val_idx, test_idx in splitter.split(X2, Y):
            X_train_val, X_test = X2.iloc[train_val_idx], X2.iloc[test_idx]
            Y_train_val, Y_test = Y.iloc[train_val_idx], Y.iloc[test_idx]
        
        splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.25, random_state=seed)
        for train_idx, val_idx in splitter.split(X_train_val, Y_train_val):
            X_train, X_val = X_train_val.iloc[train_idx], X_train_val.iloc[val_idx]
            Y_train, Y_val = Y_train_val.iloc[train_idx], Y_train_val.iloc[val_idx] 

        
        # compute bin edges as shap values for bucketized data
        
        bucket_shap,bucket_pred = compute_shap(X_train,X_test,Y_train,Y_test,seed)
        sv = bucket_shap.values
        ed_rank, ed_age_rank = get_ranks(sv)
        # Compute fidelity and indices where explanation is the same.

        preds = base_preds[i]
        sv = bucket_shap.values
        base = bucket_shap.base_values
        ed_fid, ed_agreed = compute_fidelity(preds,sv,base)

        b_fids.append(ed_fid)
        b_preds.append(bucket_pred)
        b_ranks.append(ed_rank)
        b_age_ranks.append(ed_age_rank)
        b_shap_vals.append(bucket_shap)

        
        # Compuute rank shift
        base_age_rank = base_age_ranks[i]
        rank_dif = [r1-r2 for r1,r2 in zip(base_age_rank,ed_age_rank)]
        b_rank_difs.append(rank_dif)

        # Errors
        tp,fp,tn,fn = get_tfs(sv, Y_test,pred)
        tp_rank, tp_age_rank = get_ranks(sv[tp])
        fp_rank, fp_age_rank = get_ranks(sv[fp])
        tn_rank, tn_age_rank = get_ranks(sv[tn])
        fn_rank, fn_age_rank = get_ranks(sv[fn])
        b_tp_age_ranks.append(tp_rank)
        b_fp_age_ranks.append(fp_rank)
        b_tn_age_ranks.append(tn_rank)
        b_fn_age_ranks.append(fn_rank)
        
        # First
    
        ## Indices of first
        first = [int(j) for j,v in enumerate(ed_age_rank) if v == 1]
        first_rank_dif = [rank_dif[idx] for idx in first]
        
        b_firsts.append(first)
        b_percentages.append(len(first)/len(X_test)*100 )
        b_first_rank_difs.append(first_rank_dif)
        
    ed_fids.append(b_fids)
    ed_preds.append(b_preds)
    ed_ranks.append(b_ranks)
    ed_age_ranks.append(b_age_ranks)
    ed_shap_vals.append(b_shap_vals)
    ed_rank_difs.append(b_rank_difs)

    ed_tp_age_ranks.append(b_tp_age_ranks)
    ed_fp_age_ranks.append(b_fp_age_ranks)
    ed_tn_age_ranks.append(b_tn_age_ranks)
    ed_fn_age_ranks.append(b_fn_age_ranks)

    ed_firsts.append(b_firsts)
    ed_percentages.append(b_percentages)
    ed_first_rank_difs.append(b_first_rank_difs)
    # print(f'Average fidelity {sum(cut5_fidelities)/len(cut5_fidelities)*100:.3f}')


**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********

In [54]:
# # save
with open(path + '/Sens_Income_ed_fids_cv.pickle', 'wb') as f:
    pickle.dump(ed_fids, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_ed_ranks_cv.pickle', 'wb') as f:
    pickle.dump(ed_ranks, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_ed_age_ranks_cv.pickle', 'wb') as f:
    pickle.dump(ed_age_ranks, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_ed_shap_vals_cv.pickle', 'wb') as f:
    pickle.dump(ed_shap_vals, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_ed_rank_difs_cv.pickle', 'wb') as f:
    pickle.dump(ed_rank_difs, f, pickle.HIGHEST_PROTOCOL)

with open(path + '/Sens_Income_ed_tp_ranks_ranks_cv.pickle', 'wb') as f:
    pickle.dump(ed_tp_age_ranks, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_ed_fp_age_ranks_cv.pickle', 'wb') as f:
    pickle.dump(ed_fp_age_ranks, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_ed_tn_age_ranks_cv.pickle', 'wb') as f:
    pickle.dump(ed_tn_age_ranks, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_ed_fn_age_ranks_cv.pickle', 'wb') as f:
    pickle.dump(ed_fn_age_ranks, f, pickle.HIGHEST_PROTOCOL)

with open(path + '/Sens_Income_ed_firsts_cv.pickle', 'wb') as f:
    pickle.dump(ed_firsts, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_ed_percentages_cv.pickle', 'wb') as f:
    pickle.dump(ed_percentages, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_ed_first_rank_difss_cv.pickle', 'wb') as f:
    pickle.dump(ed_first_rank_difs, f, pickle.HIGHEST_PROTOCOL)

 # Equi Width

In [77]:
ew_fids = list()
ew_preds = list()
ew_ranks = list()
ew_age_ranks = list()
ew_shap_vals = list()
ew_rank_difs = list()

ew_tp_age_ranks = list()
ew_fp_age_ranks = list()
ew_tn_age_ranks = list()
ew_fn_age_ranks = list()

ew_firsts = list()
ew_percentages = list()
ew_first_rank_difs = list()

for bucket in range(2,21):
    b_fids = list()
    b_preds = list()
    b_ranks = list()
    b_age_ranks = list()
    b_shap_vals = list()
    b_rank_difs = list()

    b_tp_age_ranks = list()
    b_fp_age_ranks = list()
    b_tn_age_ranks = list()
    b_fn_age_ranks = list()

    b_firsts = list()
    b_percentages = list()
    b_first_rank_difs = list()
    for i, seed in enumerate(seeds):
        
        X2 = X.copy()
        bucket_edge = get_bins('uniform',bucket,seed)
        X2['Age'] = X2['Age'].apply(lambda age: assign_age(age, bucket_edge)) 
        np.random.seed(seed)
        splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=seed)
        for train_val_idx, test_idx in splitter.split(X2, Y):
            X_train_val, X_test = X2.iloc[train_val_idx], X2.iloc[test_idx]
            Y_train_val, Y_test = Y.iloc[train_val_idx], Y.iloc[test_idx]
        
        splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.25, random_state=seed)
        for train_idx, val_idx in splitter.split(X_train_val, Y_train_val):
            X_train, X_val = X_train_val.iloc[train_idx], X_train_val.iloc[val_idx]
            Y_train, Y_val = Y_train_val.iloc[train_idx], Y_train_val.iloc[val_idx] 

        
        # compute bin edges as shap values for bucketized data
        
        bucket_shap,bucket_pred = compute_shap(X_train,X_test,Y_train,Y_test,seed)
        sv = bucket_shap.values
        ew_rank, ew_age_rank = get_ranks(sv)
        # Compute fidelity and indices where explanation is the same.

        preds = base_preds[i]
        sv = bucket_shap.values
        base = bucket_shap.base_values
        ew_fid, ew_agreed = compute_fidelity(preds,sv,base)

        b_fids.append(ew_fid)
        b_preds.append(bucket_pred)
        b_ranks.append(ew_rank)
        b_age_ranks.append(ew_age_rank)
        b_shap_vals.append(bucket_shap)

        
        # Compuute rank shift
        base_age_rank = base_age_ranks[i]
        rank_dif = [r1-r2 for r1,r2 in zip(base_age_rank,ew_age_rank)]
        b_rank_difs.append(rank_dif)

        # Errors
        tp,fp,tn,fn = get_tfs(sv, Y_test,pred)
        tp_rank, tp_age_rank = get_ranks(sv[tp])
        fp_rank, fp_age_rank = get_ranks(sv[fp])
        tn_rank, tn_age_rank = get_ranks(sv[tn])
        fn_rank, fn_age_rank = get_ranks(sv[fn])
        b_tp_age_ranks.append(tp_rank)
        b_fp_age_ranks.append(fp_rank)
        b_tn_age_ranks.append(tn_rank)
        b_fn_age_ranks.append(fn_rank)

        # First
    
        ## Indices of first
        first = [int(j) for j,v in enumerate(ew_age_rank) if v == 1]
        first_rank_dif = [rank_dif[idx] for idx in first]
        
        b_firsts.append(first)
        b_percentages.append(len(first)/len(X_test)*100 )
        b_first_rank_difs.append(first_rank_dif)
        
    ew_fids.append(b_fids)
    ew_preds.append(b_preds)
    ew_ranks.append(b_ranks)
    ew_age_ranks.append(b_age_ranks)
    ew_shap_vals.append(b_shap_vals)
    ew_rank_difs.append(b_rank_difs)

    ew_tp_age_ranks.append(b_tp_age_ranks)
    ew_fp_age_ranks.append(b_fp_age_ranks)
    ew_tn_age_ranks.append(b_tn_age_ranks)
    ew_fn_age_ranks.append(b_fn_age_ranks)

    ew_firsts.append(b_firsts)
    ew_percentages.append(b_percentages)
    ew_first_rank_difs.append(b_first_rank_difs)
    # print(f'Average fidelity {sum(cut5_fidelities)/len(cut5_fidelities)*100:.3f}')


**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********START**********
Fitting 5 folds for each of 270 candidates, totalling 1350 fits
**********

In [78]:
# # save
with open(path + '/Sens_Income_ew_fids_cv.pickle', 'wb') as f:
    pickle.dump(ew_fids, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_ew_ranks_cv.pickle', 'wb') as f:
    pickle.dump(ew_ranks, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_ew_age_ranks_cv.pickle', 'wb') as f:
    pickle.dump(ew_age_ranks, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_ew_shap_vals_cv.pickle', 'wb') as f:
    pickle.dump(ew_shap_vals, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_ew_rank_difs_cv.pickle', 'wb') as f:
    pickle.dump(ew_rank_difs, f, pickle.HIGHEST_PROTOCOL)

with open(path + '/Sens_Income_ew_tp_age_ranks_ranks_cv.pickle', 'wb') as f:
    pickle.dump(ew_tp_age_ranks, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_ew_fp_age_ranks_cv.pickle', 'wb') as f:
    pickle.dump(ew_fp_age_ranks, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_ew_tn_age_ranks_cv.pickle', 'wb') as f:
    pickle.dump(ew_tn_age_ranks, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_ew_fn_age_ranks_cv.pickle', 'wb') as f:
    pickle.dump(ew_fn_age_ranks, f, pickle.HIGHEST_PROTOCOL)

with open(path + '/Sens_Income_ew_firsts_cv.pickle', 'wb') as f:
    pickle.dump(ew_firsts, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_ew_percentages_cv.pickle', 'wb') as f:
    pickle.dump(ew_percentages, f, pickle.HIGHEST_PROTOCOL)
with open(path + '/Sens_Income_ew_first_rank_difs_cv.pickle', 'wb') as f:
    pickle.dump(ew_first_rank_difs, f, pickle.HIGHEST_PROTOCOL)

# Bucket Boundaries

In [65]:
import math
bucket = 10
seed = seeds[0]  
X2 = X.copy()
bucket_edge = get_bins('uniform',bucket,seed)
bucket_edge = [math.ceil(e) for e in bucket_edge]

with open(path + '/Sens_Income_bucket_edge_cv.pickle', 'wb') as f:
    pickle.dump(bucket_edge, f, pickle.HIGHEST_PROTOCOL)

In [67]:
np.random.seed(seed)
splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=seed)
for train_val_idx, test_idx in splitter.split(X, Y):
    X_train_val, X_test = X.iloc[train_val_idx], X.iloc[test_idx]
    Y_train_val, Y_test = Y.iloc[train_val_idx], Y.iloc[test_idx]
    
with open(path + '/Sens_Income_X_test_cv.pickle', 'wb') as f:
    pickle.dump(X_test, f, pickle.HIGHEST_PROTOCOL)


In [18]:
# base_first_freq = dict()
# for i in bf:
#     b_num = X_test_bucket_indices[i]
#     if b_num not in base_first_freq:
#         base_first_freq[b_num] = 1
#     else:
#         base_first_freq[b_num] += 1

In [19]:
# down_all_freq_dict = dict()
# up_all_freq_dict = dict()
# same_all_freq_dict = dict()
# # for i,rank_dif in enumerate(first_rank_shifts):
# for i,rank_dif in enumerate(first_rank_shifts):
#     for j, dif in enumerate(rank_dif):
#         real_idx = bf[j]
#         bucket_idx = X_test_bucket_indices[real_idx]
#         if dif < 0:
            
#             if bucket_idx not in down_all_freq_dict:
#                 down_all_freq_dict[bucket_idx] = 1
#             else:
#                 down_all_freq_dict[bucket_idx] += 1
#         elif dif > 0:
#             if bucket_idx not in up_all_freq_dict:
#                 up_all_freq_dict[bucket_idx] = 1
#             else:
#                 up_all_freq_dict[bucket_idx] += 1
#         else:
#             if bucket_idx not in same_all_freq_dict:
#                 same_all_freq_dict[bucket_idx] = 1
#             else:
#                 same_all_freq_dict[bucket_idx] += 1
#     break

In [20]:
# # plot the total distribution
# plt.bar(list(base_first_freq.keys()), list(base_first_freq.values()))
# plt.xticks(list(range(len(cut10_bin))), cut10_bin)
# plt.xlabel('Buckets')
# plt.xlabel('Counts')
# plt.title('Distribution of Age Buckets for First ')

In [21]:
# with open('Attack_Income_first_down_all_freq_dict.pickle', 'wb') as f:
#     pickle.dump(down_all_freq_dict, f, pickle.HIGHEST_PROTOCOL)
# with open('Attack_Income_first_same_all_freq_dict.pickle', 'wb') as f:
#     pickle.dump(same_all_freq_dict, f, pickle.HIGHEST_PROTOCOL)

In [22]:
# plt.bar(list(down_all_freq_dict.keys()), list(down_all_freq_dict.values()))
# down_cnt = sum([v for v in down_all_freq_dict.values()])
# plt.xticks(list(range(len(cut10_bin))), cut10_bin)
# plt.title('Rank Demotion for buckets')
# plt.xlabel('Buckets')
# plt.xlabel('Counts')
# print(down_cnt)
# print(f'precent {down_cnt/len(bf)*100:.2f}')

In [23]:
# plt.bar(list(up_all_freq_dict.keys()), list(up_all_freq_dict.values()))
# up_cnt = sum([v for v in up_all_freq_dict.values()])
# plt.xticks(list(range(len(cut10_bin))), cut10_bin)
# plt.title('Rank Promotion for buckets')
# plt.xlabel('Buckets')
# plt.xlabel('Counts')
# print(up_cnt)
# print(f'precent {up_cnt/len(X_test)*100:.2f}')

In [24]:
# plt.bar(list(same_all_freq_dict.keys()), list(same_all_freq_dict.values()))
# same_cnt = sum([v for v in same_all_freq_dict.values()])
# plt.xticks(list(range(len(cut10_bin))), cut10_bin)
# plt.title('No Rank Shift for Buckets')
# plt.xlabel('Buckets')
# plt.xlabel('Counts')
# print(same_cnt)
# print(f'precent {same_cnt/len(bf)*100:.2f}')

In [25]:
# base_freq = dict()
# for i in X_test_bucket_indices:
#     if i not in base_freq:
#         base_freq[i] = 1
#     else:
#         base_freq[i] += 1

In [26]:
# # plot the total distribution
# plt.bar(list(base_freq.keys()), list(base_freq.values()))
# plt.xticks(list(range(len(cut10_bin))), cut10_bin)
# plt.xlabel('Buckets')
# plt.xlabel('Counts')
# plt.title('Distribution of Age Buckets')

In [27]:
# down_all_freq_dict = dict()
# up_all_freq_dict = dict()
# same_all_freq_dict = dict()
# for i,rank_dif in enumerate(rank_difs):
#     for j, dif in enumerate(rank_dif):
#         bucket_idx = X_test_bucket_indices[j]
#         if dif < 0:
            
#             if bucket_idx not in down_all_freq_dict:
#                 down_all_freq_dict[bucket_idx] = 1
#             else:
#                 down_all_freq_dict[bucket_idx] += 1
#         elif dif > 0:
#             if bucket_idx not in up_all_freq_dict:
#                 up_all_freq_dict[bucket_idx] = 1
#             else:
#                 up_all_freq_dict[bucket_idx] += 1
#         else:
#             if bucket_idx not in same_all_freq_dict:
#                 same_all_freq_dict[bucket_idx] = 1
#             else:
#                 same_all_freq_dict[bucket_idx] += 1
            
            
# with open('Attack_Income_cut5_bins.pickle', 'wb') as f:
#     pickle.dump(cut10_bin, f, pickle.HIGHEST_PROTOCOL)
# with open('Attack_Income_down_all_freq_dict.pickle', 'wb') as f:
#     pickle.dump(down_all_freq_dict, f, pickle.HIGHEST_PROTOCOL)
# with open('Attack_Income_up_all_freq_dict.pickle', 'wb') as f:
#     pickle.dump(up_all_freq_dict, f, pickle.HIGHEST_PROTOCOL)
# with open('Attack_Income_same_all_freq_dict.pickle', 'wb') as f:
#     pickle.dump(same_all_freq_dict, f, pickle.HIGHEST_PROTOCOL)

In [28]:
# plt.bar(list(down_all_freq_dict.keys()), list(down_all_freq_dict.values()))
# down_cnt = sum([v for v in down_all_freq_dict.values()])
# plt.xticks(list(range(len(cut10_bin))), cut10_bin)
# plt.title('Rank Demotion for buckets')
# plt.xlabel('Buckets')
# plt.xlabel('Counts')
# print(down_cnt)
# print(f'precent {down_cnt/len(X_test)*100:.2f}')

In [29]:
# plt.bar(list(up_all_freq_dict.keys()), list(up_all_freq_dict.values()))
# up_cnt = sum([v for v in up_all_freq_dict.values()])
# plt.xticks(list(range(len(cut10_bin))), cut10_bin)
# plt.title('Rank Promotion for buckets')
# plt.xlabel('Buckets')
# plt.xlabel('Counts')
# print(up_cnt)
# print(f'precent {up_cnt/len(X_test)*100:.2f}')

In [30]:
# plt.bar(list(same_all_freq_dict.keys()), list(same_all_freq_dict.values()))
# same_cnt = sum([v for v in same_all_freq_dict.values()])
# plt.xticks(list(range(len(cut10_bin))), cut10_bin)
# plt.title('No Rank Shift for Buckets')
# plt.xlabel('Buckets')
# plt.xlabel('Counts')
# print(same_cnt)
# print(f'precent {same_cnt/len(X_test)*100:.2f}')

# examples

In [31]:
# # examples
# first_dict = dict()
# for idx,shift in enumerate(examples[0][0][0]):
#     if shift not in first_dict:
#         first_dict[shift] = [examples[0][1][0][idx]]
#     else:
#         first_dict[shift].append(examples[0][1][0][idx])
# smallest = min(list(first_dict.keys()))
# print(smallest)
# for it,i in enumerate(first_dict[smallest]):
#     shap.plots.bar(base_result['shap_vals'][0][i],max_display=22)
# for it,i in enumerate(first_dict[smallest]):
#     shap.plots.bar(example_results[0]['shap_vals'][0][i],max_display=22)   

In [32]:
# # TPS
# first_dict = dict()
# for idx,shift in enumerate(examples[1][0][0]):
#     if shift not in first_dict:
#         first_dict[shift] = [examples[1][1][0][idx]]
#     else:
#         first_dict[shift].append(examples[1][1][0][idx])
# smallest = min(list(first_dict.keys()))
# print(smallest)
# for it,i in enumerate(first_dict[smallest]):
#     shap.plots.bar(base_result['shap_vals'][0][tps[0][i]],max_display=22)
# for it,i in enumerate(first_dict[smallest]):
#     shap.plots.bar(example_results[0]['shap_vals'][0][tps[0][i]],max_display=22)  

In [33]:
# # fPS
# first_dict = dict()
# for idx,shift in enumerate(examples[2][0][0]):
#     if shift not in first_dict:
#         first_dict[shift] = [examples[2][1][0][idx]]
#     else:
#         first_dict[shift].append(examples[2][1][0][idx])
# smallest = min(list(first_dict.keys()))
# print(smallest)
# for it,i in enumerate(first_dict[smallest]):
#     shap.plots.bar(base_result['shap_vals'][0][fps[0][i]],max_display=22)
# for it,i in enumerate(first_dict[smallest]):
#     shap.plots.bar(example_results[0]['shap_vals'][0][fps[0][i]],max_display=22)  

In [34]:
# # tnS
# first_dict = dict()
# for idx,shift in enumerate(examples[3][0][0]):
#     if shift not in first_dict:
#         first_dict[shift] = [examples[3][1][0][idx]]
#     else:
#         first_dict[shift].append(examples[3][1][0][idx])
# smallest = min(list(first_dict.keys()))
# print(smallest)
# for it,i in enumerate(first_dict[smallest]):
#     shap.plots.bar(base_result['shap_vals'][0][tns[0][i]],max_display=22)
# for it,i in enumerate(first_dict[smallest]):
#     shap.plots.bar(example_results[0]['shap_vals'][0][tns[0][i]],max_display=22)  

In [35]:
# # fnS
# first_dict = dict()
# for idx,shift in enumerate(examples[4][0][0]):
#     if shift not in first_dict:
#         first_dict[shift] = [examples[4][1][0][idx]]
#     else:
#         first_dict[shift].append(examples[4][1][0][idx])
# smallest = min(list(first_dict.keys()))
# print(smallest)
# for it,i in enumerate(first_dict[smallest]):
#     shap.plots.bar(base_result['shap_vals'][0][fns[0][i]],max_display=22)
# for it,i in enumerate(first_dict[smallest]):
#     shap.plots.bar(example_results[0]['shap_vals'][0][fns[0][i]],max_display=22)  

In [36]:
# means = [m for m,s in ew_bucket_downs]
# stds = [s for m,s in ew_bucket_downs]

# decrease_bin = list()
# for decrease_idx in range(4):
#     m = [mean[decrease_idx] for mean in means] 
#     decrease_bin.append(m)

# colors = ['r','b','g','c','m']
# p_labels = ['Rank 2','Rank 3','Rank 4', 'Rank 5 or below']

# for i in range(len(decrease_bin)):
#     plt.plot(list(range(2,21)), decrease_bin[i],color=colors[i],label=p_labels[i],ls='-')

# plt.xticks(range(2,21))
# plt.title('Rank shifts where Rank of Age is 1')
# plt.xlabel('Number of Buckets')
# plt.xlabel('Proportion')
# plt.legend()

In [37]:


# with open('Attack_Income_ew_downs_tp.pickle', 'wb') as f:
#     pickle.dump(ew_bucket_tp_downs, f, pickle.HIGHEST_PROTOCOL)
# with open('Attack_Income_ew_downs_fp.pickle', 'wb') as f:
#     pickle.dump(ew_bucket_fp_downs, f, pickle.HIGHEST_PROTOCOL)
# with open('Attack_Income_ew_downs_tn.pickle', 'wb') as f:
#     pickle.dump(ew_bucket_tn_downs, f, pickle.HIGHEST_PROTOCOL)
# with open('Attack_Income_ew_downs_fn.pickle', 'wb') as f:
#     pickle.dump(ew_bucket_fn_downs, f, pickle.HIGHEST_PROTOCOL)

In [38]:
# means = [m for m,s in ew_bucket_tp_downs]
# stds = [s for m,s in ew_bucket_tp_downs]

# decrease_bin = list()
# for decrease_idx in range(4):
#     m = [mean[decrease_idx] for mean in means] 
#     decrease_bin.append(m)

# colors = ['r','b','g','c','m']
# p_labels = ['Rank 2','Rank 3','Rank 4', 'Rank 5 or below']

# for i in range(len(decrease_bin)):
#     plt.plot(list(range(2,21)), decrease_bin[i],color=colors[i],label=p_labels[i],ls='-')

# plt.xticks(range(2,21))
# plt.title('TP Rank shifts where Rank of Age is 1')
# plt.xlabel('Number of Buckets')
# plt.xlabel('Proportion')
# plt.legend()

In [39]:
# means = [m for m,s in ew_bucket_fp_downs]
# stds = [s for m,s in ew_bucket_fp_downs]

# decrease_bin = list()
# for decrease_idx in range(4):
#     m = [mean[decrease_idx] for mean in means] 
#     decrease_bin.append(m)

# colors = ['r','b','g','c','m']
# p_labels = ['Rank 2','Rank 3','Rank 4', 'Rank 5 or below']

# for i in range(len(decrease_bin)):
#     plt.plot(list(range(2,21)), decrease_bin[i],color=colors[i],label=p_labels[i],ls='-')

# plt.xticks(range(2,21))
# plt.title('FP Rank shifts where Rank of Age is 1')
# plt.xlabel('Number of Buckets')
# plt.xlabel('Proportion')
# plt.legend()

In [40]:
# means = [m for m,s in ew_bucket_tn_downs]
# stds = [s for m,s in ew_bucket_tn_downs]

# decrease_bin = list()
# for decrease_idx in range(4):
#     m = [mean[decrease_idx] for mean in means] 
#     decrease_bin.append(m)

# colors = ['r','b','g','c','m']
# p_labels = ['Rank 2','Rank 3','Rank 4', 'Rank 5 or below']

# for i in range(len(decrease_bin)):
#     plt.plot(list(range(2,21)), decrease_bin[i],color=colors[i],label=p_labels[i],ls='-')

# plt.xticks(range(2,21))
# plt.title('TN Rank shifts where Rank of Age is 1')
# plt.xlabel('Number of Buckets')
# plt.xlabel('Proportion')
# plt.legend()

In [41]:
# means = [m for m,s in ew_bucket_fn_downs]
# stds = [s for m,s in ew_bucket_fn_downs]

# decrease_bin = list()
# for decrease_idx in range(4):
#     m = [mean[decrease_idx] for mean in means] 
#     decrease_bin.append(m)

# colors = ['r','b','g','c','m']
# p_labels = ['Rank 2','Rank 3','Rank 4', 'Rank 5 or below']

# for i in range(len(decrease_bin)):
#     plt.plot(list(range(2,21)), decrease_bin[i],color=colors[i],label=p_labels[i],ls='-')

# plt.xticks(range(2,21))
# plt.title('FN Rank shifts where Rank of Age is 1')
# plt.xlabel('Number of Buckets')
# plt.xlabel('Proportion')
# plt.legend()