In [1]:
import os
import sys
sys.path.append('../')
import hashlib
import json
from pathlib import Path
import copy

import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_auc_score, auc, roc_curve
from sklearn import metrics
import warnings
# NB: Warnings occur when computing metrics for groups with
# low sample sizes. In our case, we don't use the metrics 
# for these groups.
warnings.filterwarnings("ignore")

from compute_metrics import Model
from lib.TabularDataset import dataset_params
from lib import TabularDataset
from result_latex_utils import maxPairDiff, agg_func, bold_max


res_dir = Path('/scratch/ssd001/home/aparna/explanations-subpopulations/output')

In [2]:
# retain column names
retain_cols=False

In [3]:
df_raw = []
for path in res_dir.glob('**/*.csv'):
    args = json.load((path.parent/'args.json').open('r'))
    # all settings for paper table 1
    if args['explanation_type']=='local' and args['dataset'] in ['adult','compas_balanced','lsac','mimic_tab'] and 'balance_groups' in args.keys() and args['explanation_model'] in ['lime','shap_blackbox_preprocessed',
    'shap_blackbox_preprocessed']:
        hparams = [i for i in list(args.keys()) if i not in ['seed', 'output_dir']]
        args['hparam_id'] = hashlib.md5(str([args[i] for i in hparams]).encode('utf-8')).hexdigest()
        res = pd.read_csv(path)
        
        # get performance on test set
        res = res[res.set=='test']

        if res.shape[0]>0:
            dataset = TabularDataset.Dataset(args['dataset'])
            X_train, X_train_expl, X_val_expl, X_test, y_train, y_train_expl, y_val_expl, y_test, g_train, g_train_expl, g_val_expl, g_test = dataset.get_data(
            retain_cols=retain_cols)
            
            # to check that the whole test set is present
            assert res.shape[0]==g_test.shape[0]
            res[dataset_params[args['dataset']].sensitive_attributes] = g_test.values
                            
            if sum(~np.isfinite(res['expl_pred'].values))>0:
                raise ValueError('For dataset {}, model {}'.format(args['dataset'],
                                                                  args['explanation_model']))
                
            # compute metrics over all items in test set
            args['accuracy_all'] = ((res['expl_pred'].values >= 0.5) == res[
                                          'blackbox_pred']).sum() / len(res)
            fpr, tpr, thresholds = roc_curve(res['blackbox_pred'].values, res['expl_pred'].values
                                             , pos_label=1)
            roc = metrics.auc(fpr, tpr)
            args['auroc_all'] = roc

            group_names = dataset_params[args['dataset']].sensitive_attributes
            unique_groups = res[group_names].drop_duplicates()
            for grp in group_names:
                for val in res[grp].unique():
                    unique_groups = unique_groups.append({grp: 
                                                          val, **{i: np.nan for i in group_names if i != grp}}, 
                                                         ignore_index = True)

            for group in range(unique_groups.shape[0]):
                group_i = unique_groups.iloc[group].values
                
                mask = ~pd.isnull(group_i)
                curr_group_names = np.array(group_names)[mask]
                curr_group_vals = group_i[mask]
                sel_rows1 = res[(res[np.array(group_names)[mask]].values ==
                       group_i[mask]).all(1)]
                
                # if we have 2 sensitive group columns
                if len(curr_group_names)==2:
                    sel_rows = res[(res[curr_group_names[0]]== curr_group_vals[0])&
                                   (res[curr_group_names[1]]== curr_group_vals[1])]
                elif len(curr_group_names)==1:
                    sel_rows = res[res[curr_group_names[0]]== curr_group_vals[0]]
                else:
                    print('Warning: Default setting!')
                    sel_rows = res
                
                # checking that logic of row selection works 
                assert sel_rows1.equals(sel_rows)
                
                # computing performance on given group
                group_0_val_model = Model(sel_rows['expl_pred'].values, sel_rows[
                                          'blackbox_pred'].values,
                                         sel_rows[
                                          'blackbox_prob'].values)
                all_metrics = group_0_val_model.compute()
                args_group = {
                    **copy.deepcopy(args),
                    **all_metrics
                             }
                # some metadata
                args_group['group'] = str(unique_groups.iloc[group].values)
                args_group['n'] = len(sel_rows)
                args_group['level'] = mask.sum()
                df_raw.append(args_group)

