In [2]:
import sklearn
# import shap

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings('ignore')

# Import base classifiers
from sklearn.dummy import DummyClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC, LinearSVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.naive_bayes import GaussianNB, BernoulliNB, MultinomialNB
from sklearn.neural_network import MLPClassifier
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier
from baselines import AdaFairClassifier
from imbens.ensemble import SMOTEBoostClassifier, SMOTEBaggingClassifier, RUSBoostClassifier, UnderBaggingClassifier, SelfPacedEnsembleClassifier
from fairlearn.postprocessing import ThresholdOptimizer
from fairens import FairAugEnsemble, FairEnsemble

# Import utilities
from data import FairDataset    # This is a custom class that we will use to load the datasets
from eval import evaluate_multi_split, verbose_print
from trainer import Trainer
from utils import seed_generator, dict_info, describe_data

In [4]:
def load_adult(
    path_prefix='./raw_data/adult/IBM_adult',
    response='Income',
    sensitive='Gender',
    val=True,
    return_df=False,
):
    df = pd.read_csv(path_prefix + '_X.txt', sep='\t', index_col=0)
    df['gender'] = pd.read_csv(
        path_prefix + '_A.txt', sep='\t', index_col=0, header=None
    )  # Male = 1
    df['label'] = pd.read_csv(path_prefix + '_Y.txt', sep='\t', index_col=0, header=None)

    return df

df = load_adult()
print(df.shape)
df.describe()

(45222, 99)


Unnamed: 0,capital-gain,race,age,education-num,capital-loss,hours-per-week,workclass=Federal-gov,workclass=Local-gov,workclass=Private,workclass=Self-emp-inc,workclass=Self-emp-not-inc,workclass=State-gov,workclass=Without-pay,education=10th,education=11th,education=12th,education=1st-4th,education=5th-6th,education=7th-8th,education=9th,education=Assoc-acdm,education=Assoc-voc,education=Bachelors,education=Doctorate,education=HS-grad,education=Masters,education=Preschool,education=Prof-school,education=Some-college,marital-status=Divorced,marital-status=Married-AF-spouse,marital-status=Married-civ-spouse,marital-status=Married-spouse-absent,marital-status=Never-married,marital-status=Separated,marital-status=Widowed,occupation=Adm-clerical,occupation=Armed-Forces,occupation=Craft-repair,occupation=Exec-managerial,...,native-country=Columbia,native-country=Cuba,native-country=Dominican-Republic,native-country=Ecuador,native-country=El-Salvador,native-country=England,native-country=France,native-country=Germany,native-country=Greece,native-country=Guatemala,native-country=Haiti,native-country=Holand-Netherlands,native-country=Honduras,native-country=Hong,native-country=Hungary,native-country=India,native-country=Iran,native-country=Ireland,native-country=Italy,native-country=Jamaica,native-country=Japan,native-country=Laos,native-country=Mexico,native-country=Nicaragua,native-country=Outlying-US(Guam-USVI-etc),native-country=Peru,native-country=Philippines,native-country=Poland,native-country=Portugal,native-country=Puerto-Rico,native-country=Scotland,native-country=South,native-country=Taiwan,native-country=Thailand,native-country=Trinadad&Tobago,native-country=United-States,native-country=Vietnam,native-country=Yugoslavia,gender,label
count,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,...,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0
mean,1101.430344,0.860267,38.547941,10.11846,88.595418,40.938017,0.031091,0.068551,0.736522,0.036398,0.083941,0.043032,0.000464,0.027044,0.035801,0.012759,0.004909,0.009929,0.018199,0.014948,0.033324,0.04332,0.167396,0.01203,0.326898,0.055592,0.001592,0.017359,0.218898,0.139246,0.000708,0.465592,0.012206,0.322807,0.031202,0.028238,0.122507,0.00031,0.133121,0.132325,...,0.001813,0.002941,0.002145,0.000951,0.003251,0.002631,0.000796,0.004268,0.001084,0.001902,0.001526,2.2e-05,0.00042,0.000619,0.000398,0.003251,0.001238,0.000796,0.002211,0.002278,0.001968,0.000464,0.019968,0.001061,0.000486,0.000995,0.006258,0.001791,0.001371,0.00387,0.000442,0.002233,0.001216,0.000641,0.000575,0.913095,0.001835,0.000509,0.675048,0.247844
std,7506.430084,0.346714,13.21787,2.552881,404.956092,12.007508,0.173566,0.252691,0.440524,0.187281,0.277303,0.202932,0.021545,0.162214,0.185796,0.112235,0.069894,0.099149,0.133672,0.121348,0.179484,0.203578,0.373334,0.109019,0.469085,0.229135,0.03987,0.130606,0.413504,0.346207,0.026592,0.49882,0.109808,0.467555,0.173864,0.165655,0.327874,0.017592,0.339709,0.338847,...,0.042544,0.054152,0.046265,0.030822,0.056922,0.051231,0.028204,0.06519,0.0329,0.043568,0.039032,0.004702,0.020493,0.024876,0.019947,0.056922,0.035169,0.028204,0.046973,0.047671,0.04432,0.021545,0.139892,0.032563,0.022051,0.03153,0.078861,0.042285,0.037002,0.062088,0.021026,0.047207,0.034854,0.025316,0.023971,0.281698,0.042803,0.022547,0.468362,0.431766
min,0.0,0.0,17.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,1.0,28.0,9.0,0.0,40.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
50%,0.0,1.0,37.0,10.0,0.0,40.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0
75%,0.0,1.0,47.0,13.0,0.0,45.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0
max,99999.0,1.0,90.0,16.0,4356.0,99.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [6]:
df_processed = pd.read_csv('./data/adult.csv')
df_processed.describe()

