Skip to content

Commit

Permalink
More sklearn-compatible algorithms (Trusted-AI#318)
Browse files Browse the repository at this point in the history
* new ROC and LFR implementations

* allow prot_attr/target input to be Series

* LFR uses pytorch to calculate grad

* change defaults in metaestimators

* clean up tests

* infer proba behavior for PostProcessingMeta from estimator 'requires_proba' tag

* fix prefit, score PostProcessingMeta bugs

* propagate classes_ in metaestimators

Signed-off-by: Samuel Hoffman <hoffman.sc@gmail.com>
  • Loading branch information
hoffmansc committed Jul 1, 2022
1 parent 951d462 commit db843a1
Show file tree
Hide file tree
Showing 38 changed files with 3,389 additions and 201 deletions.
14 changes: 7 additions & 7 deletions aif360/algorithms/postprocessing/reject_option_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,24 +216,24 @@ def fit_predict(self, dataset_true, dataset_pred):
return self.fit(dataset_true, dataset_pred).predict(dataset_pred)

# Function to obtain the pareto frontier
def _get_pareto_frontier(costs, return_mask = True): # <- Fastest for many points
def _get_pareto_frontier(scores, return_mask = True): # <- Fastest for many points
"""
:param costs: An (n_points, n_costs) array
:param scores: An (n_points, n_scores) array
:param return_mask: True to return a mask, False to return integer indices of efficient points.
:return: An array of indices of pareto-efficient points.
If return_mask is True, this will be an (n_points, ) boolean array
Otherwise it will be a (n_efficient_points, ) integer array of indices.
adapted from: https://stackoverflow.com/questions/32791911/fast-calculation-of-pareto-front-in-python
"""
is_efficient = np.arange(costs.shape[0])
n_points = costs.shape[0]
is_efficient = np.arange(scores.shape[0])
n_points = scores.shape[0]
next_point_index = 0 # Next index in the is_efficient array to search for

while next_point_index<len(costs):
nondominated_point_mask = np.any(costs<=costs[next_point_index], axis=1)
while next_point_index<len(scores):
nondominated_point_mask = np.any(scores>=scores[next_point_index], axis=1)
is_efficient = is_efficient[nondominated_point_mask] # Remove dominated points
costs = costs[nondominated_point_mask]
scores = scores[nondominated_point_mask]
next_point_index = np.sum(nondominated_point_mask[:next_point_index])+1

if return_mask:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,6 @@ def fit_transform(self, dataset):
repaired_features = repairer.repair(features)
repaired.features = np.array(repaired_features, dtype=np.float64)
# protected attribute shouldn't change
repaired.features[:, index] = repaired.protected_attributes[:, 0]
repaired.features[:, index] = repaired.protected_attributes[:, repaired.protected_attribute_names.index(self.sensitive_attribute)]

return repaired
2 changes: 1 addition & 1 deletion aif360/algorithms/preprocessing/lfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def transform(self, dataset, threshold=0.5):
dataset_new.labels = transformed_bin_labels
dataset_new.scores = np.array(transformed_labels)

return dataset_new
return dataset_new

def fit_transform(self, dataset, maxiter=5000, maxfun=5000, threshold=0.5):
"""Fit and transform methods sequentially.
Expand Down
47 changes: 31 additions & 16 deletions aif360/sklearn/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def check_already_dropped(labels, dropped_cols, name, dropped_by='numeric_only',
haven't.
Args:
labels (single label or list-like): Column labels to check.
labels (label, pandas.Series, or list-like of labels/Series): Column
labels to check.
dropped_cols (set or pandas.Index): Columns that were already dropped.
name (str): Original arg that triggered the check (e.g. dropcols).
dropped_by (str, optional): Original arg that caused dropped_cols``
Expand All @@ -27,28 +28,38 @@ def check_already_dropped(labels, dropped_cols, name, dropped_by='numeric_only',
Returns:
list: Columns in labels which are not in dropped_cols.
"""
if not is_list_like(labels):
if isinstance(labels, pd.Series) or not is_list_like(labels):
labels = [labels]
str_labels = [c for c in labels if isinstance(c, str)]
already_dropped = dropped_cols.intersection(str_labels)
str_labels = [c for c in labels if not isinstance(c, pd.Series)]
try:
already_dropped = dropped_cols.intersection(str_labels)
if isinstance(already_dropped, pd.MultiIndex):
raise TypeError # list of lists results in MultiIndex
except TypeError as e:
raise TypeError("Only labels or Series are allowed for {}. Got types:\n"
"{}".format(name, [type(c) for c in labels]))
if warn and any(already_dropped):
warnings.warn("Some column labels from `{}` were already dropped by "
"`{}`:\n{}".format(name, dropped_by, already_dropped.tolist()),
ColumnAlreadyDroppedWarning, stacklevel=2)
return [c for c in labels if not isinstance(c, str) or c not in already_dropped]
return [c for c in labels if isinstance(c, pd.Series)
or c not in already_dropped]

def standardize_dataset(df, prot_attr, target, sample_weight=None, usecols=[],
dropcols=[], numeric_only=False, dropna=True):
def standardize_dataset(df, *, prot_attr, target, sample_weight=None,
usecols=[], dropcols=[], numeric_only=False,
dropna=True):
"""Separate data, targets, and possibly sample weights and populate
protected attributes as sample properties.
Args:
df (pandas.DataFrame): DataFrame with features and target together.
prot_attr (single label or list-like): Label or list of labels
corresponding to protected attribute columns. Even if these are
dropped from the features, they remain in the index.
target (single label or list-like): Column label of the target (outcome)
variable.
prot_attr (label, pandas.Series, or list-like of labels/Series): Single
label, Series, or list-like of labels/Series corresponding to
protected attribute columns. Even if these are dropped from the
features, they remain in the index. If a Series is provided, it will
be added to the index but not show up in the features.
target (label, pandas.Series, or list-like of labels/Series): Column
label(s) or values of the target (outcome) variable.
sample_weight (single label, optional): Name of the column containing
sample weights.
usecols (single label or list-like, optional): Column(s) to keep. All
Expand Down Expand Up @@ -77,9 +88,11 @@ def standardize_dataset(df, prot_attr, target, sample_weight=None, usecols=[],
>>> import pandas as pd
>>> from sklearn.linear_model import LinearRegression
>>> df = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=['X', 'y', 'Z'])
>>> train = standardize_dataset(df, prot_attr='Z', target='y')
>>> reg = LinearRegression().fit(*train)
>>> df = pd.DataFrame([[0.5, 1, 1, 0.75], [-0.5, 0, 0, 0.25]],
... columns=['X', 'y', 'Z', 'w'])
>>> train = standardize_dataset(df, prot_attr='Z', target='y',
... sample_weight='w')
>>> reg = LinearRegression().fit(**train._asdict())
>>> import numpy as np
>>> from sklearn.datasets import make_classification
Expand All @@ -105,7 +118,9 @@ def standardize_dataset(df, prot_attr, target, sample_weight=None, usecols=[],
target = check_already_dropped(target, nonnumeric, 'target')
if len(target) == 0:
raise ValueError("At least one target must be present.")
y = pd.concat([df.pop(t) for t in target], axis=1).squeeze() # maybe Series
y = pd.concat([df.pop(t) if not isinstance(t, pd.Series) else
t.set_axis(df.index, inplace=False) for t in target], axis=1)
y = y.squeeze() # maybe Series

# Column-wise drops
orig_cols = df.columns
Expand Down
4 changes: 2 additions & 2 deletions aif360/sklearn/inprocessing/adversarial_debiasing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, prot_attr=None, scope_name='classifier',
entire model (classifier and adversary).
adversary_loss_weight (float or ``None``, optional): If ``None``,
this will use the suggestion from the paper:
:math:`\alpha = \sqrt(global_step)` with inverse time decay on
:math:`\alpha = \sqrt{global\_step}` with inverse time decay on
the learning rate. Otherwise, it uses the provided coefficient
with exponential learning rate decay.
num_epochs (int, optional): Number of epochs for which to train.
Expand Down Expand Up @@ -340,7 +340,7 @@ def predict(self, X):
"""
scores = self.decision_function(X)
if scores.ndim == 1:
indices = (scores > 0).astype(np.int)
indices = (scores > 0).astype(int)
else:
indices = scores.argmax(axis=1)
return self.classes_[indices]
116 changes: 72 additions & 44 deletions aif360/sklearn/metrics/metrics.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import warnings

import numpy as np
import pandas as pd
from sklearn.metrics import make_scorer as _make_scorer, recall_score
from sklearn.metrics import multilabel_confusion_matrix
from sklearn.metrics._classification import _prf_divide, _check_zero_division
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import check_X_y
from sklearn.exceptions import UndefinedMetricWarning, deprecated
from sklearn.utils.validation import column_or_1d
from sklearn.exceptions import deprecated

from aif360.sklearn.utils import check_groups
from aif360.detectors.mdss.ScoringFunctions import BerkJones, Bernoulli
Expand Down Expand Up @@ -77,7 +77,7 @@ def difference(func, y, *args, prot_attr=None, priv_group=1, sample_weight=None,
return func(*unpriv, **kwargs) - func(*priv, **kwargs)

def ratio(func, y, *args, prot_attr=None, priv_group=1, sample_weight=None,
**kwargs):
zero_division='warn', **kwargs):
"""Compute the ratio between unprivileged and privileged subsets for an
arbitrary metric.
Expand All @@ -96,11 +96,15 @@ def ratio(func, y, *args, prot_attr=None, priv_group=1, sample_weight=None,
priv_group (scalar, optional): The label of the privileged group.
sample_weight (array-like, optional): Sample weights passed through to
func.
zero_division ('warn', 0 or 1): Sets the value to return when there is a
zero division. If set to “warn”, this acts as 0, but warnings are
also raised.
**kwargs: Additional keyword args to be passed through to func.
Returns:
scalar: Ratio of metric values for unprivileged and privileged groups.
"""
_check_zero_division(zero_division)
groups, _ = check_groups(y, prot_attr)
idx = (groups == priv_group)
unpriv = map(lambda a: a[~idx], (y,) + args)
Expand All @@ -112,13 +116,14 @@ def ratio(func, y, *args, prot_attr=None, priv_group=1, sample_weight=None,
numerator = func(*unpriv, **kwargs)
denominator = func(*priv, **kwargs)

if denominator == 0:
warnings.warn("The ratio is ill-defined and being set to 0.0 because "
"'{}' for privileged samples is 0.".format(func.__name__),
UndefinedMetricWarning)
return 0.

return numerator / denominator
if func == base_rate:
modifier = 'positive privileged'
elif func == selection_rate:
modifier = 'predicted privileged'
else:
modifier = f'value for {func.__name__} on privileged'
return _prf_divide(np.array([numerator]), np.array([denominator]), 'ratio',
modifier, None, ('ratio',), zero_division).item()


# =========================== SCORER FACTORY =================================
Expand Down Expand Up @@ -151,24 +156,26 @@ def score(y, y_pred, **kwargs):
return scorer

# ================================ HELPERS =====================================
def specificity_score(y_true, y_pred, pos_label=1, sample_weight=None):
def specificity_score(y_true, y_pred, pos_label=1, sample_weight=None,
zero_division='warn'):
"""Compute the specificity or true negative rate.
Args:
y_true (array-like): Ground truth (correct) target values.
y_pred (array-like): Estimated targets as returned by a classifier.
pos_label (scalar, optional): The label of the positive class.
sample_weight (array-like, optional): Sample weights.
zero_division ('warn', 0 or 1): Sets the value to return when there is a
zero division. If set to “warn”, this acts as 0, but warnings are
also raised.
"""
_check_zero_division(zero_division)
MCM = multilabel_confusion_matrix(y_true, y_pred, labels=[pos_label],
sample_weight=sample_weight)
tn, fp, fn, tp = MCM.ravel()
tn, fp = MCM[:, 0, 0], MCM[:, 0, 1]
negs = tn + fp
if negs == 0:
warnings.warn('specificity_score is ill-defined and being set to 0.0 '
'due to no negative samples.', UndefinedMetricWarning)
return 0.
return tn / negs
return _prf_divide(tn, negs, 'specificity', 'negative', None,
('specificity',), zero_division).item()

def base_rate(y_true, y_pred=None, pos_label=1, sample_weight=None):
r"""Compute the base rate, :math:`Pr(Y = \text{pos_label}) = \frac{P}{P+N}`.
Expand Down Expand Up @@ -200,7 +207,8 @@ def selection_rate(y_true, y_pred, pos_label=1, sample_weight=None):
"""
return base_rate(y_pred, pos_label=pos_label, sample_weight=sample_weight)

def generalized_fpr(y_true, probas_pred, pos_label=1, sample_weight=None):
def generalized_fpr(y_true, probas_pred, pos_label=1, sample_weight=None,
zero_division='warn'):
r"""Return the ratio of generalized false positives to negative examples in
the dataset, :math:`GFPR = \tfrac{GFP}{N}`.
Expand All @@ -212,22 +220,29 @@ def generalized_fpr(y_true, probas_pred, pos_label=1, sample_weight=None):
probas_pred (array-like): Probability estimates of the positive class.
pos_label (scalar, optional): The label of the positive class.
sample_weight (array-like, optional): Sample weights.
zero_division ('warn', 0 or 1): Sets the value to return when there is a
zero division. If set to “warn”, this acts as 0, but warnings are
also raised.
Returns:
float: Generalized false positive rate. If there are no negative samples
in y_true, this will raise an
:class:`~sklearn.exceptions.UndefinedMetricWarning` and return 0.
float: Generalized false positive rate.
"""
_check_zero_division(zero_division)
y_true, probas_pred = column_or_1d(y_true), column_or_1d(probas_pred)

idx = (y_true != pos_label)
if not np.any(idx):
warnings.warn("generalized_fpr is ill-defined because there are no "
"negative samples in y_true.", UndefinedMetricWarning)
return 0.
gfps = probas_pred[idx]
if sample_weight is None:
return probas_pred[idx].mean()
return np.average(probas_pred[idx], weights=sample_weight[idx])
gfp = np.array([gfps.sum()])
neg = np.array([len(gfps)])
else:
gfp = np.array([np.dot(gfps, sample_weight[idx])])
neg = np.array([sample_weight[idx].sum()])
return _prf_divide(gfp, neg, 'generalized FPR', 'negative', None,
('generalized FPR',), zero_division).item()

def generalized_fnr(y_true, probas_pred, pos_label=1, sample_weight=None):
def generalized_fnr(y_true, probas_pred, pos_label=1, sample_weight=None,
zero_division='warn'):
r"""Return the ratio of generalized false negatives to positive examples in
the dataset, :math:`GFNR = \tfrac{GFN}{P}`.
Expand All @@ -239,20 +254,26 @@ def generalized_fnr(y_true, probas_pred, pos_label=1, sample_weight=None):
probas_pred (array-like): Probability estimates of the positive class.
pos_label (scalar, optional): The label of the positive class.
sample_weight (array-like, optional): Sample weights.
zero_division ('warn', 0 or 1): Sets the value to return when there is a
zero division. If set to “warn”, this acts as 0, but warnings are
also raised.
Returns:
float: Generalized false negative rate. If there are no positive samples
in y_true, this will raise an
:class:`~sklearn.exceptions.UndefinedMetricWarning` and return 0.
float: Generalized false negative rate.
"""
_check_zero_division(zero_division)
y_true, probas_pred = column_or_1d(y_true), column_or_1d(probas_pred)

idx = (y_true == pos_label)
if not np.any(idx):
warnings.warn("generalized_fnr is ill-defined because there are no "
"positive samples in y_true.", UndefinedMetricWarning)
return 0.
gfns = 1 - probas_pred[idx]
if sample_weight is None:
return 1 - probas_pred[idx].mean()
return 1 - np.average(probas_pred[idx], weights=sample_weight[idx])
gfn = np.array([gfns.sum()])
pos = np.array([len(gfns)])
else:
gfn = np.array([np.dot(gfns, sample_weight[idx])])
pos = np.array([sample_weight[idx].sum()])
return _prf_divide(gfn, pos, 'generalized FNR', 'positive', None,
('generalized FNR',), zero_division).item()


# ============================ GROUP FAIRNESS ==================================
Expand Down Expand Up @@ -291,7 +312,7 @@ def statistical_parity_difference(*y, prot_attr=None, priv_group=1, pos_label=1,
pos_label=pos_label, sample_weight=sample_weight)

def disparate_impact_ratio(*y, prot_attr=None, priv_group=1, pos_label=1,
sample_weight=None):
sample_weight=None, zero_division='warn'):
r"""Ratio of selection rates.
.. math::
Expand All @@ -313,6 +334,9 @@ def disparate_impact_ratio(*y, prot_attr=None, priv_group=1, pos_label=1,
priv_group (scalar, optional): The label of the privileged group.
pos_label (scalar, optional): The label of the positive class.
sample_weight (array-like, optional): Sample weights.
zero_division ('warn', 0 or 1): Sets the value to return when there is a
zero division. If set to “warn”, this acts as 0, but warnings are
also raised.
Returns:
float: Disparate impact.
Expand All @@ -322,7 +346,8 @@ def disparate_impact_ratio(*y, prot_attr=None, priv_group=1, pos_label=1,
"""
rate = base_rate if len(y) == 1 or y[1] is None else selection_rate
return ratio(rate, *y, prot_attr=prot_attr, priv_group=priv_group,
pos_label=pos_label, sample_weight=sample_weight)
pos_label=pos_label, sample_weight=sample_weight,
zero_division=zero_division)

