In [1]:
import os
import sys
sys.path.append('../')
import hashlib
import json
from pathlib import Path
import copy

import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_auc_score, auc, roc_curve
from sklearn import metrics
import warnings
# NB: Warnings occur when computing metrics for groups with
# low sample sizes. In our case, we don't use the metrics 
# for these groups.
warnings.filterwarnings("ignore")

from compute_metrics import Model
from lib.TabularDataset import dataset_params
from lib import TabularDataset
from result_latex_utils import meanPairDiff, agg_func, bold_max


res_dir = Path('/scratch/ssd001/home/aparna/explanations-subpopulations/output_main')

In [2]:
# retain column names
retain_cols=False

In [3]:
df_raw = []
for path in res_dir.glob('**/*.csv'):
    args = json.load((path.parent/'args.json').open('r'))
    # all settings for paper table 1
    #'lime','shap_blackbox_preprocessed',
    if args['explanation_type']=='local' and args['dataset'] in ['compas_balanced','lsac','adult_cleaned','mimic_tab_robust','mimic_tab_fair_ds']:# and args['explanation_model'] in ['lime_test_new','lime_test','lime','shap_blackbox_preprocessed']:
        hparams = [i for i in list(args.keys()) if i not in ['seed', 'output_dir']]
        args['hparam_id'] = hashlib.md5(str([args[i] for i in hparams]).encode('utf-8')).hexdigest()
        res = pd.read_csv(path)
        
        # get performance on test set
        res = res[res.set=='test']

        if res.shape[0]>0:
            dataset = TabularDataset.Dataset(args['dataset'])
            X_train, X_train_expl, X_val_expl, X_test, y_train, y_train_expl, y_val_expl, y_test, g_train, g_train_expl, g_val_expl, g_test = dataset.get_data(
            retain_cols=retain_cols)
            
            # to check that the whole test set is present
            assert res.shape[0]==g_test.shape[0]
            res[dataset_params[args['dataset']].sensitive_attributes] = g_test.values
            assert res[res['blackbox_prob'].isna()].shape[0]==0
            assert res[res['expl_pred'].isna()].shape[0]==0
                            
            if sum(~np.isfinite(res['expl_pred'].values))>0:
                raise ValueError('For dataset {}, model {}'.format(args['dataset'],
                                                                  args['explanation_model']))
                
            # compute metrics over all items in test set
            args['accuracy_all'] = ((res['expl_pred'].values >= 0.5) == res[
                                          'blackbox_pred']).sum() / len(res)
            fpr, tpr, thresholds = roc_curve(res['blackbox_pred'].values, res['expl_pred'].values
                                             , pos_label=1)
            roc = metrics.auc(fpr, tpr)
            args['auroc_all'] = roc
            args['prevalence_all']= res['blackbox_pred'].mean()

            group_names = dataset_params[args['dataset']].sensitive_attributes
            unique_groups = res[group_names].drop_duplicates()
            for grp in group_names:
                for val in res[grp].unique():
                    unique_groups = unique_groups.append({grp: 
                                                          val, **{i: np.nan for i in group_names if i != grp}}, 
                                                         ignore_index = True)

            for group in range(unique_groups.shape[0]):
                group_i = unique_groups.iloc[group].values
                
                mask = ~pd.isnull(group_i)
                curr_group_names = np.array(group_names)[mask]
                curr_group_vals = group_i[mask]
                sel_rows1 = res[(res[np.array(group_names)[mask]].values ==
                       group_i[mask]).all(1)]
                
                # if we have 2 sensitive group columns
                if len(curr_group_names)==2:
                    sel_rows = res[(res[curr_group_names[0]]== curr_group_vals[0])&
                                   (res[curr_group_names[1]]== curr_group_vals[1])]
                elif len(curr_group_names)==1:
                    sel_rows = res[res[curr_group_names[0]]== curr_group_vals[0]]
                else:
                    print('Warning: Default setting!')
                    sel_rows = res
                
                # checking that logic of row selection works 
                assert sel_rows1.equals(sel_rows)
                
                # computing performance on given group
                group_0_val_model = Model(sel_rows['expl_pred'].values, sel_rows[
                                          'blackbox_pred'].values,
                                         sel_rows[
                                          'blackbox_prob'].values)
                all_metrics = group_0_val_model.compute()
                args_group = {
                    **copy.deepcopy(args),
                    **all_metrics
                             }
                # some metadata
                args_group['group'] = str(unique_groups.iloc[group].values)
                args_group['n'] = len(sel_rows)
                args_group['level'] = mask.sum()
                args_group['prevalence'] = sel_rows['blackbox_pred'].mean()
                args_group['pred_prevalence'] = np.mean(sel_rows['expl_pred']>=0.5)
                df_raw.append(args_group)

df_local = pd.DataFrame(df_raw)

