From 49556bb32a5a73fa60ed9cf0a37aebee83ff3a53 Mon Sep 17 00:00:00 2001 From: Freddy Boulton <41651716+freddyaboulton@users.noreply.github.com> Date: Mon, 10 Aug 2020 15:17:40 -0400 Subject: [PATCH] Explain Predictions (#1016) * Working implementation of explain_predictions and explain_predictions_best_worst. * Refactoring explainers so that differences between report types are more modular. * Updating release notes for PR 1016. * Adding tests for error metrics. * Moving release note for PR 1016 to upcoming release. * Adding test for classification pipeline _classes property. * Adding section to model understanding user guide about explain_predictions and explain_predictions_best_worst. * Adding test for custom metric. Replacing names of default error metrics with more user-friendly output. * Adding explain_predictions and explain_predictions_best_worst to api reference and fixing docs. * Adding Predicted Value as a field in the classification reports. * Making metrics public. * Reducing the text output in the prediction explanations section of the model understanding user guide. * If a user tries to access pipeline._classes before fitting the pipeline, a helpful ValueError will be raised. * Fixing typo in docs. * Moving the ReportSectionMakers to _user_interface and adding some base classes to make the structure more clear. * Adding a test for pipeline._classes for problems where the labels are ints. * Modifying changelog for 1016. * Adding docstring for _make_single_prediction_shap_table. * Making _TableMaker private. * Fixing lint. * Not mocking shap values in test_explain_prediction_value_error. --- docs/source/api_reference.rst | 2 + docs/source/release_notes.rst | 1 + .../user_guide/model_understanding.ipynb | 79 +++- evalml/pipelines/classification_pipeline.py | 7 + .../prediction_explanations/__init__.py | 2 +- .../_user_interface.py | 260 ++++++++++-- .../prediction_explanations/explainers.py | 158 ++++++- .../test_classification.py | 31 ++ .../explanations_tests/test_explainers.py | 390 +++++++++++++++++- .../explanations_tests/test_user_interface.py | 43 +- 10 files changed, 877 insertions(+), 96 deletions(-) diff --git a/docs/source/api_reference.rst b/docs/source/api_reference.rst index 7460b393cc..2b0994362d 100644 --- a/docs/source/api_reference.rst +++ b/docs/source/api_reference.rst @@ -237,6 +237,8 @@ Prediction Explanations :nosignatures: explain_prediction + explain_predictions + explain_predictions_best_worst .. currentmodule:: evalml.objectives diff --git a/docs/source/release_notes.rst b/docs/source/release_notes.rst index 6bd9aeb382..6f91eb71e9 100644 --- a/docs/source/release_notes.rst +++ b/docs/source/release_notes.rst @@ -4,6 +4,7 @@ Release Notes **Future Releases** * Enhancements * Split `fill_value` into `categorical_fill_value` and `numeric_fill_value` for Imputer :pr:`1019` + * Added `explain_predictions` and `explain_predictions_best_worst` for explaining multiple predictions with SHAP :pr:`1016` * Fixes * Changes * Documentation Changes diff --git a/docs/source/user_guide/model_understanding.ipynb b/docs/source/user_guide/model_understanding.ipynb index 498792bce8..925a65d78c 100644 --- a/docs/source/user_guide/model_understanding.ipynb +++ b/docs/source/user_guide/model_understanding.ipynb @@ -132,9 +132,9 @@ "outputs": [], "source": [ "# get the predicted probabilities associated with the \"true\" label\n", - "y = y.map({'malignant': 0, 'benign': 1})\n", + "y_encoded = y.map({'malignant': 0, 'benign': 1})\n", "y_pred_proba = pipeline.predict_proba(X)[\"benign\"]\n", - "evalml.pipelines.graph_utils.graph_precision_recall_curve(y, y_pred_proba)" + "evalml.pipelines.graph_utils.graph_precision_recall_curve(y_encoded, y_pred_proba)" ] }, { @@ -154,7 +154,7 @@ "source": [ "# get the predicted probabilities associated with the \"benign\" label\n", "y_pred_proba = pipeline.predict_proba(X)[\"benign\"]\n", - "evalml.pipelines.graph_utils.graph_roc_curve(y, y_pred_proba)" + "evalml.pipelines.graph_utils.graph_roc_curve(y_encoded, y_pred_proba)" ] }, { @@ -163,7 +163,7 @@ "source": [ "## Explaining Individual Predictions\n", "\n", - "We can explain why the model made an individual prediction with the `explain_prediction` function. This will use the [Shapley Additive Explanations (SHAP)](https://github.com/slundberg/shap) algorithms to identify the top features that explain the predicted value. \n", + "We can explain why the model made an individual prediction with the [explain_prediction](../generated/evalml.pipelines.prediction_explanations.explain_prediction.ipynb) function. This will use the [Shapley Additive Explanations (SHAP)](https://github.com/slundberg/shap) algorithms to identify the top features that explain the predicted value. \n", "\n", "This function can explain both classification and regression models - all you need to do is provide the pipeline, the input features (must correspond to one row of the input data) and the training data. The function will return a table that you can print summarizing the top 3 most positive and negative contributing features to the predicted value.\n", "\n", @@ -191,6 +191,77 @@ "\n", "This functionality is currently **not supported** for **XGBoost** models or **CatBoost multiclass** classifiers." ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Explaining Multiple Predictions\n", + "\n", + "When debugging machine learning models, it is often useful to analyze the best and worst predictions the model made. The [explain_predictions_best_worst](../generated/evalml.pipelines.prediction_explanations.explain_predictions_best_worst.ipynb) function can help us with this.\n", + "\n", + "This function will display the output of [explain_prediction](../generated/evalml.pipelines.prediction_explanations.explain_prediction.ipynb) for the best 2 and worst 2 predictions. By default, the best and worst predictions are determined by the absolute error for regression problems and [cross entropy](https://en.wikipedia.org/wiki/Cross_entropy) for classification problems.\n", + "\n", + "We can specify our own ranking function by passing in a function to the `metric` parameter. This function will be called on `y_true` and `y_pred`. By convention, lower scores are better.\n", + "\n", + "At the top of each table, we can see the predicted probabilities, target value, and error on that prediction. For a regression problem, we would see the predicted value instead of predicted probabilities.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from evalml.pipelines.prediction_explanations import explain_predictions_best_worst\n", + "\n", + "report = explain_predictions_best_worst(pipeline=pipeline, input_features=X, y_true=y,\n", + " include_shap_values=True, num_to_explain=2)\n", + "\n", + "print(report)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We use a custom metric ([hinge loss](https://en.wikipedia.org/wiki/Hinge_loss)) for selecting the best and worst predictions. See this example:\n", + "\n", + "```python\n", + "import numpy as np\n", + "\n", + "def hinge_loss(y_true, y_pred_proba):\n", + " \n", + " probabilities = np.clip(y_pred_proba.iloc[:, 1], 0.001, 0.999)\n", + " y_true[y_true == 0] = -1\n", + " \n", + " return np.clip(1 - y_true * np.log(probabilities / (1 - probabilities)), a_min=0, a_max=None)\n", + "\n", + "report = explain_predictions_best_worst(pipeline=pipeline, input_features=X, y_true=y,\n", + " include_shap_values=True, num_to_explain=5, metric=hinge_loss)\n", + "\n", + "print(report)\n", + "```\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also manually explain predictions on any subset of the training data with the [explain_predictions](../generated/evalml.pipelines.prediction_explanations.explain_predictions.ipynb) function. Below, we explain the predictions on the first, fifth, and tenth row of the data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from evalml.pipelines.prediction_explanations import explain_predictions\n", + "\n", + "report = explain_predictions(pipeline=pipeline, input_features=X.iloc[[0, 4, 9]], include_shap_values=True)\n", + "print(report)" + ] } ], "metadata": { diff --git a/evalml/pipelines/classification_pipeline.py b/evalml/pipelines/classification_pipeline.py index 29ca95917b..ef600068eb 100644 --- a/evalml/pipelines/classification_pipeline.py +++ b/evalml/pipelines/classification_pipeline.py @@ -59,6 +59,13 @@ def _decode_targets(self, y): originally had integer targets.""" return self._encoder.inverse_transform(y.astype(int)) + @property + def _classes(self): + """Gets the class names for the problem.""" + if not hasattr(self._encoder, "classes_"): + raise AttributeError("Cannot access class names before fitting the pipeline.") + return self._encoder.classes_ + def _predict(self, X, objective=None): """Make predictions using selected features. diff --git a/evalml/pipelines/prediction_explanations/__init__.py b/evalml/pipelines/prediction_explanations/__init__.py index 6d92603eee..4b675c84d5 100644 --- a/evalml/pipelines/prediction_explanations/__init__.py +++ b/evalml/pipelines/prediction_explanations/__init__.py @@ -1,2 +1,2 @@ # flake8:noqa -from .explainers import explain_prediction +from .explainers import explain_prediction, explain_predictions_best_worst, explain_predictions diff --git a/evalml/pipelines/prediction_explanations/_user_interface.py b/evalml/pipelines/prediction_explanations/_user_interface.py index 72d156f955..8fcc44d4f3 100644 --- a/evalml/pipelines/prediction_explanations/_user_interface.py +++ b/evalml/pipelines/prediction_explanations/_user_interface.py @@ -1,5 +1,14 @@ +import abc + +import pandas as pd from texttable import Texttable +from evalml.pipelines.prediction_explanations._algorithms import ( + _compute_shap_values, + _normalize_shap_values +) +from evalml.problem_types import ProblemTypes + def _make_rows(shap_values, normalized_values, top_k, include_shap_values=False): """Makes the rows (one row for each feature) for the SHAP table. @@ -34,12 +43,10 @@ def _make_rows(shap_values, normalized_values, top_k, include_shap_values=False) return rows -def _make_table(dtypes, alignment, shap_values, normalized_values, top_k, include_shap_values=False): +def _make_table(shap_values, normalized_values, top_k, include_shap_values=False): """Make a table displaying the SHAP values for a prediction. Arguments: - dtypes (list(str)): The dtypes of each column in the table. - alignment (list(str)): How the data in each column in the table should be aligned. shap_values (dict): Dictionary mapping the feature names to their SHAP values. In a multiclass setting, this dictionary for correspond to the SHAP values for a single class. normalized_values (dict): Normalized SHAP values. Same structure as shap_values parameter. @@ -49,6 +56,9 @@ def _make_table(dtypes, alignment, shap_values, normalized_values, top_k, includ Returns: str """ + dtypes = ["t", "t", "f"] if include_shap_values else ["t", "t"] + alignment = ["c", "c", "c"] if include_shap_values else ["c", "c"] + table = Texttable() table.set_deco(Texttable.HEADER) table.set_cols_dtype(dtypes) @@ -64,46 +74,218 @@ def _make_table(dtypes, alignment, shap_values, normalized_values, top_k, includ return table.draw() -def _make_single_prediction_table(shap_values, normalized_values, top_k=3, include_shap_values=False): - """Makes a table from the SHAP values for a single prediction. +class _TableMaker(abc.ABC): + """Makes a SHAP table for a regression, binary, or multiclass classification problem.""" + + @abc.abstractmethod + def __call__(self, shap_values, normalized_values, top_k, include_shap_values=False): + """Creates a table given shap values.""" + + +class _SHAPRegressionTableMaker(_TableMaker): + """Makes a SHAP table explaining a prediction for a regression problems.""" + + def __call__(self, shap_values, normalized_values, top_k, include_shap_values=False): + return _make_table(shap_values, normalized_values, top_k, include_shap_values) + + +class _SHAPBinaryTableMaker(_TableMaker): + """Makes a SHAP table explaining a prediction for a binary classification problem.""" + + def __call__(self, shap_values, normalized_values, top_k, include_shap_values=False): + # The SHAP algorithm will return a two-element list for binary problems. + # By convention, we display the explanation for the dominant class. + return _make_table(shap_values[1], normalized_values[1], top_k, include_shap_values) + + +class _SHAPMultiClassTableMaker(_TableMaker): + """Makes a SHAP table explaining a prediction for a multiclass classification problem.""" + + def __init__(self, class_names): + self.class_names = class_names + + def __call__(self, shap_values, normalized_values, top_k, include_shap_values=False): + strings = [] + for class_name, class_values, normalized_class_values in zip(self.class_names, shap_values, normalized_values): + strings.append(f"Class: {class_name}\n") + table = _make_table(class_values, normalized_class_values, top_k, include_shap_values) + strings += table.splitlines() + strings.append("\n") + return "\n".join(strings) + + +def _make_single_prediction_shap_table(pipeline, input_features, top_k=3, training_data=None, + include_shap_values=False): + """Creates table summarizing the top_k positive and top_k negative contributing features to the prediction of a single datapoint. Arguments: - shap_values (list(dict) or dict): Dictionary mapping a feature name to a one-element list - containing its scaled value. In classification problems, the input will be a list storing a dict for - each class. - normalized_values (list(dict) or dict): Dictionary mapping a feature name to a one-element list - containing its scaled value. In classification problems, the input will be a list storing a dict for - each class. - top_k (int): Will include the top_k highest and lowest features in the table. - include_shap_values (bool): Whether the SHAP values should be included as an extra column in the table. + pipeline (PipelineBase): Fitted pipeline whose predictions we want to explain with SHAP. + input_features (pd.DataFrame): Dataframe of features - needs to correspond to data the pipeline was fit on. + top_k (int): How many of the highest/lowest features to include in the table. + training_data (pd.DataFrame): Training data the pipeline was fit on. + This is required for non-tree estimators because we need a sample of training data for the KernelSHAP algorithm. + include_shap_values (bool): Whether the SHAP values should be included in an extra column in the output. + Default is False. Returns: - str + str: Table """ - dtypes = ["t", "t"] - alignment = ["c", "c"] + if not (isinstance(input_features, pd.DataFrame) and input_features.shape[0] == 1): + raise ValueError("features must be stored in a dataframe of one row.") - if include_shap_values: - dtypes.append("f") - alignment.append("c") - - # Classification - if isinstance(shap_values, list): - # Binary Classification - if len(shap_values) == 2: - strings = ["Positive Label\n"] - table = _make_table(dtypes, alignment, shap_values[1], normalized_values[1], top_k, include_shap_values) - strings += table.splitlines() - return "\n".join(strings) - # Multiclass - else: - strings = [] - for class_index, (class_values, normalized_class_values) in enumerate(zip(shap_values, normalized_values)): - strings.append(f"Class {class_index}\n") - table = _make_table(dtypes, alignment, class_values, normalized_class_values, top_k, include_shap_values) - strings += table.splitlines() - strings.append("\n") - return "\n".join(strings) - # Regression + shap_values = _compute_shap_values(pipeline, input_features, training_data) + normalized_shap_values = _normalize_shap_values(shap_values) + + if pipeline.problem_type == ProblemTypes.REGRESSION: + table_maker = _SHAPRegressionTableMaker() + elif pipeline.problem_type == ProblemTypes.BINARY: + table_maker = _SHAPBinaryTableMaker() else: - return _make_table(dtypes, alignment, shap_values, normalized_values, top_k, include_shap_values) + table_maker = _SHAPMultiClassTableMaker(pipeline._classes) + + return table_maker(shap_values, normalized_shap_values, top_k, include_shap_values) + + +class _ReportSectionMaker: + """Make a prediction explanation report. + + A report is made up of three parts: the header, the predicted values (if any), and the table. + + There are two kinds of reports we make: Reports where we explain the best and worst predictions and + reports where we explain predictions for features the user has manually selected. + + Each of these reports is slightly different depending on the type of problem (regression, binary, multiclass). + + Rather than addressing all cases in one function/class, we write individual classes for formatting each part + of the report depending on the type of problem and report. + + This class creates the report given callables for creating the header, predicted values, and table. + """ + + def __init__(self, heading_maker, predicted_values_maker, table_maker): + self.heading_maker = heading_maker + self.make_predicted_values_maker = predicted_values_maker + self.table_maker = table_maker + + def make_report_section(self, pipeline, input_features, indices, y_pred, y_true, errors): + """Make a report for a subset of input features to a fitted pipeline. + + Arguments: + pipeline (PipelineBase): Fitted pipeline. + input_features (pd.DataFrame): Features where the pipeline predictions will be explained. + indices (list(int)): List of indices specifying the subset of input features whose predictions + we want to explain. + y_pred (pd.Series): Predicted values of the input_features. + y_true (pd.Series): True labels of the input_features. + errors (pd.Series): Error between y_pred and y_true + + Returns: + str + """ + report = [] + for rank, index in enumerate(indices): + report.extend(self.heading_maker(rank, index)) + report.extend(self.make_predicted_values_maker(index, y_pred, y_true, errors)) + report.extend(self.table_maker(index, pipeline, input_features)) + return report + + +class _SectionMaker(abc.ABC): + """Makes a section for a prediction explanations report. + + A report is made up of three parts: the header, the predicted values (if any), and the table. + + Each subclass of this class will be responsible for creating one of these sections. + """ + + @abc.abstractmethod + def __call__(self, *args, **kwargs): + """Makes the report section. + + Returns: + list(str): A list containing the lines of report section. + """ + + +class _HeadingMaker(_SectionMaker): + """Makes the heading section for reports. + + Differences between best/worst reports and reports where user manually specifies the input features subset + are handled by formatting the value of the prefix parameter in the initialization. + """ + + def __init__(self, prefix, n_indices): + self.prefix = prefix + self.n_indices = n_indices + + def __call__(self, rank, index): + return [f"\t{self.prefix}{rank + 1} of {self.n_indices}\n\n"] + + +class _EmptyPredictedValuesMaker(_SectionMaker): + """Omits the predicted values section for reports where the user specifies the subset of the input features.""" + + def __call__(self, index, y_pred, y_true, scores): + return [""] + + +class _ClassificationPredictedValuesMaker(_SectionMaker): + """Makes the predicted values section for classification problem best/worst reports.""" + + def __init__(self, error_name, y_pred_values): + # Replace the default name with something more user-friendly + if error_name == "cross_entropy": + error_name = "Cross Entropy" + self.error_name = error_name + self.predicted_values = y_pred_values + + def __call__(self, index, y_pred, y_true, scores): + pred_value = [f"{col_name}: {pred}" for col_name, pred in + zip(y_pred.columns, round(y_pred.iloc[index], 3).tolist())] + pred_value = "[" + ", ".join(pred_value) + "]" + true_value = y_true[index] + + return [f"\t\tPredicted Probabilities: {pred_value}\n", + f"\t\tPredicted Value: {self.predicted_values[index]}\n", + f"\t\tTarget Value: {true_value}\n", + f"\t\t{self.error_name}: {round(scores[index], 3)}\n\n"] + + +class _RegressionPredictedValuesMaker(_SectionMaker): + """Makes the predicted values section for regression problem best/worst reports.""" + + def __init__(self, error_name): + # Replace the default name with something more user-friendly + if error_name == "abs_error": + error_name = "Absolute Difference" + self.error_name = error_name + + def __call__(self, index, y_pred, y_true, scores): + + return [f"\t\tPredicted Value: {round(y_pred.iloc[index], 3)}\n", + f"\t\tTarget Value: {round(y_true[index], 3)}\n", + f"\t\t{self.error_name}: {round(scores[index], 3)}\n\n"] + + +class _SHAPTableMaker(_SectionMaker): + """Makes the SHAP table section for reports. + + The table is the same whether the user requests a best/worst report or they manually specified the + subset of the input features. + + Handling the differences in how the table is formatted between regression and classification problems + is delegated to the explain_prediction function. + """ + + def __init__(self, top_k_features, include_shap_values, training_data): + self.top_k_features = top_k_features + self.include_shap_values = include_shap_values + self.training_data = training_data + + def __call__(self, index, pipeline, input_features): + table = _make_single_prediction_shap_table(pipeline, input_features.iloc[index:(index + 1)], + training_data=self.training_data, top_k=self.top_k_features, + include_shap_values=self.include_shap_values) + table = table.splitlines() + # Indent the rows of the table to match the indentation of the entire report. + return ["\t\t" + line + "\n" for line in table] + ["\n\n"] diff --git a/evalml/pipelines/prediction_explanations/explainers.py b/evalml/pipelines/prediction_explanations/explainers.py index 01058272c3..454d1335a0 100644 --- a/evalml/pipelines/prediction_explanations/explainers.py +++ b/evalml/pipelines/prediction_explanations/explainers.py @@ -1,12 +1,20 @@ +import sys +import traceback + +import numpy as np import pandas as pd -from evalml.pipelines.prediction_explanations._algorithms import ( - _compute_shap_values, - _normalize_shap_values -) +from evalml.exceptions import PipelineScoreError from evalml.pipelines.prediction_explanations._user_interface import ( - _make_single_prediction_table + _ClassificationPredictedValuesMaker, + _EmptyPredictedValuesMaker, + _HeadingMaker, + _make_single_prediction_shap_table, + _RegressionPredictedValuesMaker, + _ReportSectionMaker, + _SHAPTableMaker ) +from evalml.problem_types import ProblemTypes def explain_prediction(pipeline, input_features, top_k=3, training_data=None, include_shap_values=False): @@ -21,12 +29,142 @@ def explain_prediction(pipeline, input_features, top_k=3, training_data=None, in training_data (pd.DataFrame): Training data the pipeline was fit on. This is required for non-tree estimators because we need a sample of training data for the KernelSHAP algorithm. include_shap_values (bool): Whether the SHAP values should be included in an extra column in the output. + Default is False. Returns: str: Table """ - if not (isinstance(input_features, pd.DataFrame) and input_features.shape[0] == 1): - raise ValueError("features must be stored in a dataframe of one row.") - shap_values = _compute_shap_values(pipeline, input_features, training_data) - normalized_shap_values = _normalize_shap_values(shap_values) - return _make_single_prediction_table(shap_values, normalized_shap_values, top_k, include_shap_values) + return _make_single_prediction_shap_table(pipeline, input_features, top_k, training_data, include_shap_values) + + +def abs_error(y_true, y_pred): + """Computes the absolute error per data point for regression problems. + + Arguments: + y_true (pd.Series): True labels. + y_pred (pd.Series): Predicted values. + + Returns: + pd.Series + """ + return np.abs(y_true - y_pred) + + +def cross_entropy(y_true, y_pred_proba): + """Computes Cross Entropy Loss per data point for classification problems. + + Arguments: + y_true (pd.Series): True labels encoded as ints. + y_pred_proba (pd.DataFrame): Predicted probabilities. One column per class. + + Returns: + pd.Series + """ + n_data_points = y_pred_proba.shape[0] + log_likelihood = -np.log(y_pred_proba.values[range(n_data_points), y_true.values.astype("int")]) + return pd.Series(log_likelihood) + + +DEFAULT_METRICS = {ProblemTypes.BINARY: cross_entropy, + ProblemTypes.MULTICLASS: cross_entropy, + ProblemTypes.REGRESSION: abs_error} + + +def explain_predictions(pipeline, input_features, training_data=None, top_k_features=3, include_shap_values=False): + """Creates a report summarizing the top contributing features for each data point in the input features. + + XGBoost models and CatBoost multiclass classifiers are not currently supported. + + Arguments: + pipeline (PipelineBase): Fitted pipeline whose predictions we want to explain with SHAP. + input_features (pd.DataFrame): Dataframe of input data to evaluate the pipeline on. + training_data (pd.DataFrame): Dataframe of data the pipeline was fit on. This can be omitted for pipelines + with tree-based estimators. + top_k_features (int): How many of the highest/lowest contributing feature to include in the table for each + data point. + include_shap_values (bool): Whether SHAP values should be included in the table. Default is False. + + Returns: + str - A report with the pipeline name and parameters and a table for each row of input_features. + The table will have the following columns: Feature Name, Contribution to Prediction, SHAP Value (optional), + and each row of the table will be a feature. + """ + if not (isinstance(input_features, pd.DataFrame) and not input_features.empty): + raise ValueError("Parameter input_features must be a non-empty dataframe.") + report = [pipeline.name + "\n\n", str(pipeline.parameters) + "\n\n"] + header_maker = _HeadingMaker(prefix="", n_indices=input_features.shape[0]) + prediction_results_maker = _EmptyPredictedValuesMaker() + table_maker = _SHAPTableMaker(top_k_features, include_shap_values, training_data) + section_maker = _ReportSectionMaker(header_maker, prediction_results_maker, table_maker) + report.extend(section_maker.make_report_section(pipeline, input_features, indices=range(input_features.shape[0]), + y_true=None, y_pred=None, errors=None)) + return "".join(report) + + +def explain_predictions_best_worst(pipeline, input_features, y_true, num_to_explain=5, top_k_features=3, + include_shap_values=False, metric=None): + """Creates a report summarizing the top contributing features for the best and worst points in the dataset as measured by error to true labels. + + XGBoost models and CatBoost multiclass classifiers are not currently supported. + + Arguments: + pipeline (PipelineBase): Fitted pipeline whose predictions we want to explain with SHAP. + input_features (pd.DataFrame): Dataframe of input data to evaluate the pipeline on. + y_true (pd.Series): True labels for the input data. + num_to_explain (int): How many of the best, worst, random data points to explain. + top_k_features (int): How many of the highest/lowest contributing feature to include in the table for each + data point. + include_shap_values (bool): Whether SHAP values should be included in the table. Default is False. + metric (callable): The metric used to identify the best and worst points in the dataset. Function must accept + the true labels and predicted value or probabilities as the only arguments and lower values + must be better. By default, this will be the absolute error for regression problems and cross entropy loss + for classification problems. + + Returns: + str - A report with the pipeline name and parameters. For each of the best/worst rows of input_features, the + predicted values, true labels, and metric value will be listed along with a table. The table will have the + following columns: Feature Name, Contribution to Prediction, SHAP Value (optional), and each row of the + table will correspond to a feature. + """ + if not (isinstance(input_features, pd.DataFrame) and input_features.shape[0] >= num_to_explain * 2): + raise ValueError(f"Input features must be a dataframe with more than {num_to_explain * 2} rows! " + "Convert to a dataframe and select a smaller value for num_to_explain if you do not have " + "enough data.") + if not isinstance(y_true, pd.Series): + raise ValueError("Parameter y_true must be a pd.Series.") + if y_true.shape[0] != input_features.shape[0]: + raise ValueError("Parameters y_true and input_features must have the same number of data points. Received: " + f"true labels: {y_true.shape[0]} and {input_features.shape[0]}") + if not metric: + metric = DEFAULT_METRICS[pipeline.problem_type] + + table_maker = _SHAPTableMaker(top_k_features, include_shap_values, training_data=input_features) + + try: + if pipeline.problem_type == ProblemTypes.REGRESSION: + y_pred = pipeline.predict(input_features) + errors = metric(y_true, y_pred) + prediction_results_maker = _RegressionPredictedValuesMaker(metric.__name__) + else: + y_pred = pipeline.predict_proba(input_features) + y_pred_values = pipeline.predict(input_features) + errors = metric(pipeline._encode_targets(y_true), y_pred) + prediction_results_maker = _ClassificationPredictedValuesMaker(metric.__name__, y_pred_values) + except Exception as e: + tb = traceback.format_tb(sys.exc_info()[2]) + raise PipelineScoreError(exceptions={metric.__name__: (e, tb)}, scored_successfully={}) + + sorted_scores = errors.sort_values() + best = sorted_scores.index[:num_to_explain] + worst = sorted_scores.index[-num_to_explain:] + report = [pipeline.name + "\n\n", str(pipeline.parameters) + "\n\n"] + + # The trailing space after Best and Worst is intentional. It makes sure there is a space + # between the prefix and rank for the _HeadingMaker + for index_list, prefix in zip([best, worst], ["Best ", "Worst "]): + header_maker = _HeadingMaker(prefix, n_indices=num_to_explain) + report_section_maker = _ReportSectionMaker(header_maker, prediction_results_maker, table_maker) + section = report_section_maker.make_report_section(pipeline, input_features, index_list, y_pred, + y_true, errors) + report.extend(section) + return "".join(report) diff --git a/evalml/tests/pipeline_tests/classification_pipeline_tests/test_classification.py b/evalml/tests/pipeline_tests/classification_pipeline_tests/test_classification.py index ef1593328e..45e78253f9 100644 --- a/evalml/tests/pipeline_tests/classification_pipeline_tests/test_classification.py +++ b/evalml/tests/pipeline_tests/classification_pipeline_tests/test_classification.py @@ -1,6 +1,10 @@ +from itertools import product + import pandas as pd import pytest +from evalml.demos import load_breast_cancer, load_wine + @pytest.mark.parametrize("problem_type", ["binary", "multi"]) def test_new_unique_targets_in_score(X_y_binary, logistic_regression_binary_pipeline_class, @@ -16,3 +20,30 @@ def test_new_unique_targets_in_score(X_y_binary, logistic_regression_binary_pipe pipeline.fit(X, y) with pytest.raises(ValueError, match="y contains previously unseen labels"): pipeline.score(X, pd.Series([4] * len(y)), [objective]) + + +@pytest.mark.parametrize("problem_type,use_ints", product(["binary", "multi"], [True, False])) +def test_pipeline_has_classes_property(logistic_regression_binary_pipeline_class, + logistic_regression_multiclass_pipeline_class, problem_type, use_ints): + if problem_type == "binary": + X, y = load_breast_cancer() + pipeline = logistic_regression_binary_pipeline_class(parameters={}) + if use_ints: + y = y.map({'malignant': 0, 'benign': 1}) + answer = [0, 1] + else: + answer = ["benign", "malignant"] + elif problem_type == "multi": + X, y = load_wine() + pipeline = logistic_regression_multiclass_pipeline_class(parameters={}) + if use_ints: + y = y.map({"class_0": 0, "class_1": 1, "class_2": 2}) + answer = [0, 1, 2] + else: + answer = ["class_0", "class_1", "class_2"] + + with pytest.raises(AttributeError, match="Cannot access class names before fitting the pipeline."): + pipeline._classes + + pipeline.fit(X, y) + pd.testing.assert_series_equal(pd.Series(pipeline._classes), pd.Series(answer)) diff --git a/evalml/tests/pipeline_tests/explanations_tests/test_explainers.py b/evalml/tests/pipeline_tests/explanations_tests/test_explainers.py index 451e755bed..01e60c425b 100644 --- a/evalml/tests/pipeline_tests/explanations_tests/test_explainers.py +++ b/evalml/tests/pipeline_tests/explanations_tests/test_explainers.py @@ -4,41 +4,391 @@ import pandas as pd import pytest +from evalml.exceptions import PipelineScoreError from evalml.pipelines.prediction_explanations.explainers import ( - explain_prediction + abs_error, + cross_entropy, + explain_prediction, + explain_predictions, + explain_predictions_best_worst ) +from evalml.problem_types import ProblemTypes + + +def compare_two_tables(table_1, table_2): + assert len(table_1) == len(table_2) + for row, row_answer in zip(table_1, table_2): + assert row.strip().split() == row_answer.strip().split() + test_features = [5, [1], np.ones((1, 15)), pd.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}).iloc[0], pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}), pd.DataFrame()] @pytest.mark.parametrize("test_features", test_features) -@patch("evalml.pipelines.prediction_explanations.explainers._compute_shap_values") -@patch("evalml.pipelines.prediction_explanations.explainers._normalize_shap_values") -def test_explain_prediction_value_error(mock_normalize_shap_values, mock_compute_shap_values, test_features): +def test_explain_prediction_value_error(test_features): with pytest.raises(ValueError, match="features must be stored in a dataframe of one row."): explain_prediction(None, input_features=test_features, training_data=None) -@patch("evalml.pipelines.prediction_explanations.explainers._compute_shap_values", return_value={"a": [1], "b": [-2], - "c": [-0.25], "d": [2]}) -@patch("evalml.pipelines.prediction_explanations.explainers._normalize_shap_values", return_value={"a": [0.5], - "b": [-0.75], - "c": [-0.25], - "d": [0.75]}) -def test_explain_prediction_runs(mock_normalize_shap_values, mock_compute_shap_values): +explain_prediction_answer = """Feature Name Contribution to Prediction + ========================================= + d ++++ + a +++ + c -- + b ----""".splitlines() + + +explain_prediction_multiclass_answer = """Class: class_0 + + Feature Name Contribution to Prediction + ========================================= + a + + b + + c - + d - + + + Class: class_1 + + Feature Name Contribution to Prediction + ========================================= + a +++ + b ++ + c - + d -- + + + Class: class_2 - answer = """Feature Name Contribution to Prediction + Feature Name Contribution to Prediction ========================================= - d ++++ - a +++ - c -- - b ----""".splitlines() + a + + b + + c --- + d --- + """.splitlines() + +@pytest.mark.parametrize("problem_type,shap_values,normalized_shap_values,answer", + [(ProblemTypes.REGRESSION, + {"a": [1], "b": [-2], "c": [-0.25], "d": [2]}, + {"a": [0.5], "b": [-0.75], "c": [-0.25], "d": [0.75]}, + explain_prediction_answer), + (ProblemTypes.BINARY, + [{}, {"a": [1], "b": [-2], "c": [-0.25], "d": [2]}], + [{}, {"a": [0.5], "b": [-0.75], "c": [-0.25], "d": [0.75]}], + explain_prediction_answer), + (ProblemTypes.MULTICLASS, + [{}, {}, {}], + [{"a": [0.1], "b": [0.09], "c": [-0.04], "d": [-0.06]}, + {"a": [0.53], "b": [0.24], "c": [-0.15], "d": [-0.22]}, + {"a": [0.03], "b": [0.02], "c": [-0.42], "d": [-0.47]}], + explain_prediction_multiclass_answer) + ]) +@patch("evalml.pipelines.prediction_explanations._user_interface._compute_shap_values") +@patch("evalml.pipelines.prediction_explanations._user_interface._normalize_shap_values") +def test_explain_prediction(mock_normalize_shap_values, + mock_compute_shap_values, + problem_type, shap_values, normalized_shap_values, answer): + mock_compute_shap_values.return_value = shap_values + mock_normalize_shap_values.return_value = normalized_shap_values pipeline = MagicMock() + pipeline.problem_type = problem_type + pipeline._classes = ["class_0", "class_1", "class_2"] features = pd.DataFrame({"a": [1], "b": [2]}) - table = explain_prediction(pipeline, features).splitlines() + table = explain_prediction(pipeline, features, top_k=2).splitlines() - assert len(table) == len(answer) - for row, row_answer in zip(table, answer): - assert row.strip().split() == row_answer.strip().split() + compare_two_tables(table, answer) + + +def test_error_metrics(): + + pd.testing.assert_series_equal(abs_error(pd.Series([1, 2, 3]), pd.Series([4, 1, 0])), pd.Series([3, 1, 3])) + pd.testing.assert_series_equal(cross_entropy(pd.Series([1, 0]), + pd.DataFrame({"a": [0.1, 0.2], "b": [0.9, 0.8]})), + pd.Series([-np.log(0.9), -np.log(0.2)])) + + +input_features_and_y_true = [([1], None, "^Input features must be a dataframe with more than 10 rows!"), + (pd.DataFrame({"a": [1]}), None, "^Input features must be a dataframe with more than 10 rows!"), + (pd.DataFrame({"a": range(15)}), [1], "^Parameter y_true must be a pd.Series."), + (pd.DataFrame({"a": range(15)}), pd.Series(range(12)), "^Parameters y_true and input_features must have the same number of data points.") + ] + + +@pytest.mark.parametrize("input_features,y_true,error_message", input_features_and_y_true) +def test_explain_predictions_best_worst_value_errors(input_features, y_true, error_message): + with pytest.raises(ValueError, match=error_message): + explain_predictions_best_worst(None, input_features, y_true) + + +def test_explain_predictions_raises_pipeline_score_error(): + with pytest.raises(PipelineScoreError, match="Division by zero!"): + + def raise_zero_division(input_features): + raise ZeroDivisionError("Division by zero!") + + pipeline = MagicMock() + pipeline.problem_type = ProblemTypes.BINARY + pipeline.predict_proba.side_effect = raise_zero_division + explain_predictions_best_worst(pipeline, pd.DataFrame({"a": range(15)}), pd.Series(range(15))) + + +@pytest.mark.parametrize("input_features", [1, [1], "foo", pd.DataFrame()]) +def test_explain_predictions_value_errors(input_features): + with pytest.raises(ValueError, match="Parameter input_features must be a non-empty dataframe."): + explain_predictions(None, input_features) + + +regression_best_worst_answer = """Test Pipeline Name + + Parameters go here + + Best 1 of 1 + + Predicted Value: 1 + Target Value: 2 + Absolute Difference: 1 + + table goes here + + + Worst 1 of 1 + + Predicted Value: 2 + Target Value: 3 + Absolute Difference: 4 + + table goes here + + +""" + +no_best_worst_answer = """Test Pipeline Name + + Parameters go here + + 1 of 2 + + table goes here + + + 2 of 2 + + table goes here + + +""" + + +binary_best_worst_answer = """Test Pipeline Name + + Parameters go here + + Best 1 of 1 + + Predicted Probabilities: [benign: 0.05, malignant: 0.95] + Predicted Value: malignant + Target Value: malignant + Cross Entropy: 0.2 + + table goes here + + + Worst 1 of 1 + + Predicted Probabilities: [benign: 0.1, malignant: 0.9] + Predicted Value: malignant + Target Value: benign + Cross Entropy: 0.78 + + table goes here + + +""" + +multiclass_table = """Class: setosa + + table goes here + + + Class: versicolor + + table goes here + + + Class: virginica + + table goes here""" + +multiclass_best_worst_answer = """Test Pipeline Name + + Parameters go here + + Best 1 of 1 + + Predicted Probabilities: [setosa: 0.8, versicolor: 0.1, virginica: 0.1] + Predicted Value: setosa + Target Value: setosa + Cross Entropy: 0.15 + + {multiclass_table} + + + Worst 1 of 1 + + Predicted Probabilities: [setosa: 0.2, versicolor: 0.75, virginica: 0.05] + Predicted Value: versicolor + Target Value: versicolor + Cross Entropy: 0.34 + + {multiclass_table} + + +""".format(multiclass_table=multiclass_table) + +multiclass_no_best_worst_answer = """Test Pipeline Name + + Parameters go here + + 1 of 2 + + {multiclass_table} + + + 2 of 2 + + {multiclass_table} + + +""".format(multiclass_table=multiclass_table) + + +@pytest.mark.parametrize("problem_type,answer,explain_predictions_answer", + [(ProblemTypes.REGRESSION, regression_best_worst_answer, no_best_worst_answer), + (ProblemTypes.BINARY, binary_best_worst_answer, no_best_worst_answer), + (ProblemTypes.MULTICLASS, multiclass_best_worst_answer, multiclass_no_best_worst_answer)]) +@patch("evalml.pipelines.prediction_explanations.explainers.DEFAULT_METRICS") +@patch("evalml.pipelines.prediction_explanations._user_interface._make_single_prediction_shap_table") +def test_explain_predictions_best_worst_and_explain_predictions(mock_make_table, mock_default_metrics, + problem_type, answer, explain_predictions_answer): + + mock_make_table.return_value = "table goes here" + pipeline = MagicMock() + pipeline.parameters = "Parameters go here" + input_features = pd.DataFrame({"a": [3, 4]}) + pipeline.problem_type = problem_type + pipeline.name = "Test Pipeline Name" + + if problem_type == ProblemTypes.REGRESSION: + abs_error_mock = MagicMock(__name__="abs_error") + abs_error_mock.return_value = pd.Series([4, 1], dtype="int") + mock_default_metrics.__getitem__.return_value = abs_error_mock + pipeline.predict.return_value = pd.Series([2, 1]) + y_true = pd.Series([3, 2]) + elif problem_type == ProblemTypes.BINARY: + pipeline._classes.return_value = ["benign", "malignant"] + cross_entropy_mock = MagicMock(__name__="cross_entropy") + mock_default_metrics.__getitem__.return_value = cross_entropy_mock + cross_entropy_mock.return_value = pd.Series([0.2, 0.78]) + pipeline.predict_proba.return_value = pd.DataFrame({"benign": [0.05, 0.1], "malignant": [0.95, 0.9]}) + pipeline.predict.return_value = pd.Series(["malignant"] * 2) + y_true = pd.Series(["malignant", "benign"]) + else: + mock_make_table.return_value = multiclass_table + pipeline._classes.return_value = ["setosa", "versicolor", "virginica"] + cross_entropy_mock = MagicMock(__name__="cross_entropy") + mock_default_metrics.__getitem__.return_value = cross_entropy_mock + cross_entropy_mock.return_value = pd.Series([0.15, 0.34]) + pipeline.predict_proba.return_value = pd.DataFrame({"setosa": [0.8, 0.2], "versicolor": [0.1, 0.75], + "virginica": [0.1, 0.05]}) + pipeline.predict.return_value = ["setosa", "versicolor"] + y_true = pd.Series(["setosa", "versicolor"]) + + best_worst_report = explain_predictions_best_worst(pipeline, input_features, y_true=y_true, + num_to_explain=1) + + compare_two_tables(best_worst_report.splitlines(), answer.splitlines()) + + report = explain_predictions(pipeline, input_features) + compare_two_tables(report.splitlines(), explain_predictions_answer.splitlines()) + + +@pytest.mark.parametrize("problem_type,answer", + [(ProblemTypes.REGRESSION, no_best_worst_answer), + (ProblemTypes.BINARY, no_best_worst_answer), + (ProblemTypes.MULTICLASS, multiclass_no_best_worst_answer)]) +@patch("evalml.pipelines.prediction_explanations._user_interface._make_single_prediction_shap_table") +def test_explain_predictions_custom_index(mock_make_table, problem_type, answer): + + mock_make_table.return_value = "table goes here" + pipeline = MagicMock() + pipeline.parameters = "Parameters go here" + input_features = pd.DataFrame({"a": [3, 4]}, index=["first", "second"]) + pipeline.problem_type = problem_type + pipeline.name = "Test Pipeline Name" + + if problem_type == ProblemTypes.REGRESSION: + pipeline.predict.return_value = pd.Series([2, 1]) + elif problem_type == ProblemTypes.BINARY: + pipeline._classes.return_value = ["benign", "malignant"] + pipeline.predict.return_value = pd.Series(["malignant"] * 2) + pipeline.predict_proba.return_value = pd.DataFrame({"benign": [0.05, 0.1], "malignant": [0.95, 0.9]}) + else: + mock_make_table.return_value = multiclass_table + pipeline._classes.return_value = ["setosa", "versicolor", "virginica"] + pipeline.predict.return_value = pd.Series(["setosa", "versicolor"]) + pipeline.predict_proba.return_value = pd.DataFrame({"setosa": [0.8, 0.2], "versicolor": [0.1, 0.75], + "virginica": [0.1, 0.05]}) + + report = explain_predictions(pipeline, input_features, training_data=input_features) + + compare_two_tables(report.splitlines(), answer.splitlines()) + + +regression_custom_metric_answer = """Test Pipeline Name + + Parameters go here + + Best 1 of 1 + + Predicted Value: 1 + Target Value: 2 + sum: 3 + + table goes here + + + Worst 1 of 1 + + Predicted Value: 2 + Target Value: 3 + sum: 5 + + table goes here + + +""" + + +@patch("evalml.pipelines.prediction_explanations._user_interface._make_single_prediction_shap_table") +def test_explain_predictions_best_worst_custom_metric(mock_make_table): + + mock_make_table.return_value = "table goes here" + pipeline = MagicMock() + pipeline.parameters = "Parameters go here" + input_features = pd.DataFrame({"a": [5, 6]}) + pipeline.problem_type = ProblemTypes.REGRESSION + pipeline.name = "Test Pipeline Name" + + pipeline.predict.return_value = pd.Series([2, 1]) + y_true = pd.Series([3, 2]) + + def sum(y_true, y_pred): + return y_pred + y_true + + best_worst_report = explain_predictions_best_worst(pipeline, input_features, y_true=y_true, + num_to_explain=1, metric=sum) + + compare_two_tables(best_worst_report.splitlines(), regression_custom_metric_answer.splitlines()) diff --git a/evalml/tests/pipeline_tests/explanations_tests/test_user_interface.py b/evalml/tests/pipeline_tests/explanations_tests/test_user_interface.py index ea7b83cf1f..a91e67065b 100644 --- a/evalml/tests/pipeline_tests/explanations_tests/test_user_interface.py +++ b/evalml/tests/pipeline_tests/explanations_tests/test_user_interface.py @@ -5,8 +5,10 @@ from evalml.pipelines.prediction_explanations._user_interface import ( _make_rows, - _make_single_prediction_table, - _make_table + _make_table, + _SHAPBinaryTableMaker, + _SHAPMultiClassTableMaker, + _SHAPRegressionTableMaker ) make_rows_test_cases = [({"a": [0.2], "b": [0.1]}, 3, [["a", "++"], ["b", "+"]]), @@ -32,13 +34,7 @@ def test_make_rows_and_make_table(test_case, include_shap_values): assert _make_rows(values, values, top_k, include_shap_values) == new_answer - dtypes = ["t", "t"] - alignment = ["c", "c"] - if include_shap_values: - dtypes.append("f") - alignment.append("c") - - table = _make_table(dtypes, alignment, values, values, top_k, include_shap_values).splitlines() + table = _make_table(values, values, top_k, include_shap_values).splitlines() if include_shap_values: assert "SHAP Value" in table[0] # Subtracting two because a header and a line under the header are included in the table. @@ -80,9 +76,7 @@ def test_make_rows_and_make_table(test_case, include_shap_values): {'a': [0.102], 'b': [0.097], 'c': [0.0], 'd': [-0.225], 'e': [-0.2422], 'f': [-0.251], 'foo': [-0.087]}] -binary_table = """Positive Label - - Feature Name Contribution to Prediction +binary_table = """Feature Name Contribution to Prediction ========================================= a + b + @@ -91,9 +85,7 @@ def test_make_rows_and_make_table(test_case, include_shap_values): e -- f --""".splitlines() -binary_table_shap = """Positive Label - - Feature Name Contribution to Prediction SHAP Value +binary_table_shap = """Feature Name Contribution to Prediction SHAP Value ====================================================== a + 1.180 b + 1.120 @@ -113,7 +105,7 @@ def test_make_rows_and_make_table(test_case, include_shap_values): {'a': [0.102], 'b': [0.097], 'c': [0.0], 'd': [-0.221], 'e': [-0.242], 'f': [-0.251], 'foo': [-0.0865]}, {'a': [0.0825], 'b': [0.0], 'c': [0.0], 'd': [-0.223], 'e': [-0.247], 'f': [-0.325], 'foo': [-0.121]}] -multiclass_table = """Class 0 +multiclass_table = """Class: 0 Feature Name Contribution to Prediction ========================================= @@ -125,7 +117,7 @@ def test_make_rows_and_make_table(test_case, include_shap_values): foo ----- - Class 1 + Class: 1 Feature Name Contribution to Prediction ========================================= @@ -137,7 +129,7 @@ def test_make_rows_and_make_table(test_case, include_shap_values): f -- - Class 2 + Class: 2 Feature Name Contribution to Prediction ========================================= @@ -148,7 +140,7 @@ def test_make_rows_and_make_table(test_case, include_shap_values): e -- f --""".splitlines() -multiclass_table_shap = """Class 0 +multiclass_table_shap = """Class: 0 Feature Name Contribution to Prediction SHAP Value ====================================================== @@ -160,7 +152,7 @@ def test_make_rows_and_make_table(test_case, include_shap_values): foo ----- -1.000 - Class 1 + Class: 1 Feature Name Contribution to Prediction SHAP Value ====================================================== @@ -172,7 +164,7 @@ def test_make_rows_and_make_table(test_case, include_shap_values): f -- -2.900 - Class 2 + Class: 2 Feature Name Contribution to Prediction SHAP Value ====================================================== @@ -192,7 +184,14 @@ def test_make_rows_and_make_table(test_case, include_shap_values): (multiclass, multiclass_normalized, False, multiclass_table), (multiclass, multiclass_normalized, True, multiclass_table_shap)]) def test_make_single_prediction_table(values, normalized_values, include_shap, answer): - table = _make_single_prediction_table(values, normalized_values, include_shap_values=include_shap) + if isinstance(values, list): + if len(values) > 2: + table_maker = _SHAPMultiClassTableMaker(class_names=["0", "1", "2"]) + else: + table_maker = _SHAPBinaryTableMaker() + else: + table_maker = _SHAPRegressionTableMaker() + table = table_maker(values, normalized_values, top_k=3, include_shap_values=include_shap) # Making sure the content is the same, regardless of formatting. for row_table, row_answer in zip(table.splitlines(), answer):