Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scikit-learn-compatible API #134

Merged
merged 61 commits into from
Feb 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
8cfa9de
Initial sklearn-compatible datasets and metrics
May 16, 2019
1f4ae57
added initial dataset tests
May 16, 2019
2aef3fc
fixed to_list for older pandas versions
May 17, 2019
2b1799a
added metrics tests
hoffmansc May 17, 2019
9da5abd
added README and docs
hoffmansc May 21, 2019
025ecc1
simpler dataset loading and 'groups' for metrics
hoffmansc May 23, 2019
8e96177
fixes to categoricals
hoffmansc Jun 5, 2019
8abb897
fixes for tests, updated README
hoffmansc Jun 5, 2019
15a8eb2
added travis badge to README
hoffmansc Jun 6, 2019
3f594a4
updated todo with external blockers
hoffmansc Jun 13, 2019
7754b32
added reweighing workaround to example
hoffmansc Jun 13, 2019
17b0c95
added Reweighing algorithm
hoffmansc Jun 18, 2019
cc9246f
clean up comments
hoffmansc Jun 18, 2019
8c58f65
fixed package version in docs
hoffmansc Jun 18, 2019
1e7899c
adding hyperlinks to SLEPs
animeshsingh Jun 20, 2019
c1c1e40
added binary_age opt to german; fixed NAs in bank
hoffmansc Jun 24, 2019
93a7cdf
modified onehot_transformer to return DataFrame
hoffmansc Jun 24, 2019
8e52268
tweaks to reweighing to conform with sklearn
hoffmansc Jun 24, 2019
0183449
updated README
hoffmansc Jun 24, 2019
89b4a79
fixed docstring formatting
hoffmansc Jun 24, 2019
d57b6df
changed metrics to use prot_attr
hoffmansc Jun 24, 2019
d8958bb
added __all__ to __init__s
hoffmansc Jun 24, 2019
0bd3837
updated notebook with reweighing example
hoffmansc Jun 27, 2019
4107dd7
initial adversarial debiasing port
hoffmansc Jul 11, 2019
df85e42
multiclass/multigroup support for adv debiasing
hoffmansc Jul 16, 2019
d2d0ddc
fix build errors
hoffmansc Jul 30, 2019
7a2414a
Add ensure_binary option to check_groups
hoffmansc Aug 12, 2019
aac9954
`numeric_only` converts index and label as well
hoffmansc Oct 29, 2019
dc317cf
changed Reweighing to return X, sample_weight
hoffmansc Oct 29, 2019
0f184c3
made sample_weight optional in check_inputs
hoffmansc Oct 29, 2019
ec4a1de
matched tests to new numeric dataset format
hoffmansc Oct 29, 2019
f8c4fc5
added generalized_fnr/fpr metrics
hoffmansc Oct 29, 2019
7ce2f42
fixed dataset_processing
hoffmansc Oct 29, 2019
973a774
initial calibrated equalized odds port
hoffmansc Oct 29, 2019
40cad96
fixed adversarial debiasing reproducibility
hoffmansc Oct 30, 2019
dc410a2
updated Getting Started notebook
hoffmansc Oct 30, 2019
e0856e3
updated readme
hoffmansc Oct 31, 2019
8f8cd76
fixed tests and added additional tests
hoffmansc Oct 31, 2019
e01f23f
added COMPAS and other dataset fixes* fixed german dataset to match p…
hoffmansc Nov 11, 2019
e92f846
fix more edge cases in metrics
hoffmansc Nov 12, 2019
27aa55c
removed unused import
hoffmansc Nov 12, 2019
831775c
make cache dir if necessary
hoffmansc Dec 9, 2019
a0e56b0
docstring, formatting, and typo fixes
hoffmansc Dec 13, 2019
0e48ead
more gitignores
hoffmansc Dec 13, 2019
0cbc3f4
docstrings and add alpha=sqrt(global_step) option
hoffmansc Dec 13, 2019
8be6449
docstrings and input is now predict_proba output
hoffmansc Dec 13, 2019
994bdf0
moved tests to main test folder
hoffmansc Dec 18, 2019
372e111
more docs and formatting changes
hoffmansc Dec 19, 2019
8d10893
postprocessor takes DataFrame if use_proba
hoffmansc Dec 19, 2019
e0ff2b6
readme changes overwritten in the merge
hoffmansc Dec 19, 2019
a2cd77e
train, test were swapped for adult
hoffmansc Dec 19, 2019
ee7f23c
remove branch mentions
hoffmansc Dec 19, 2019
c8154ec
remove "attributes" line if none present
hoffmansc Dec 20, 2019
7ef94e7
moved example to main folder
hoffmansc Dec 28, 2019
c5af647
use_proba -> needs_proba
hoffmansc Jan 31, 2020
042bb12
fixed/renamed/reordered/added some attributes
hoffmansc Jan 31, 2020
ff9e70c
fixed sample_weight=None bug and classes_ typo
hoffmansc Feb 5, 2020
57b2ab5
improved specificity_score and added fpr/fnr error
hoffmansc Feb 6, 2020
8fdd6dc
made foreign_worker and education (bank) ordered
hoffmansc Feb 6, 2020
2cf455f
various fixes to address PR comments
hoffmansc Feb 19, 2020
789e96b
added comments to tests
hoffmansc Feb 19, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 16 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,23 @@
.cache/
.ipynb_checkpoints/
.pytest_cache/
__pycache__/

