Permalink
Fetching contributors…
Cannot retrieve contributors at this time
128 lines (102 sloc) 4.9 KB
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
from aif360.algorithms import Transformer
from aif360.metrics import utils
class Reweighing(Transformer):
"""Reweighing is a preprocessing technique that Weights the examples in each
(group, label) combination differently to ensure fairness before
classification [4]_.
References:
.. [4] F. Kamiran and T. Calders, "Data Preprocessing Techniques for
Classification without Discrimination," Knowledge and Information
Systems, 2012.
"""
def __init__(self, unprivileged_groups, privileged_groups):
"""
Args:
unprivileged_groups (list(dict)): Representation for unprivileged
group.
privileged_groups (list(dict)): Representation for privileged group.
"""
super(Reweighing, self).__init__(
unprivileged_groups=unprivileged_groups,
privileged_groups=privileged_groups)
self.unprivileged_groups = unprivileged_groups
self.privileged_groups = privileged_groups
self.w_p_fav = 1.
self.w_p_unfav = 1.
self.w_up_fav = 1.
self.w_up_unfav = 1.
def fit(self, dataset):
"""Compute the weights for reweighing the dataset.
Args:
dataset (BinaryLabelDataset): Dataset containing true labels.
Returns:
Reweighing: Returns self.
"""
(priv_cond, unpriv_cond, fav_cond, unfav_cond,
cond_p_fav, cond_p_unfav, cond_up_fav, cond_up_unfav) =\
self._obtain_conditionings(dataset)
n = np.sum(dataset.instance_weights, dtype=np.float64)
n_p = np.sum(dataset.instance_weights[priv_cond], dtype=np.float64)
n_up = np.sum(dataset.instance_weights[unpriv_cond], dtype=np.float64)
n_fav = np.sum(dataset.instance_weights[fav_cond], dtype=np.float64)
n_unfav = np.sum(dataset.instance_weights[unfav_cond], dtype=np.float64)
n_p_fav = np.sum(dataset.instance_weights[cond_p_fav], dtype=np.float64)
n_p_unfav = np.sum(dataset.instance_weights[cond_p_unfav],
dtype=np.float64)
n_up_fav = np.sum(dataset.instance_weights[cond_up_fav],
dtype=np.float64)
n_up_unfav = np.sum(dataset.instance_weights[cond_up_unfav],
dtype=np.float64)
# reweighing weights
self.w_p_fav = n_fav*n_p / (n*n_p_fav)
self.w_p_unfav = n_unfav*n_p / (n*n_p_unfav)
self.w_up_fav = n_fav*n_up / (n*n_up_fav)
self.w_up_unfav = n_unfav*n_up / (n*n_up_unfav)
return self
def transform(self, dataset):
"""Transform the dataset to a new dataset based on the estimated
transformation.
Args:
dataset (BinaryLabelDataset): Dataset that needs to be transformed.
Returns:
dataset (BinaryLabelDataset): Transformed dataset.
"""
dataset_transformed = dataset.copy(deepcopy=True)
(_, _, _, _, cond_p_fav, cond_p_unfav, cond_up_fav, cond_up_unfav) =\
self._obtain_conditionings(dataset)
# apply reweighing
dataset_transformed.instance_weights[cond_p_fav] *= self.w_p_fav
dataset_transformed.instance_weights[cond_p_unfav] *= self.w_p_unfav
dataset_transformed.instance_weights[cond_up_fav] *= self.w_up_fav
dataset_transformed.instance_weights[cond_up_unfav] *= self.w_up_unfav
return dataset_transformed
##############################
#### Supporting functions ####
##############################
def _obtain_conditionings(self, dataset):
"""Obtain the necessary conditioning boolean vectors to compute
instance level weights.
"""
# conditioning
priv_cond = utils.compute_boolean_conditioning_vector(
dataset.protected_attributes,
dataset.protected_attribute_names,
condition=self.privileged_groups)
unpriv_cond = utils.compute_boolean_conditioning_vector(
dataset.protected_attributes,
dataset.protected_attribute_names,
condition=self.unprivileged_groups)
fav_cond = dataset.labels.ravel() == dataset.favorable_label
unfav_cond = dataset.labels.ravel() == dataset.unfavorable_label
# combination of label and privileged/unpriv. groups
cond_p_fav = np.logical_and(fav_cond, priv_cond)
cond_p_unfav = np.logical_and(unfav_cond, priv_cond)
cond_up_fav = np.logical_and(fav_cond, unpriv_cond)
cond_up_unfav = np.logical_and(unfav_cond, unpriv_cond)
return (priv_cond, unpriv_cond, fav_cond, unfav_cond,
cond_p_fav, cond_p_unfav, cond_up_fav, cond_up_unfav)