-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for multiclass classification to roc_curve
#1164
Conversation
Codecov Report
@@ Coverage Diff @@
## main #1164 +/- ##
=======================================
Coverage 99.92% 99.92%
=======================================
Files 196 196
Lines 11729 11780 +51
=======================================
+ Hits 11720 11771 +51
Misses 9 9
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@christopherbunn This looks good to me! My one comment is that I think it would be nice if we could give users the option of passing in the predict_proba
dataframe rather than forcing them to pick out the column for the positive class for binary problems.
evalml/model_understanding/graphs.py
Outdated
|
||
Arguments: | ||
y_true (pd.Series or np.array): true labels. | ||
y_pred_proba (pd.Series or np.array): predictions from a classifier, before thresholding has been applied. Note that 1 dimensional input is expected. | ||
|
||
y_pred_proba (pd.Series or np.array): predictions from a classifier, before thresholding has been applied. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit-pick: I think we need to update the docstring because y_pred_proba can now be a dataframe
if isinstance(y_pred_proba, (pd.Series, pd.DataFrame)): | ||
y_pred_proba = y_pred_proba.to_numpy() | ||
|
||
if y_pred_proba.ndim == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should also check for the binary case case like so:
if y_pred_proba.shape[1] == 2:
y_pred_proba = y_pred_proba.iloc[:, 1].reshape(-1, 1)
My thought for doing this is that it would be nice if the api for binary and multiclass classification would be the same. As it stands now, a user has to manually pick out the column for the positive class from the predict_proba
dataframe but for multiclass they pass in the entire dataframe.
To be clear, this wouldn't be a breaking change because a user could still pass the probabilities for the dominant class and the y_pred_proba.ndim == 1
case would catch that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I intended for the binary and multiclass API to be the same, but I don't think I caught the fact that the binary case for predict_proba
was calculated incorrectly in my original implementation. I've added your code snippet in.
87c881a
to
e1e904e
Compare
e1e904e
to
9e4b66d
Compare
Moved LabelBinarization code from
graph_roc_curve
toroc_curve
to enable support for multiclass classification. Also updated API docs and model understanding section to include multiclass example.There's currently a breaking API change where data from
roc_curve
will now be returned as a list of dicts (with each class represented as a dict with corresponding ROC data). Previously, we were returning a dict with ROC data for a binary class.Resolves #1063