Skip to content

Commit

Permalink
added comments to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hoffmansc committed Feb 19, 2020
1 parent 1b829d7 commit 1002610
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 12 deletions.
8 changes: 8 additions & 0 deletions tests/sklearn/test_adversarial_debiasing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
'hours-per-week'], features_to_drop=[])

def test_adv_debias_old_reproduce():
"""Test that the old AdversarialDebiasing is reproducible."""
sess = tf.Session()
old_adv_deb = OldAdversarialDebiasing(unprivileged_groups=[{'sex': 0}],
privileged_groups=[{'sex': 1}],
Expand All @@ -34,6 +35,8 @@ def test_adv_debias_old_reproduce():
assert np.allclose(old_preds.labels, old_preds2.labels)

def test_adv_debias_old():
"""Test that the predictions of the old and new AdversarialDebiasing match.
"""
tf.reset_default_graph()
sess = tf.Session()
old_adv_deb = OldAdversarialDebiasing(unprivileged_groups=[{'sex': 0}],
Expand All @@ -48,6 +51,7 @@ def test_adv_debias_old():
assert np.allclose(old_preds.labels.flatten(), new_preds)

def test_adv_debias_reproduce():
"""Test that the new AdversarialDebiasing is reproducible."""
adv_deb = AdversarialDebiasing('sex', num_epochs=5, random_state=123)
new_preds = adv_deb.fit(X, y).predict(X)
adv_deb.sess_.close()
Expand All @@ -60,12 +64,16 @@ def test_adv_debias_reproduce():
assert new_acc == accuracy_score(y, new_preds)

def test_adv_debias_intersection():
"""Test that the new AdversarialDebiasing runs with >2 protected groups."""
adv_deb = AdversarialDebiasing(scope_name='intersect', num_epochs=5)
adv_deb.fit(X, y)
adv_deb.sess_.close()
assert adv_deb.adversary_logits_.shape[1] == 4

def test_adv_debias_grid():
"""Test that the new AdversarialDebiasing works in a grid search (and that
debiasing results in reduced accuracy).
"""
adv_deb = AdversarialDebiasing('sex', num_epochs=10, random_state=123)

params = {'debias': [True, False]}
Expand Down
12 changes: 11 additions & 1 deletion tests/sklearn/test_calibrated_equalized_odds.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
'hours-per-week'], features_to_drop=[])

def test_calib_eq_odds_sex_weighted():
"""Test that the old and new CalibratedEqualizedOdds produce the same mix
rates.
"""
logreg = LogisticRegression(solver='lbfgs', max_iter=500)
y_pred = logreg.fit(X, y, sample_weight=sample_weight).predict_proba(X)
adult_pred = adult.copy()
Expand All @@ -28,6 +31,12 @@ def test_calib_eq_odds_sex_weighted():
assert np.isclose(orig_cal_eq_odds.unpriv_mix_rate, cal_eq_odds.mix_rates_[0])

def test_postprocessingmeta_fnr():
"""Test that the old and new CalibratedEqualizedOdds produce the same
probability predictions.
This tests the whole "pipeline": splitting the data the same way, training a
LogisticRegression classifier, and training the post-processor the same way.
"""
adult_train, adult_test = adult.split([0.9], shuffle=False)
X_tr, X_te, y_tr, _, sw_tr, _ = train_test_split(X, y, sample_weight,
train_size=0.9, shuffle=False)
Expand All @@ -52,7 +61,8 @@ def test_postprocessingmeta_fnr():
orig_cal_eq_odds.fit(adult_post, adult_pred)

cal_eq_odds = PostProcessingMeta(estimator=logreg,
postprocessor=CalibratedEqualizedOdds('sex', cost_constraint='fnr', random_state=0),
postprocessor=CalibratedEqualizedOdds('sex', cost_constraint='fnr',
random_state=0),
shuffle=False)
cal_eq_odds.fit(X_tr, y_tr, sample_weight=sw_tr)

Expand Down
10 changes: 10 additions & 0 deletions tests/sklearn/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
dropna=False)

def test_standardize_dataset_basic():
"""Tests standardize_dataset on a toy example."""
dataset = basic()
X, y = dataset
X, y = dataset.X, dataset.y
Expand All @@ -28,11 +29,13 @@ def test_standardize_dataset_basic():
assert X.shape == (3, 3)

