Skip to content

Commit

Permalink
Dataset Improvements (#278)
Browse files Browse the repository at this point in the history
* allow explicit arrays for prot_attr, target
* add MEPS and violent recidivism datasets
* option to skip cache
* binary_race only affects protected attribute unless numeric_only
* remove unused categories after dropping
* minimum python version >= 3.7; scikit-learn >= 1.0
  • Loading branch information
hoffmansc committed Aug 25, 2022
1 parent faa75ee commit 3df4fa9
Show file tree
Hide file tree
Showing 19 changed files with 1,913 additions and 1,438 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ jobs:
wget ${UCI_DB}/statlog/german/german.data -P aif360/data/raw/german/
wget ${UCI_DB}/statlog/german/german.doc -P aif360/data/raw/german/
wget ${PROPUBLICA_GH}/compas-scores-two-years.csv -P aif360/data/raw/compas/
(cd aif360/data/raw/meps;Rscript generate_data.R <<< y)
- name: Lint with flake8
run: |
Expand Down
7 changes: 4 additions & 3 deletions aif360/sklearn/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
processing steps, when placed before an ``aif360.sklearn`` step in a
Pipeline, will cause errors.
"""
from aif360.sklearn.datasets.utils import *
from aif360.sklearn.datasets.openml_datasets import *
from aif360.sklearn.datasets.utils import standardize_dataset, NumericConversionWarning
from aif360.sklearn.datasets.openml_datasets import fetch_adult, fetch_german, fetch_bank
from aif360.sklearn.datasets.compas_dataset import fetch_compas
from aif360.sklearn.datasets.tempeh_datasets import *
from aif360.sklearn.datasets.meps_datasets import fetch_meps
from aif360.sklearn.datasets.tempeh_datasets import fetch_lawschool_gpa
37 changes: 28 additions & 9 deletions aif360/sklearn/datasets/compas_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
# cache location
DATA_HOME_DEFAULT = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'..', 'data', 'raw')
COMPAS_URL = 'https://raw.githubusercontent.com/propublica/compas-analysis/master/compas-scores-two-years.csv'
COMPAS_URL = 'https://raw.githubusercontent.com/propublica/compas-analysis/bafff5da3f2e45eca6c2d5055faad269defd135a/compas-scores-two-years.csv'
COMPAS_VIOLENT_URL = 'https://raw.githubusercontent.com/propublica/compas-analysis/bafff5da3f2e45eca6c2d5055faad269defd135a/compas-scores-two-years-violent.csv'

def fetch_compas(data_home=None, binary_race=False,
def fetch_compas(subset='all', *, data_home=None, cache=True, binary_race=False,
usecols=['sex', 'age', 'age_cat', 'race', 'juv_fel_count',
'juv_misd_count', 'juv_other_count', 'priors_count',
'c_charge_degree', 'c_charge_desc'],
dropcols=[], numeric_only=False, dropna=True):
dropcols=None, numeric_only=False, dropna=True):
"""Load the COMPAS Recidivism Risk Scores dataset.
Optionally binarizes 'race' to 'Caucasian' (privileged) or
Expand All @@ -28,9 +29,14 @@ def fetch_compas(data_home=None, binary_race=False,
'Female and 0 for 'Male' -- opposite the convention of other datasets.
Args:
subset ({'all' or 'violent'}): Use the violent recidivism or full
version of the dataset. Note: 'violent' is not a strict subset of
'all' -- there are four samples in 'violent' which do not show up in
'all'.
data_home (string, optional): Specify another download and cache folder
for the datasets. By default all AIF360 datasets are stored in
'aif360/sklearn/data/raw' subfolders.
cache (bool): Whether to cache downloaded datasets.
binary_race (bool, optional): Filter only White and Black defendants.
usecols (single label or list-like, optional): Feature column(s) to
keep. All others are dropped.
Expand All @@ -43,26 +49,39 @@ def fetch_compas(data_home=None, binary_race=False,
namedtuple: Tuple containing X and y for the COMPAS dataset accessible
by index or name.
"""
if subset not in {'violent', 'all'}:
raise ValueError("subset must be either 'violent' or 'all'; cannot be "
f"{subset}")

data_url = COMPAS_VIOLENT_URL if subset == 'violent' else COMPAS_URL
cache_path = os.path.join(data_home or DATA_HOME_DEFAULT,
os.path.basename(COMPAS_URL))
if os.path.isfile(cache_path):
os.path.basename(data_url))
if cache and os.path.isfile(cache_path):
df = pd.read_csv(cache_path, index_col='id')
else:
df = pd.read_csv(COMPAS_URL, index_col='id')
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
df.to_csv(cache_path)
df = pd.read_csv(data_url, index_col='id')
if cache:
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
df.to_csv(cache_path)

# Perform the same preprocessing as the original analysis:
# https://github.com/propublica/compas-analysis/blob/master/Compas%20Analysis.ipynb
df = df[(df.days_b_screening_arrest <= 30)
& (df.days_b_screening_arrest >= -30)
& (df.is_recid != -1)
& (df.c_charge_degree != 'O')
& (df.score_text != 'N/A')]
& (df['score_text' if subset == 'all' else 'v_score_text'] != 'N/A')]

for col in ['sex', 'age_cat', 'race', 'c_charge_degree', 'c_charge_desc']:
df[col] = df[col].astype('category')

# Misdemeanor < Felony
df.c_charge_degree = df.c_charge_degree.cat.reorder_categories(
['M', 'F'], ordered=True)
# 'Less than 25' < '25 - 45' < 'Greater than 45'
df.age_cat = df.age_cat.cat.reorder_categories(
['Less than 25', '25 - 45', 'Greater than 45'], ordered=True)

# 'Survived' < 'Recidivated'
cats = ['Survived', 'Recidivated']
df.two_year_recid = df.two_year_recid.replace([0, 1], cats).astype('category')
Expand Down
132 changes: 132 additions & 0 deletions aif360/sklearn/datasets/meps_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from io import BytesIO
import os
from zipfile import ZipFile

import pandas as pd
import requests

from aif360.sklearn.datasets.utils import standardize_dataset


# cache location
DATA_HOME_DEFAULT = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'..', 'data', 'raw')
MEPS_URL = "https://meps.ahrq.gov/mepsweb/data_files/pufs"
PROMPT = """
By using this function you acknowledge the responsibility for reading and
abiding by any copyright/usage rules and restrictions as stated on the MEPS web
site (https://meps.ahrq.gov/data_stats/data_use.jsp).
Continue [y/n]? > """

def fetch_meps(panel, *, accept_terms=None, data_home=None, cache=True,
usecols=['REGION', 'AGE', 'SEX', 'RACE', 'MARRY', 'FTSTU',
'ACTDTY', 'HONRDC', 'RTHLTH', 'MNHLTH', 'HIBPDX',
'CHDDX', 'ANGIDX', 'MIDX', 'OHRTDX', 'STRKDX', 'EMPHDX',
'CHBRON', 'CHOLDX', 'CANCERDX', 'DIABDX', 'JTPAIN',
'ARTHDX', 'ARTHTYPE', 'ASTHDX', 'ADHDADDX', 'PREGNT',
'WLKLIM', 'ACTLIM', 'SOCLIM', 'COGLIM', 'DFHEAR42',
'DFSEE42', 'ADSMOK42', 'PCS42', 'MCS42', 'K6SUM42',
'PHQ242', 'EMPST', 'POVCAT', 'INSCOV'],
dropcols=None, numeric_only=False, dropna=True):
"""Load the Medical Expenditure Panel Survey (MEPS) dataset.
Note:
For descriptions of the dataset features, see the `data codebook
<https://meps.ahrq.gov/mepsweb/data_stats/download_data_files_codebook.jsp?PUFId=H181>`_.
Args:
panel ({19, 20, 21}): Panel number (only 19, 20, and 21 are currently
supported).
accept_terms (bool, optional): Bypass terms prompt. Note: by setting
this to ``True``, you acknowledge responsibility for reading and
accepting the MEPS usage terms.
data_home (string, optional): Specify another download and cache folder
for the datasets. By default all AIF360 datasets are stored in
'aif360/sklearn/data/raw' subfolders.
cache (bool): Whether to cache downloaded datasets.
usecols (single label or list-like, optional): Feature column(s) to
keep. All others are dropped.
dropcols (single label or list-like, optional): Feature column(s) to
drop.
numeric_only (bool): Drop all non-numeric feature columns.
dropna (bool): Drop rows with NAs.
Returns:
namedtuple: Tuple containing X and y for the MEPS dataset accessible by
index or name.
"""
if panel not in {19, 20, 21}:
raise ValueError("only panels 19, 20, and 21 are currently supported.")

fname = 'h192' if panel == 21 else 'h181'
cache_path = os.path.join(data_home or DATA_HOME_DEFAULT, fname + '.csv')
if cache and os.path.isfile(cache_path):
df = pd.read_csv(cache_path)
else:
# skip prompt if user chooses
accept = accept_terms or input(PROMPT)
if accept != 'y' and accept != True:
raise PermissionError("Terms not agreed.")
rawz = requests.get(os.path.join(MEPS_URL, fname + 'ssp.zip')).content
with ZipFile(BytesIO(rawz)) as zf:
with zf.open(fname + '.ssp') as ssp:
df = pd.read_sas(ssp, format='xport')
# TODO: does this cause any differences?
# reduce storage size
df = df.apply(pd.to_numeric, errors='ignore', downcast='integer')
if cache:
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
df.to_csv(cache_path, index=None)
# restrict to correct panel
df = df[df['PANEL'] == panel]
# change all 15s to 16s if panel == 21
yr = 16 if panel == 21 else 15

# non-Hispanic Whites are marked as WHITE; all others as NON-WHITE
df['RACEV2X'] = (df['HISPANX'] == 2) & (df['RACEV2X'] == 1)

# rename all columns that are panel/round-specific
df = df.rename(columns={
'FTSTU53X': 'FTSTU', 'ACTDTY53': 'ACTDTY', 'HONRDC53': 'HONRDC',
'RTHLTH53': 'RTHLTH', 'MNHLTH53': 'MNHLTH', 'CHBRON53': 'CHBRON',
'JTPAIN53': 'JTPAIN', 'PREGNT53': 'PREGNT', 'WLKLIM53': 'WLKLIM',
'ACTLIM53': 'ACTLIM', 'SOCLIM53': 'SOCLIM', 'COGLIM53': 'COGLIM',
'EMPST53': 'EMPST', 'REGION53': 'REGION', 'MARRY53X': 'MARRY',
'AGE53X': 'AGE', f'POVCAT{yr}': 'POVCAT', f'INSCOV{yr}': 'INSCOV',
f'PERWT{yr}F': 'PERWT', 'RACEV2X': 'RACE'})

df.loc[df.AGE < 0, 'AGE'] = None # set invalid ages to NaN
cat_cols = ['REGION', 'SEX', 'RACE', 'MARRY', 'FTSTU', 'ACTDTY', 'HONRDC',
'RTHLTH', 'MNHLTH', 'HIBPDX', 'CHDDX', 'ANGIDX', 'MIDX',
'OHRTDX', 'STRKDX', 'EMPHDX', 'CHBRON', 'CHOLDX', 'CANCERDX',
'DIABDX', 'JTPAIN', 'ARTHDX', 'ARTHTYPE', 'ASTHDX', 'ADHDADDX',
'PREGNT', 'WLKLIM', 'ACTLIM', 'SOCLIM', 'COGLIM', 'DFHEAR42',
'DFSEE42', 'ADSMOK42', 'PHQ242', 'EMPST', 'POVCAT', 'INSCOV',
# NOTE: education tracking seems to have changed between panels. 'EDUYRDG'
# was used for panel 19, 'EDUCYR' and 'HIDEG' were used for panels 20 & 21.
# User may change usecols to include these manually.
'EDUCYR', 'HIDEG']
if panel == 19:
cat_cols += ['EDUYRDG']

for col in cat_cols:
df[col] = df[col].astype('category')
thresh = 0 if col in ['REGION', 'MARRY', 'ASTHDX'] else -1
na_cats = [c for c in df[col].cat.categories if c < thresh]
df[col] = df[col].cat.remove_categories(na_cats) # set NaN cols to NaN

df['SEX'] = df['SEX'].cat.rename_categories({1: 'Male', 2: 'Female'})
df['RACE'] = df['RACE'].cat.rename_categories({False: 'Non-White', True: 'White'})
df['RACE'] = df['RACE'].cat.reorder_categories(['Non-White', 'White'], ordered=True)

# Compute UTILIZATION, binarize it to 0 (< 10) and 1 (>= 10)
cols = [f'OBTOTV{yr}', f'OPTOTV{yr}', f'ERTOT{yr}', f'IPNGTD{yr}', f'HHTOTD{yr}']
util = df[cols].sum(axis=1)
df['UTILIZATION'] = pd.cut(util, [min(util)-1, 10, max(util)+1], right=False,
labels=['< 10 Visits', '>= 10 Visits'])#['low', 'high'])

return standardize_dataset(df, prot_attr='RACE', target='UTILIZATION',
sample_weight='PERWT', usecols=usecols,
dropcols=dropcols, numeric_only=numeric_only,
dropna=dropna)

0 comments on commit 3df4fa9

Please sign in to comment.