# Comparing Explainers with Astrapia

This notebook showcases the workflow of Astrapia: instantiating custom explainers and creating explanations from representative
instances in the data in order to derive comparable metrics from the explanations.

### Imports

In [None]:
import sklearn.ensemble

import astrapia as xb
from astrapia import explainers, dataset
from astrapia.comparator import ExplainerComparator
from astrapia.visualization import print_metrics, load_metrics_from_json, print_properties

### Initialization
Retrieve dataset. "adult" and "breast" are already implemented in this framework.

In [None]:
data = dataset.load_csv_data('breast', root_path='../data')

Train machine learning classifier that the explainers are supposed to explain.

In [None]:
rf = sklearn.ensemble.RandomForestClassifier(n_estimators=50, n_jobs=5)
rf.fit(xb.utils.onehot_encode(data.data, data), data.target.to_numpy().reshape(-1))
print('Train', sklearn.metrics.accuracy_score(data.target, rf.predict(xb.utils.onehot_encode(data.data, data))))
print('Dev', sklearn.metrics.accuracy_score(data.target_dev, rf.predict(xb.utils.onehot_encode(data.data_dev, data))))
print('Test',
      sklearn.metrics.accuracy_score(data.target_test, rf.predict(xb.utils.onehot_encode(data.data_test, data))))

Retrieve classification probabilities from machine learning classifier.

In [None]:
pred_fn = lambda x: rf.predict_proba(xb.utils.onehot_encode(x, data))

Initialize explainers LIME and DLIME.

In [None]:
ex_lime = explainers.LimeExplainer(data, pred_fn, discretize_continuous=False)
ex_dlime = explainers.DLimeExplainer(data, pred_fn, discretize_continuous=False)

Initialize explainers ANCHOR with different values for the precision of the explanation.

In [None]:
ex_anchors1 = explainers.AnchorsExplainer(data, pred_fn, 0.9)
ex_anchors2 = explainers.AnchorsExplainer(data, pred_fn, 0.75)
ex_anchors3 = explainers.AnchorsExplainer(data, pred_fn, 0.6)

Initialize comparator of different explainers and add them to it.

In [None]:
comp = ExplainerComparator()
comp.add_explainer(ex_anchors1, 'ANCHORS 0.9')
comp.add_explainer(ex_anchors2, 'ANCHORS 0.75')
comp.add_explainer(ex_anchors3, 'ANCHORS 0.6')
comp.add_explainer(ex_lime, 'LIME')
comp.add_explainer(ex_dlime, 'DLIME')

### Execution
Provide the comparator with representative instances that the explainers will explain.

In [None]:
comp.explain_representative(data, sampler='splime', count=5, pred_fn=pred_fn)

Store metric data as json and assert that storing and reloading data does not modify it.

In [None]:
metric_data = comp.get_metric_data()
comp.store_metrics()
assert load_metrics_from_json('metrics.json') == metric_data

### Visualization
Output properties and metrics as tables or bar charts.

In [None]:
print_metrics(metric_data, plot='table', show_metric_with_one_value=True)
print_metrics(metric_data, plot='bar', show_metric_with_one_value=True)
print_metrics(metric_data, explainer='ANCHORS 0.9')
print_metrics(metric_data, plot="bar", explainer='ANCHORS 0.9')

In [None]:
print_properties(metric_data)