In [1]:
import numpy as np
import pandas as pd
from pathlib import Path
import os
import sys
sys.path.append('../')
from compute_metrics import Model
import json
from lib.TabularDataset import dataset_params
import copy
from sklearn.metrics import roc_auc_score
import hashlib

res_dir = Path('/scratch/hdd001/home/haoran/explanations-subpopulations/')

In [21]:
df_raw = []
for path in res_dir.glob('**/*.csv'):
    args = json.load((path.parent/'args.json').open('r'))
    hparams = [i for i in list(args.keys()) if i not in  ['seed', 'output_dir']]
    args['hparam_id'] = hashlib.md5(str([args[i] for i in hparams]).encode('utf-8')).hexdigest()
    res = pd.read_csv(path)
    args['accuracy_all'] = ((res['expl_pred'].values > 0.5) == res[
                                  'blackbox_pred']).sum() / len(res)
    try:
        args['auroc_all'] = roc_auc_score(res['blackbox_pred'].values, res['expl_pred'].values)
    except ValueError:
        # print(args)
        args['auroc_all'] = 0.5
    
    group_names = dataset_params[args['dataset']].sensitive_attributes
    unique_groups = res[group_names].drop_duplicates()
    for grp in group_names:
        for val in res[grp].unique():
            unique_groups = unique_groups.append({grp: val, **{i: np.nan for i in group_names if i != grp}}, ignore_index = True)
    
    for group in range(unique_groups.shape[0]):
        group_i = unique_groups.iloc[group].values
        mask = ~pd.isnull(group_i)
        sel_rows = res[(res[np.array(group_names)[mask]].values ==
                       group_i[mask]).all(1)]

        group_0_val_model = Model(sel_rows['expl_pred'].values, sel_rows[
                                  'blackbox_pred'].values)
        all_metrics = group_0_val_model.compute()
        args_group = {
            **copy.deepcopy(args),
            **all_metrics
                     }
        
        args_group['group'] = str(unique_groups.iloc[group].values)
        args_group['n'] = len(sel_rows)
        args_group['level'] = mask.sum()
        
        df_raw.append(args_group)

df = pd.DataFrame(df_raw)

In [24]:
def agg_func(x):
    res = {}
    for met, disp in zip(['ACC', 'AUROC'], ['accuracy', 'auroc']):
        res[f'{disp}_all'] = x[f'{disp}_all'].iloc[0]
        res[f'{disp}_min'] = x[met].min()
        res[f'{disp}_minority'] = x.loc[x['n'].idxmin(), met]
        res[f'worst_group_{disp}'] = x.loc[x[met].idxmin(), 'group']
        
    return pd.Series(res)

def bold_max(s):
    fmts = []
    for grp in s.index.get_level_values(0).unique():
        max_seq = [i for i in s[grp] if isinstance(i, (float, np.float32, np.float64))]
        max_val = max(max_seq) if len(max_seq) else 0
        fmts += ['font-weight: bold' if i == max_val else '' for i in s[grp]]
    return fmts

In [35]:
LEVEL = 2

res_df = df.query(f'level == {LEVEL}').groupby('output_dir').apply(agg_func)
metrics = [i for i in res_df.columns if not i.startswith('worst_group')]
out1 = (res_df
     .reset_index()
     .merge(df[hparams + ['output_dir', 'seed', 'hparam_id']].drop_duplicates()))

out2 = (out1.groupby('hparam_id').agg({
   i: ('mean', 'std') for i in metrics
})
       .merge(df[hparams + ['hparam_id']].drop_duplicates().set_index('hparam_id'), left_index = True, right_index = True)
       .reset_index())

for col in metrics:
    out2[col] = out2[(col, 'mean')].apply(lambda x: '{0:.1%}'.format(x)) +' ± ' +  out2[(col, 'std')].apply(lambda x: '{0:.1%}'.format(x))

