# ExplainerDashboard

This package makes it convenient to quickly deploy a dashboard web app that explains the workings of a (scikit-learn compatible) machine learning model. The dashboard provides interactive plots on model performance, feature importances, feature contributions to individual predictions, "what if" analysis, partial dependence plots, SHAP (interaction) values, visualisation of individual decision trees, etc.

*Works with scikit-learn, xgboost, catboost, lightgbm, and skorch (sklearn wrapper for tabular PyTorch models) and others.*

In [4]:
from sklearn.ensemble import RandomForestClassifier
from explainerdashboard import ClassifierExplainer, ExplainerDashboard
from explainerdashboard.datasets import titanic_survive, titanic_names

feature_descriptions = {
    "Sex": "Gender of passenger",
    "Gender": "Gender of passenger",
    "Deck": "The deck the passenger had their cabin on",
    "PassengerClass": "The class of the ticket: 1st, 2nd or 3rd class",
    "Fare": "The amount of money people paid", 
    "Embarked": "the port where the passenger boarded the Titanic. Either Southampton, Cherbourg or Queenstown",
    "Age": "Age of the passenger",
    "No_of_siblings_plus_spouses_on_board": "The sum of the number of siblings plus the number of spouses on board",
    "No_of_parents_plus_children_on_board" : "The sum of the number of parents plus the number of children on board",
}

In [5]:
X_train, y_train, X_test, y_test = titanic_survive()
train_names, test_names = titanic_names()
model = RandomForestClassifier(n_estimators=50, max_depth=5)
model.fit(X_train, y_train)

In [6]:
explainer = ClassifierExplainer(model, X_test, y_test, 
                                cats=['Deck', 'Embarked',
                                    {'Gender': ['Sex_male', 'Sex_female', 'Sex_nan']}],
                                cats_notencoded={'Embarked': 'Stowaway'}, # defaults to 'NOT_ENCODED'
                                descriptions=feature_descriptions, # adds a table and hover labels to dashboard
                                labels=['Not survived', 'Survived'], # defaults to ['0', '1', etc]
                                idxs = test_names, # defaults to X.index
                                index_name = "Passenger", # defaults to X.index.name
                                target = "Survival", # defaults to y.name
                                )

Detected RandomForestClassifier model: Changing class type to RandomForestClassifierExplainer...
Note: model_output=='probability', so assuming that raw shap output of RandomForestClassifier is in probability space...
Generating self.shap_explainer = shap.TreeExplainer(model)


In [7]:
db = ExplainerDashboard(explainer, 
                        title="Titanic Explainer", # defaults to "Model Explainer"
                        shap_interaction=False, # you can switch off tabs with bools
                        )
db.run(port=8050)

Building ExplainerDashboard..
Detected notebook environment, consider setting mode='external', mode='inline' or mode='jupyterlab' to keep the notebook interactive while the dashboard is running...
Generating layout...
Calculating shap values...
Calculating prediction probabilities...
Calculating metrics...
Calculating confusion matrices...
Calculating classification_dfs...
Calculating roc auc curves...
Calculating pr auc curves...
Calculating liftcurve_dfs...
Calculating dependencies...
Calculating permutation importances (if slow, try setting n_jobs parameter)...
Calculating predictions...
Calculating pred_percentiles...
Calculating ShadowDecTree for each individual decision tree...
Reminder: you can store the explainer (including calculated dependencies) with explainer.dump('explainer.joblib') and reload with e.g. ClassifierExplainer.from_file('explainer.joblib')
Registering callbacks...
Starting ExplainerDashboard on http://172.25.133.51:8050
Dash is running on http://0.0.0.0:8050/


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:8050
 * Running on http://172.25.133.51:8050
