In [None]:
import numpy as np
import pandas as pd
import joblib
import plotly.express as px
import plotly.graph_objs as go

from cdei_helpers.plot import group_box_plots, group_roc_curves
from cdei_helpers.fairness_measures import *

from aif360.datasets import StandardDataset
from aif360.algorithms.postprocessing.eq_odds_postprocessing import (
    EqOddsPostprocessing,
)

## Load data

In [None]:
train = pd.read_csv("/project/data/adult/processed/train-one-hot.csv").sample(
    2000
)
test = pd.read_csv("/project/data/adult/processed/test-one-hot.csv").sample(
    2000
)
val = pd.read_csv("/project/data/adult/processed/val-one-hot.csv").sample(6000)

In [None]:
train_sds = StandardDataset(
    train,
    label_name="salary",
    favorable_classes=[1],
    protected_attribute_names=["sex"],
    privileged_classes=[[1]],
)
test_sds = StandardDataset(
    test,
    label_name="salary",
    favorable_classes=[1],
    protected_attribute_names=["sex"],
    privileged_classes=[[1]],
)
val_sds = StandardDataset(
    val,
    label_name="salary",
    favorable_classes=[1],
    protected_attribute_names=["sex"],
    privileged_classes=[[1]],
)
index = train_sds.feature_names.index("sex")

In [None]:
privileged_groups = [{"sex": 1.0}]
unprivileged_groups = [{"sex": 0.0}]

## Load original model

In [None]:
model = joblib.load("/project/data/adult/baseline.pkl")
val_pred_labels = model.predict(val.drop("salary", axis=1))

In [None]:
val_sds_pred = val_sds.copy(deepcopy=True)
val_sds_pred.labels = val_pred_labels.reshape(-1, 1)

## Perform intervention

In [None]:
# Learn parameters to equalize odds and apply to create a new dataset
eopp = EqOddsPostprocessing(
    privileged_groups=privileged_groups,
    unprivileged_groups=unprivileged_groups,
    seed=np.random.seed(),
)
eopp = eopp.fit(val_sds, val_sds_pred)
val_sds_pred_tranf = eopp.predict(val_sds_pred)

In [None]:
val_sds_pred_tranf.scores = val_sds_pred_tranf.labels

## Analyse accuracy and fairness

In [None]:
fnr = np.abs(
    val_sds_pred_tranf.scores[(val.salary == 1) & (val.sex == 1)].mean()
    - val_sds_pred_tranf.scores[(val.salary == 1) & (val.sex == 0)].mean()
)
fpr = np.abs(
    val_sds_pred_tranf.scores[(val.salary == 0) & (val.sex == 1)].mean()
    - val_sds_pred_tranf.scores[(val.salary == 0) & (val.sex == 0)].mean()
)

In [None]:
print("Accuracy =", accuracy(val_sds_pred_tranf.scores.flatten(), val.salary))
print(
    "Female accuracy =",
    accuracy(
        val_sds_pred_tranf.scores.flatten()[val.sex == 0],
        val.salary[val.sex == 0],
    ),
)
print(
    "Male accuracy =",
    accuracy(
        val_sds_pred_tranf.scores.flatten()[val.sex == 1],
        val.salary[val.sex == 1],
    ),
)
print("FNR =", fnr)
print("FPR =", fpr)

In [None]:
print(
    "Equalised odds = ",
    separation_p(val_sds_pred_tranf.scores.flatten(), val.sex, val.salary),
)

### Plots

In [None]:
group_roc_curves(val.salary, val_sds_pred_tranf.labels, val.sex)

In [None]:
go.Figure(
    data=[
        go.Bar(
            x=[label],
            y=[
                val_sds_pred_tranf.scores[
                    (val.sex == sex) & (val.salary == label)
                ].mean()
            ],
            name="Male" if sex else "Female",
        )
        for label in range(2)
        for sex in range(2)
    ]
)