Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anchor #3

Merged
merged 7 commits into from
Mar 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions alibi/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from typing import Tuple


def adult(features_drop: list = ["fnlwgt", "Education-Num"]) -> Tuple[np.ndarray, np.ndarray, list, dict]:
"""
Downloads and pre-processes 'adult' dataset.
More info: http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/

Parameters
----------
features_drop
List of features to be dropped from dataset

Returns
-------
Dataset, labels, a list of features and a dictionary containing a list with the potential categories
for each categorical feature where the key refers to the feature column.
"""
# download data
dataset_url = 'http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data'
raw_features = ['Age', 'Workclass', 'fnlwgt', 'Education', 'Education-Num', 'Marital Status',
'Occupation', 'Relationship', 'Race', 'Sex', 'Capital Gain', 'Capital Loss',
'Hours per week', 'Country', 'Target']
raw_data = pd.read_csv(dataset_url, names=raw_features, delimiter=', ').fillna('?')

# get labels, features and drop unnecessary features
labels = (raw_data['Target'] == '>50K').astype(int).values
features_drop += ['Target']
data = raw_data.drop(features_drop, axis=1)
features = list(data.columns)

# map categorical features
education_map = {
'10th': 'Dropout', '11th': 'Dropout', '12th': 'Dropout', '1st-4th':
'Dropout', '5th-6th': 'Dropout', '7th-8th': 'Dropout', '9th':
'Dropout', 'Preschool': 'Dropout', 'HS-grad': 'High School grad',
'Some-college': 'High School grad', 'Masters': 'Masters',
'Prof-school': 'Prof-School', 'Assoc-acdm': 'Associates',
'Assoc-voc': 'Associates'
}
occupation_map = {
"Adm-clerical": "Admin", "Armed-Forces": "Military",
"Craft-repair": "Blue-Collar", "Exec-managerial": "White-Collar",
"Farming-fishing": "Blue-Collar", "Handlers-cleaners":
"Blue-Collar", "Machine-op-inspct": "Blue-Collar", "Other-service":
"Service", "Priv-house-serv": "Service", "Prof-specialty":
"Professional", "Protective-serv": "Other", "Sales":
"Sales", "Tech-support": "Other", "Transport-moving":
"Blue-Collar"
}
country_map = {
'Cambodia': 'SE-Asia', 'Canada': 'British-Commonwealth', 'China':
'China', 'Columbia': 'South-America', 'Cuba': 'Other',
'Dominican-Republic': 'Latin-America', 'Ecuador': 'South-America',
'El-Salvador': 'South-America', 'England': 'British-Commonwealth',
'France': 'Euro_1', 'Germany': 'Euro_1', 'Greece': 'Euro_2',
'Guatemala': 'Latin-America', 'Haiti': 'Latin-America',
'Holand-Netherlands': 'Euro_1', 'Honduras': 'Latin-America',
'Hong': 'China', 'Hungary': 'Euro_2', 'India':
'British-Commonwealth', 'Iran': 'Other', 'Ireland':
'British-Commonwealth', 'Italy': 'Euro_1', 'Jamaica':
'Latin-America', 'Japan': 'Other', 'Laos': 'SE-Asia', 'Mexico':
'Latin-America', 'Nicaragua': 'Latin-America',
'Outlying-US(Guam-USVI-etc)': 'Latin-America', 'Peru':
'South-America', 'Philippines': 'SE-Asia', 'Poland': 'Euro_2',
'Portugal': 'Euro_2', 'Puerto-Rico': 'Latin-America', 'Scotland':
'British-Commonwealth', 'South': 'Euro_2', 'Taiwan': 'China',
'Thailand': 'SE-Asia', 'Trinadad&Tobago': 'Latin-America',
'United-States': 'United-States', 'Vietnam': 'SE-Asia'
}
married_map = {
'Never-married': 'Never-Married', 'Married-AF-spouse': 'Married',
'Married-civ-spouse': 'Married', 'Married-spouse-absent':
'Separated', 'Separated': 'Separated', 'Divorced':
'Separated', 'Widowed': 'Widowed'
}
mapping = {'Education': education_map, 'Occupation': occupation_map, 'Country': country_map,
'Marital Status': married_map}

data_copy = data.copy()
for f, f_map in mapping.items():
data_tmp = data_copy[f].values
for key, value in f_map.items():
data_tmp[data_tmp == key] = value
data[f] = data_tmp

# get categorical features and apply labelencoding
categorical_features = [f for f in features if data[f].dtype == 'O']
category_map = {}
for f in categorical_features:
le = LabelEncoder()
data_tmp = le.fit_transform(data[f].values)
data[f] = data_tmp
category_map[features.index(f)] = list(le.classes_)

# only return data values
data = data.values

return data, labels, features, category_map
7 changes: 7 additions & 0 deletions alibi/explainers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""
The 'alibi.explainers' module includes feature importance, counterfactual and anchor-based explainers.
"""

from .anchor.anchor_tabular import AnchorTabular

__all__ = ["AnchorTabular"]
Empty file.
Loading