In [4]:
# group settings
list_groups = [
    '[ 0. nan]',
    '[ 1. nan]',
    '[nan  0.]',
    '[nan  1.]',
    '[nan  2.]',
    '[nan  3.]',
    '[nan  4.]'
    
    
]

# the two sensitive groups of interest for all 4 datasets
# NB: first item in tuple is sex, second is race
# so here sex is considered the sensitive attribute for
# Adult and MIMIC and race for COMPAS and LSAC
list_groups_dict={}
list_groups_dict['adult_cleaned']=[
    '[ 0. nan]',
    '[ 1. nan]']
list_groups_dict['mimic_tab_robust']=[
     '[ 0. nan]',
     '[ 1. nan]']

list_groups_dict['lsac']=[
    '[nan  0.]',
    '[nan  1.]',
    '[nan  2.]',
    '[nan  3.]',
    '[nan  4.]']
list_groups_dict['lsac_cat']=[
    '[nan  0.]',
    '[nan  1.]',
    '[nan  2.]',
    '[nan  3.]',
    '[nan  4.]']

list_groups_dict['compas_balanced']=[
    '[nan  0.]',
    '[nan  1.]']



In [5]:
df_alg = []

In [6]:
# fetching hyperparams
for path in res_dir.glob('**/*.csv'):
    args = json.load((path.parent/'args.json').open('r'))
    hparams = [i for i in list(args.keys()) if i not in  ['seed', 'output_dir']]
    break

In [7]:
def get_max_gap_row(row, disp='auroc',kind='mean',return_type='str_sum',
                    gap_or_not='_gap',
                    list_groups=list_groups):
    col_names = []
    col_mean_names = []
    for group in list_groups:
        col_names.append(f'{disp}_{group}{gap_or_not}')
        col_mean_names.append((f'{disp}_{group}{gap_or_not}',kind))
    
    if return_type=='str_sum':
        max_val_ind = np.nanargmax(row[col_mean_names].values)
        val=row[col_names[max_val_ind]]
    elif return_type=='n':
        max_val_ind = np.nanargmax(row[col_mean_names].values)
        val= row[col_mean_names[max_val_ind]]
    else:
        max_val_ind = np.nanargmax(row[col_names].values)
        val= row[col_names[max_val_ind]]
    return val

In [8]:
sens_feature_dict={}
sens_feature_dict['adult_cleaned']='1'
sens_feature_dict['lsac']='2'
sens_feature_dict['mimic_tab_robust']='1'
sens_feature_dict['compas_balanced']='2'


In [9]:
# balance sensitive attribute groups by oversampling during training?
balance_group_ind = False

# balance labels by oversampling during training?
# NB: We will always do this for LSAC, see Appendix
# in paper
balance_labels_ind = False

In [10]:
df_sel=pd.DataFrame(df_raw)
df_sel=df_sel[((df_sel.dataset!='lsac')&(df_sel.balance_labels==balance_labels_ind))|
              ((df_sel.dataset=='lsac')&(df_sel.balance_labels==True))
             ]
df_sel=df_sel[df_sel.balance_groups==balance_group_ind]

In [11]:
LEVEL = 1
df=df_sel
df=df[df.seed.isin([1,2,3,4,5])]
df=df[(df.experiment.isin(['lime_balanced_correct_replication']))]

res_df = df.query(f'level == {LEVEL}').groupby('output_dir').apply(agg_func)

metrics = [i for i in res_df.columns if not i.startswith('worst_group')]
out1 = (res_df
     .reset_index()
     .merge(df[hparams + ['output_dir', 'seed', 'hparam_id']].drop_duplicates()))

out2 = (out1.groupby('hparam_id').agg({
   i: ('mean', 'std') for i in metrics
})
       .merge(df[hparams + ['hparam_id']].drop_duplicates().set_index('hparam_id'), left_index = True, right_index = True)
       .reset_index())

for col in metrics:
    out2[col] = out2[(col, 'mean')].apply(lambda x: '{0:.001%}'.format(x)) +' ± ' +  out2[(col, 'std')].apply(lambda x: '{0:.001%}'.format(x))


out_all=out2.copy()

