In [50]:
from __future__ import print_function
import numpy as np
import sklearn.ensemble
from anchor import utils
from anchor import anchor_tabular
import pandas as pd
import xaibenchmark as xb

np.random.seed(1)

In [51]:
# make sure you have adult/adult.data inside dataset_folder
dataset_folder = '../data/'
dataset = utils.load_dataset('adult', balance=True, dataset_folder=dataset_folder, discretize=True)

In [52]:
rf = sklearn.ensemble.RandomForestClassifier(n_estimators=50, n_jobs=5)
rf.fit(dataset.train, dataset.labels_train)
print('Train', sklearn.metrics.accuracy_score(dataset.labels_train, rf.predict(dataset.train)))
print('Test', sklearn.metrics.accuracy_score(dataset.labels_test, rf.predict(dataset.test)))

Train 0.9350338780390594
Test 0.8489483747609943


--------------
### Implementation of Anchors Explainer onto base class

In [53]:
class AnchorsExplainer(xb.Explainer):
    
    def __init__(self, dataset, pathToData):
        
        self.explainer = anchor_tabular.AnchorTabularExplainer(
            dataset.class_names,
            dataset.feature_names,
            dataset.train,
            dataset.categorical_names)
        self.dataset = dataset   
        self.data = pd.read_csv(pathToData, sep=',')
        
    def explain_instance(self, instance, predictor, threshold=0.95):
        self.explanation = self.explainer.explain_instance(instance, predictor, threshold=threshold)
        self.instance = instance   
        return self.explanation
    
    @xb.metric
    def coverage(self):
        if hasattr(self, 'explanation'):
            return self.explanation.coverage()
    
    @xb.metric
    def precision(self):
        if hasattr(self, 'explanation'):
            return self.explanation.precision()
        
    @xb.metric
    def balance_data_train(self):
        return np.mean(self.dataset.labels_train)
    
    @xb.metric
    def balance_data_dev(self):
        return np.mean(self.dataset.labels_validation)    
    
    @xb.metric
    def balance_data_test(self):
        return np.mean(self.dataset.labels_test)
    
    @xb.metric
    def balance_Explanation(self):
        if hasattr(self, 'explanation'):
            
            # Use original labels from data (not reduced ones by Anchor)
            labels = ["age", "workclass", "fnlwgt", "education", "education-num", "marital-status", "occupation", "relationship", "race", "sex", "capital-gain", "capital-loss", "hours-per-week", "country", "label"]
            examples_in_explanation_area = 0
            examples_positive_label = 0
            # print(self.explanation.names())
            
            # loop over all data elements
            for i in range(0,len(self.data)):
                correct_features = True
                
                # check whether the element is in the area of the explanation
                for feature in self.explanation.names():
                    
                    # separately handle non-categorical features
                    if "Age" in feature:
                        continue
                    elif "Capital Gain" in feature:
                        continue
                    elif "Hours per week" in feature:
                        continue
                        
                    # identify feature and categorical value in explanation and compare to the value of the element
                    else:
                        feature2 = feature.replace(" ", "-")
                        feature_name, value = feature2.split("-=-")
                        index = labels.index(feature_name.lower())
                        if str(self.data.iat[i,index]).strip() != str(value).strip(): 
                            # print("First index: ", str(self.data.iat[i,index]).strip(), "Second Index" , str(value).strip(), "Data ", self.data.iloc[[i]])
                            correct_features = False
                            break
                            
                # if the element is in the explanation's area, count it for balance calculation
                if correct_features:
                    examples_in_explanation_area += 1
                    if self.data.iat[i,14].strip() == ">50K":
                        examples_positive_label += 1
                                
            if examples_positive_label == 0: return 0   # TODO Capital Loss = 2
            return examples_positive_label / examples_in_explanation_area
            
    @xb.metric
    def balance_explanation_train(self):
        if hasattr(self, 'explanation'):
            fit_anchor = np.where(np.all(self.dataset.train[:, self.explanation.features()] == 
                                         self.instance[self.explanation.features()], axis=1))[0]
            return np.mean(self.dataset.labels_train[fit_anchor])
        
    @xb.metric
    def balance_explanation_dev(self):
        if hasattr(self, 'explanation'):
            fit_anchor = np.where(np.all(self.dataset.validation[:, self.explanation.features()] == 
                                         self.instance[self.explanation.features()], axis=1))[0]
            return np.mean(self.dataset.labels_validation[fit_anchor])
        
    @xb.metric
    def balance_explanation_test(self):
        if hasattr(self, 'explanation'):
            fit_anchor = np.where(np.all(self.dataset.test[:, self.explanation.features()] == 
                                         self.instance[self.explanation.features()], axis=1))[0]
            return np.mean(self.dataset.labels_test[fit_anchor])
        
    @xb.metric
    def balance_model_train(self):
        return
    
    @xb.metric
    def balance_model_train(self):
        return
    
    @xb.metric
    def balance_model_train(self):
        return
    
    @xb.utility
    def get_neighborhood_instances(self): 
        if hasattr(self, 'explanation'):
            fit_anchor = np.where(np.all(self.dataset.train[:, self.explanation.features()] == 
                                         self.instance[self.explanation.features()], axis=1))[0]
            return self.dataset.train[fit_anchor]
        return []
    
    @xb.utility
    def get_explained_instance(self):
        return self.instance
    
    @xb.utility
    def distance(self, x, y):
        return np.linalg.norm(x-y)