.idea/
.vscode/

.eggs/
aif360.egg-info
build/
dist/

.coverage*
coverage.txt

docs/build/
docs/source/modules/generated

aif360/version.py
aif360/data/raw/**
!aif360/data/raw/*/*.md
aif360/version.py
aif360/sklearn/data/
15 changes: 9 additions & 6 deletions aif360/algorithms/inprocessing/adversarial_debiasing.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ def _classifier_model(self, features, features_dim, keep_prob):
"""
with tf.variable_scope("classifier_model"):
W1 = tf.get_variable('W1', [features_dim, self.classifier_num_hidden_units],
initializer=tf.contrib.layers.xavier_initializer())
initializer=tf.contrib.layers.xavier_initializer(seed=self.seed1))
b1 = tf.Variable(tf.zeros(shape=[self.classifier_num_hidden_units]), name='b1')

h1 = tf.nn.relu(tf.matmul(features, W1) + b1)
h1 = tf.nn.dropout(h1, keep_prob=keep_prob)
h1 = tf.nn.dropout(h1, keep_prob=keep_prob, seed=self.seed2)

W2 = tf.get_variable('W2', [self.classifier_num_hidden_units, 1],
initializer=tf.contrib.layers.xavier_initializer())
initializer=tf.contrib.layers.xavier_initializer(seed=self.seed3))
b2 = tf.Variable(tf.zeros(shape=[1]), name='b2')

pred_logit = tf.matmul(h1, W2) + b2
Expand All @@ -103,7 +103,7 @@ def _adversary_model(self, pred_logits, true_labels):
s = tf.sigmoid((1 + tf.abs(c)) * pred_logits)

W2 = tf.get_variable('W2', [3, 1],
initializer=tf.contrib.layers.xavier_initializer())
initializer=tf.contrib.layers.xavier_initializer(seed=self.seed4))
b2 = tf.Variable(tf.zeros(shape=[1]), name='b2')

pred_protected_attribute_logit = tf.matmul(tf.concat([s, s * true_labels, s * (1.0 - true_labels)], axis=1), W2) + b2
Expand All @@ -123,6 +123,8 @@ def fit(self, dataset):
"""
if self.seed is not None:
np.random.seed(self.seed)
ii32 = np.iinfo(np.int32)
self.seed1, self.seed2, self.seed3, self.seed4 = np.random.randint(ii32.min, ii32.max, size=4)

# Map the dataset labels to 0 and 1.
temp_labels = dataset.labels.copy()
Expand Down Expand Up @@ -177,14 +179,15 @@ def fit(self, dataset):

if self.debias:
# Update adversary parameters
adversary_minimizer = adversary_opt.minimize(pred_protected_attributes_loss, var_list=adversary_vars, global_step=global_step)
with tf.control_dependencies([classifier_minimizer]):
adversary_minimizer = adversary_opt.minimize(pred_protected_attributes_loss, var_list=adversary_vars)#, global_step=global_step)

self.sess.run(tf.global_variables_initializer())
self.sess.run(tf.local_variables_initializer())

# Begin training
for epoch in range(self.num_epochs):
shuffled_ids = np.random.choice(num_train_samples, num_train_samples)
shuffled_ids = np.random.choice(num_train_samples, num_train_samples, replace=False)
nrkarthikeyan marked this conversation as resolved.
Show resolved Hide resolved
for i in range(num_train_samples//self.batch_size):
batch_ids = shuffled_ids[self.batch_size*i: self.batch_size*(i+1)]
batch_features = dataset.features[batch_ids]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,16 @@ def predict(self, dataset, threshold=0.5):
dataset.protected_attribute_names,
self.unprivileged_groups)

priv_indices = (np.random.random(sum(cond_vec_priv))
<= self.priv_mix_rate)
priv_new_pred = dataset.scores[cond_vec_priv].copy()
priv_new_pred[priv_indices] = self.base_rate_priv