df_local = pd.DataFrame(df_raw)

In [4]:
try:
    os.mkdir('csv_pred_outputs')
    os.mkdir('csv_pred_outputs/table1')
except:
    print('Folder exists!')
df_local.to_csv('csv_pred_outputs/table1/local_models_data_balanced.csv')

Folder exists!


In [5]:
# group settings
list_groups = [
    '[ 0. nan]',
    '[ 1. nan]',
    '[nan  0.]',
    '[nan  1.]',
    '[nan  2.]',
    '[nan  3.]',
    '[nan  4.]'
    
    
]

# the two sensitive groups of interest for all 4 datasets
# NB: first item in tuple is sex, second is race
# so here sex is considered the sensitive attribute for
# Adult and MIMIC and race for COMPAS and LSAC
list_groups_dict={}
list_groups_dict['adult']=[
    '[ 0. nan]',
    '[ 1. nan]']
list_groups_dict['mimic_tab']=[
     '[ 0. nan]',
     '[ 1. nan]']
list_groups_dict['lsac']=[
    '[nan  0.]',
    '[nan  1.]',
    '[nan  2.]',
    '[nan  3.]',
    '[nan  4.]']
list_groups_dict['lsac_cat']=[
    '[nan  0.]',
    '[nan  1.]',
    '[nan  2.]',
    '[nan  3.]',
    '[nan  4.]']

list_groups_dict['compas_balanced']=[
    '[nan  0.]',
    '[nan  1.]']



In [6]:
df_alg = []

In [7]:
hparams=['dataset',
 'blackbox_model',
 'explanation_type',
 'explanation_model',
 'n_features',
 'model_type',
 'experiment',
 'ignore_lime_weights',
 'evaluate_val',
 'max_epochs',
 'perturb_sigma',
 'train_grp_clf',
 'grp_clf_attr',
 'lr',
 'C',
 'batch_size',
 'debug',
 'jtt_lambda',
 'jtt_thres',
 'joint_dro_alpha',
 'groupdro_eta',
 'reductionist_type',
 'reductionist_difference_bound',
'balance_group_idx',
 'reductionist_thres',
'expl_reductionist_type']

In [8]:
def get_max_gap_row(row, disp='auroc',kind='mean',return_type='str_sum',
                    gap_or_not='_gap',
                    list_groups=list_groups):
    col_names = []
    col_mean_names = []
    for group in list_groups:
        col_names.append(f'{disp}_{group}{gap_or_not}')
        col_mean_names.append((f'{disp}_{group}{gap_or_not}',kind))
    
    if return_type=='str_sum':
        max_val_ind = np.nanargmax(row[col_mean_names].values)
        val=row[col_names[max_val_ind]]
    elif return_type=='n':
        max_val_ind = np.nanargmax(row[col_mean_names].values)
        val= row[col_mean_names[max_val_ind]]
    else:
        max_val_ind = np.nanargmax(row[col_names].values)
        val= row[col_names[max_val_ind]]
    return val

In [9]:
sens_feature_dict={}
sens_feature_dict['adult']='1'
sens_feature_dict['lsac']='2'
sens_feature_dict['lsac_cat']='2'
sens_feature_dict['mimic_tab']='1'
sens_feature_dict['recidivism']='2'
sens_feature_dict['compas_balanced']='2'
sens_feature_dict['synthetic']='1'


In [10]:
LEVEL = 1
df=pd.read_csv('csv_pred_outputs/table1/local_models_data_balanced.csv')
df=df[df.seed.isin([1,2,3,4,5])]
print(df.experiment.unique())
df=df[(df.experiment.isin(['lime_all_correct_replication']))]
df=df[(df.balance_groups==False)]
# df=df[(df.balance_group_idx.isna())]

res_df = df.query(f'level == {LEVEL}').groupby('output_dir').apply(agg_func)

metrics = [i for i in res_df.columns if not i.startswith('worst_group')]
out1 = (res_df
     .reset_index()
     .merge(df[hparams + ['output_dir', 'seed', 'hparam_id']].drop_duplicates()))

out2 = (out1.groupby('hparam_id').agg({
   i: ('mean', 'std') for i in metrics
})
       .merge(df[hparams + ['hparam_id']].drop_duplicates().set_index('hparam_id'), left_index = True, right_index = True)
       .reset_index())