def test_sample_weight_basic():
"""Tests returning sample_weight on a toy example."""
with_weights = basic(sample_weight='X2')
assert len(with_weights) == 3
assert with_weights.X.shape == (3, 2)

def test_usecols_dropcols_basic():
"""Tests various combinations of usecols and dropcols on a toy example."""
assert basic(usecols='X1').X.columns.tolist() == ['X1']
assert basic(usecols=['X1', 'Z']).X.columns.tolist() == ['X1', 'Z']

Expand All @@ -44,30 +47,35 @@ def test_usecols_dropcols_basic():
pd.DataFrame)

def test_dropna_basic():
"""Tests dropna on a toy example."""
basic_dropna = partial(standardize_dataset, df=df, prot_attr='Z',
target='y', dropna=True)
assert basic_dropna().X.shape == (2, 3)
assert basic(dropcols='X1').X.shape == (3, 2)

def test_numeric_only_basic():
"""Tests numeric_only on a toy example."""
assert basic(prot_attr='X2', numeric_only=True).X.shape == (3, 2)
assert (basic(prot_attr='X2', dropcols='Z', numeric_only=True).X.shape
== (3, 2))

def test_fetch_adult():
"""Tests Adult Income dataset shapes with various options."""
adult = fetch_adult()
assert len(adult) == 3
assert adult.X.shape == (45222, 13)
assert fetch_adult(dropna=False).X.shape == (48842, 13)
assert fetch_adult(numeric_only=True).X.shape == (48842, 7)

def test_fetch_german():
"""Tests German Credit dataset shapes with various options."""
german = fetch_german()
assert len(german) == 2
assert german.X.shape == (1000, 21)
assert fetch_german(numeric_only=True).X.shape == (1000, 9)

def test_fetch_bank():
"""Tests Bank Marketing dataset shapes with various options."""
bank = fetch_bank()
assert len(bank) == 2
assert bank.X.shape == (45211, 15)
Expand All @@ -76,6 +84,7 @@ def test_fetch_bank():

@pytest.mark.filterwarnings('error', category=ColumnAlreadyDroppedWarning)
def test_fetch_compas():
"""Tests COMPAS Recidivism dataset shapes with various options."""
compas = fetch_compas()
assert len(compas) == 2
assert compas.X.shape == (6167, 10)
Expand All @@ -84,5 +93,6 @@ def test_fetch_compas():
assert fetch_compas(numeric_only=True).X.shape == (6172, 6)

def test_onehot_transformer():
"""Tests that categorical features can be correctly one-hot encoded."""
X, y = fetch_german()
assert len(pd.get_dummies(X).columns) == 63
14 changes: 14 additions & 0 deletions tests/sklearn/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,61 +29,75 @@
privileged_groups=[{'sex': 1}])

def test_dataset_equality():
"""Tests that the old and new datasets match exactly."""
assert (adult.features == X.values).all()
assert (adult.labels.ravel() == y).all()

def test_consistency():
"""Tests that the old and new consistency_score matches exactly."""
assert np.isclose(consistency_score(X, y), cm.consistency())

def test_specificity():
"""Tests that the old and new specificity_score matches exactly."""
spec = specificity_score(y, y_pred, sample_weight=sample_weight)
assert spec == cm.specificity()

def test_base_rate():
"""Tests that the old and new base_rate matches exactly."""
base = base_rate(y, y_pred, sample_weight=sample_weight)
assert base == cm.base_rate()

def test_selection_rate():
"""Tests that the old and new selection_rate matches exactly."""
select = selection_rate(y, y_pred, sample_weight=sample_weight)
assert select == cm.selection_rate()

def test_generalized_fpr():
"""Tests that the old and new generalized_fpr matches exactly."""
gfpr = generalized_fpr(y, y_proba, sample_weight=sample_weight)
assert np.isclose(gfpr, cm.generalized_false_positive_rate())

def test_generalized_fnr():
"""Tests that the old and new generalized_fnr matches exactly."""
gfnr = generalized_fnr(y, y_proba, sample_weight=sample_weight)
assert np.isclose(gfnr, cm.generalized_false_negative_rate())

