In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go

from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler

from cdei_helpers.plot import group_box_plots, group_roc_curves
from cdei_helpers.fairness_measures import *

from aif360.datasets import StandardDataset
from aif360.algorithms.preprocessing.reweighing import Reweighing

## Load data

In [None]:
train = pd.read_csv("/project/data/adult/processed/train-one-hot.csv").sample(
    6000
)
test = pd.read_csv("/project/data/adult/processed/test-one-hot.csv").sample(
    2000
)
val = pd.read_csv("/project/data/adult/processed/val-one-hot.csv").sample(2000)

In [None]:
train_sds = StandardDataset(
    train,
    label_name="salary",
    favorable_classes=[1],
    protected_attribute_names=["sex"],
    privileged_classes=[[1]],
)
test_sds = StandardDataset(
    test,
    label_name="salary",
    favorable_classes=[1],
    protected_attribute_names=["sex"],
    privileged_classes=[[1]],
)
val_sds = StandardDataset(
    val,
    label_name="salary",
    favorable_classes=[1],
    protected_attribute_names=["sex"],
    privileged_classes=[[1]],
)

In [None]:
privileged_groups = [{"sex": 1.0}]
unprivileged_groups = [{"sex": 0.0}]

## Train original model

In [None]:
model = LogisticRegression(max_iter=10000)
X_train = train_sds.features
y_train = train_sds.labels.flatten()
model.fit(X_train, y_train)

In [None]:
val_scores = model.predict_proba(val.drop("salary", axis=1))[:, 1]

In [None]:
print("Original model accuracy =", accuracy(val_scores, val.salary))
print(
    "Female accuracy =",
    accuracy(val_scores[val.sex == 0], val.salary[val.sex == 0]),
)
print(
    "Male accuracy =",
    accuracy(val_scores[val.sex == 1], val.salary[val.sex == 1]),
)
print("Mean female score =", val_scores[val.sex == 0].mean())
print("Mean male score =", val_scores[val.sex == 1].mean())

## Perform intervention

### Train with and transform the original training data

In [None]:
RW = Reweighing(
    unprivileged_groups=unprivileged_groups,
    privileged_groups=privileged_groups,
)
RW.fit(train_sds)
train_sds_transf = RW.transform(train_sds)

### Train model with transformed training data

In [None]:
model_fair = LogisticRegression(max_iter=10000)
X_train = train_sds_transf.features
y_train = train_sds_transf.labels.flatten()
model_fair.fit(
    X_train, y_train, sample_weight=train_sds_transf.instance_weights
)

### Predict fairly on validation set
Note that the pre-processing intervention of the validation data happens in the model prediction since the model has been based on the weighting which was determined by the reweight transformed training data. 

In [None]:
val_sds_pred = val_sds.copy(deepcopy=True)
X_val = val_sds_pred.features
y_val = val_sds.labels
val_sds_pred.scores = model_fair.predict_proba(X_val)[:, 1].reshape(-1, 1)

## Analyse fairness and accuracy

In [None]:
print("Accuracy =", accuracy(val_sds_pred.scores.flatten(), val.salary))

print(
    "Female accuracy =",
    accuracy(
        val_sds_pred.scores.flatten()[val.sex == 0], val.salary[val.sex == 0],
    ),
)
print(
    "Male accuracy =",
    accuracy(
        val_sds_pred.scores.flatten()[val.sex == 1], val.salary[val.sex == 1],
    ),
)
print(
    "Mean female score =", val_sds_pred.scores.flatten()[val.sex == 0].mean(),
)
print(
    "Mean male score =", val_sds_pred.scores.flatten()[val.sex == 1].mean(),
)

### Plots

In [None]:
go.Figure(
    data=[
        go.Box(
            x=[sex] * (val.sex == sex).sum(),
            y=val_sds_pred.scores.flatten()[val.sex == sex],
            name="Male" if sex else "Female",
        )
        for sex in range(2)
    ]
)

In [None]:
go.Figure(
    data=[
        go.Bar(
            x=[sex],
            y=[val_sds_pred.scores.flatten()[val.sex == sex].mean()],
            name="Male" if sex else "Female",
        )
        for sex in range(2)
    ]
)