-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Partial dependence for multiclass #1554
Conversation
Codecov Report
@@ Coverage Diff @@
## main #1554 +/- ##
=========================================
+ Coverage 100.0% 100.0% +0.1%
=========================================
Files 236 236
Lines 16877 16933 +56
=========================================
+ Hits 16869 16925 +56
Misses 8 8
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@freddyaboulton Everything looks great!
evalml/model_understanding/graphs.py
Outdated
fig.add_trace(_go.Scatter(x=part_dep.loc[part_dep.class_label == label, 'feature_values'], | ||
y=part_dep.loc[part_dep.class_label == label, 'partial_dependence'], | ||
line=dict(width=3)), | ||
row=1, col=i + 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@freddyaboulton Great! Also if the dataset has a larger number of labels, it might be difficult for the user to see all the partial dependency plots in one row.
Maybe:
_subplots.make_subplots(rows=(len(class_labels)+1) // 2, cols=2 ...etc)
and then in line 523:
row=(i+2) // 2
and col=(i%2) + 1
.
Not really a big deal either way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! I think the only thing missing was that we should only use two columns in the case where class_label=None
or else there would be an empty second column in the plot.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent!
|
||
data = pd.DataFrame({"feature_values": np.tile(values[0], avg_pred.shape[0]), | ||
"partial_dependence": np.concatenate([pred for pred in avg_pred])}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neat!
f029bcb
to
35be5f4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left two nit-picky things, but otherwise LGTM!! This is great stuff 😁
data = pd.DataFrame({"feature_values": np.tile(values[0], avg_pred.shape[0]), | ||
"partial_dependence": np.concatenate([pred for pred in avg_pred])}) | ||
if classes is not None: | ||
data['class_label'] = np.repeat(classes, len(values[0])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit-pick: Since we're changing the output to return this new field in the DF, could be good to update this docstring too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion! Done!
evalml/model_understanding/graphs.py
Outdated
@@ -476,6 +486,10 @@ def graph_partial_dependence(pipeline, X, feature, grid_resolution=100): | |||
feature (int, string): The target feature for which to create the partial dependence plot for. | |||
If feature is an int, it must be the index of the feature to use. | |||
If feature is a string, it must be a valid column name in X. | |||
class_label (string, None): Name of class to plot for multiclass problems. If None, will plot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively:
class_label (string, optional): Name of class to plot for multiclass problems. If None, will plot...; Defaults to None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
e17d49e
to
03fd463
Compare
@angela97lin good catch! I just pushed this up and updated the tests for multiclass/not multiclass. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! One thing I think would be cool to have is to allow users to pass in a list of classes to create these plots for in multiclass scenarios. For instance, I, as a user, decided to create a multiclass pipeline with 20 possible target classes, I might want to plot partial dependencies for a subset of the classes only, rather than for 1 or all. Certainly not blocking, but wanted to bring that suggestion up for possible discussion.
y=part_dep.loc[part_dep.class_label == label, 'partial_dependence'], | ||
line=dict(width=3), | ||
name=label), | ||
row=(i + 2) // 2, col=(i % 2) + 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Great suggestion @bchen1116 ! I'm on board but let's continue the discussion to #1565 since this issue/PR tracks returning the partial dependence for all the classes as opposed to just the first one in multiclass problems. |
Pull Request Description
Fixes #1404
Example on wine dataset
After creating the pull request: in order to pass the release_notes_updated check you will need to update the "Future Release" section of
docs/source/release_notes.rst
to include this pull request by adding :pr:123
.