Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a bias detector based on optimal transport #434

Merged
merged 26 commits into from
Jul 23, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8e42e67
Updated all required files
Illia-Kryvoviaz Jun 7, 2023
7a85647
Some minor changes in files and updated tests
Illia-Kryvoviaz Jun 7, 2023
0d276aa
Deleted extra prints
Illia-Kryvoviaz Jun 7, 2023
a6faacd
Simplifying the ot notebook and correcting some mistypes
Illia-Kryvoviaz Jun 10, 2023
2db5f83
Added more examples to ot notebook
Illia-Kryvoviaz Jun 12, 2023
6d77b72
Update __init__.py
Illia-Kryvoviaz Jun 12, 2023
a9e4b55
Update detectors.py
Illia-Kryvoviaz Jun 12, 2023
4921f5b
Update requirements.txt
Illia-Kryvoviaz Jun 12, 2023
881ea08
Improving the notebook and adding a new feature
Illia-Kryvoviaz Jun 18, 2023
90c538f
Dev ot detector (#4)
Illia-Kryvoviaz Jun 24, 2023
7029b20
Added outputs to the notebook
Illia-Kryvoviaz Jun 24, 2023
425452b
Minor docstrings changes and update detectors.py
Illia-Kryvoviaz Jun 25, 2023
0be802e
changed demo_ot_detector to use load_preproc_data_adult
Illia-Kryvoviaz Jul 10, 2023
f96a451
ot_detector: renamed sensitive_attribute to prot_attr, minor changes
Illia-Kryvoviaz Jul 10, 2023
9c46f9a
updated comments, demo_ot_detector.ipynb
Illia-Kryvoviaz Jul 11, 2023
872f112
ot_detector: removed str arguments
Illia-Kryvoviaz Jul 11, 2023
d280b4e
ot_detector: added cost_matrix as a named parameter, minor changes
Illia-Kryvoviaz Jul 11, 2023
aa24084
ot_detector: minor changes
Illia-Kryvoviaz Jul 11, 2023
f587762
added outputs to demo_ot_detector
Illia-Kryvoviaz Jul 11, 2023
2094b47
ot_detector: changed default scoring to Wasserstein1
Illia-Kryvoviaz Jul 11, 2023
47a2678
moved OT from detector to metric
Illia-Kryvoviaz Jul 14, 2023
6acd270
renamed ot_detector to ot_metric
Illia-Kryvoviaz Jul 14, 2023
6155f99
reworked demo_ot_metric to use aif360.sklearn definition
Illia-Kryvoviaz Jul 14, 2023
3d3f039
renamed ot_bias_scan to ot_distance, minor changes
Illia-Kryvoviaz Jul 21, 2023
bb2ec68
detectors.py: reset changes
Illia-Kryvoviaz Jul 21, 2023
10c7ec8
test_ot_metric: minor changes
Illia-Kryvoviaz Jul 21, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aif360/detectors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from aif360.detectors.mdss.MDSS import MDSS
from aif360.detectors.mdss_detector import bias_scan
from aif360.detectors.mdss_detector import bias_scan
from aif360.detectors.ot_detector import ot_bias_scan
257 changes: 257 additions & 0 deletions aif360/detectors/ot_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
from typing import Union
import pandas as pd
import numpy as np
import ot
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import OrdinalEncoder, LabelEncoder

def _normalize(distribution1, distribution2):
"""
Transform distributions to pleasure form, that is their sums are equal to 1,
and in case if there is negative values, increase all values with absolute value of smallest number.

Args:
distribution1 (numpy array): nontreated distribution
distribution2 (numpy array): nontreated distribution
"""
if np.minimum(np.min(distribution1), np.min(distribution2)) < 0:
extra = -np.minimum(np.min(distribution1), np.min(distribution2))
distribution1 += extra
distribution2 += extra

total_of_distribution1 = np.sum(distribution1)
if total_of_distribution1 != 0:
distribution1 /= total_of_distribution1
total_of_distribution2 = np.sum(distribution2)
if total_of_distribution2 != 0:
distribution2 /= total_of_distribution2

def _transform(ground_truth, classifier, data, cost_matrix=None):
"""
Transform given distributions from pandas type to numpy arrays, and _normalize them.
Rearanges distributions, with totall data allocated of one.
Generates matrix distance with respect to (ground_truth[i] - classifier[j])^2.

Args:
ground_truth (series): ground truth (correct) target values
classifier (series, dataframe, optional): pandas series estimated targets
as returned by a model for binary, continuous and ordinal modes.
data (dataframe): the dataset (containing the features) the model was trained on

Returns:
initial_distribution, which is an processed ground_truth (numpy array)
required_distribution, which is an processed classifier (numpy array)
matrix_distance, which stores the distances between the cells of distributions (2d numpy array)
"""
initial_distribution = (pd.Series.to_numpy(ground_truth)).astype(float)
required_distribution = (pd.Series.to_numpy(classifier)).astype(float)
hoffmansc marked this conversation as resolved.
Show resolved Hide resolved

_normalize(initial_distribution, required_distribution)

if cost_matrix is not None:
matrix_distance = cost_matrix
else:
matrix_distance = np.array([abs(i - required_distribution) for i in initial_distribution], dtype=float)
return initial_distribution, required_distribution, matrix_distance

def _evaluate(
ground_truth: pd.Series,
classifier: pd.Series,
sensitive_attribute: pd.Series=None,
data: pd.DataFrame=None,
num_iters=1e5,
**kwargs):
"""If the given golden_standart and classifier are distributions, it returns the Wasserstein distance between them,
otherwise it extract all neccessary information from data, makes logistic regression and
compute optimal transport for all possible options for the given classifier.
hoffmansc marked this conversation as resolved.
Show resolved Hide resolved

Args:
ground_truth (pd.Series, str): ground truth (correct) target value
classifier (pd.Series): estimated target values
sensitive_attribute (pd.Series, str): pandas series of sensitive attribute values
data (dataframe): the dataset (containing the features) the model was trained on; \
None if `ground_truth`, `classifier` and `sensitive_attribute` are `pd.Series`
num_iters (int, optional): number of iterations (random restarts). Should be positive.

Returns:
ot.emd2 (float, dict): Earth mover's distance or dictionary of optimal transports for each of option of classifier
"""

# Calculate just the EMD between ground_truth and classifier
if sensitive_attribute is None:
initial_distribution, required_distribution, matrix_distance = _transform(ground_truth, classifier, data, kwargs.get("cost_matrix"))
return ot.emd2(a=initial_distribution, b=required_distribution, M=matrix_distance, numItermax=num_iters)

if not ground_truth.nunique() == 2:
raise ValueError(f"Expected to have exactly 2 target values, got {len(set(data[ground_truth]))}.")

# Calculate EMD between ground truth distribution and distribution of each group
emds = {}
for sa_val in sorted(sensitive_attribute.unique()):
initial_distribution = ground_truth[sensitive_attribute == sa_val]
required_distribution = classifier[sensitive_attribute == sa_val]
initial_distribution, required_distribution, matrix_distance = _transform(initial_distribution, required_distribution, data, kwargs.get("cost_matrix"))
emds[sa_val] = ot.emd2(a=initial_distribution, b=required_distribution, M=matrix_distance, numItermax=num_iters)

return emds


# Function called by the user
def ot_bias_scan(
ground_truth: Union[pd.Series, str],
classifier: Union[pd.Series, str],
sensitive_attribute: Union[pd.Series, str] = None,
data: pd.DataFrame = None,
favorable_value: Union[str, float] = None,
overpredicted: bool = True,
scoring: str = "Optimal Transport",
num_iters: int = 1e5,
penalty: float = 1e-17,
mode: str = "binary",
**kwargs,
hoffmansc marked this conversation as resolved.
Show resolved Hide resolved
):
"""Calculated the Wasserstein distance for two given distributions.
Transforms pandas Series into numpy arrays, transofrms and normalize them.
After all, solves the optimal transport problem.

Args:
ground_truth (pd.Series, str): ground truth (correct) target values.
If `str`, denotes the column in `data` in which the ground truth target values are stored.
classifier (pd.Series, pd.DataFrame, str): estimated target values.
If `str`, must denote the column or columns in `data` in which the estimated target values are stored.
If `mode` is nominal, must be a dataframe with columns containing predictions for each nominal class,
or list of corresponding column names in `data`.
If `None`, model is assumed to be a dummy model that predicts the mean of the targets
or 1/(number of categories) for nominal mode.
sensitive_attribute (pd.Series, str): sensitive attribute values.
If `str`, must denote the column in `data` in which the sensitive attrbute values are stored.
If `None`, assume all samples belong to the same protected group.
data (dataframe, optional): the dataset (containing the features) the model was trained on.
favorable_value(str, float, optional): Either "high", "low" or a float value if the mode in [binary, ordinal, or continuous].
If float, value has to be the minimum or the maximum in the ground_truth column.
Defaults to high if None for these modes.
Support for float left in to keep the intuition clear in binary classification tasks.
If `mode` is nominal, favorable values should be one of the unique categories in the ground_truth.
Defaults to a one-vs-all scan if None for nominal mode.
overpredicted (bool, optional): flag for group to scan for.
`True` scans for overprediction, `False` scans for underprediction.
scoring (str or class): only 'Optimal Transport'
num_iters (int, optional): number of iterations (random restarts) for EMD. Should be positive.
penalty (float, optional): penalty term. Should be positive. The penalty term as with any regularization parameter
may need to be tuned for a particular use case. The higher the penalty, the higher the influence of entropy regualizer.
mode: one of ['binary', 'continuous', 'nominal', 'ordinal']. Defaults to binary.
In nominal mode, up to 10 categories are supported by default.
To increase this, pass in keyword argument max_nominal = integer value.

Returns:
ot.emd2 (float, dict): Earth mover's distance or dictionary of optimal transports for each of option of classifier

Raises:
ValueError: if `mode` is 'binary' but `ground_truth` contains less than 1 or more than 2 unique values.
"""

# Assert correct mode passed
if mode not in ['binary', 'continuous', 'nominal', 'ordinal']:
raise ValueError(f"Expected one of {['binary', 'continuous', 'nominal', 'ordinal']}, got {mode}.")

# Assert correct types passed to ground_truth, classifier and sensitive_attribute
if not isinstance(ground_truth, (pd.Series, str)):
raise TypeError(f"ground_truth: expected pd.Series or str, got {type(ground_truth)}")
if classifier is not None:
if mode in ["binary", "continuous"] and not isinstance(classifier, pd.Series):
raise TypeError(f"classifier: expected pd.Series for {mode} mode, got {type(classifier)}")
if mode in ["nominal", "ordinal"] and not isinstance(classifier, pd.DataFrame):
raise TypeError(f"classifier: expected pd.DataFrame for {mode} mode, got {type(classifier)}")
if sensitive_attribute is not None and not isinstance(sensitive_attribute, (pd.Series, str)):
raise TypeError(f"sensitive_attribute: expected pd.Series or str, got {type(sensitive_attribute)}")

# Assert correct type passed to cost_matrix
if kwargs.get("cost_matrix") is not None:
if not isinstance(kwargs.get("cost_matrix"), np.ndarray):
raise TypeError(f"cost_matrix: expected numpy.ndarray, got {type(kwargs.get('cost_matrix'))}")

# Assert scoring is "Optimal Transport"
if not scoring == "Optimal Transport":
raise ValueError(f"Scoring mode can only be \"Optimal Transport\", got {scoring}")

# If any of input data arguments passed as str, retrieve the values from data
if isinstance(ground_truth, str): # ground truth
if not isinstance(data, pd.DataFrame):
raise TypeError(f"if ground_truth is a string, data must be pd.DataFrame; got {type(data)}")
grt = data[ground_truth].copy()
else:
grt = ground_truth.copy()

if isinstance(classifier, str): # classifier
if not isinstance(data, pd.DataFrame):
raise TypeError(f"if classifier is a string, data must be pd.DataFrame; got {type(data)}")
cls = data[classifier].copy()
elif classifier is not None:
cls = classifier.copy()
if sensitive_attribute is not None:
cls.index = grt.index
else:
cls = None

if isinstance(sensitive_attribute, str): # sensitive attribute
if not isinstance(data, pd.DataFrame):
raise TypeError(f"if sensitive_attribute is a string, data must be pd.DataFrame; got {type(data)}")
sat = data[sensitive_attribute].copy()
sat.index = grt.index
elif sensitive_attribute is not None:
sat = sensitive_attribute.copy()
sat.index = grt.index
else:
sat = None

uniques = list(grt.unique())
if mode == "binary":
if len(uniques) > 2:
raise ValueError(f"Only 2 unique values allowed in ground_truth for binary mode, got {uniques}")

# Encode variables
if not np.issubdtype(grt.dtype, np.number):
grt_encoder = LabelEncoder().fit(grt)
grt = pd.Series(grt_encoder.transform(grt))

# Set correct favorable value (this tells us if higher or lower is better)
min_val, max_val = grt.min(), grt.max()

if favorable_value == 'high':
favorable_value = max_val
elif favorable_value == 'low':
favorable_value = min_val
elif favorable_value is None:
if mode in ["binary", "ordinal", "continuous"]:
favorable_value = max_val # Default to higher is better
elif mode == "nominal":
favorable_value = "flag-all" # Default to scan through all categories

if favorable_value not in [min_val, max_val, "flag-all", *uniques,]:
raise ValueError(f"Favorable_value should be high, low, or one of categories {uniques}, got {favorable_value}.")

if mode == "binary": # Flip ground truth if favorable_value is 0 in binary mode.
grt = pd.Series(grt == favorable_value, dtype=int)
if cls is None:
cls = pd.Series(grt.mean(), index=grt.index)
emds = _evaluate(grt, cls, sat, data, num_iters, **kwargs)

elif mode == "continuous":
if cls is None:
cls = pd.Series(grt.mean(), index=grt.index)
emds = _evaluate(grt, cls, sat, data, num_iters, **kwargs)

## TODO: rework ordinal mode to take into account distance between pred and true
elif mode in ["nominal", "ordinal"]:
if cls is None: # Set classifier to 1/(num of categories) for nominal mode
cls = pd.DataFrame([pd.Series(1 / grt.nunique(), index=grt.index)]*grt.nunique())
if grt.nunique() != cls.shape[-1]:
raise ValueError(
f"classifier must have a column for each class. Expected shape [:, {grt.nunique()}], got {cls.shape}")
emds = {}
for class_label in uniques:
grt_cl = grt.map({class_label: 1}).fillna(0)
cls_cl = cls[class_label]
emds[class_label] = _evaluate(grt_cl, cls_cl, sat, num_iters, **kwargs)

return emds
2 changes: 2 additions & 0 deletions aif360/sklearn/detectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
Methods for detecting subsets for which a model or dataset is biased.
"""
from aif360.sklearn.detectors.detectors import bias_scan
from aif360.sklearn.detectors.detectors import ot_bias_scan

__all__ = [
'bias_scan',
'ot_bias_scan',
]
70 changes: 68 additions & 2 deletions aif360/sklearn/detectors/detectors.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,76 @@
from typing import Union

from aif360.detectors.ot_detector import ot_bias_scan
from aif360.detectors import bias_scan
from aif360.detectors.mdss.ScoringFunctions import ScoringFunction

from typing import Union
import pandas as pd
import numpy as np

def ot_bias_scan(
y_true: Union[pd.Series, str],
y_pred: Union[pd.Series, pd.DataFrame, str],
sensitive_attribute: Union[pd.Series, str] = None,
X: pd.DataFrame = None,
hoffmansc marked this conversation as resolved.
Show resolved Hide resolved
pos_label: Union[str, float] = None,
overpredicted: bool = True,
scoring: str = "Optimal Transport",
num_iters: int = 100,
penalty: float = 1e-17,
mode: str = "ordinal",
**kwargs,
):
hoffmansc marked this conversation as resolved.
Show resolved Hide resolved
"""Calculated the Wasserstein distance for two given distributions.
Transforms pandas Series into numpy arrays, transofrms and normalize them.
After all, solves the optimal transport problem.

Args:
y_true (pd.Series, str): ground truth (correct) target values.
If `str`, denotes the column in `data` in which the ground truth target values are stored.
y_pred (pd.Series, pd.DataFrame, str): estimated target values.
If `str`, must denote the column or columns in `data` in which the estimated target values are stored.
If `mode` is nominal, must be a dataframe with columns containing predictions for each nominal class,
or list of corresponding column names in `data`.
If `None`, model is assumed to be a dummy model that predicts the mean of the targets
or 1/(number of categories) for nominal mode.
sensitive_attribute (pd.Series, str): sensitive attribute values.
If `str`, must denote the column in `data` in which the sensitive attrbute values are stored.
If `None`, assume all samples belong to the same protected group.
X (dataframe, optional): the dataset (containing the features) the model was trained on.
pos_label(str, float, optional): Either "high", "low" or a float value if the mode in [binary, ordinal, or continuous].
If float, value has to be the minimum or the maximum in the ground_truth column.
Defaults to high if None for these modes.
Support for float left in to keep the intuition clear in binary classification tasks.
If `mode` is nominal, favorable values should be one of the unique categories in the ground_truth.
Defaults to a one-vs-all scan if None for nominal mode.
overpredicted (bool, optional): flag for group to scan for.
`True` scans for overprediction, `False` scans for underprediction.
scoring (str or class): only 'Optimal Transport'
num_iters (int, optional): number of iterations (random restarts) for EMD. Should be positive.
penalty (float, optional): penalty term. Should be positive. The penalty term as with any regularization parameter
may need to be tuned for a particular use case. The higher the penalty, the higher the influence of entropy regualizer.
mode: one of ['binary', 'continuous', 'nominal', 'ordinal']. Defaults to binary.
In nominal mode, up to 10 categories are supported by default.
To increase this, pass in keyword argument max_nominal = integer value.

Returns:
ot.emd2 (float, dict): Earth mover's distance or dictionary of optimal transports for each of option of classifier

Raises:
ValueError: if `mode` is 'binary' but `ground_truth` contains less than 1 or more than 2 unique values.
"""
return ot_bias_scan(
ground_truth=y_true,
classifier=y_pred,
sensitive_attribute=sensitive_attribute,
data=X,
favorable_value=pos_label,
overpredicted=overpredicted,
scoring=scoring,
num_iters=num_iters,
penalty=penalty,
mode=mode,
kwargs=kwargs
)

def bias_scan(
X: pd.DataFrame,
Expand Down
Loading
Loading