Skip to content

Commit

Permalink
Adding the threshold to the ROC curve hovertext (#1161)
Browse files Browse the repository at this point in the history
* Adding the threshold of the roc curve to the hovertext. Adding True, False positive rate as well.

* Adding PR 1161 to release notes.
  • Loading branch information
freddyaboulton committed Sep 14, 2020
1 parent 6d4e98e commit bab52e2
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Release Notes
* Added label encoder to lightGBM for binary classification :pr:`1152`
* Added labels for the row index of confusion matrix :pr:`1154`
* Added AutoMLSearch object as another parameter in search callbacks :pr:`1156`
* Added the corresponding probability threshold for each point displayed in `graph_roc_curve` :pr:`1161`
* Fixes
* Fixed XGBoost column names for partial dependence methods :pr:`1104`
* Removed dead code validating column type from `TextFeaturizer` :pr:`1122`
Expand Down
8 changes: 5 additions & 3 deletions evalml/model_understanding/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,12 @@ def graph_roc_curve(y_true, y_pred_proba, custom_class_names=None, title_additio
data = []
for i in range(n_classes):
roc_curve_data = roc_curve(y_one_hot_true[:, i], y_pred_proba[:, i])
name = i + 1 if custom_class_names is None else custom_class_names[i]
data.append(_go.Scatter(x=roc_curve_data['fpr_rates'], y=roc_curve_data['tpr_rates'],
name='Class {name} (AUC {:06f})'
.format(roc_curve_data['auc_score'],
name=i + 1 if custom_class_names is None else custom_class_names[i]),
hovertemplate="(False Postive Rate: %{x}, True Positive Rate: %{y})<br>" +
"Threshold: %{text}",
name=f"Class {name} (AUC {roc_curve_data['auc_score']:.06f})",
text=roc_curve_data["thresholds"],
line=dict(width=3)))
data.append(_go.Scatter(x=[0, 1], y=[0, 1],
name='Trivial Model (AUC 0.5)',
Expand Down
2 changes: 2 additions & 0 deletions evalml/tests/model_understanding_tests/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def test_graph_roc_curve_binary(X_y_binary, data_type):
roc_curve_data = roc_curve(y_true, y_pred_proba)
assert np.array_equal(fig_dict['data'][0]['x'], roc_curve_data['fpr_rates'])
assert np.array_equal(fig_dict['data'][0]['y'], roc_curve_data['tpr_rates'])
assert np.array_equal(fig_dict['data'][0]['text'], roc_curve_data['thresholds'])
assert fig_dict['data'][0]['name'] == 'Class 1 (AUC {:06f})'.format(roc_curve_data['auc_score'])
assert np.array_equal(fig_dict['data'][1]['x'], np.array([0, 1]))
assert np.array_equal(fig_dict['data'][1]['y'], np.array([0, 1]))
Expand Down Expand Up @@ -325,6 +326,7 @@ def test_graph_roc_curve_multiclass(binarized_ys):
roc_curve_data = roc_curve(y_tr[:, i], y_pred_proba[:, i])
assert np.array_equal(fig_dict['data'][i]['x'], roc_curve_data['fpr_rates'])
assert np.array_equal(fig_dict['data'][i]['y'], roc_curve_data['tpr_rates'])
assert np.array_equal(fig_dict['data'][i]['text'], roc_curve_data['thresholds'])
assert fig_dict['data'][i]['name'] == 'Class {name} (AUC {:06f})'.format(roc_curve_data['auc_score'], name=i + 1)
assert np.array_equal(fig_dict['data'][3]['x'], np.array([0, 1]))
assert np.array_equal(fig_dict['data'][3]['y'], np.array([0, 1]))
Expand Down

0 comments on commit bab52e2

Please sign in to comment.