# ART Classifier

In [1]:
# import relevant dependencies
import numpy as np
import pandas as pd

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

from aif360.datasets import StandardDataset
from aif360.algorithms.preprocessing.lfr import LFR
from aif360.metrics import ClassificationMetric

In [2]:
def create_dataset(
    X: pd.DataFrame,
    y,
    protected_attribute_name: str
) -> StandardDataset:
    if isinstance(y, np.ndarray):
        y = pd.Series(y.flatten(), index=X.index, name='class')
    return StandardDataset(
        df=pd.concat([X, y], axis=1),
        label_name="class",
        favorable_classes=[1],
        protected_attribute_names=[protected_attribute_name],
        privileged_classes=[[1]],
    )


In [3]:
# fetch raw-data from sklearn.datasets
raw_data = fetch_openml(data_id=1590, as_frame=True)

In [4]:
from sklearn.preprocessing import MinMaxScaler

X_raw = pd.get_dummies(raw_data.data)
X_raw = pd.DataFrame(MinMaxScaler().fit_transform(X_raw), columns=X_raw.columns)
y = 1 * (raw_data.target == ">50K")

X_train, X_test, y_train, y_test = train_test_split(X_raw, y, test_size=0.5, random_state=42)
X_test, X_val, y_test, y_val = train_test_split(X_test, y_test, test_size=0.5, random_state=42)

In [5]:
protected_attribute_name = "sex_Male"

privileged_groups = [{protected_attribute_name: 1.0}]
unprivileged_groups = [{protected_attribute_name: 0.0}]

In [6]:
dataset_train = create_dataset(X_train, y_train, protected_attribute_name)
dataset_test = create_dataset(X_test, y_test, protected_attribute_name)
dataset_val = create_dataset(X_val, y_val, protected_attribute_name)

In [9]:
TR = LFR(unprivileged_groups=unprivileged_groups,
         privileged_groups=privileged_groups,
         k=10, Ax=0.1, Ay=1.0, Az=2.0,
         verbose=1
        )
TR = TR.fit(dataset_train, maxiter=5000, maxfun=5000)

step: 0, loss: 0.7463518403694693, L_x: 0.4790356986487554,  L_y: 0.6835459874199005,  L_z: 0.0074511415423466785
step: 250, loss: 0.7463518253740046, L_x: 0.47903571839609604,  L_y: 0.6835459693945398,  L_z: 0.007451142069927597
step: 500, loss: 0.7463518660846189, L_x: 0.4790357191853597,  L_y: 0.6835460120271091,  L_z: 0.007451141069486922
step: 750, loss: 0.7463518433856515, L_x: 0.4790357096628256,  L_y: 0.6835459891270563,  L_z: 0.007451141646156345
step: 1000, loss: 0.7463518324276022, L_x: 0.479035727000465,  L_y: 0.6835459773348115,  L_z: 0.00745114119637208
step: 1250, loss: 0.6615224785849658, L_x: 0.47974851503706006,  L_y: 0.5989872404984589,  L_z: 0.007280193291400505
step: 1500, loss: 0.6615225139330358, L_x: 0.479748526237971,  L_y: 0.5989872716658331,  L_z: 0.00728019482170286
step: 1750, loss: 0.6615225306769555, L_x: 0.47974853210418245,  L_y: 0.5989872858202235,  L_z: 0.007280195823156849
step: 2000, loss: 0.661522495593003, L_x: 0.47974852929761086,  L_y: 0.5989872

In [12]:
dataset_train_transf = TR.transform(dataset_train)
dataset_test_transf = TR.transform(dataset_test)

In [17]:
dataset_train_transf.convert_to_dataframe()[0]

