Skip to content

Commit

Permalink
check german and lawschool match old
Browse files Browse the repository at this point in the history
Signed-off-by: Samuel Hoffman <hoffman.sc@gmail.com>
  • Loading branch information
hoffmansc committed Jan 5, 2022
1 parent cf0c6c3 commit 6a7bbef
Showing 1 changed file with 46 additions and 2 deletions.
48 changes: 46 additions & 2 deletions tests/sklearn/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import numpy as np
from numpy.testing import assert_array_equal
import pandas as pd
from pandas.api.types import is_numeric_dtype
from pandas.testing import assert_frame_equal
import pytest
from sklearn.compose import make_column_transformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import OneHotEncoder, minmax_scale

from aif360.datasets import (
AdultDataset, CompasDataset, MEPSDataset19, MEPSDataset20, MEPSDataset21)
AdultDataset, GermanDataset, CompasDataset, LawSchoolGPADataset,
MEPSDataset19, MEPSDataset20, MEPSDataset21)
from aif360.sklearn.datasets import (
standardize_dataset, NumericConversionWarning, fetch_adult, fetch_bank,
fetch_german, fetch_compas, fetch_lawschool_gpa, fetch_meps)
Expand Down Expand Up @@ -146,6 +148,37 @@ def test_fetch_german():
assert german.X.shape == (1000, 21)
assert fetch_german(numeric_only=True).X.shape == (1000, 9)

def test_german_matches_old():
"""Tests German Credit datasets matches original version."""
column_map = {
'checking_status': 'status',
'duration': 'month',
'savings_status': 'savings',
'installment_commitment': 'investment_as_income_percentage',
'other_parties': 'other_debtors',
'property_magnitude': 'property',
'other_payment_plans': 'installment_plans',
'existing_credits': 'number_of_credits',
'job': 'skill_level',
'num_dependents': 'people_liable_for',
'own_telephone': 'telephone',
}
X, y = fetch_german()
# marital status was not included before and age was binary
X = X.drop(columns=['marital_status', 'age']).reset_index('age')
# columns are named differently in the old version
X = X.rename(columns=column_map)

old = GermanDataset()
old = old.convert_to_dataframe(de_dummy_code=True)[0].drop(columns=old.label_names)

# categories in the old version were not renamed so just map both to ints
X = X.apply(lambda c: c.factorize()[0] if not is_numeric_dtype(c) else c)
old = old.apply(lambda c: c.factorize()[0] if not is_numeric_dtype(c) else c)

assert_frame_equal(X.reset_index(drop=True), old.reset_index(drop=True),
check_like=True)

def test_fetch_bank():
"""Tests Bank Marketing dataset shapes with various options."""
bank = fetch_bank()
Expand All @@ -154,6 +187,8 @@ def test_fetch_bank():
assert fetch_bank(dropcols=None).X.shape == (45211, 16)
assert fetch_bank(numeric_only=True).X.shape == (45211, 7)

# TODO: bank doesn't match old

@pytest.mark.filterwarnings('ignore', category=NumericConversionWarning)
def test_fetch_compas():
"""Tests COMPAS Recidivism dataset shapes with various options."""
Expand Down Expand Up @@ -183,6 +218,15 @@ def test_fetch_lawschool_gpa():
assert gpa.y.nunique() > 2 # regression
assert fetch_lawschool_gpa(numeric_only=True, dropna=False).X.shape == (22342, 3)

def test_lawschool_matches_old():
"""Tests Law School GPA dataset matches original version."""
X, y = fetch_lawschool_gpa(numeric_only=True)

law = LawSchoolGPADataset()
law = law.convert_to_dataframe()[0].drop(columns=law.label_names)

assert_array_equal(minmax_scale(X), law)

@pytest.mark.parametrize("panel", [19, 20, 21])
def test_cache_meps(panel):
"""Tests if cached MEPS matches raw."""
Expand Down

0 comments on commit 6a7bbef

Please sign in to comment.