def equal_opportunity_difference(y_true, y_pred, prot_attr=None, priv_group=1,
pos_label=1, sample_weight=None):
Expand Down Expand Up @@ -384,8 +409,8 @@ def average_odds_difference(y_true, y_pred, prot_attr=None, priv_group=1,
sample_weight=sample_weight)
return (tpr_diff + fpr_diff) / 2

def average_odds_error(y_true, y_pred, prot_attr=None, pos_label=1,
sample_weight=None):
def average_odds_error(y_true, y_pred, prot_attr=None, priv_group=None,
pos_label=1, sample_weight=None):
r"""A relaxed version of equality of odds.
Returns the average of the absolute difference in FPR and TPR for the
Expand All @@ -403,14 +428,17 @@ def average_odds_error(y_true, y_pred, prot_attr=None, pos_label=1,
y_pred (array-like): Estimated targets as returned by a classifier.
prot_attr (array-like, keyword-only): Protected attribute(s). If
``None``, all protected attributes in y_true are used.
priv_group (scalar, optional): The label of the privileged group.
priv_group (scalar, optional): The label of the privileged group. If
prot_attr is binary, this may be ``None``.
pos_label (scalar, optional): The label of the positive class.
sample_weight (array-like, optional): Sample weights.
Returns:
float: Average odds error.
"""
priv_group = check_groups(y_true, prot_attr=prot_attr)[0][0]
if priv_group is None:
priv_group = check_groups(y_true, prot_attr=prot_attr,
ensure_binary=True)[0][0]
fpr_diff = -difference(specificity_score, y_true, y_pred,
prot_attr=prot_attr, priv_group=priv_group,
pos_label=pos_label, sample_weight=sample_weight)
Expand Down
Loading

0 comments on commit db843a1

Please sign in to comment.