diff --git a/aif360/sklearn/datasets/__init__.py b/aif360/sklearn/datasets/__init__.py index cd475d14..525b1431 100644 --- a/aif360/sklearn/datasets/__init__.py +++ b/aif360/sklearn/datasets/__init__.py @@ -8,7 +8,8 @@ processing steps, when placed before an ``aif360.sklearn`` step in a Pipeline, will cause errors. """ -from aif360.sklearn.datasets.utils import * -from aif360.sklearn.datasets.openml_datasets import * +from aif360.sklearn.datasets.utils import standardize_dataset, NumericConversionWarning +from aif360.sklearn.datasets.openml_datasets import fetch_adult, fetch_german, fetch_bank from aif360.sklearn.datasets.compas_dataset import fetch_compas -from aif360.sklearn.datasets.tempeh_datasets import * +from aif360.sklearn.datasets.meps_datasets import fetch_meps +from aif360.sklearn.datasets.tempeh_datasets import fetch_lawschool_gpa diff --git a/aif360/sklearn/datasets/meps_datasets.py b/aif360/sklearn/datasets/meps_datasets.py new file mode 100644 index 00000000..12fc5cca --- /dev/null +++ b/aif360/sklearn/datasets/meps_datasets.py @@ -0,0 +1,124 @@ +from io import BytesIO +import os +from zipfile import ZipFile + +import pandas as pd +import requests + +from aif360.sklearn.datasets.utils import standardize_dataset + + +# cache location +DATA_HOME_DEFAULT = os.path.join(os.path.dirname(os.path.abspath(__file__)), + '..', 'data', 'raw') +MEPS_URL = "https://meps.ahrq.gov/mepsweb/data_files/pufs" +PROMPT = """ +By using this function you acknowledge the responsibility for reading and +abiding by any copyright/usage rules and restrictions as stated on the MEPS web +site (https://meps.ahrq.gov/data_stats/data_use.jsp). + +Continue [y/n]? > """ + +def fetch_meps(panel, *, accept_terms=None, data_home=None, cache=True, + usecols=['REGION', 'AGE', 'SEX', 'RACE', 'MARRY', 'FTSTU', + 'ACTDTY', 'HONRDC', 'RTHLTH', 'MNHLTH', 'HIBPDX', + 'CHDDX', 'ANGIDX', 'MIDX', 'OHRTDX', 'STRKDX', 'EMPHDX', + 'CHBRON', 'CHOLDX', 'CANCERDX', 'DIABDX', 'JTPAIN', + 'ARTHDX', 'ARTHTYPE', 'ASTHDX', 'ADHDADDX', 'PREGNT', + 'WLKLIM', 'ACTLIM', 'SOCLIM', 'COGLIM', 'DFHEAR42', + 'DFSEE42', 'ADSMOK42', 'PCS42', 'MCS42', 'K6SUM42', + 'PHQ242', 'EMPST', 'POVCAT', 'INSCOV'], + dropcols=None, numeric_only=False, dropna=True): + """Load the Medical Expenditure Panel Survey (MEPS) dataset. + + Args: + panel ({19, 20, 21}): Panel number (only 19, 20, and 21 are currently + supported). + accept_terms (bool, optional): Bypass terms prompt. Note: by setting + this to ``True``, you acknowledge responsibility for reading and + accepting the MEPS usage terms. + data_home (string, optional): Specify another download and cache folder + for the datasets. By default all AIF360 datasets are stored in + 'aif360/sklearn/data/raw' subfolders. + cache (bool): Whether to cache downloaded datasets. + usecols (single label or list-like, optional): Feature column(s) to + keep. All others are dropped. + dropcols (single label or list-like, optional): Feature column(s) to + drop. + numeric_only (bool): Drop all non-numeric feature columns. + dropna (bool): Drop rows with NAs. + + Returns: + namedtuple: Tuple containing X and y for the MEPS dataset accessible by + index or name. + """ + if panel not in {19, 20, 21}: + raise ValueError("only panels 19, 20, and 21 are currently supported.") + + fname = 'h192' if panel == 21 else 'h181' + data_url = os.path.join(MEPS_URL, fname + 'ssp.zip') + cache_path = os.path.join(data_home or DATA_HOME_DEFAULT, fname + '.csv') + if cache and os.path.isfile(cache_path): + df = pd.read_csv(cache_path) + else: + # skip prompt if user chooses + accept = accept_terms or input(PROMPT) + if accept != 'y' and accept != True: + raise PermissionError("Terms not agreed.") + rawz = requests.get(os.path.join(MEPS_URL, fname + 'ssp.zip')).content + with ZipFile(BytesIO(rawz)) as zf: + with zf.open(fname + '.ssp') as ssp: + df = pd.read_sas(ssp, format='xport') + # TODO: does this cause any differences? + # reduce storage size + df = df.apply(pd.to_numeric, errors='ignore', downcast='integer') + if cache: + os.makedirs(os.path.dirname(cache_path), exist_ok=True) + df.to_csv(cache_path, index=None) + # restrict to correct panel + df = df[df['PANEL'] == panel] + # change all 15s to 16s if panel == 21 + yr = 16 if panel == 21 else 15 + + # non-Hispanic Whites are marked as WHITE; all others as NON-WHITE + df['RACEV2X'] = (df['HISPANX'] == 2) & (df['RACEV2X'] == 1) + + # rename all columns that are panel/round-specific + df = df.rename(columns={ + 'FTSTU53X': 'FTSTU', 'ACTDTY53': 'ACTDTY', 'HONRDC53': 'HONRDC', + 'RTHLTH53': 'RTHLTH', 'MNHLTH53': 'MNHLTH', 'CHBRON53': 'CHBRON', + 'JTPAIN53': 'JTPAIN', 'PREGNT53': 'PREGNT', 'WLKLIM53': 'WLKLIM', + 'ACTLIM53': 'ACTLIM', 'SOCLIM53': 'SOCLIM', 'COGLIM53': 'COGLIM', + 'EMPST53': 'EMPST', 'REGION53': 'REGION', 'MARRY53X': 'MARRY', + 'AGE53X': 'AGE', f'POVCAT{yr}': 'POVCAT', f'INSCOV{yr}': 'INSCOV', + f'PERWT{yr}F': 'PERWT', 'RACEV2X': 'RACE'}) + + df.loc[df.AGE < 0, 'AGE'] = None # set invalid ages to NaN + cat_cols = ['REGION', 'SEX', 'RACE', 'MARRY', 'FTSTU', 'ACTDTY', 'HONRDC', + 'RTHLTH', 'MNHLTH', 'HIBPDX', 'CHDDX', 'ANGIDX', 'MIDX', + 'OHRTDX', 'STRKDX', 'EMPHDX', 'CHBRON', 'CHOLDX', 'CANCERDX', + 'DIABDX', 'JTPAIN', 'ARTHDX', 'ARTHTYPE', 'ASTHDX', 'ADHDADDX', + 'PREGNT', 'WLKLIM', 'ACTLIM', 'SOCLIM', 'COGLIM', 'DFHEAR42', + 'DFSEE42', 'ADSMOK42', 'PHQ242', 'EMPST', 'POVCAT', 'INSCOV', + 'EDUCYR', 'HIDEG'] # TODO: why are these included here but not in usecols? + for col in cat_cols: + df[col] = df[col].astype('category') + thresh = 0 if col in ['REGION', 'MARRY', 'ASTHDX'] else -1 + na_cats = [c for c in df[col].cat.categories if c < thresh] + df[col] = df[col].cat.remove_categories(na_cats) # set NaN cols to NaN + + df['SEX'] = df['SEX'].cat.rename_categories({1: 'Male', 2: 'Female'}) + df['RACE'] = df['RACE'].cat.rename_categories({False: 'Non-White', True: 'White'}) + df['RACE'] = df['RACE'].cat.reorder_categories(['Non-White', 'White'], ordered=True) + + # Compute UTILIZATION, binarize it to 0 (< 10) and 1 (>= 10) + cols = [f'OBTOTV{yr}', f'OPTOTV{yr}', f'ERTOT{yr}', f'IPNGTD{yr}', f'HHTOTD{yr}'] + util = df[cols].sum(axis=1) + df['UTILIZATION'] = pd.cut(util, [min(util)-1, 10, max(util)+1], right=False, + labels=['< 10 Visits', '>= 10 Visits'])#['low', 'high']) + + # TODO: let standardize_dataset handle dropna (see above todo re: extra cols) + return standardize_dataset(df.dropna(), prot_attr='RACE', target='UTILIZATION', + sample_weight='PERWT', usecols=usecols, + dropcols=dropcols, numeric_only=numeric_only, + dropna=dropna) diff --git a/tests/sklearn/test_datasets.py b/tests/sklearn/test_datasets.py index 9f26130e..994a80a9 100644 --- a/tests/sklearn/test_datasets.py +++ b/tests/sklearn/test_datasets.py @@ -6,9 +6,10 @@ from sklearn.compose import make_column_transformer from sklearn.preprocessing import OneHotEncoder +from aif360.datasets import MEPSDataset19, MEPSDataset20, MEPSDataset21 from aif360.sklearn.datasets import ( standardize_dataset, NumericConversionWarning, fetch_adult, fetch_bank, - fetch_german, fetch_compas, fetch_lawschool_gpa) + fetch_german, fetch_compas, fetch_lawschool_gpa, fetch_meps) df = pd.DataFrame([[1, 2, 3, 'a'], [5, 6, 7, 'b'], [np.NaN, 10, 11, 'c']], @@ -155,15 +156,26 @@ def test_fetch_lawschool_gpa(): assert len(gpa) == 2 assert gpa.X.shape == (22342, 3) assert gpa.y.nunique() > 2 # regression - assert fetch_lawschool_gpa(numeric_only=True, dropna=True).X.shape == (22342, 3) + assert fetch_lawschool_gpa(numeric_only=True, dropna=False).X.shape == (22342, 3) + +@pytest.mark.parametrize("panel, cls", [(19, MEPSDataset19), (20, MEPSDataset20), (21, MEPSDataset21)]) +def test_fetch_meps(panel, cls): + """Tests MEPS datasets shapes with various options.""" + meps = fetch_meps(panel, cache=False, accept_terms=True) + assert len(meps) == 3 + meps.X.RACE = meps.X.RACE.factorize(sort=True)[0] + MEPS = cls() + assert all(pd.get_dummies(meps.X) == MEPS.features) + assert all(meps.y.factorize(sort=True)[0] == MEPS.labels.ravel()) + + # assert fetch_meps(panel, dropna=False).X.shape == () def test_onehot_transformer(): """Tests that categorical features can be correctly one-hot encoded.""" X, y = fetch_german() - assert pd.get_dummies(X).shape[1] == 64 - # XXX: 'purpose' col contains unused category 'vacation' - X.purpose = X.purpose.cat.remove_unused_categories() - assert pd.get_dummies(X).shape[1] == 63 - assert make_column_transformer((OneHotEncoder(), X.dtypes == 'category'), - remainder='passthrough').fit_transform(X).shape[1] == 63 - + ohe = make_column_transformer( + (OneHotEncoder(), X.dtypes == 'category'), + remainder='passthrough', verbose_feature_names_out=False) + dum = pd.get_dummies(X) + assert ohe.fit_transform(X).shape[1] == dum.shape[1] == 63 + assert dum.columns.symmetric_difference(ohe.get_feature_names_out()).empty