In [14]:
auroc_gaps=[]
acc_gaps=[]
for sig in np.sort(out_all.perturb_sigma.unique()):
    temp1=out_all[(out_all.explanation_type=='local')&
                  (out_all.blackbox_model.isin(['lr','nn']))&
                 (out_all.perturb_sigma==sig)]

    temp1=temp1[temp1.n_features.isin([100])]
    temp1=temp1[(temp1.explanation_model=='shap_blackbox_preprocessed')|
               (temp1.explanation_model=='lime')|
               (temp1.explanation_model=='lime_test')|
                (temp1.explanation_model=='lime_test_new')
               ]
    temp1.loc[temp1.explanation_model=='shap_blackbox_preprocessed','explanation_model']='SHAP'
    temp1.loc[temp1.explanation_model=='lime','explanation_model']='LIME'
    temp1 = temp1[temp1.dataset.isin(['mimic_tab_robust','lsac','adult_cleaned','compas_balanced'])]
    temp1 = temp1[temp1.explanation_model.isin(['LIME','SHAP'])]

    for perf_metric in ['auroc','accuracy','epsilon']:
        temp1[f'{perf_metric}_max_gap']=temp1.apply(lambda row: get_max_gap_row(row,perf_metric,
                                                                       list_groups=list_groups_dict[row['dataset']]), axis=1)
        temp1[f'{perf_metric}_sens_gap']=temp1.apply(lambda row: row[f'{perf_metric}_sens{sens_feature_dict[row.dataset]}_gap'],axis=1)


In [15]:
metrics_reported_table1=['auroc_all',
                 'accuracy_max_gap',
                  'auroc_sens_gap',
                 'accuracy_sens_gap',
                 'epsilon_sens_gap'
                      ]
# temp1=temp1[temp1.explanation_model=='LIME']
pd.pivot_table(data=temp1[['dataset','blackbox_model','explanation_model']+metrics_reported_table1],
               values=metrics_reported_table1,   index = ['dataset', 'blackbox_model','explanation_model'], 
                                                    aggfunc = lambda x: x,
                                                    )

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,accuracy_max_gap,accuracy_sens_gap,auroc_all,auroc_sens_gap,epsilon_sens_gap
dataset,blackbox_model,explanation_model,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
adult_cleaned,lr,LIME,0.8% ± 0.0%,2.4% ± 0.1%,99.9% ± 0.0%,0.1% ± 0.0%,1.9% ± 0.0%
adult_cleaned,lr,SHAP,0.0% ± 0.0%,0.0% ± 0.0%,100.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%
adult_cleaned,nn,LIME,6.9% ± 0.7%,20.6% ± 2.0%,95.7% ± 1.2%,3.0% ± 1.2%,0.8% ± 0.5%
adult_cleaned,nn,SHAP,0.0% ± 0.0%,0.0% ± 0.0%,100.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%
compas_balanced,lr,LIME,0.1% ± 0.1%,0.3% ± 0.2%,100.0% ± 0.0%,0.0% ± 0.0%,0.3% ± 0.0%
compas_balanced,lr,SHAP,0.0% ± 0.0%,0.0% ± 0.0%,100.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%
compas_balanced,nn,LIME,0.9% ± 0.3%,2.4% ± 0.7%,99.0% ± 0.2%,0.7% ± 0.3%,1.1% ± 0.1%
compas_balanced,nn,SHAP,0.0% ± 0.0%,0.0% ± 0.0%,100.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%
lsac,lr,LIME,2.0% ± 1.0%,1.5% ± 0.5%,100.0% ± 0.0%,0.0% ± 0.0%,1.5% ± 0.1%
lsac,lr,SHAP,0.0% ± 0.0%,0.0% ± 0.0%,100.0% ± 0.0%,0.0% ± 0.0%,0.0% ± 0.0%


In [18]:
temp1=temp1[temp1.explanation_model.isin(['LIME'])]
print(temp1[['dataset','blackbox_model','explanation_model',
#              'auroc_all',
#              'auroc_max_gap',
             'accuracy_max_gap',
             'auroc_sens_gap',
             'accuracy_sens_gap',
             'epsilon_sens_gap'
             ]].sort_values(['dataset','blackbox_model','explanation_model']).to_latex())


\begin{tabular}{llllllll}
\toprule
{} &           dataset & blackbox\_model & explanation\_model & accuracy\_max\_gap & auroc\_sens\_gap & accuracy\_sens\_gap & epsilon\_sens\_gap \\
\midrule
5  &     adult\_cleaned &             lr &              LIME &      0.8\% ± 0.0\% &    0.1\% ± 0.0\% &       2.4\% ± 0.1\% &      1.9\% ± 0.0\% \\
15 &     adult\_cleaned &             nn &              LIME &      6.9\% ± 0.7\% &    3.0\% ± 1.2\% &      20.6\% ± 2.0\% &      0.8\% ± 0.5\% \\
21 &   compas\_balanced &             lr &              LIME &      0.1\% ± 0.1\% &    0.0\% ± 0.0\% &       0.3\% ± 0.2\% &      0.3\% ± 0.0\% \\
24 &   compas\_balanced &             nn &              LIME &      0.9\% ± 0.3\% &    0.7\% ± 0.3\% &       2.4\% ± 0.7\% &      1.1\% ± 0.1\% \\
14 &              lsac &             lr &              LIME &      2.0\% ± 1.0\% &    0.0\% ± 0.0\% &       1.5\% ± 0.5\% &      1.5\% ± 0.1\% \\
29 &              lsac &             nn &              LIME &     21.4\% ±