Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for multiclass classification to roc_curve #1164

Merged
merged 8 commits into from
Sep 16, 2020

Conversation

christopherbunn
Copy link
Contributor

@christopherbunn christopherbunn commented Sep 14, 2020

Moved LabelBinarization code from graph_roc_curve to roc_curve to enable support for multiclass classification. Also updated API docs and model understanding section to include multiclass example.

There's currently a breaking API change where data from roc_curve will now be returned as a list of dicts (with each class represented as a dict with corresponding ROC data). Previously, we were returning a dict with ROC data for a binary class.

Resolves #1063

@codecov
Copy link

codecov bot commented Sep 14, 2020

Codecov Report

Merging #1164 into main will increase coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1164   +/-   ##
=======================================
  Coverage   99.92%   99.92%           
=======================================
  Files         196      196           
  Lines       11729    11780   +51     
=======================================
+ Hits        11720    11771   +51     
  Misses          9        9           
Impacted Files Coverage Δ
evalml/model_understanding/graphs.py 100.00% <100.00%> (ø)
...lml/tests/model_understanding_tests/test_graphs.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update ccc7e05...db65ea9. Read the comment docs.

@christopherbunn christopherbunn marked this pull request as ready for review September 14, 2020 17:32
Copy link
Contributor

@freddyaboulton freddyaboulton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@christopherbunn This looks good to me! My one comment is that I think it would be nice if we could give users the option of passing in the predict_proba dataframe rather than forcing them to pick out the column for the positive class for binary problems.


Arguments:
y_true (pd.Series or np.array): true labels.
y_pred_proba (pd.Series or np.array): predictions from a classifier, before thresholding has been applied. Note that 1 dimensional input is expected.

y_pred_proba (pd.Series or np.array): predictions from a classifier, before thresholding has been applied.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit-pick: I think we need to update the docstring because y_pred_proba can now be a dataframe

if isinstance(y_pred_proba, (pd.Series, pd.DataFrame)):
y_pred_proba = y_pred_proba.to_numpy()

if y_pred_proba.ndim == 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should also check for the binary case case like so:

if y_pred_proba.shape[1] == 2:
    y_pred_proba = y_pred_proba.iloc[:, 1].reshape(-1, 1)

My thought for doing this is that it would be nice if the api for binary and multiclass classification would be the same. As it stands now, a user has to manually pick out the column for the positive class from the predict_proba dataframe but for multiclass they pass in the entire dataframe.

To be clear, this wouldn't be a breaking change because a user could still pass the probabilities for the dominant class and the y_pred_proba.ndim == 1 case would catch that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I intended for the binary and multiclass API to be the same, but I don't think I caught the fact that the binary case for predict_proba was calculated incorrectly in my original implementation. I've added your code snippet in.

@christopherbunn christopherbunn merged commit 9d1303c into main Sep 16, 2020
@christopherbunn christopherbunn deleted the 1063_roc_multiclass branch September 16, 2020 21:46
This was referenced Sep 17, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ROC curve for multiclass classification
2 participants