### Usage of implemented explainer

In [54]:
# instantiate anchors explainer
exp = AnchorsExplainer(dataset ,"../data/adult/adult.data")
explanation = exp.explain_instance(dataset.test[50], rf.predict, threshold=0.95)
print(explanation.names())

['Education = Bachelors', 'Relationship = Husband', 'Race = White', 'Country = United-States', '28.00 < Age <= 37.00', 'Marital Status = Married-civ-spouse', 'Workclass = Private']


In [55]:
# Test area
dataset.feature_names

['Age',
 'Workclass',
 'Education',
 'Marital Status',
 'Occupation',
 'Relationship',
 'Race',
 'Sex',
 'Capital Gain',
 'Capital Loss',
 'Hours per week',
 'Country']

In [56]:
# get all currently defined metrics
exp.metrics()

['balance_Explanation',
 'balance_data_dev',
 'balance_data_test',
 'balance_data_train',
 'balance_explanation_dev',
 'balance_explanation_test',
 'balance_explanation_train',
 'balance_model_train',
 'coverage',
 'precision']

In [57]:
# report all current metrics
exp.report()

{('balance_Explanation', 0.7256970610399397),
 ('balance_data_dev', 0.4968112244897959),
 ('balance_data_test', 0.49776927979604846),
 ('balance_data_train', 0.5006775607811877),
 ('balance_explanation_dev', 0.8529411764705882),
 ('balance_explanation_test', 0.8333333333333334),
 ('balance_explanation_train', 0.8980392156862745),
 ('balance_model_train', nan),
 ('coverage', 0.0191),
 ('precision', 0.9552238805970149)}

In [58]:
exp.get_neighborhood_instances()

array([[ 1.,  4.,  9., ...,  0.,  0., 39.],
       [ 1.,  4.,  9., ...,  0.,  2., 39.],
       [ 1.,  4.,  9., ...,  0.,  2., 39.],
       ...,
       [ 1.,  4.,  9., ...,  0.,  0., 39.],
       [ 1.,  4.,  9., ...,  0.,  2., 39.],
       [ 1.,  4.,  9., ...,  0.,  0., 39.]])

In [59]:
# infer other possible metrics
exp.infer_metrics()

inferred metrics: {'balance_data_test', 'balance_data_train', 'balance_model_train', 'inverse_coverage', 'balance_explanation_test', 'furthest_distance', 'coverage', 'precision', 'balance_data_dev', 'balance_explanation_dev', 'balance_explanation_train', 'balance_Explanation'}