[33mPress CTRL+C to quit[0m
127.0.0.1 - - [10/Dec/2022 01:31:18] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [10/Dec/2022 01:31:18] "GET /assets/bootstrap.min.css?m=1670653145.5315115 HTTP/1.1" 200 -
127.0.0.1 - - [10/Dec/2022 01:31:18] "GET /_dash-component-suites/dash/deps/react-dom@16.v2_7_0m1670653144.14.0.min.js HTTP/1.1" 200 -
127.0.0.1 - - [10/Dec/2022 01:31:18] "GET /_dash-component-suites/dash/deps/prop-types@15.v2_7_0m1670653144.8.1.min.js HTTP/1.1" 200 -
127.0.0.1 - - [10/Dec/2022 01:31:18] "GET /_dash-component-suites/dash/deps/react@16.v2_7_0m1670653144.14.0.min.js HTTP/1.1" 200 -
127.0.0.1 - - [10/Dec/2022 01:31:18] "GET /_dash-component-suites/dash/deps/polyfill@7.v2_7_0m1670653144.12.1.min.js HTTP/1.1" 200 -
127.0.0.1 - - [10/Dec/2022 01:31:18] "GET /_dash-component-suites/dash/dash-renderer/build/dash_renderer.v2_7_0m1670653143.min.js HTTP/1.1" 200 -
127.0.0.1 -

For a regression model you can also pass the units of the target variable (e.g. dollars):

In [None]:
from explainerdashboard.datasets import titanic_fare
from sklearn.ensemble import RandomForestRegressor
from explainerdashboard import RegressionExplainer

X_train, y_train, X_test, y_test = titanic_fare()
model = RandomForestRegressor().fit(X_train, y_train)

explainer = RegressionExplainer(model, X_test, y_test, 
                                cats=['Deck', 'Embarked', 'Sex'],
                                descriptions=feature_descriptions, 
                                units = "$", # defaults to ""
                                )

In [11]:
ExplainerDashboard(explainer).run(port=8060)

Building ExplainerDashboard..
Detected notebook environment, consider setting mode='external', mode='inline' or mode='jupyterlab' to keep the notebook interactive while the dashboard is running...
Generating layout...
Calculating dependencies...
Reminder: you can store the explainer (including calculated dependencies) with explainer.dump('explainer.joblib') and reload with e.g. ClassifierExplainer.from_file('explainer.joblib')
Registering callbacks...
Starting ExplainerDashboard on http://172.25.133.51:8060
Dash is running on http://0.0.0.0:8060/

Dash is running on http://0.0.0.0:8060/

Dash is running on http://0.0.0.0:8060/

 * Serving Flask app 'explainerdashboard.dashboards'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:8060
 * Running on http://172.25.133.51:8060
[33mPress CTRL+C to quit[0m
127.0.0.1 - - [10/Dec/2022 01:40:16] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [10/Dec/2022 01:40:16] "GET /assets/bootstrap.min.css?m=1670653145.5315115 HTTP/1.1" 200 -
127.0.0.1 - - [10/Dec/2022 01:40:16] "GET /_dash-component-suites/dash/deps/react@16.v2_7_0m1670653144.14.0.min.js HTTP/1.1" 200 -
127.0.0.1 - - [10/Dec/2022 01:40:16] "GET /_dash-component-suites/dash/deps/polyfill@7.v2_7_0m1670653144.12.1.min.js HTTP/1.1" 200 -
127.0.0.1 - - [10/Dec/2022 01:40:16] "GET /_dash-component-suites/dash/deps/prop-types@15.v2_7_0m1670653144.8.1.min.js HTTP/1.1" 200 -
127.0.0.1 - - [10/Dec/2022 01:40:16] "GET /_dash-component-suites/dash/deps/react-dom@16.v2_7_0m1670653144.14.0.min.js HTTP/1.1" 200 -
127.0.0.1 - - [10/Dec/2022 01:40:16] "GET /_dash-component-suites/dash_bootstrap_components/_components/dash_bootstrap_components.v1_2_1m1670653145.min.js HT

https://github.com/oegedijk/explainerdashboard