In [134]:
from __future__ import print_function
import numpy as np
import sklearn.ensemble
from anchor import utils
from anchor import anchor_tabular
import xaibenchmark as xb

np.random.seed(1)

In [135]:
# make sure you have adult/adult.data inside dataset_folder
dataset_folder = '../data/'
dataset = utils.load_dataset('adult', balance=True, dataset_folder=dataset_folder, discretize=True)

In [136]:
rf = sklearn.ensemble.RandomForestClassifier(n_estimators=50, n_jobs=5)
rf.fit(dataset.train, dataset.labels_train)
print('Train', sklearn.metrics.accuracy_score(dataset.labels_train, rf.predict(dataset.train)))
print('Test', sklearn.metrics.accuracy_score(dataset.labels_test, rf.predict(dataset.test)))

Train 0.9350338780390594
Test 0.8489483747609943


--------------
### Implementation of Anchors Explainer onto base class

In [137]:
class AnchorsExplainer(xb.Explainer):
    
    def __init__(self, predictor, dataset):
        
        self.explainer = anchor_tabular.AnchorTabularExplainer(
            dataset.class_names,
            dataset.feature_names,
            dataset.train,
            dataset.categorical_names)
        self.dataset = dataset
        self.predictor = predictor

    def get_subset(self, subset_name):
        if subset_name == "train":
            return self.dataset.train, self.dataset.labels_train
        elif subset_name == "dev":
            return self.dataset.validation, self.dataset.labels_validation
        elif subset_name == "test":
            return self.dataset.test, self.dataset.labels_test
        else:
            raise NameError
        
    def explain_instance(self, instance, instance_set, threshold=0.95):
        self.explanation = self.explainer.explain_instance(instance, self.predictor.predict, threshold=threshold)
        self.instance = instance
        self.instance_set, self.instance_label_set = self.get_subset(instance_set)
        return self.explanation
    
    @xb.metric
    def coverage(self):
        if hasattr(self, 'explanation'):
            return self.explanation.coverage()
        return np.nan
    
    @xb.metric
    def precision(self):
        if hasattr(self, 'explanation'):
            return self.explanation.precision()
        return np.nan
    
    @xb.metric
    def balance_data_train(self):
        return np.mean(self.dataset.labels_train)
    
    @xb.metric
    def balance_data_dev(self):
        return np.mean(self.dataset.labels_validation)    
    
    @xb.metric
    def balance_data_test(self):
        return np.mean(self.dataset.labels_test)
            
    @xb.metric
    def balance_explanation_train(self):
        if hasattr(self, 'explanation'):
            fit_anchor = np.where(np.all(self.dataset.train[:, self.explanation.features()] == 
                                         self.instance[self.explanation.features()], axis=1))[0]
            return np.mean(self.dataset.labels_train[fit_anchor])
        return np.nan
    
    @xb.metric
    def balance_explanation_dev(self):
        if hasattr(self, 'explanation'):
            fit_anchor = np.where(np.all(self.dataset.validation[:, self.explanation.features()] == 
                                         self.instance[self.explanation.features()], axis=1))[0]
            return np.mean(self.dataset.labels_validation[fit_anchor])
        return np.nan
    
    @xb.metric
    def balance_explanation_test(self):
        if hasattr(self, 'explanation'):
            fit_anchor = np.where(np.all(self.dataset.test[:, self.explanation.features()] == 
                                         self.instance[self.explanation.features()], axis=1))[0]
            return np.mean(self.dataset.labels_test[fit_anchor])
        return np.nan
    
    @xb.metric
    def balance_model_train(self):
        return np.mean(self.predictor.predict(self.dataset.train))
    
    @xb.metric
    def balance_model_dev(self):
        return np.mean(self.predictor.predict(self.dataset.validation))
    
    @xb.metric
    def balance_model_test(self):
        return np.mean(self.predictor.predict(self.dataset.test))
    
    @xb.metric
    def area(self):
        if hasattr(self, 'explanation'):
            array = np.amax(self.dataset.train, axis=0)[self.explanation.features()]
            array = array + 1
            
            # optionally with n-th root. n=amount of features or dimension of features?
            # print(np.power(np.prod(1 / array), 1/len(array)), np.power(np.prod(1 / array), 1/np.sum(array)))
            return np.prod(1 / array)
        return np.nan
    
    @xb.metric
    def accuracy(self):
        if hasattr(self, 'explanation'):
            explanation_label = self.explanation.exp_map["prediction"]
            relevant_examples = self.get_neighborhood_instances()
            ml_pred = self.predictor.predict(relevant_examples)
            return np.count_nonzero(ml_pred == explanation_label) / len(relevant_examples)
        return np.nan                
    
    @xb.utility
    def get_neighborhood_instances(self): 
        if hasattr(self, 'explanation'):
            fit_anchor = np.where(np.all(self.instance_set[:, self.explanation.features()] == 
                                         self.instance[self.explanation.features()], axis=1))[0]
            return self.instance_set[fit_anchor]
        return []
    
    @xb.utility
    def get_explained_instance(self):
        return self.instance
    
    @xb.utility
    def distance(self, x, y):
        return np.linalg.norm(x-y)