out2.pivot_table(values = metrics, 
                  index = ['dataset', 'blackbox_model', 'n_features'], columns = ['model_type'], aggfunc = lambda x: x)



Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,accuracy_all,accuracy_all,accuracy_all,accuracy_min,accuracy_min,accuracy_min,accuracy_minority,accuracy_minority,accuracy_minority,auroc_all,auroc_all,auroc_all,auroc_min,auroc_min,auroc_min,auroc_minority,auroc_minority,auroc_minority
Unnamed: 0_level_1,Unnamed: 1_level_1,model_type,ARL,ERM,sklearn,ARL,ERM,sklearn,ARL,ERM,sklearn,ARL,ERM,sklearn,ARL,ERM,sklearn,ARL,ERM,sklearn
dataset,blackbox_model,n_features,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2
adult,lr,5,95.5% ± 0.1%,93.1% ± 0.5%,100.0% ± 0.0%,88.6% ± 2.6%,77.5% ± 6.1%,100.0% ± 0.0%,97.1% ± 6.4%,80.0% ± 7.8%,100.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%
adult,lr,10,80.0% ± 0.9%,54.7% ± 0.6%,100.0% ± 0.0%,54.7% ± 20.2%,19.7% ± 7.3%,100.0% ± 0.0%,82.9% ± 18.6%,48.6% ± 21.7%,100.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%
adult,nn,5,69.8% ± 7.1%,72.4% ± 5.0%,75.1% ± 7.3%,58.2% ± 9.6%,61.6% ± 4.7%,64.8% ± 10.5%,85.7% ± 10.1%,91.4% ± 7.8%,88.6% ± 6.4%,56.0% ± 11.0%,60.7% ± 7.6%,64.4% ± 12.7%,33.5% ± 16.1%,42.1% ± 4.5%,42.4% ± 9.7%,50.0% ± 0.0%,60.0% ± 22.4%,50.0% ± 0.0%
adult,nn,10,62.2% ± 4.0%,64.3% ± 2.0%,72.7% ± 6.8%,47.4% ± 5.0%,51.6% ± 3.5%,59.2% ± 10.3%,82.9% ± 15.6%,85.7% ± 10.1%,91.4% ± 7.8%,41.7% ± 6.4%,48.1% ± 2.4%,60.5% ± 10.3%,23.4% ± 5.3%,25.8% ± 6.7%,38.5% ± 9.8%,50.0% ± 0.0%,60.0% ± 22.4%,50.0% ± 0.0%
adult,rf,5,92.8% ± 0.5%,93.1% ± 0.5%,95.8% ± 0.1%,90.0% ± 1.4%,89.8% ± 2.4%,93.0% ± 0.0%,100.0% ± 0.0%,97.1% ± 6.4%,100.0% ± 0.0%,84.8% ± 0.9%,85.3% ± 0.9%,89.4% ± 1.0%,0.0% ± 0.0%,2.7% ± 3.7%,0.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%
adult,rf,10,79.5% ± 0.4%,71.9% ± 0.3%,95.4% ± 0.1%,64.1% ± 5.3%,39.5% ± 6.3%,93.0% ± 0.0%,82.9% ± 18.6%,48.6% ± 16.3%,100.0% ± 0.0%,61.1% ± 0.9%,53.3% ± 0.9%,89.0% ± 0.9%,10.4% ± 19.8%,1.3% ± 3.0%,0.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%
adult,svm_rbf,5,90.7% ± 0.3%,91.4% ± 0.2%,93.6% ± 0.1%,86.5% ± 0.6%,85.7% ± 2.8%,91.0% ± 0.2%,97.1% ± 6.4%,94.3% ± 7.8%,100.0% ± 0.0%,85.9% ± 0.3%,86.4% ± 0.4%,88.8% ± 0.1%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%
adult,svm_rbf,10,84.5% ± 0.1%,80.8% ± 0.4%,91.1% ± 0.1%,77.3% ± 1.3%,63.1% ± 2.4%,86.9% ± 0.2%,94.3% ± 7.8%,85.7% ± 0.0%,100.0% ± 0.0%,74.7% ± 0.4%,68.9% ± 0.8%,87.0% ± 0.1%,50.0% ± 0.0%,42.6% ± 6.2%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%
adult,xgb,5,94.0% ± 0.1%,93.5% ± 0.1%,95.5% ± 0.0%,90.7% ± 2.8%,90.3% ± 2.6%,93.0% ± 0.0%,97.1% ± 6.4%,97.1% ± 6.4%,100.0% ± 0.0%,87.3% ± 0.3%,86.9% ± 0.2%,89.1% ± 0.1%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%
adult,xgb,10,92.6% ± 0.1%,86.8% ± 0.5%,95.3% ± 0.1%,87.6% ± 1.1%,67.3% ± 4.7%,93.0% ± 0.0%,100.0% ± 0.0%,88.6% ± 6.4%,100.0% ± 0.0%,86.9% ± 0.4%,78.7% ± 0.6%,89.2% ± 0.1%,43.6% ± 7.7%,37.1% ± 17.6%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%,50.0% ± 0.0%