Unnamed: 0,capital-gain,race,age,education-num,capital-loss,hours-per-week,workclass=Federal-gov,workclass=Local-gov,workclass=Private,workclass=Self-emp-inc,workclass=Self-emp-not-inc,workclass=State-gov,workclass=Without-pay,education=10th,education=11th,education=12th,education=1st-4th,education=5th-6th,education=7th-8th,education=9th,education=Assoc-acdm,education=Assoc-voc,education=Bachelors,education=Doctorate,education=HS-grad,education=Masters,education=Preschool,education=Prof-school,education=Some-college,marital-status=Divorced,marital-status=Married-AF-spouse,marital-status=Married-civ-spouse,marital-status=Married-spouse-absent,marital-status=Never-married,marital-status=Separated,marital-status=Widowed,occupation=Adm-clerical,occupation=Armed-Forces,occupation=Craft-repair,occupation=Exec-managerial,...,native-country=Columbia,native-country=Cuba,native-country=Dominican-Republic,native-country=Ecuador,native-country=El-Salvador,native-country=England,native-country=France,native-country=Germany,native-country=Greece,native-country=Guatemala,native-country=Haiti,native-country=Holand-Netherlands,native-country=Honduras,native-country=Hong,native-country=Hungary,native-country=India,native-country=Iran,native-country=Ireland,native-country=Italy,native-country=Jamaica,native-country=Japan,native-country=Laos,native-country=Mexico,native-country=Nicaragua,native-country=Outlying-US(Guam-USVI-etc),native-country=Peru,native-country=Philippines,native-country=Poland,native-country=Portugal,native-country=Puerto-Rico,native-country=Scotland,native-country=South,native-country=Taiwan,native-country=Thailand,native-country=Trinadad&Tobago,native-country=United-States,native-country=Vietnam,native-country=Yugoslavia,gender,label
count,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,...,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0,45222.0
mean,1101.430344,0.860267,38.547941,10.11846,88.595418,40.938017,0.031091,0.068551,0.736522,0.036398,0.083941,0.043032,0.000464,0.027044,0.035801,0.012759,0.004909,0.009929,0.018199,0.014948,0.033324,0.04332,0.167396,0.01203,0.326898,0.055592,0.001592,0.017359,0.218898,0.139246,0.000708,0.465592,0.012206,0.322807,0.031202,0.028238,0.122507,0.00031,0.133121,0.132325,...,0.001813,0.002941,0.002145,0.000951,0.003251,0.002631,0.000796,0.004268,0.001084,0.001902,0.001526,2.2e-05,0.00042,0.000619,0.000398,0.003251,0.001238,0.000796,0.002211,0.002278,0.001968,0.000464,0.019968,0.001061,0.000486,0.000995,0.006258,0.001791,0.001371,0.00387,0.000442,0.002233,0.001216,0.000641,0.000575,0.913095,0.001835,0.000509,0.675048,0.247844
std,7506.430084,0.346714,13.21787,2.552881,404.956092,12.007508,0.173566,0.252691,0.440524,0.187281,0.277303,0.202932,0.021545,0.162214,0.185796,0.112235,0.069894,0.099149,0.133672,0.121348,0.179484,0.203578,0.373334,0.109019,0.469085,0.229135,0.03987,0.130606,0.413504,0.346207,0.026592,0.49882,0.109808,0.467555,0.173864,0.165655,0.327874,0.017592,0.339709,0.338847,...,0.042544,0.054152,0.046265,0.030822,0.056922,0.051231,0.028204,0.06519,0.0329,0.043568,0.039032,0.004702,0.020493,0.024876,0.019947,0.056922,0.035169,0.028204,0.046973,0.047671,0.04432,0.021545,0.139892,0.032563,0.022051,0.03153,0.078861,0.042285,0.037002,0.062088,0.021026,0.047207,0.034854,0.025316,0.023971,0.281698,0.042803,0.022547,0.468362,0.431766
min,0.0,0.0,17.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,1.0,28.0,9.0,0.0,40.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
50%,0.0,1.0,37.0,10.0,0.0,40.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0
75%,0.0,1.0,47.0,13.0,0.0,45.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0
max,99999.0,1.0,90.0,16.0,4356.0,99.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [3]:
"""Load Datasets"""