### Usage of implemented explainer

In [138]:
# instantiate anchors explainer
exp = AnchorsExplainer(rf, dataset)
explanation = exp.explain_instance(dataset.test[5], "test", threshold=0.6)
print("Current explanation:", explanation.names())

['Education = Doctorate']


In [141]:
# get all currently defined metrics
exp.metrics()

['accuracy',
 'area',
 'balance_data_dev',
 'balance_data_test',
 'balance_data_train',
 'balance_explanation_dev',
 'balance_explanation_test',
 'balance_explanation_train',
 'balance_model_dev',
 'balance_model_test',
 'balance_model_train',
 'coverage',
 'precision']

In [142]:
# report all current metrics
exp.report()

Neighborhood:  [[ 3.  4. 10.  2. 10.  0.  4.  1.  0.  0.  2. 39.]
 [ 3.  1. 10.  2. 10.  0.  4.  1.  2.  0.  2. 11.]
 [ 2.  4. 10.  2. 10.  0.  4.  1.  0.  0.  2. 39.]
 [ 3.  4. 10.  2. 10.  0.  4.  1.  0.  0.  0. 39.]
 [ 1.  1. 10.  2. 10.  0.  4.  1.  0.  0.  2.  0.]
 [ 2.  4. 10.  2. 10.  5.  4.  0.  0.  0.  0. 39.]
 [ 2.  4. 10.  3.  8.  1.  4.  1.  2.  0.  2.  0.]
 [ 3.  7. 10.  2. 10.  0.  4.  1.  0.  0.  0. 39.]
 [ 1.  4. 10.  0. 10.  1.  4.  0.  0.  0.  2. 39.]
 [ 2.  5. 10.  2. 10.  0.  4.  1.  0.  0.  1. 39.]
 [ 2.  4. 10.  2. 10.  0.  4.  1.  0.  0.  2. 39.]
 [ 2.  4. 10.  4. 10.  1.  4.  1.  0.  0.  1. 39.]
 [ 2.  4. 10.  2. 10.  0.  4.  1.  0.  0.  1. 39.]
 [ 2.  6. 10.  4. 10.  1.  4.  0.  0.  0.  1. 39.]
 [ 1.  4. 10.  2. 10.  0.  4.  1.  0.  0.  2. 39.]
 [ 3.  7. 10.  4.  4.  1.  4.  0.  0.  0.  1. 39.]
 [ 3.  6. 10.  2. 12.  0.  4.  1.  0.  2.  0. 39.]
 [ 3.  7. 10.  2.  4.  0.  4.  1.  0.  2.  2. 39.]
 [ 1.  6. 10.  2. 10.  0.  4.  1.  0.  0.  2. 39.]
 [ 3.  4. 10.  2

{('accuracy', 0.9166666666666666),
 ('area', 0.0625),
 ('balance_data_dev', 0.4968112244897959),
 ('balance_data_test', 0.49776927979604846),
 ('balance_data_train', 0.5006775607811877),
 ('balance_explanation_dev', 0.896551724137931),
 ('balance_explanation_test', 0.8333333333333334),
 ('balance_explanation_train', 0.9026217228464419),
 ('balance_model_dev', 0.5280612244897959),
 ('balance_model_test', 0.5264499681325685),
 ('balance_model_train', 0.5195695496213631),
 ('coverage', 0.0222),
 ('precision', 0.6305705955851728)}

In [144]:
# infer other possible metrics
exp.infer_metrics()

inferred metrics: {'balance_data_test', 'accuracy', 'balance_data_train', 'balance_model_dev', 'balance_model_train', 'inverse_coverage', 'area', 'balance_explanation_test', 'furthest_distance', 'coverage', 'precision', 'balance_model_test', 'balance_data_dev', 'balance_explanation_dev', 'balance_explanation_train'}
