In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go

from sklearn.neural_network import MLPClassifier

from cdei_helpers.plot import group_box_plots, group_roc_curves
from cdei_helpers.fairness_measures import *

from aif360.datasets import StandardDataset
from aif360.algorithms.postprocessing.reject_option_classification import (
    RejectOptionClassification,
)

## Load data

In [None]:
train = pd.read_csv("/project/data/adult/processed/train-one-hot.csv").sample(
    6000
)
test = pd.read_csv("/project/data/adult/processed/test-one-hot.csv").sample(
    2000
)
val = pd.read_csv("/project/data/adult/processed/val-one-hot.csv").sample(2000)

In [None]:
train_sds = StandardDataset(
    train,
    label_name="salary",
    favorable_classes=[1],
    protected_attribute_names=["sex"],
    privileged_classes=[[1]],
)
test_sds = StandardDataset(
    test,
    label_name="salary",
    favorable_classes=[1],
    protected_attribute_names=["sex"],
    privileged_classes=[[1]],
)
val_sds = StandardDataset(
    val,
    label_name="salary",
    favorable_classes=[1],
    protected_attribute_names=["sex"],
    privileged_classes=[[1]],
)

In [None]:
privileged_groups = [{"sex": 1.0}]
unprivileged_groups = [{"sex": 0.0}]

## Train original model

In [None]:
model = MLPClassifier(hidden_layer_sizes=(100, 100), early_stopping=True)
model.fit(train.drop("salary", axis=1), train.salary)

In [None]:
val_scores = model.predict_proba(val.drop("salary", axis=1))[:, 1]

In [None]:
print("Original model accuracy =", accuracy(val_scores, val.salary))
print(
    "Female accuracy =",
    accuracy(val_scores[val.sex == 0], val.salary[val.sex == 0]),
)
print(
    "Male accuracy =",
    accuracy(val_scores[val.sex == 1], val.salary[val.sex == 1]),
)
print("Mean female score =", val_scores[val.sex == 0].mean())
print("Mean male score =", val_scores[val.sex == 1].mean())

In [None]:
val_sds_pred = val_sds.copy(deepcopy=True)
val_sds_pred.scores = val_scores.reshape(-1, 1)

## Perform intervention

### Find best threshold for classification only

In [None]:
from aif360.metrics import ClassificationMetric

num_thresh = 100
ba_arr = np.zeros(num_thresh)
class_thresh_arr = np.linspace(0.01, 0.99, num_thresh)
for idx, class_thresh in enumerate(class_thresh_arr):

    fav_inds = val_sds_pred.scores > class_thresh
    val_sds_pred.labels[fav_inds] = val_sds_pred.favorable_label
    val_sds_pred.labels[~fav_inds] = val_sds_pred.unfavorable_label

    classified_metric_orig_valid = ClassificationMetric(
        val_sds,
        val_sds_pred,
        unprivileged_groups=unprivileged_groups,
        privileged_groups=privileged_groups,
    )

    ba_arr[idx] = 0.5 * (
        classified_metric_orig_valid.true_positive_rate()
        + classified_metric_orig_valid.true_negative_rate()
    )

best_ind = np.where(ba_arr == np.max(ba_arr))[0][0]
best_class_thresh = class_thresh_arr[best_ind]

print(
    "Best balanced accuracy (no fairness constraints) = %.4f" % np.max(ba_arr)
)
print(
    "Optimal classification threshold (no fairness constraints) = %.4f"
    % best_class_thresh
)


fav_inds = val_sds_pred.scores > 0.5
val_sds_pred.labels[fav_inds] = val_sds_pred.favorable_label
val_sds_pred.labels[~fav_inds] = val_sds_pred.unfavorable_label

print( 'Accurac'accuracy(val_sds_pred.labels, val.salary)

In [None]:
# Metric used (should be one of allowed_metrics)
metric_name = "Statistical parity difference"

# Upper and lower bound on the fairness metric used
metric_ub = 0.05
metric_lb = -0.05

### Estimate optimal parameters in ROC

In [None]:
ROC = RejectOptionClassification(
    unprivileged_groups=unprivileged_groups,
    privileged_groups=privileged_groups,
    low_class_thresh=0.01,
    high_class_thresh=0.99,
    num_class_thresh=100,
    num_ROC_margin=50,
    metric_name=metric_name,
    metric_ub=metric_ub,
    metric_lb=metric_lb,
)
ROC = ROC.fit(val_sds, val_sds_pred)

In [None]:
print(
    "Optimal classification threshold (with fairness constraints) = %.4f"
    % ROC.classification_threshold
)
print("Optimal ROC margin = %.4f" % ROC.ROC_margin)

### Predictions from validation set

In [None]:
# Metrics for the test set
fav_inds = val_sds_pred.scores > best_class_thresh
val_sds_pred.labels[fav_inds] = val_sds_pred.favorable_label
val_sds_pred.labels[~fav_inds] = val_sds_pred.unfavorable_label

## Apply intervention

In [None]:
# Transform the validation set
val_sds_pred_transf = ROC.predict(val_sds_pred).copy(deepcopy=True)

## Analyse fairness and accuracy

In [None]:
print("Accuracy =", accuracy(val_sds_pred_transf.labels.flatten(), val.salary))
print(
    "Female accuracy =",
    accuracy(
        val_sds_pred_transf.labels.flatten()[val.sex == 0],
        val.salary[val.sex == 0],
    ),
)
print(
    "Male accuracy =",
    accuracy(
        val_sds_pred_transf.labels.flatten()[val.sex == 1],
        val.salary[val.sex == 1],
    ),
)
print(
    "Mean female score =",
    val_sds_pred_transf.labels.flatten()[val.sex == 0].mean(),
)
print(
    "Mean male score =",
    val_sds_pred_transf.labels.flatten()[val.sex == 1].mean(),
)

### Plots

In [None]:
go.Figure(
    data=[
        go.Bar(
            x=[sex],
            y=[val_sds_pred_transf.labels.flatten()[val.sex == sex].mean()],
            name="Male" if sex else "Female",
        )
        for sex in range(2)
    ]
)