Skip to content

Commit

Permalink
Anchor (#3)
Browse files Browse the repository at this point in the history
* add tabular anchors

* add type annotations

* add adult anchor example

* expose anchor at explainer level and fix setup file

* fix type annotation

* add tests for anchor

* update style and annotations
  • Loading branch information
arnaudvl committed Mar 5, 2019
1 parent f6e8c50 commit f5e09ea
Show file tree
Hide file tree
Showing 11 changed files with 1,772 additions and 1 deletion.
102 changes: 102 additions & 0 deletions alibi/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from typing import Tuple


def adult(features_drop: list = ["fnlwgt", "Education-Num"]) -> Tuple[np.ndarray, np.ndarray, list, dict]:
"""
Downloads and pre-processes 'adult' dataset.
More info: http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/
Parameters
----------
features_drop
List of features to be dropped from dataset
Returns
-------
Dataset, labels, a list of features and a dictionary containing a list with the potential categories
for each categorical feature where the key refers to the feature column.
"""
# download data
dataset_url = 'http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data'
raw_features = ['Age', 'Workclass', 'fnlwgt', 'Education', 'Education-Num', 'Marital Status',
'Occupation', 'Relationship', 'Race', 'Sex', 'Capital Gain', 'Capital Loss',
'Hours per week', 'Country', 'Target']
raw_data = pd.read_csv(dataset_url, names=raw_features, delimiter=', ').fillna('?')

# get labels, features and drop unnecessary features
labels = (raw_data['Target'] == '>50K').astype(int).values
features_drop += ['Target']
data = raw_data.drop(features_drop, axis=1)
features = list(data.columns)

# map categorical features
education_map = {
'10th': 'Dropout', '11th': 'Dropout', '12th': 'Dropout', '1st-4th':
'Dropout', '5th-6th': 'Dropout', '7th-8th': 'Dropout', '9th':
'Dropout', 'Preschool': 'Dropout', 'HS-grad': 'High School grad',
'Some-college': 'High School grad', 'Masters': 'Masters',
'Prof-school': 'Prof-School', 'Assoc-acdm': 'Associates',
'Assoc-voc': 'Associates'
}
occupation_map = {
"Adm-clerical": "Admin", "Armed-Forces": "Military",
"Craft-repair": "Blue-Collar", "Exec-managerial": "White-Collar",
"Farming-fishing": "Blue-Collar", "Handlers-cleaners":
"Blue-Collar", "Machine-op-inspct": "Blue-Collar", "Other-service":
"Service", "Priv-house-serv": "Service", "Prof-specialty":
"Professional", "Protective-serv": "Other", "Sales":
"Sales", "Tech-support": "Other", "Transport-moving":
"Blue-Collar"
}
country_map = {
'Cambodia': 'SE-Asia', 'Canada': 'British-Commonwealth', 'China':
'China', 'Columbia': 'South-America', 'Cuba': 'Other',
'Dominican-Republic': 'Latin-America', 'Ecuador': 'South-America',
'El-Salvador': 'South-America', 'England': 'British-Commonwealth',
'France': 'Euro_1', 'Germany': 'Euro_1', 'Greece': 'Euro_2',
'Guatemala': 'Latin-America', 'Haiti': 'Latin-America',
'Holand-Netherlands': 'Euro_1', 'Honduras': 'Latin-America',
'Hong': 'China', 'Hungary': 'Euro_2', 'India':
'British-Commonwealth', 'Iran': 'Other', 'Ireland':
'British-Commonwealth', 'Italy': 'Euro_1', 'Jamaica':
'Latin-America', 'Japan': 'Other', 'Laos': 'SE-Asia', 'Mexico':
'Latin-America', 'Nicaragua': 'Latin-America',
'Outlying-US(Guam-USVI-etc)': 'Latin-America', 'Peru':
'South-America', 'Philippines': 'SE-Asia', 'Poland': 'Euro_2',
'Portugal': 'Euro_2', 'Puerto-Rico': 'Latin-America', 'Scotland':
'British-Commonwealth', 'South': 'Euro_2', 'Taiwan': 'China',
'Thailand': 'SE-Asia', 'Trinadad&Tobago': 'Latin-America',
'United-States': 'United-States', 'Vietnam': 'SE-Asia'
}
married_map = {
'Never-married': 'Never-Married', 'Married-AF-spouse': 'Married',
'Married-civ-spouse': 'Married', 'Married-spouse-absent':
'Separated', 'Separated': 'Separated', 'Divorced':
'Separated', 'Widowed': 'Widowed'
}
mapping = {'Education': education_map, 'Occupation': occupation_map, 'Country': country_map,
'Marital Status': married_map}

data_copy = data.copy()
for f, f_map in mapping.items():
data_tmp = data_copy[f].values
for key, value in f_map.items():
data_tmp[data_tmp == key] = value
data[f] = data_tmp

# get categorical features and apply labelencoding
categorical_features = [f for f in features if data[f].dtype == 'O']
category_map = {}
for f in categorical_features:
le = LabelEncoder()
data_tmp = le.fit_transform(data[f].values)
data[f] = data_tmp
category_map[features.index(f)] = list(le.classes_)

# only return data values
data = data.values

return data, labels, features, category_map
7 changes: 7 additions & 0 deletions alibi/explainers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""
The 'alibi.explainers' module includes feature importance, counterfactual and anchor-based explainers.
"""

from .anchor.anchor_tabular import AnchorTabular

__all__ = ["AnchorTabular"]
Empty file.
Loading

0 comments on commit f5e09ea

Please sign in to comment.