### Imports

In [1]:
import xaibenchmark as xb
from xaibenchmark import explainers
from xaibenchmark import dataset
from xaibenchmark.comparator import ExplainerComparator
from xaibenchmark.visualize_metrics import print_metrics, load_metrics_from_json
import sklearn.ensemble

### Initialization

In [2]:
data = dataset.load_csv_data('adult', root_path='../data')

Train machine learning classifier that the explainers are supposed to explain.

In [3]:
rf = sklearn.ensemble.RandomForestClassifier(n_estimators=50, n_jobs=5)
rf.fit(xb.utils.onehot_encode(data.data, data), data.target.to_numpy().reshape(-1))
print('Train', sklearn.metrics.accuracy_score(data.target, rf.predict(xb.utils.onehot_encode(data.data, data))))
print('Dev', sklearn.metrics.accuracy_score(data.target_dev, rf.predict(xb.utils.onehot_encode(data.data_dev, data))))
print('Test', sklearn.metrics.accuracy_score(data.target_test, rf.predict(xb.utils.onehot_encode(data.data_test, data))))

Train 0.9996490150484798
Dev 0.8524774774774775
Test 0.8509305325225723


Initialize explainer LIME.

In [4]:
pred_fn = lambda x: rf.predict_proba(xb.utils.onehot_encode(x, data))

In [5]:
ex_lime = explainers.LimeExplainer(data, pred_fn, discretize_continuous=False)

Initialize explainers ANCHOR with different values for the precision of the explanation.

In [6]:
ex_anchors1 = explainers.AnchorsExplainer(data, pred_fn, 0.9)
ex_anchors2 = explainers.AnchorsExplainer(data, pred_fn, 0.75)
ex_anchors3 = explainers.AnchorsExplainer(data, pred_fn, 0.6)

Initialize comparator of different explainers and add them to it.

In [7]:
comp = ExplainerComparator()
comp.add_explainer(ex_anchors1, 'ANCHORS 0.9')
comp.add_explainer(ex_anchors2, 'ANCHORS 0.75')
comp.add_explainer(ex_anchors3, 'ANCHORS 0.6')
comp.add_explainer(ex_lime, 'LIME 1')


### Execution
Provide the comparator with different instances that the explainers will explain.

In [8]:
comp.explain_instances(data.data.iloc[[123, 234, 345, 456]])

Store metric data as json and assert that storing and reloading data does not modify it.

In [9]:
metric_data = comp.get_metric_data()
comp.store_metrics()
assert load_metrics_from_json('metrics.json') == metric_data

Visualize metrics as tables or bar charts.

In [10]:
print_metrics(metric_data, plot='table', show_metric_with_one_value=True)
print_metrics(metric_data, plot='bar', show_metric_with_one_value=False)
print_metrics(metric_data, explainer='ANCHORS 0.9')
print_metrics(metric_data, plot="bar", explainer='ANCHORS 0.9')