for col in metrics:
    out2[col] = out2[(col, 'mean')].apply(lambda x: '{0:.001%}'.format(x)) +' ± ' +  out2[(col, 'std')].apply(lambda x: '{0:.001%}'.format(x))


out_all=out2.copy()
# out2=out2[out2.n_features.isin([100])]
temp1=out2[out2.explanation_model.isin(['lime','shap_blackbox_preprocessed'])]
dat_cols=['dataset', 'blackbox_model','explanation_model','n_features']
temp1[dat_cols+metrics].sort_values(dat_cols)


['lime_vary_sigma_correct_replication'
 'lime_vary_features_correct_replication' 'lime_all_correct_replication'
 'JTT' 'lime_balanced_correct_replication']


Unnamed: 0,dataset,blackbox_model,explanation_model,n_features,accuracy_all,accuracy_min,accuracy_minority,accuracy_majority,accuracy_gap,accuracy_sens1_gap,...,epsilon_gap,epsilon_sens1_gap,epsilon_sens2_gap,epsilon_[ 0. nan]_gap,epsilon_[ 1. nan]_gap,epsilon_[nan 0.]_gap,epsilon_[nan 1.]_gap,epsilon_[nan 2.]_gap,epsilon_[nan 3.]_gap,epsilon_[nan 4.]_gap
11,adult,lr,lime,100,98.5% ± 0.5%,96.7% ± 1.0%,98.3% ± 2.4%,98.0% ± 0.2%,1.8% ± 0.6%,1.7% ± 0.4%,...,2.0% ± 0.1%,1.9% ± 0.1%,1.9% ± 0.1%,0.8% ± 0.1%,-1.2% ± 0.1%,1.8% ± 0.3%,1.6% ± 0.3%,-0.7% ± 0.2%,-1.7% ± 0.4%,-0.6% ± 0.1%
2,adult,lr,shap_blackbox_preprocessed,100,100.0% ± 0.0%,100.0% ± 0.0%,100.0% ± 0.0%,100.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%,...,-0.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%,-0.0% ± 0.0%,-0.0% ± 0.0%,0.0% ± 0.0%
10,adult,nn,lime,100,88.3% ± 4.0%,81.6% ± 5.5%,88.7% ± 7.3%,84.4% ± 4.6%,6.7% ± 2.0%,10.7% ± 3.3%,...,4.8% ± 0.9%,1.3% ± 0.7%,4.2% ± 0.8%,1.4% ± 0.7%,0.1% ± 0.5%,4.8% ± 0.9%,-1.4% ± 2.0%,-2.3% ± 0.8%,-3.5% ± 1.3%,1.0% ± 0.5%
1,adult,nn,shap_blackbox_preprocessed,100,100.0% ± 0.0%,100.0% ± 0.0%,100.0% ± 0.0%,100.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%,...,-0.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%,-0.0% ± 0.0%,-0.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%
13,compas_balanced,lr,lime,100,99.9% ± 0.2%,99.8% ± 0.3%,99.8% ± 0.3%,100.0% ± 0.1%,0.1% ± 0.2%,0.1% ± 0.3%,...,-0.1% ± 0.0%,0.2% ± 0.0%,0.3% ± 0.0%,0.1% ± 0.0%,-0.1% ± 0.0%,-0.1% ± 0.0%,0.1% ± 0.0%,nan% ± nan%,nan% ± nan%,nan% ± nan%
4,compas_balanced,lr,shap_blackbox_preprocessed,100,100.0% ± 0.0%,100.0% ± 0.0%,100.0% ± 0.0%,100.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%,...,-0.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%,-0.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%,-0.0% ± 0.0%,nan% ± nan%,nan% ± nan%,nan% ± nan%
7,compas_balanced,nn,lime,100,94.1% ± 1.5%,92.6% ± 2.1%,92.8% ± 2.0%,94.6% ± 1.4%,1.4% ± 0.6%,1.8% ± 0.9%,...,-0.6% ± 0.1%,0.7% ± 0.2%,1.3% ± 0.2%,-0.5% ± 0.1%,0.2% ± 0.1%,-0.4% ± 0.1%,0.8% ± 0.1%,nan% ± nan%,nan% ± nan%,nan% ± nan%
3,compas_balanced,nn,shap_blackbox_preprocessed,100,100.0% ± 0.0%,100.0% ± 0.0%,100.0% ± 0.0%,100.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%,...,-0.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%,-0.0% ± 0.0%,0.0% ± 0.0%,-0.0% ± 0.0%,0.0% ± 0.0%,nan% ± nan%,nan% ± nan%,nan% ± nan%
15,lsac,lr,lime,100,97.4% ± 0.0%,91.4% ± 0.0%,91.4% ± 0.0%,99.9% ± 0.0%,6.0% ± 0.0%,0.0% ± 0.0%,...,3.5% ± 0.0%,0.1% ± 0.0%,2.8% ± 0.0%,-0.9% ± 0.0%,-1.0% ± 0.0%,-1.4% ± 0.0%,-0.4% ± 0.1%,-1.8% ± 0.1%,3.5% ± 0.0%,2.0% ± 0.1%
5,lsac,lr,shap_blackbox_preprocessed,100,100.0% ± 0.0%,100.0% ± 0.0%,100.0% ± 0.0%,100.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%,...,0.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%,-0.0% ± 0.0%,-0.0% ± 0.0%,-0.0% ± 0.0%,0.0% ± 0.0%,-0.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%


