------
# Set up a Classifier

In [1]:
from __future__ import print_function
import numpy as np
np.random.seed(1)
import sys
import sklearn
import sklearn.ensemble
from anchor import utils
from anchor import anchor_tabular

In [2]:
# make sure you have adult/adult.data inside dataset_folder
dataset_folder = '../data/'
dataset = utils.load_dataset('adult', balance=True, dataset_folder=dataset_folder, discretize=True)

In [3]:
rf = sklearn.ensemble.RandomForestClassifier(n_estimators=50, n_jobs=5)
rf.fit(dataset.train, dataset.labels_train)
print('Train', sklearn.metrics.accuracy_score(dataset.labels_train, rf.predict(dataset.train)))
print('Test', sklearn.metrics.accuracy_score(dataset.labels_test, rf.predict(dataset.test)))

Train 0.9350338780390594
Test 0.8489483747609943


---------
# Set up the Explainer Wrapper

In [5]:
def metric(fn):
    # mark the method as something that requires view's class
    fn.tag = 'metric'
    return fn

In [6]:
class Explainer:
    
    
    def __init__(self):
        raise NotImplementedError
        
    @metric
    def area(self):
        raise NotImplementedError
        
    @metric
    def coverage(self):
        raise NotImplementedError
        
    def metrics(self):
                
        def checkImplemented(f):
            try:
                f()
            except NotImplementedError:
                return False
            return True
        
        all_metrics_strings = [x for x in dir(self) if getattr(getattr(self, x), 'tag', None) == 'metric']
        all_metrics = [getattr(self, m) for m in all_metrics_strings]
        implemented_metrics = [metric for metric, metric_name in zip(all_metrics, all_metrics_strings) if checkImplemented(metric)]
        
        implemented_metric_names = set([metric_name for metric, metric_name in zip(all_metrics, all_metrics_strings) if checkImplemented(metric)])
        return implemented_metric_names
    
    def infer_metrics(self):
        
        def checkImplemented(f):
            try:
                f()
            except NotImplementedError:
                return False
            return True
        
        all_metrics_strings = [x for x in dir(self) if getattr(getattr(self, x), 'tag', None) == 'metric']
        all_metrics = [getattr(self, m) for m in all_metrics_strings]
        implemented_metrics = {metric for metric, metric_name in zip(all_metrics, all_metrics_strings) if checkImplemented(metric)}
        implemented_metric_names = set([metric_name for metric, metric_name in zip(all_metrics, all_metrics_strings) if checkImplemented(metric)])
        
        transfer = [
            ({'coverage'}, 'inverse_coverage', lambda : 1 / self.coverage()),
        ]
        
        old_metrics = {}
        new_metrics = implemented_metric_names
        while (new_metrics != old_metrics):
            for transfer_list in transfer:
                if transfer_list[0] <= new_metrics:
                    setattr(self, transfer_list[1], metric(transfer_list[2]))
                    
            old_metrics = new_metrics
            all_metrics_strings = [x for x in dir(self) if getattr(getattr(self, x), 'tag', None) == 'metric']
            all_metrics = [getattr(self, m) for m in all_metrics_strings]
            new_metrics = {metric_name for metric, metric_name in zip(all_metrics, all_metrics_strings) if checkImplemented(metric)}
            
        print('inferred metrics:', new_metrics)
        
    def report(self):
        
        def checkImplemented(f):
            try:
                f()
            except NotImplementedError:
                return False
            return True
        
        all_metrics = {(x, getattr(self, x)) for x in dir(self) if getattr(getattr(self, x), 'tag', None) == 'metric'}
        implemented_metrics = {(x, f()) for (x, f) in all_metrics if checkImplemented(f)}
        return implemented_metrics
                

--------------
# Set up specific Explainers

In [7]:
class LimeExplainer(Explainer):
    
    def __init__(self, data, feature_names, class_names):
        self.explainer = lime.lime_tabular.LimeTabularExplainer(data, feature_names=feature_names,
                                                   class_names=class_names,
                                                   discretize_continuous=False)
        self.training_data = data
        
    @metric
    def coverage(self):
        return 6
    
    def distance(self):
        return np.sqrt(training_data.shape[1]) * .75

In [8]:
class AnchorsExplainer(Explainer):
    
    def __init__(self, class_names, feature_names, train, categorical_names):
        
        self.explainer = anchor_tabular.AnchorTabularExplainer(
            class_names,
            feature_names,
            train,
            categorical_names)
        
    def explain_instance(self, instance, predictor, threshold=0.95):
        self.explanation = self.explainer.explain_instance(instance, predictor, threshold=threshold)
        return self.explanation
    
    @metric
    def coverage(self):
        if hasattr(self, 'explanation'):
            return self.explanation.coverage()
        raise NotImplementedError 
        
    
    @metric
    def precision(self):
        if hasattr(self, 'explanation'):
            return self.explanation.precision()
        raise NotImplementedError 

----------
# Use Explainer

In [9]:
# instantiate anchors explainer
exp = AnchorsExplainer(dataset.class_names,
    dataset.feature_names,
    dataset.train,
    dataset.categorical_names)

In [10]:
# get all currently defined metrics
exp.metrics()

set()

In [11]:
# report all current metrics
exp.report()

set()

In [12]:
# infer other possible metrics
exp.infer_metrics()

inferred metrics: set()


In [13]:
# get all currently defined metrics
exp.metrics()

set()

In [14]:
# report all current metrics
exp.report()

set()

In [15]:
# explain a single instance (needed to compute coverage)
explanation = exp.explain_instance(dataset.test[0], rf.predict, threshold=0.95)

In [16]:
# get all currently defined metrics
exp.metrics()

{'coverage', 'precision'}

In [17]:
# report all current metrics
exp.report()

{('coverage', 0.0161), ('precision', 0.9833333333333333)}

In [18]:
# infer other possible metrics
exp.infer_metrics()

inferred metrics: {'inverse_coverage', 'coverage', 'precision'}


In [19]:
# get all currently defined metrics
exp.metrics()

{'coverage', 'inverse_coverage', 'precision'}

In [20]:
# report all current metrics
exp.report()

{('coverage', 0.0161),
 ('inverse_coverage', 62.11180124223603),
 ('precision', 0.9833333333333333)}