In [None]:
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

In [None]:
import os
import argparse
from datetime import datetime
from pathlib import Path
import json
import glob

import numpy as np
import pandas as pd

from scipy.stats import wilcoxon
from statsmodels.stats.weightstats import ttost_paired

from dataloader import Dataset

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_context('talk', rc={"font.size":18,"axes.titlesize":18,"axes.labelsize":18,"xtick.labelsize":18})
sns.set_palette('colorblind')
import dataframe_image as dfi
from sklearn import tree
from matplotlib.gridspec import GridSpec
from sklearn.calibration import CalibrationDisplay

STORAGE_PATH = './results'

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--dir', default='alldata')  # output directory
parser.add_argument('--dataset_id', default=-1, type=int)  # 1 indexed
parser.add_argument('--model_id', default=-1, type=int)  # 1 indexed
parser.add_argument('--data_max_size', default=1000000000, type=int)  # max number of data points in each dataset
parser.add_argument('--exclude_group_feature', action='store_true')  # do not give group indicator as a feature
parser.add_argument('--to_categorical', action='store_true')  # use categorical features
parser.add_argument('--steps_alg3', default=10000, type=int)  # param for minimax algorithm of Abernethy et al. 2022
parser.add_argument('--steps_minimax_fair', default=10000, type=int)  # param for minimax algorithm of Diana et al. 2021
parser.add_argument('--error_minimax_fair', default='Log-Loss', type=str)  # Options 'MSE', '0/1 Loss', 'FP', 'FN', 'Log-Loss', 'FP-Log-Loss', 'FN-Log-Loss'
parser.add_argument('--warm_start_minimax_fair', action='store_true')  # start from a high weight on worst-off group

args = parser.parse_args('')
print(args)

## Tables

In [None]:
# Specify models to plot
MODEL_TYPES_TO_PLOT = [
    'TabularMLPClassifier',
    'LogisticRegressionSGD',
    'LogisticRegressionSGDdim2',
    'LogisticRegressionSGDdim4',
    'LogisticRegressionSGDdim8',
    # 'LogisticRegressionSGDdim16',
    'DecisionTree2',
    'DecisionTree4',
    'DecisionTree8',
    'RandomForest',
    'LinearSVC',
]

# Helper functions to read saved files,
# compute worst-case performance, and tabulate
def findworseoff(df):
    erm = df[df['trained_on']=='full'].value.min()
    minmax = df[df['trained_on']=='minmax'].value.min()
    diff_mm = (erm - minmax)
    df['difference_minmax'] = diff_mm
    df['difference_minmax_percentage_full'] = diff_mm/(erm+1e-5)*100
    df['minmax'] = minmax
    df['erm'] = erm
    group_optimal = df[~df['trained_on'].isin(['full','minmax'])].value.min()
    df['group_optimal'] = group_optimal
    df['difference_group_optimal'] = (erm - group_optimal)
    return df

def prepare_dataframe_worseoff(directories):
    files = []
    
    for d in directories:
        files += glob.glob(os.path.join(d,"*/metrics_did*_dataset.csv"))

    print(f'total results files {len(files)}')
    print(files)
    
    dfs = []
    for f in files:
        dfs.append(pd.read_csv(f))
    df = pd.concat(dfs, ignore_index=True)
    
    df['group_type'] = df['group_type'].fillna(value='custom')
    
    # Example of one dataset
    df_one_dataset = df[(df['dataset']=='adult_income_NY')
                    & (df['model_type']=='RandomForest')
                    & (df['eval_data_type']=='train')
                    & (df['metric'].isin(['accuracy']))
    ]
    df_one_dataset_pivot = df_one_dataset.pivot_table(index=['dataset','model_type','group_type','eval_data_type','trained_on','metric'], columns=['evaluated_on'], values=['value'])
    print(df_one_dataset_pivot)
    print("\n\n")
    
    df_diff = df.groupby(['dataset','group_type','metric','model_type','eval_data_type']).apply(findworseoff)
    
    df_full = df_diff[df_diff['trained_on']=='full']
    
    df_worse_performing_group = df_full.drop(['evaluated_on','value'], axis=1).drop_duplicates()  # lower values of metrics are worse
    
    return df_worse_performing_group

## Comparing ERM w minimax-fair and group-optimal models

In [None]:
# Specify the results directory
directories = [
    '../results/allmodel',
]

files = []
    
for d in directories:
    files += glob.glob(os.path.join(d,"*/metrics_did*_dataset.csv"))

