# Zemel et al. pre-processing fairness intervention

Zemel et al. (2013) proposes a clustering method which transforms the original data set by expressing points as linear combinations of learnt cluster centres. The transformed data set is as close as possible to the original while containing as little information as possible about the sensitive attributes. Thereby, demographic parity is achieved.

The output of their method includes besides a fair data representation also fair label predictions, which allows the comparison according to the usual fairness metrics. We apply their approach as implemented by IBM's AIF360 fairness tool box.

In [None]:
from pathlib import Path

import joblib
import numpy as np
import pandas as pd
import plotly.graph_objs as go
from aif360.algorithms.preprocessing.lfr import LFR
from aif360.datasets import StandardDataset
from helpers.fairness_measures import *
from helpers.finance import preprocess
from helpers.plot import group_box_plots, group_roc_curves

In [None]:
from helpers import export_plot

## Load data

In [None]:
artifacts_dir = Path("../../../artifacts")

In [None]:
# override data_dir in source notebook
# this is stripped out for the hosted notebooks
artifacts_dir = Path("../../../../artifacts")

In [None]:
data_dir = artifacts_dir / "data" / "adult"
preprocess(data_dir)

In [None]:
train = pd.read_csv(data_dir / "processed" / "train-one-hot.csv")
val = pd.read_csv(data_dir / "processed" / "val-one-hot.csv")
test = pd.read_csv(data_dir / "processed" / "test-one-hot.csv")

## Set up fairness intervention
AIF360 requires expressing the original data sets via the "StandardDataset" class.

In [None]:
train_sds = StandardDataset(
    train,
    label_name="salary",
    favorable_classes=[1],
    protected_attribute_names=["sex"],
    privileged_classes=[[1]],
)
test_sds = StandardDataset(
    test,
    label_name="salary",
    favorable_classes=[1],
    protected_attribute_names=["sex"],
    privileged_classes=[[1]],
)
val_sds = StandardDataset(
    val,
    label_name="salary",
    favorable_classes=[1],
    protected_attribute_names=["sex"],
    privileged_classes=[[1]],
)
index = train_sds.feature_names.index("sex")

In [None]:
privileged_groups = [{"sex": 1.0}]
unprivileged_groups = [{"sex": 0.0}]

In [None]:
val.head()

## Learn fair representation
The hyperparameters $A_x, A_y, A_z$ and $k$ are chosen how? According to optimum of a grid search for Adult.

NB: Have the number of maximal iterations altered compared to 

In [None]:
TR = joblib.load(artifacts_dir / "models" / "finance" / "zemel-sex.pkl")

In [None]:
# TR = LFR(
#     unprivileged_groups=unprivileged_groups,
#     privileged_groups=privileged_groups,
#     k=5,
#     Ax=0.01,
#     Ay=1.0,
#     Az=25.0,
# )
# TR = TR.fit(train_sds)

### Apply transformation to validation data

In [None]:
transf_val_sds = TR.transform(val_sds)

In [None]:
acc = accuracy(transf_val_sds.labels.flatten(), val_sds.labels.flatten())
print("Accuracy after fairness intervention =", acc)

## Evaluate fairness

In [None]:
# Fair labels
val_fair_labels = transf_val_sds.labels.flatten()

In [None]:
print("Accuracy =", accuracy(val_fair_labels, val.salary))
print(
    "Female accuracy =",
    accuracy(val_fair_labels[val.sex == 0], val.salary[val.sex == 0]),
)
print(
    "Male accuracy =",
    accuracy(val_fair_labels[val.sex == 1], val.salary[val.sex == 1]),
)
print("Mean female score =", val_fair_labels[val.sex == 0].mean())
print("Mean male score =", val_fair_labels[val.sex == 1].mean())

### Demographic parity

In [None]:
dp_d = disparate_impact_d(val_fair_labels, val.sex)
print("Sex demographic parity =", dp_d)

In [None]:
dp_bar = go.Figure(
    data=[
        go.Bar(
            x=[sex],
            y=[val_fair_labels[val.sex == sex].mean()],
            name="Male" if sex else "Female",
        )
        for sex in range(2)
    ],
    layout={"yaxis": {"range": [0, 1]}},
)
dp_bar

In [None]:
export_plot(dp_bar, "zemel-sex-dp.json")

### Equalised Odds

In [None]:
# go.Figure(
#     data=[
#         go.Bar(
#             x=[label],
#             y=[
#                 val_fair_labels[
#                     (val.sex == sex) & (val.salary == label)
#                 ].mean()
#             ],
#             name="Male" if sex else "Female",
#         )
#         for label in range(2)
#         for sex in range(2)
#     ]
# )