unpriv_indices = (np.random.random(sum(cond_vec_unpriv))
<= self.unpriv_mix_rate)
unpriv_new_pred = dataset.scores[cond_vec_unpriv].copy()
unpriv_new_pred[unpriv_indices] = self.base_rate_unpriv

priv_indices = (np.random.random(sum(cond_vec_priv))
<= self.priv_mix_rate)
priv_new_pred = dataset.scores[cond_vec_priv].copy()
priv_new_pred[priv_indices] = self.base_rate_priv

dataset_new = dataset.copy(deepcopy=True)

dataset_new.scores = np.zeros_like(dataset.scores, dtype=np.float64)
Expand Down Expand Up @@ -208,4 +208,4 @@ def weighted_cost(fp_rate, fn_rate, cm, privileged):
* (1 - cm.base_rate(privileged=privileged))) +
(fn_rate / norm_const
* cm.generalized_false_negative_rate(privileged=privileged)
* (1 - cm.base_rate(privileged=privileged))))
* cm.base_rate(privileged=privileged)))
nrkarthikeyan marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion aif360/datasets/adult_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(self, label_name='income-per-year',
import sys
sys.exit(1)

df = pd.concat([train, test], ignore_index=True)
df = pd.concat([test, train], ignore_index=True)
nrkarthikeyan marked this conversation as resolved.
Show resolved Hide resolved

super(AdultDataset, self).__init__(df=df, label_name=label_name,
favorable_classes=favorable_classes,
Expand Down
23 changes: 17 additions & 6 deletions aif360/datasets/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,14 +411,25 @@ def import_dataset(self, import_metadata=False):
return None

def split(self, num_or_size_splits, shuffle=False, seed=None):
"""Split the dataset into multiple datasets
"""Split this dataset into multiple partitions.

Args:
num_or_size_splits (list or int):
shuffle (bool):
seed (int or array_like): takes the same argument as `numpy.random.seed()`
function
num_or_size_splits (array or int): If `num_or_size_splits` is an
int, *k*, the value is the number of equal-sized folds to make
(if *k* does not evenly divide the dataset these folds are
approximately equal-sized). If `num_or_size_splits` is an array
of type int, the values are taken as the indices at which to
split the dataset. If the values are floats (< 1.), they are
considered to be fractional proportions of the dataset at which
to split.
shuffle (bool, optional): Randomly shuffle the dataset before
splitting.
seed (int or array_like): Takes the same argument as
:func:`numpy.random.seed()`.

Returns:
list: Each element of this list is a dataset obtained during the split
list: Splits. Contains *k* or `len(num_or_size_splits) + 1`
datasets depending on `num_or_size_splits`.
"""

# Set seed
Expand Down
47 changes: 47 additions & 0 deletions aif360/sklearn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
## `aif360.sklearn`

This is a wholly separate interface for interacting with data, viewing metrics,
and running debiasing algorithms than the main AIF360 package. The purpose of
this sub-package is to match scikit-learn paradigms/APIs for easier integration
in typical machine learning workflows.

See [Getting Started](examples/Getting%20Started.ipynb) to see `aif360.sklearn`
in action.

To do:

- [x] Reformat datasets as separate X and y (and sample_weight) DataFrame
objects with sample properties (protected attributes) as the index
- [ ] Load included datasets in the above format
- [x] Use `sklearn.datasets.fetch_openml` to load UCI datasets (#53)
- [ ] COMPAS
- [ ] MEPS
- [ ] Implement metrics as individual functions instead of instance methods
- [x] Make certain metrics compatible as sklearn scorers
- [x] Use "prot_attr" and "priv_group" keywords to specify protected attributes to
functions
- [x] Generalized confusion matrix
- [ ] Sample distortion metrics
- [ ] Make inprocessing algorithms compatible as sklearn `Estimator`s
- [x] Adversarial debiasing
- [ ] **[External]** `get_feature_names()` from data preprocessing
steps that would remove DataFrame formatting
- [ ] [SLEP007](https://github.com/scikit-learn/enhancement_proposals/pull/17)/[SLEP008](https://github.com/scikit-learn/enhancement_proposals/pull/18) - feature names
- [ ] Prejudice remover
- [ ] Meta-fair classifier
- [ ] Make preprocessing algorithms compatible as sklearn `Transformer`s
- [ ] **[External]** Add functionality to modify X and y
- [ ] [SLEP005](https://github.com/scikit-learn/enhancement_proposals/pull/15) - Resampler API (see discussion; meta-estimator workaround may be enough)
- [ ] Disparate impact remover
- [ ] Learning fair representations
- [ ] Optimized preprocessing
- [X] Reweighing
- [X] Meta-estimator workaround
- [ ] **[External]** [SLEP006](https://github.com/scikit-learn/enhancement_proposals/pull/16) - Sample properties (meta-estimator works but would be very nice to have)
- [ ] Make postprocessing algorithms compatible
- [x] Calibrated equalized odds postprocessing
- [x] Meta-estimator workaround again
- [ ] Equalized odds postprocessing
- [ ] Reject option classification
- [ ] Miscellaneous:
- [ ] Explainers
Empty file added aif360/sklearn/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions aif360/sklearn/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
The dataset format for ``aif360.sklearn`` is a :class:`pandas.DataFrame` with
protected attributes in the index.

Warning:
Currently, while all scikit-learn classes will accept DataFrames as inputs,
most classes will return a :class:`numpy.ndarray`. Therefore, many pre-
processing steps, when placed before an ``aif360.sklearn`` step in a
Pipeline, will cause errors.
"""
from aif360.sklearn.datasets.utils import *
from aif360.sklearn.datasets.openml_datasets import *
from aif360.sklearn.datasets.compas_dataset import fetch_compas
83 changes: 83 additions & 0 deletions aif360/sklearn/datasets/compas_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os

import pandas as pd

from aif360.sklearn.datasets.utils import standardize_dataset


# cache location
DATA_HOME_DEFAULT = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'..', 'data', 'raw')
COMPAS_URL = 'https://raw.githubusercontent.com/propublica/compas-analysis/master/compas-scores-two-years.csv'

def fetch_compas(data_home=None, binary_race=False,
usecols=['sex', 'age', 'age_cat', 'race', 'juv_fel_count',
'juv_misd_count', 'juv_other_count', 'priors_count',
'c_charge_degree', 'c_charge_desc'],
dropcols=[], numeric_only=False, dropna=True):
"""Load the COMPAS Recidivism Risk Scores dataset.

Optionally binarizes 'race' to 'Caucasian' (privileged) or
'African-American' (unprivileged). The other protected attribute is 'sex'
('Male' is *unprivileged* and 'Female' is *privileged*). The outcome
variable is 'Survived' (favorable) if the person was not accused of a crime
within two years or 'Recidivated' (unfavorable) if they were.

Note:
The values for the 'sex' variable if numeric_only is ``True`` are 1 for
'Female and 0 for 'Male' -- opposite the convention of other datasets.

Args:
data_home (string, optional): Specify another download and cache folder
for the datasets. By default all AIF360 datasets are stored in
'aif360/sklearn/data/raw' subfolders.
binary_race (bool, optional): Filter only White and Black defendants.
usecols (single label or list-like, optional): Feature column(s) to
keep. All others are dropped.
dropcols (single label or list-like, optional): Feature column(s) to
drop.
numeric_only (bool): Drop all non-numeric feature columns.
dropna (bool): Drop rows with NAs.

Returns:
namedtuple: Tuple containing X and y for the COMPAS dataset accessible
by index or name.
"""
cache_path = os.path.join(data_home or DATA_HOME_DEFAULT,
os.path.basename(COMPAS_URL))
if os.path.isfile(cache_path):
df = pd.read_csv(cache_path, index_col='id')
else:
df = pd.read_csv(COMPAS_URL, index_col='id')
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
df.to_csv(cache_path)

# Perform the same preprocessing as the original analysis:
# https://github.com/propublica/compas-analysis/blob/master/Compas%20Analysis.ipynb
df = df[(df.days_b_screening_arrest <= 30)
& (df.days_b_screening_arrest >= -30)
& (df.is_recid != -1)
& (df.c_charge_degree != 'O')
& (df.score_text != 'N/A')]

for col in ['sex', 'age_cat', 'race', 'c_charge_degree', 'c_charge_desc']:
df[col] = df[col].astype('category')

# 'Survived' < 'Recidivated'
cats = ['Survived', 'Recidivated']
df.two_year_recid = df.two_year_recid.replace([0, 1], cats).astype('category')
df.two_year_recid = df.two_year_recid.cat.set_categories(cats, ordered=True)

if binary_race:
# 'African-American' < 'Caucasian'
df.race = df.race.cat.set_categories(['African-American', 'Caucasian'],
ordered=True)

# 'Male' < 'Female'
df.sex = df.sex.astype('category').cat.reorder_categories(
['Male', 'Female'], ordered=True)

return standardize_dataset(df, prot_attr=['sex', 'race'],
target='two_year_recid', usecols=usecols,
dropcols=dropcols, numeric_only=numeric_only,
dropna=dropna)