def test_disparate_impact():
"""Tests that the old and new disparate_impact matches exactly."""
di = disparate_impact_ratio(y, y_pred, prot_attr='sex',
sample_weight=sample_weight)
assert di == cm.disparate_impact()

def test_statistical_parity():
"""Tests that the old and new statistical_parity matches exactly."""
stat = statistical_parity_difference(y, y_pred, prot_attr='sex',
sample_weight=sample_weight)
assert stat == cm.statistical_parity_difference()

def test_equal_opportunity():
"""Tests that the old and new equal_opportunity matches exactly."""
eopp = equal_opportunity_difference(y, y_pred, prot_attr='sex',
sample_weight=sample_weight)
assert eopp == cm.equal_opportunity_difference()

def test_average_odds_difference():
"""Tests that the old and new average_odds_difference matches exactly."""
aod = average_odds_difference(y, y_pred, prot_attr='sex',
sample_weight=sample_weight)
assert np.isclose(aod, cm.average_odds_difference())

def test_average_odds_error():
"""Tests that the old and new average_odds_error matches exactly."""
aoe = average_odds_error(y, y_pred, prot_attr='sex',
sample_weight=sample_weight)
assert np.isclose(aoe, cm.average_abs_odds_difference())

def test_generalized_entropy_index():
"""Tests that the old and new generalized_entropy_index matches exactly."""
gei = generalized_entropy_error(y, y_pred)
assert np.isclose(gei, cm.generalized_entropy_index())

def test_between_group_generalized_entropy_index():
"""Tests that the old and new between_group_GEI matches exactly."""
bggei = between_group_generalized_entropy_error(y, y_pred, prot_attr='sex')
assert bggei == cm.between_group_generalized_entropy_index()
17 changes: 6 additions & 11 deletions tests/sklearn/test_reweighing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,32 @@
from aif360.sklearn.preprocessing import Reweighing, ReweighingMeta


# X, y = fetch_german(numeric_only=True, dropcols='duration')
# X.age = (X.age >= 25).astype('int')
# german = GermanDataset(categorical_features=[], features_to_keep=[
# 'credit_amount', 'investment_as_income_percentage', 'residence_since',
# 'age', 'number_of_credits', 'people_liable_for', 'sex'])
X, y, sample_weight = fetch_adult(numeric_only=True)
adult = AdultDataset(instance_weights_name='fnlwgt', categorical_features=[],
features_to_keep=['age', 'education-num', 'capital-gain', 'capital-loss',
'hours-per-week'], features_to_drop=[])

def test_reweighing_sex():
"""Test that the old and new Reweighing produce the same sample_weights."""
orig_rew = OrigReweighing(unprivileged_groups=[{'sex': 0}],
privileged_groups=[{'sex': 1}])
adult_fair = orig_rew.fit_transform(adult)
rew = Reweighing('sex')
_, new_sample_weight = rew.fit_transform(X, y, sample_weight=sample_weight)

# assert np.allclose([[orig_rew.w_up_unfav, orig_rew.w_up_fav],
# [orig_rew.w_p_unfav, orig_rew.w_p_fav]],
# rew.reweigh_factors_)
assert np.allclose([[orig_rew.w_up_unfav, orig_rew.w_up_fav],
[orig_rew.w_p_unfav, orig_rew.w_p_fav]],
rew.reweigh_factors_)
assert np.allclose(adult_fair.instance_weights, new_sample_weight)

def test_reweighing_intersection():
"""Test that the new Reweighing runs with >2 protected groups."""
rew = Reweighing()
rew.fit_transform(X, y)
assert rew.reweigh_factors_.shape == (4, 2)

def test_gridsearch():
# logreg = LogisticRegression(solver='lbfgs', max_iter=500)
# rew = ReweighingMeta(estimator=logreg, reweigher=Reweighing('sex'))
"""Test that ReweighingMeta works in a grid search."""
rew = ReweighingMeta(estimator=LogisticRegression(solver='liblinear'))

# UGLY workaround for sklearn issue: https://stackoverflow.com/a/49598597
Expand All @@ -51,4 +47,3 @@ def score_func(y_true, y_pred, sample_weight):

clf = GridSearchCV(rew, params, scoring=scoring, cv=5, iid=False)
clf.fit(X, y, **{'sample_weight': sample_weight})
# print(clf.best_score_)

0 comments on commit 1002610

Please sign in to comment.