Skip to content

Commit

Permalink
Change colors of confusion matrix to shades of blue and change the ax…
Browse files Browse the repository at this point in the history
…is order to match scikit-learn's (#426)

* changing colors of heatmap

* changelog

* update output labels

* update changelog

* actually fixing columns to make it the same order as sklearn's

* moving reversal of cols to data generation step instead
  • Loading branch information
angela97lin committed Mar 1, 2020
1 parent 5dba555 commit 6105522
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/source/changelog.rst
Expand Up @@ -8,6 +8,7 @@ Changelog
* Add CatBoost (gradient-boosted trees) classification and regression components and pipelines :pr:`247`
* Added Tuner abstract base class :pr:`351`
* Added n_jobs as parameter for AutoClassificationSearch and AutoRegressionSearch :pr:`403`
* Changed colors of confusion matrix to shades of blue and updated axis order to match scikit-learn's :pr:`426`
* Fixes
* Fixed ROC and confusion matrix plots not being calculated if user passed own additional_objectives :pr:`276`
* Changes
Expand Down
15 changes: 10 additions & 5 deletions evalml/automl/pipeline_search_plots.py
Expand Up @@ -186,7 +186,10 @@ def get_confusion_matrix_data(self, pipeline_id):

confusion_matrix_data = []
for fold in cv_data:
confusion_matrix_data.append(fold["all_objective_scores"]["Confusion Matrix"])
conf_mat = fold["all_objective_scores"]["Confusion Matrix"]
# reverse columns in confusion matrix to change axis order to match sklearn's
conf_mat = conf_mat.iloc[:, ::-1]
confusion_matrix_data.append(conf_mat)
return confusion_matrix_data

def generate_confusion_matrix(self, pipeline_id, fold_num=None):
Expand All @@ -205,15 +208,17 @@ def generate_confusion_matrix(self, pipeline_id, fold_num=None):

conf_mat = data[fold_num]
labels = conf_mat.columns
reversed_labels = labels[::-1]

layout = go.Layout(title={'text': 'Confusion matrix of<br>{} w/ ID={}'.format(pipeline_name, pipeline_id)},
xaxis={'title': 'Predicted Label', 'tickvals': labels},
yaxis={'title': 'True Label', 'tickvals': labels})
figure = go.Figure(data=go.Heatmap(x=labels, y=labels, z=conf_mat,
xaxis={'title': 'Predicted Label', 'type': 'category', 'tickvals': labels},
yaxis={'title': 'True Label', 'type': 'category', 'tickvals': reversed_labels})
figure = go.Figure(data=go.Heatmap(x=labels, y=reversed_labels, z=conf_mat,
hovertemplate='<b>True</b>: %{y}' +
'<br><b>Predicted</b>: %{x}' +
'<br><b>Number of times</b>: %{z}' +
'<extra></extra>'), # necessary to remove unwanted trace info
'<extra></extra>', # necessary to remove unwanted trace info
colorscale='Blues'),
layout=layout)
return figure

Expand Down

0 comments on commit 6105522

Please sign in to comment.