In [2]:
%matplotlib inline
# Load all necessary packages
import sys
sys.path.insert(1, "../")  

import numpy as np
np.random.seed(0)
from tqdm import tqdm
from collections import OrderedDict

from aif360.datasets import CompasDataset
from aif360.metrics import BinaryLabelDatasetMetric
from aif360.metrics import ClassificationMetric
from aif360.algorithms.preprocessing import Reweighing

# Scalers
from sklearn.preprocessing import StandardScaler

# Classifiers
from sklearn.linear_model import LogisticRegression

from IPython.display import Markdown, display

# Utilities
from aif360.utils.general_utils import compute_metrics
from aif360.utils.classifier_metrics import ClassifierMetricUtils

import matplotlib.pyplot as plt

# Explainers
from aif360.explainers import MetricTextExplainer

In [4]:
dataset_orig = CompasDataset(
    protected_attribute_names=['sex'],
    privileged_classes=[['Female']],
    features_to_drop=['race', 'age']
)

dataset_orig_train, dataset_orig_val, dataset_orig_test = dataset_orig.split([0.5,0.8], shuffle=True)

privileged_groups = [{'sex': 1}]
unprivileged_groups = [{'sex': 0}]

scale_orig = StandardScaler()
X_train = scale_orig.fit_transform(dataset_orig_train.features)
y_train = dataset_orig_train.labels.ravel()
w_train = dataset_orig_train.instance_weights.ravel()

lmod = LogisticRegression()
lmod.fit(X_train, y_train, 
         sample_weight=dataset_orig_train.instance_weights)
y_train_pred = lmod.predict(X_train)

# positive class index
pos_ind = np.where(lmod.classes_ == dataset_orig_train.favorable_label)[0][0]

dataset_orig_train_pred = dataset_orig_train.copy()
dataset_orig_train_pred.labels = y_train_pred

dataset_orig_valid_pred = dataset_orig_val.copy(deepcopy=True)
X_valid = scale_orig.transform(dataset_orig_valid_pred.features)
y_valid = dataset_orig_valid_pred.labels
dataset_orig_valid_pred.scores = lmod.predict_proba(X_valid)[:,pos_ind].reshape(-1,1)

dataset_orig_test_pred = dataset_orig_test.copy(deepcopy=True)
X_test = scale_orig.transform(dataset_orig_test_pred.features)
y_test = dataset_orig_test_pred.labels
dataset_orig_test_pred.scores = lmod.predict_proba(X_test)[:,pos_ind].reshape(-1,1)



In [5]:
num_thresh = 100
ba_arr = np.zeros(num_thresh)
class_thresh_arr = np.linspace(0.01, 0.99, num_thresh)
for idx, class_thresh in enumerate(class_thresh_arr):
    
    fav_inds = dataset_orig_valid_pred.scores > class_thresh
    dataset_orig_valid_pred.labels[fav_inds] = dataset_orig_valid_pred.favorable_label
    dataset_orig_valid_pred.labels[~fav_inds] = dataset_orig_valid_pred.unfavorable_label
    
    classified_metric_orig_valid = ClassificationMetric(dataset_orig_val,
                                             dataset_orig_valid_pred, 
                                             unprivileged_groups=unprivileged_groups,
                                             privileged_groups=privileged_groups)
    print(classified_metric_orig_valid.binary_confusion_matrix())
    
    ba_arr[idx] = 0.5*(classified_metric_orig_valid.true_positive_rate()\
                       +classified_metric_orig_valid.true_negative_rate())

best_ind = np.where(ba_arr == np.max(ba_arr))[0][0]
best_class_thresh = class_thresh_arr[best_ind]

print("Best balanced accuracy (no reweighing) = %.4f" % np.max(ba_arr))
print("Optimal classification threshold (no reweighing) = %.4f" % best_class_thresh)

{'TP': 975.0, 'FP': 826.0, 'TN': 34.0, 'FN': 15.0}
{'TP': 974.0, 'FP': 822.0, 'TN': 38.0, 'FN': 16.0}
{'TP': 974.0, 'FP': 820.0, 'TN': 40.0, 'FN': 16.0}
{'TP': 974.0, 'FP': 817.0, 'TN': 43.0, 'FN': 16.0}
{'TP': 973.0, 'FP': 811.0, 'TN': 49.0, 'FN': 17.0}
{'TP': 972.0, 'FP': 807.0, 'TN': 53.0, 'FN': 18.0}
{'TP': 971.0, 'FP': 801.0, 'TN': 59.0, 'FN': 19.0}
{'TP': 971.0, 'FP': 796.0, 'TN': 64.0, 'FN': 19.0}
{'TP': 970.0, 'FP': 787.0, 'TN': 73.0, 'FN': 20.0}
{'TP': 969.0, 'FP': 779.0, 'TN': 81.0, 'FN': 21.0}
{'TP': 967.0, 'FP': 774.0, 'TN': 86.0, 'FN': 23.0}
{'TP': 965.0, 'FP': 769.0, 'TN': 91.0, 'FN': 25.0}
{'TP': 963.0, 'FP': 764.0, 'TN': 96.0, 'FN': 27.0}
{'TP': 962.0, 'FP': 757.0, 'TN': 103.0, 'FN': 28.0}
{'TP': 960.0, 'FP': 748.0, 'TN': 112.0, 'FN': 30.0}
{'TP': 958.0, 'FP': 747.0, 'TN': 113.0, 'FN': 32.0}
{'TP': 957.0, 'FP': 741.0, 'TN': 119.0, 'FN': 33.0}
{'TP': 955.0, 'FP': 733.0, 'TN': 127.0, 'FN': 35.0}
{'TP': 953.0, 'FP': 726.0, 'TN': 134.0, 'FN': 37.0}
{'TP': 952.0, 'FP': 721.0