# Zemel et al. pre-processing fairness intervention

Zemel et al. (2013) proposes a clustering method which transforms the original data set by expressing points as linear combinations of learnt cluster centres. The transformed data set is as close as possible to the original while containing as little information as possible about the sensitive attributes. Thereby, demographic parity is achieved.

The output of their method includes besides a fair data representation also fair label predictions, which allows the comparison according to the usual fairness metrics. We apply their approach as implemented by IBM's AIF360 fairness tool box.

In [None]:
import numpy as np
import pandas as pd
import joblib

from aif360.algorithms.preprocessing.lfr import LFR
from aif360.datasets import StandardDataset

from helpers.fairness_measures import *

import plotly.express as px
import plotly.graph_objs as go
from cdei_helpers.plot import group_box_plots, group_roc_curves

## Load data

In [None]:
train = pd.read_csv(
    "/project/data/synthetic/processed/train_processed.csv"
).sample(6000)
test = pd.read_csv(
    "/project/data/synthetic/processed/test_processed.csv"
).sample(2000)
val = pd.read_csv(
    "/project/data/synthetic/processed/val_processed.csv"
).sample(2000)

## Set up fairness intervention
AIF360 requires expressing the original data sets via the "StandardDataset" class.

In [None]:
train_sds = StandardDataset(
    train,
    label_name="employed_yes",
    favorable_classes=[1],
    protected_attribute_names=["race_white"],
    privileged_classes=[[1]],
)
test_sds = StandardDataset(
    test,
    label_name="employed_yes",
    favorable_classes=[1],
    protected_attribute_names=["race_white"],
    privileged_classes=[[1]],
)
val_sds = StandardDataset(
    val,
    label_name="employed_yes",
    favorable_classes=[1],
    protected_attribute_names=["race_white"],
    privileged_classes=[[1]],
)
index = train_sds.feature_names.index("race_white")

In [None]:
privileged_groups = [{"race_white": 1.0}]
unprivileged_groups = [{"race_white": 0.0}]

## Learn fair representation
The hyperparameters $A_x, A_y, A_z$ and $k$ are chosen how? According to optimum of a grid search for Adult.

In [None]:
TR = LFR(
    unprivileged_groups=unprivileged_groups,
    privileged_groups=privileged_groups,
    k=5,
    Ax=0.01,
    Ay=1.0,
    Az=1500.0,
)
TR = TR.fit(train_sds)  # , maxiter=500, maxfun=500)

### Apply transformation to validation data

In [None]:
transf_val_sds = TR.transform(val_sds)

In [None]:
acc = accuracy(transf_val_sds.labels.flatten(), val.employed_yes)
print("Accuracy after fairness intervention =", acc)

## Evaluate fairness and accuracy

In [None]:
# Fair labels
val_fair_labels = transf_val_sds.labels.flatten()

In [None]:
print("Accuracy =", accuracy(val_fair_labels, val.employed_yes))
print(
    "Black accuracy =",
    accuracy(
        val_fair_labels[val.race_white == 0],
        val.employed_yes[val.race_white == 0],
    ),
)
print(
    "White accuracy =",
    accuracy(
        val_fair_labels[val.race_white == 1],
        val.employed_yes[val.race_white == 1],
    ),
)
print(
    "Mean black score =", val_fair_labels[val.race_white == 0].mean(),
)
print(
    "Mean white score =", val_fair_labels[val.race_white == 1].mean(),
)

### Demographic parity

In [None]:
dp_d = disparate_impact_d(val_fair_labels, val.race_white)
print("Sex demographic parity =", dp_d)

In [None]:
go.Figure(
    data=[
        go.Bar(
            x=[race],
            y=[val_fair_labels[val.race_white == race].mean()],
            name="White" if race else "Black",
        )
        for race in range(2)
    ]
)

### Equalised Odds

In [None]:
# eo_d = equalised_odds_d(val_fair_labels, val.race_white, val.employed_yes)
# print("Sex demographic parity =", eo_d)

In [None]:
# go.Figure(
#     data=[
#         go.Bar(
#             x=[label],
#             y=[
#                 val_fair_labels[
#                     (val.race_white == race) & (val.employed_yes == label)
#                 ].mean()
#             ],
#             name="White" if race else "Black",
#         )
#         for label in range(2)
#         for race in range(2)
#     ]
# )