-
Notifications
You must be signed in to change notification settings - Fork 317
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
docs: Add interpret tutorial with Transformers (#1728)
* docs(explain): add model interpret tutorials * docs(explain): polish tuto * docs(explain): polish tuto 2 * docs(explain): add explain to explanation * docs(explain): add deps * revert files * revert file Closes #1729 (cherry picked from commit 47645d9)
- Loading branch information
1 parent
5bd7586
commit 82fe331
Showing
4 changed files
with
279 additions
and
0 deletions.
There are no files selected for viewing
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+1.52 MB
docs/_static/tutorials/nlp-model-explainability/model-explainability.mp4
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,272 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "cb4af95f-9940-4e8c-832f-446f6b2d50c5", | ||
"metadata": {}, | ||
"source": [ | ||
"# 🕵️♀️ Analyzing predictions with model explainability methods\n", | ||
"\n", | ||
"In this tutorial you will learn to log and explore NLP model explanations using Transformers and the following libraries:\n", | ||
"\n", | ||
"* Transformers Interpret\n", | ||
"* Shap\n", | ||
"\n", | ||
"Interpretability and explanation information in Rubrix is not limited to these two libraries. You can populate this information using your method of choice to highlight specific tokens. \n", | ||
"\n", | ||
"This tutorial is useful to get started and understand the underlying structure of explanation information in Rubrix records.\n", | ||
"\n", | ||
"<video width=\"100%\" controls><source src=\"../_static/tutorials/nlp-model-explainability/model-explainability.mp4\" type=\"video/mp4\"></video>\n", | ||
"\n", | ||
"Beyond browsing examples during model development and evaluation, storing explainability information in Rubrix can be really useful for monitoring and assessing production models (more tutorials on this soon!)\n", | ||
"\n", | ||
"\n", | ||
"Let's get started!" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "495af015", | ||
"metadata": {}, | ||
"source": [ | ||
"## Setup\n", | ||
"\n", | ||
"Rubrix, is a free and open-source framework for data-centric NLP. If you are new to Rubrix, check out the [Github repository](https://github.com/recognai/rubrix) ⭐.\n", | ||
"\n", | ||
"If you have not installed and launched Rubrix yet, check the [Setup and Installation guide](../getting_started/setup&installation.rst).\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "55a20408", | ||
"metadata": {}, | ||
"source": [ | ||
"## Token attributions and what do highlight colors mean?\n", | ||
"\n", | ||
"Rubrix enables you to register token attributions as part of the dataset records. For getting token attributions, you can use methods such as Integrated Gradients or SHAP. These methods try to provide a mechanism to interpret model predictions. The attributions work as follows:\n", | ||
"\n", | ||
"* **[0,1] Positive attributions (in blue)** reflect those tokens that are making the model predict the specific predicted label.\n", | ||
"\n", | ||
"* **[-1, 0] Negative attributions (in red)** reflect those tokens that can influence the model to predict a label other than the specific predicted label." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "bb5b7960-34b6-45a3-9ffe-3812cd75069a", | ||
"metadata": {}, | ||
"source": [ | ||
"## Using `Transformers Interpret`\n", | ||
"\n", | ||
"In this example, we will use the `sst` sentiment dataset and a distilbert-based sentiment classifier. For getting model explanation information, we will use the excellent [Transformers Interpret](https://github.com/cdpierse/transformers-interpret) library by [Charles Pierse](https://github.com/cdpierse)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "69d1ed13-3cbe-41d2-93e1-fa5cfffe65d5", | ||
"metadata": {}, | ||
"source": [ | ||
"### Install dependencies" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a01dfa90-6745-4f2f-84e3-4417c068db22", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install transformers-interpret==0.5.2 datasets transformers" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e79e0021-62a5-4fec-b133-a9cceaabe365", | ||
"metadata": {}, | ||
"source": [ | ||
"### Create a fully searchable dataset with model predictions and explanations" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a0c79a20-59f1-499e-a10b-fd721dec232a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from transformers import AutoModelForSequenceClassification, AutoTokenizer\n", | ||
"from transformers_interpret import SequenceClassificationExplainer\n", | ||
"from datasets import load_dataset\n", | ||
"\n", | ||
"import rubrix as rb\n", | ||
"from rubrix import TokenAttributions\n", | ||
"\n", | ||
"# Load Stanford sentiment treebank test set\n", | ||
"dataset = load_dataset(\"sst\", \"default\", split=\"test\")\n", | ||
"\n", | ||
"# Let's use a sentiment classifier fine-tuned on sst\n", | ||
"model_name = \"distilbert-base-uncased-finetuned-sst-2-english\"\n", | ||
"model = AutoModelForSequenceClassification.from_pretrained(model_name)\n", | ||
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n", | ||
"\n", | ||
"# Define the explainer using transformers_interpret\n", | ||
"cls_explainer = SequenceClassificationExplainer(model, tokenizer)\n", | ||
"\n", | ||
"records = []\n", | ||
"for example in dataset:\n", | ||
" \n", | ||
" # Build Token attributions objects \n", | ||
" word_attributions = cls_explainer(example[\"sentence\"])\n", | ||
" token_attributions = [ \n", | ||
" TokenAttributions(\n", | ||
" token=token, \n", | ||
" attributions={cls_explainer.predicted_class_name: score}\n", | ||
" ) # ignore first (CLS) and last (SEP) tokens\n", | ||
" for token, score in word_attributions[1:-1]\n", | ||
" ]\n", | ||
" # Build Text classification records\n", | ||
" record = rb.TextClassificationRecord(\n", | ||
" text=example[\"sentence\"],\n", | ||
" prediction=[(cls_explainer.predicted_class_name, cls_explainer.pred_probs)],\n", | ||
" explanation={\"text\": token_attributions},\n", | ||
" )\n", | ||
" records.append(record)\n", | ||
" \n", | ||
"# Build Rubrix dataset with interpretations for each record \n", | ||
"rb.log(records, name=\"transformers_interpret_example\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ed8d8d9a-dc80-4059-99a0-899bc02c6b6f", | ||
"metadata": {}, | ||
"source": [ | ||
"### Example: *Predicted as negative* sorted by *descending score*\n", | ||
"\n", | ||
"If you go to [http://localhost:6900/datasets/rubrix/transformers_interpret_example](http://localhost:6900/datasets/rubrix/transformers_interpret_example) (assuming you are running Rubrix on your local machine you get a fully-searchable dataset). \n", | ||
"For example, let's drill down to look at examples predicted as negative with a low score:\n", | ||
"\n", | ||
"![Predicted as negative with low score](../_static/tutorials/nlp-model-explainability/interpret.png)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "cb8ffd33-851c-4339-a29d-4bb36b66adfb", | ||
"metadata": {}, | ||
"source": [ | ||
"## Using `Shap`\n", | ||
"\n", | ||
"In this example, we will use the widely-used [Shap](https://github.com/slundberg/shap) library by [\n", | ||
"Scott Lundberg](https://github.com/slundberg)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8a16a7a1-258a-48d6-b623-e6ec28c350b1", | ||
"metadata": {}, | ||
"source": [ | ||
"### Install dependencies\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "0c0049af-d696-474b-97f8-adb34733f666", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install shap==0.40.0 numba==0.53.1" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "0650a144-c331-4b5d-b013-3ae41eb92418", | ||
"metadata": {}, | ||
"source": [ | ||
"### Create a fully searchable dataset with model predictions and explanations\n", | ||
"\n", | ||
"This example is very similar to the one above. The main difference is that we need to scale the values from Shap to match the range required by Rubrix UI. This restriction is for visualization purposes. If you are more interested in monitoring use cases you might not need to rescale." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "7586d2b9-aad8-46c0-b40c-765c398f9946", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import transformers\n", | ||
"from datasets import load_dataset\n", | ||
"\n", | ||
"from sklearn.preprocessing import MinMaxScaler\n", | ||
"\n", | ||
"import shap\n", | ||
"\n", | ||
"from rubrix import TextClassificationRecord, TokenAttributions\n", | ||
"import rubrix as rb\n", | ||
"\n", | ||
"# Transformers pipeline model\n", | ||
"model = transformers.pipeline(\"sentiment-analysis\", return_all_scores=True)\n", | ||
"\n", | ||
"# Load Stanford treebank dataset only the first 5 records for testing\n", | ||
"sst = load_dataset(\"sst\", split=\"test[0:5]\")\n", | ||
"\n", | ||
"# Use shap's library text explainer\n", | ||
"explainer = shap.Explainer(model)\n", | ||
"shap_values = explainer(sst['sentence'])\n", | ||
"\n", | ||
"# Instantiate the scaler\n", | ||
"scaler = MinMaxScaler(feature_range=[-1, 1])\n", | ||
"\n", | ||
"predictions = model(sst[\"sentence\"])\n", | ||
"\n", | ||
"for i in range(0, len(shap_values.values)):\n", | ||
"\n", | ||
" # Scale shap values betweeen -1 and 1 (using e.g., scikit-learn MinMaxScaler\n", | ||
" scaled = scaler.fit_transform(shap_values.values[i])\n", | ||
"\n", | ||
" # get prediction label idx for indexing attributions and shap_values\n", | ||
" # sorts by score to get the max score prediction\n", | ||
" sorted_predictions = sorted(predictions[i], key=lambda d: d[\"score\"], reverse=True)\n", | ||
" label_idx = 0 if sorted_predictions[0][\"label\"] == \"NEGATIVE\" else 1\n", | ||
"\n", | ||
" # Build token attributions\n", | ||
" token_attributions = [\n", | ||
" TokenAttributions(\n", | ||
" token=token, attributions={shap_values.output_names[label_idx]: score}\n", | ||
" )\n", | ||
" for token, score in zip(shap_values.data[i], [row[label_idx] for row in scaled])\n", | ||
" ]\n", | ||
"\n", | ||
" # Build Rubrix record\n", | ||
" record = TextClassificationRecord(\n", | ||
" inputs=sst[\"sentence\"][i],\n", | ||
" prediction=[(pred[\"label\"], pred[\"score\"]) for pred in predictions[i]],\n", | ||
" explanation={\"text\": token_attributions},\n", | ||
" )\n", | ||
" # Log record\n", | ||
" rb.log(record, name=\"rubrix_shap_example\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |