Skip to content

Commit

Permalink
made foreign_worker and education (bank) ordered
Browse files Browse the repository at this point in the history
  • Loading branch information
hoffmansc committed Feb 6, 2020
1 parent 57b2ab5 commit 8fdd6dc
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
6 changes: 6 additions & 0 deletions aif360/sklearn/datasets/openml_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ def fetch_german(data_home=None, binary_age=True, usecols=[], dropcols=[],
df = df.join(personal_status.astype('category'))
df.sex = df.sex.cat.as_ordered() # 'female' < 'male'

# 'no' < 'yes'
df.foreign_worker = df.foreign_worker.astype('category').cat.as_ordered()

return standardize_dataset(df, prot_attr=['sex', age, 'foreign_worker'],
target='credit-risk', usecols=usecols,
dropcols=dropcols, numeric_only=numeric_only,
Expand Down Expand Up @@ -215,6 +218,9 @@ def fetch_bank(data_home=None, percent10=False, usecols=[], dropcols='duration',
# replace 'unknown' marker with NaN
df.apply(lambda s: s.cat.remove_categories('unknown', inplace=True)
if hasattr(s, 'cat') and 'unknown' in s.cat.categories else s)
# 'primary' < 'secondary' < 'tertiary'
df.education = df.education.astype('category').cat.as_ordered()

return standardize_dataset(df, prot_attr='age', target='deposit',
usecols=usecols, dropcols=dropcols,
numeric_only=numeric_only, dropna=dropna)
4 changes: 2 additions & 2 deletions tests/sklearn/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ def test_fetch_german():
german = fetch_german()
assert len(german) == 2
assert german.X.shape == (1000, 21)
assert fetch_german(numeric_only=True).X.shape == (1000, 8)
assert fetch_german(numeric_only=True).X.shape == (1000, 9)

def test_fetch_bank():
bank = fetch_bank()
assert len(bank) == 2
assert bank.X.shape == (45211, 15)
assert fetch_bank(dropcols=[]).X.shape == (45211, 16)
assert fetch_bank(numeric_only=True).X.shape == (45211, 6)
assert fetch_bank(numeric_only=True).X.shape == (45211, 7)

@pytest.mark.filterwarnings('error', category=ColumnAlreadyDroppedWarning)
def test_fetch_compas():
Expand Down

0 comments on commit 8fdd6dc

Please sign in to comment.