print(len(files))

dfs = []
for f in files:
    dfs.append(pd.read_csv(f))
df = pd.concat(dfs, ignore_index=True)

In [None]:
df_worseoff = prepare_dataframe_worseoff(directories)
df_worseoff

## Worst-case accuracy plot to make Figure 2

In [None]:
df_plot = df_worseoff[
    (df_worseoff['metric']=='accuracy') &
    (df_worseoff['eval_data_type']=='train') # change to 'test' for Figure 3
]
df_plot.sort_values(['dataset','group_type'])

In [None]:
df_plot_ = df_plot[df_plot['model_type'].isin(['TabularMLPClassifier', 'LogisticRegressionSGD',
       'DecisionTree8', 'RandomForest', 'LinearSVC'])]

df_plot_.loc[:, 'model_type'] = df_plot_.loc[:, 'model_type'].str.replace('TabularMLPClassifier', 'MLP')
df_plot_.loc[:, 'model_type'] = df_plot_.loc[:, 'model_type'].str.replace('LogisticRegressionSGD', 'Logistic Regression')
df_plot_.loc[:, 'model_type'] = df_plot_.loc[:, 'model_type'].str.replace('DecisionTree8', 'Decision Tree depth 8')
df_plot_.loc[:, 'model_type'] = df_plot_.loc[:, 'model_type'].str.replace('RandomForest', 'Random Forest')
df_plot_.loc[:, 'model_type'] = df_plot_.loc[:, 'model_type'].str.replace('LinearSVC', 'Linear SVC')

fig, (ax1, ax2) = plt.subplots(2,figsize=(10,18))
g1 = sns.scatterplot(data=df_plot_,
            x='erm', y='group_optimal',
            hue='model_type', style='model_type', s=250, alpha=0.7,
            ax=ax1)
ax1.set_xlabel('ERM,\n worst-off group accuracy', fontsize=22)
ax1.set_ylabel('GROUP-OPTIMAL,\n worst-off group accuracy', fontsize=22)
xpoints = ypoints = ax1.get_xlim()
ax1.plot(xpoints, ypoints, linestyle='--', color='k', lw=1, scalex=False, scaley=False)
ax1.get_legend().remove()

g2 = sns.scatterplot(data=df_plot_,
            x='erm', y='minmax',
            hue='model_type', style='model_type', s=250, alpha=0.7,
            ax=ax2)
ax2.set_xlabel('ERM,\n worst-off group accuracy', fontsize=22)
ax2.set_ylabel('MINIMAX,\n worst-off group accuracy', fontsize=22)
ax2.plot(xpoints, ypoints, linestyle='--', color='k', lw=1, scalex=False, scaley=False)
ax2.legend(loc='upper center', bbox_to_anchor=(0.5, -0.25), ncol=3)
fig.tight_layout(pad=2.0)
plt.savefig('accuracy_train.pdf', bbox_inches='tight')

## Testing for equivalence and non-inferiority hypotheses to make Tables 2-6

In [None]:
lower_threshold, upper_threshold = -0.01, 0.01
df_pvalues = pd.DataFrame(
    df_worseoff.groupby(['metric','model_type','eval_data_type']).apply(
        lambda x: pd.Series({
            "pvalue_equivalence_minmax": ttost_paired(x["erm"], x["minmax"], low=lower_threshold, upp=upper_threshold)[0],
            "pvalue_noninferior_minmax": ttost_paired(x["erm"], x["minmax"], low=lower_threshold, upp=upper_threshold)[1][1],
            "pvalue_equivalence_group_optimal": ttost_paired(x["erm"], x["group_optimal"], low=lower_threshold, upp=upper_threshold)[0],
            "pvalue_noninferior_group_optimal": ttost_paired(x["erm"], x["group_optimal"], low=lower_threshold, upp=upper_threshold)[1][1]
        })
    )
).reset_index()

for metric in df_pvalues['metric'].unique():
    if metric in ['neglogloss', 'accuracy']:
        for eval_data_type in df_pvalues['eval_data_type'].unique():
            print(metric, eval_data_type)
            tab = df_pvalues[
                (df_pvalues['metric']==metric)\
                & (df_pvalues['eval_data_type']==eval_data_type)\
            ][['model_type','pvalue_equivalence_group_optimal','pvalue_noninferior_group_optimal','pvalue_equivalence_minmax','pvalue_noninferior_minmax']]
            print(tab.round(decimals=4).to_latex(index=False))