In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go

from sklearn.neural_network import MLPClassifier

from cdei_helpers.plot import group_box_plots, group_roc_curves
from cdei_helpers.fairness_measures import *

from aif360.datasets import StandardDataset
from aif360.algorithms.preprocessing import DisparateImpactRemover

## Load data

In [None]:
train = pd.read_csv("/project/data/adult/processed/train-one-hot.csv").sample(
    2000
)
test = pd.read_csv("/project/data/adult/processed/test-one-hot.csv").sample(
    2000
)
val = pd.read_csv("/project/data/adult/processed/val-one-hot.csv").sample(6000)

In [None]:
train_sds = StandardDataset(
    train,
    label_name="salary",
    favorable_classes=[1],
    protected_attribute_names=["sex"],
    privileged_classes=[[1]],
)
test_sds = StandardDataset(
    test,
    label_name="salary",
    favorable_classes=[1],
    protected_attribute_names=["sex"],
    privileged_classes=[[1]],
)
val_sds = StandardDataset(
    val,
    label_name="salary",
    favorable_classes=[1],
    protected_attribute_names=["sex"],
    privileged_classes=[[1]],
)
index = train_sds.feature_names.index("sex")

## Perform intervention

In [None]:
di = DisparateImpactRemover(repair_level=1.0)

In [None]:
train_repd = di.fit_transform(train_sds)
train_repd_X = np.delete(train_repd.features, index, axis=1)
train_repd_y = train_repd.labels.flatten()

In [None]:
val_repd = di.fit_transform(val_sds)
val_repd_X = np.delete(val_repd.features, index, axis=1)
val_repd_y = val_repd.labels.flatten()

## Train model on fair data

In [None]:
model = MLPClassifier(hidden_layer_sizes=(100, 100), early_stopping=True,)
model.fit(train_repd_X, train_repd_y)

In [None]:
val_scores = model.predict_proba(val_repd_X)[:, 1]

## Analyse unfairness and accuracy

In [None]:
print("Accuracy =", accuracy(val_scores, val.salary))
print(
    "Female accuracy =",
    accuracy(val_scores[val.sex == 0], val.salary[val.sex == 0]),
)
print(
    "Male accuracy =",
    accuracy(val_scores[val.sex == 1], val.salary[val.sex == 1]),
)
print("Mean female score =", val_scores[val.sex == 0].mean())
print("Mean male score =", val_scores[val.sex == 1].mean())

### Plots

In [None]:
go.Figure(
    data=[
        go.Box(
            x=[sex] * (val.sex == sex).sum(),
            y=val_scores[val.sex == sex],
            name="Male" if sex else "Female",
        )
        for sex in range(2)
    ]
)

In [None]:
go.Figure(
    data=[
        go.Bar(
            x=[sex],
            y=[val_scores[val.sex == sex].mean()],
            name="Male" if sex else "Female",
        )
        for sex in range(2)
    ]
)