Unnamed: 0,age,fnlwgt,education-num,capital-gain,capital-loss,hours-per-week,workclass_Private,workclass_Self-emp-not-inc,workclass_Self-emp-inc,workclass_Federal-gov,...,native-country_Nicaragua,native-country_Scotland,native-country_Thailand,native-country_Yugoslavia,native-country_El-Salvador,native-country_Trinadad&Tobago,native-country_Peru,native-country_Hong,native-country_Holand-Netherlands,class
35110,0.588871,0.472571,0.498318,0.322660,0.515189,0.440614,0.602254,0.606695,0.440063,0.586791,...,0.389656,0.311638,0.439499,0.526011,0.419255,0.433437,0.510094,0.566111,0.448939,0.0
22406,0.592583,0.487855,0.498975,0.321447,0.516790,0.441821,0.607525,0.603487,0.420691,0.585368,...,0.386230,0.306513,0.428119,0.527973,0.425134,0.435370,0.530946,0.554109,0.444946,0.0
28007,0.578469,0.463661,0.512717,0.329355,0.507759,0.431395,0.608804,0.603814,0.452122,0.575963,...,0.401303,0.313804,0.443288,0.533835,0.425066,0.426386,0.505161,0.571653,0.448908,0.0
18893,0.593325,0.443734,0.510743,0.330452,0.497694,0.446240,0.580193,0.615678,0.467387,0.600949,...,0.392475,0.334155,0.442105,0.527855,0.417797,0.425466,0.492280,0.576532,0.468018,0.0
16324,0.614607,0.471106,0.522748,0.333858,0.510034,0.445810,0.598607,0.623753,0.453634,0.607022,...,0.371321,0.326516,0.436336,0.529775,0.407984,0.425780,0.506895,0.574242,0.467781,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
11284,0.579486,0.469301,0.497168,0.316157,0.510224,0.439281,0.595268,0.616710,0.437781,0.596658,...,0.392773,0.314221,0.435989,0.524738,0.431707,0.442390,0.505861,0.558863,0.432475,0.0
44732,0.591793,0.464620,0.508310,0.328589,0.507615,0.441136,0.599061,0.615188,0.456989,0.588236,...,0.393999,0.317542,0.442868,0.524749,0.414248,0.432488,0.500653,0.574487,0.450666,0.0
38158,0.593812,0.484083,0.521117,0.339692,0.505689,0.438370,0.609475,0.613837,0.439416,0.593748,...,0.384861,0.313575,0.440852,0.527879,0.422461,0.420673,0.511186,0.578823,0.449563,0.0
860,0.593705,0.465676,0.518515,0.337605,0.489935,0.457668,0.576577,0.625663,0.448972,0.609695,...,0.393089,0.332959,0.436590,0.523387,0.425151,0.429008,0.508453,0.574684,0.451624,0.0


In [18]:
dataset_train.convert_to_dataframe()[0]

Unnamed: 0,age,fnlwgt,education-num,capital-gain,capital-loss,hours-per-week,workclass_Private,workclass_Self-emp-not-inc,workclass_Self-emp-inc,workclass_Federal-gov,...,native-country_Nicaragua,native-country_Scotland,native-country_Thailand,native-country_Yugoslavia,native-country_El-Salvador,native-country_Trinadad&Tobago,native-country_Peru,native-country_Hong,native-country_Holand-Netherlands,class
35110,0.136986,0.118706,0.600000,0.0,0.000000,0.397959,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
22406,0.315068,0.186520,0.800000,0.0,0.518365,0.479592,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
28007,0.315068,0.055337,0.733333,0.0,0.000000,0.397959,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
18893,0.191781,0.070625,0.533333,0.0,0.000000,0.397959,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
16324,0.438356,0.055715,0.533333,0.0,0.000000,0.397959,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
11284,0.260274,0.421908,0.533333,0.0,0.000000,0.397959,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
44732,0.232877,0.064500,0.600000,0.0,0.000000,0.500000,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
38158,0.150685,0.160914,0.600000,0.0,0.000000,0.397959,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
860,0.397260,0.057077,0.666667,0.0,0.000000,0.397959,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [None]:
from sklearn.linear_model import LogisticRegression

LR = LogisticRegression(solver="liblinear", random_state=42)
LR.fit(X_train, y_train)
y_val_pred_unmit = LR.predict(X_val).reshape(-1,1)

In [120]:
metric_unmit = ClassificationMetric(
    dataset_test,
    dataset_test_pred_unmit,
    unprivileged_groups=unprivileged_groups,
    privileged_groups=privileged_groups
)

metric_mit = ClassificationMetric(
    dataset_test,
    dataset_test_pred_mit,
    unprivileged_groups=unprivileged_groups,
    privileged_groups=privileged_groups
)

In [121]:
metric_unmit.disparate_impact()

0.28604007717463587

In [122]:
metric_mit.disparate_impact()

0.2864563199261373

In [123]:
metric_unmit.accuracy()

0.846928746928747

In [124]:
metric_mit.accuracy()

0.8504504504504504