In [11]:
auroc_gaps=[]
acc_gaps=[]
for sig in np.sort(out_all.perturb_sigma.unique()):
    temp1=out_all[(out_all.explanation_type=='local')&
                  (out_all.blackbox_model.isin(['lr','nn']))&
                 (out_all.perturb_sigma==sig)]

    temp1=temp1[temp1.n_features.isin([100])]
    temp1=temp1[(temp1.explanation_model=='shap_blackbox_preprocessed')|
               (temp1.explanation_model=='lime')]
    temp1.loc[temp1.explanation_model=='shap_blackbox_preprocessed','explanation_model']='SHAP'
    temp1.loc[temp1.explanation_model=='lime','explanation_model']='LIME'
    for perf_metric in ['auroc','accuracy','epsilon']:
        temp1[f'{perf_metric}_max_gap']=temp1.apply(lambda row: get_max_gap_row(row,perf_metric,
                                                                       list_groups=list_groups_dict[row['dataset']]), axis=1)
        temp1[f'{perf_metric}_sens_gap']=temp1.apply(lambda row: row[f'{perf_metric}_sens{sens_feature_dict[row.dataset]}_gap'],axis=1)


In [12]:
metrics_reported_table1=['auroc_all',
                 'accuracy_max_gap',
                 'accuracy_sens_gap',
                 'auroc_sens_gap',
                 'epsilon_sens_gap'
                      ]
pd.pivot_table(data=temp1[['dataset','blackbox_model','explanation_model']+metrics_reported_table1],
               values=metrics_reported_table1,   index = ['dataset', 'blackbox_model','explanation_model'], 
                                                    aggfunc = lambda x: x,
                                                    )

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,accuracy_max_gap,accuracy_sens_gap,auroc_all,auroc_sens_gap,epsilon_sens_gap
dataset,blackbox_model,explanation_model,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
adult,lr,LIME,1.0% ± 0.5%,1.7% ± 0.4%,99.9% ± 0.0%,0.0% ± 0.0%,1.9% ± 0.1%
adult,lr,SHAP,0.0% ± 0.0%,0.0% ± 0.0%,100.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%
adult,nn,LIME,6.7% ± 2.0%,10.7% ± 3.3%,96.3% ± 1.7%,2.9% ± 0.6%,1.3% ± 0.7%
adult,nn,SHAP,0.0% ± 0.0%,0.0% ± 0.0%,100.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%
compas_balanced,lr,LIME,0.0% ± 0.1%,0.1% ± 0.2%,100.0% ± 0.0%,0.0% ± 0.0%,0.3% ± 0.0%
compas_balanced,lr,SHAP,0.0% ± 0.0%,0.0% ± 0.0%,100.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%
compas_balanced,nn,LIME,0.8% ± 0.7%,2.4% ± 1.6%,98.9% ± 0.3%,0.6% ± 0.2%,1.3% ± 0.2%
compas_balanced,nn,SHAP,0.0% ± 0.0%,0.0% ± 0.0%,100.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%
lsac,lr,LIME,6.0% ± 0.0%,4.6% ± 0.0%,100.0% ± 0.0%,0.1% ± 0.0%,2.8% ± 0.0%
lsac,lr,SHAP,0.0% ± 0.0%,0.0% ± 0.0%,100.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%