dataset_kwargs = {
    'y_col': 'label',
    'train_size': 0.6,
    'val_size': 0.2,
    'test_size': 0.2,
    'concat_train_val': True,
    'normalize': True,
    'random_state': 42,
}

all_datasets = {
    'compas': ['sex', 'race'],
    'adult': ['gender', 'race'],
    'bank': ['age', 'marital=married'],
    # 'lsa_unfair_gender_race': ['gender', 'race'],
}

"""
Create a dictionary of datasets: dataset_zoo
key: dataset name
value: FairDataset object
"""
dataset_zoo = {}
for dataname, s_attrs in all_datasets.items():
    for s_attr in s_attrs:
        dataset = FairDataset(
            dataname=dataname,
            csv_path=f'./data/{dataname}.csv',
            s_col=s_attr,
            **dataset_kwargs
        )
        dataset_zoo[dataset.fullname] = dataset

        # dataset.describe()
        dataset.brief()

# Print the information of the datasets and models
print(
    f"////// Dataset ZOO //////\n"
    f"{dict_info(dataset_zoo)}\n"
)

dataset_zoo_subset = {
    'compas_sex': dataset_zoo['compas_sex'],
    'compas_race': dataset_zoo['compas_race'],
}

      sex  MarriageStatus       age  race  juv_fel_count  juv_misd_count  \
5485    1        0.166667  0.138462   1.0            0.0        0.000000   
502     1        0.000000  0.107692   1.0            0.0        0.000000   
4233    1        0.000000  0.030769   0.0            0.0        0.230769   
5795    1        0.000000  0.446154   1.0            0.0        0.000000   
1892    1        0.000000  0.169231   0.0            0.0        0.000000   
...   ...             ...       ...   ...            ...             ...   
872     0        1.000000  0.707692   0.0            0.0        0.000000   
5495    1        0.000000  0.092308   1.0            0.0        0.000000   
511     1        1.000000  0.569231   1.0            0.0        0.000000   
2674    1        0.000000  0.123077   0.0            0.0        0.000000   
4763    1        0.000000  0.476923   1.0            0.0        0.000000   

      juv_other_count  priors_count  days_b_screening_arrest  \
5485              0.0  

       gender  capital-gain  race       age  education-num  capital-loss  \
19818     1.0      0.000000   1.0  0.205479       0.800000      0.000000   
30549     1.0      0.000000   1.0  0.109589       0.600000      0.000000   
40845     1.0      0.064181   1.0  0.698630       0.933333      0.000000   
26919     0.0      0.000000   1.0  0.054795       0.533333      0.000000   
1560      1.0      0.000000   1.0  0.369863       0.533333      0.000000   
...       ...           ...   ...       ...            ...           ...   
29848     1.0      0.000000   1.0  0.273973       0.533333      0.436639   
42494     0.0      0.000000   1.0  0.178082       0.866667      0.000000   
25963     1.0      0.000000   1.0  0.342466       0.066667      0.000000   
17805     0.0      0.000000   0.0  0.397260       0.600000      0.000000   
12170     0.0      0.000000   1.0  0.205479       0.533333      0.000000   

       hours-per-week  workclass=Federal-gov  workclass=Local-gov  \
19818        0.397