New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Partial dependence for multiclass #1554
Changes from all commits
9988011
580d008
b9cf0c6
0ec7c9b
d2d11c9
03fd463
39c1cc4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -439,7 +439,12 @@ def partial_dependence(pipeline, X, feature, grid_resolution=100): | |
|
||
Returns: | ||
pd.DataFrame: DataFrame with averaged predictions for all points in the grid averaged | ||
over all samples of X and the values used to calculate those predictions. | ||
over all samples of X and the values used to calculate those predictions. The dataframe will | ||
contain two columns: "feature_values" (grid points at which the partial dependence was calculated) and | ||
"partial_dependence" (the partial dependence at that feature value). For classification problems, there | ||
will be a third column called "class_label" (the class label for which the partial | ||
dependence was calculated). For binary classification, the partial dependence is only calculated for the | ||
"positive" class. | ||
|
||
""" | ||
X = _convert_to_woodwork_structure(X) | ||
|
@@ -462,11 +467,21 @@ def partial_dependence(pipeline, X, feature, grid_resolution=100): | |
# Delete scikit-learn attributes that were temporarily set | ||
del pipeline._estimator_type | ||
del pipeline.feature_importances_ | ||
return pd.DataFrame({"feature_values": values[0], | ||
"partial_dependence": avg_pred[0]}) | ||
classes = None | ||
if isinstance(pipeline, evalml.pipelines.BinaryClassificationPipeline): | ||
classes = [pipeline.classes_[1]] | ||
elif isinstance(pipeline, evalml.pipelines.MulticlassClassificationPipeline): | ||
classes = pipeline.classes_ | ||
|
||
data = pd.DataFrame({"feature_values": np.tile(values[0], avg_pred.shape[0]), | ||
"partial_dependence": np.concatenate([pred for pred in avg_pred])}) | ||
if classes is not None: | ||
data['class_label'] = np.repeat(classes, len(values[0])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit-pick: Since we're changing the output to return this new field in the DF, could be good to update this docstring too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good suggestion! Done! |
||
|
||
return data | ||
|
||
def graph_partial_dependence(pipeline, X, feature, grid_resolution=100): | ||
|
||
def graph_partial_dependence(pipeline, X, feature, class_label=None, grid_resolution=100): | ||
"""Create an one-way partial dependence plot. | ||
|
||
Arguments: | ||
|
@@ -476,6 +491,10 @@ def graph_partial_dependence(pipeline, X, feature, grid_resolution=100): | |
feature (int, string): The target feature for which to create the partial dependence plot for. | ||
If feature is an int, it must be the index of the feature to use. | ||
If feature is a string, it must be a valid column name in X. | ||
class_label (string, optional): Name of class to plot for multiclass problems. If None, will plot | ||
the partial dependence for each class. This argument does not change behavior for regression or binary | ||
classification pipelines. For binary classification, the partial dependence for the positive label will | ||
always be displayed. Defaults to None. | ||
|
||
Returns: | ||
pd.DataFrame: pd.DataFrame with averaged predictions for all points in the grid averaged | ||
|
@@ -485,19 +504,47 @@ def graph_partial_dependence(pipeline, X, feature, grid_resolution=100): | |
_go = import_or_raise("plotly.graph_objects", error_msg="Cannot find dependency plotly.graph_objects") | ||
if jupyter_check(): | ||
import_or_raise("ipywidgets", warning=True) | ||
if isinstance(pipeline, evalml.pipelines.MulticlassClassificationPipeline) and class_label is not None: | ||
if class_label not in pipeline.classes_: | ||
msg = f"Class {class_label} is not one of the classes the pipeline was fit on: {', '.join(list(pipeline.classes_))}" | ||
raise ValueError(msg) | ||
|
||
part_dep = partial_dependence(pipeline, X, feature=feature, grid_resolution=grid_resolution) | ||
feature_name = str(feature) | ||
title = f"Partial Dependence of '{feature_name}'" | ||
layout = _go.Layout(title={'text': title}, | ||
xaxis={'title': f'{feature_name}', 'range': _calculate_axis_range(part_dep['feature_values'])}, | ||
yaxis={'title': 'Partial Dependence', 'range': _calculate_axis_range(part_dep['partial_dependence'])}) | ||
data = [] | ||
data.append(_go.Scatter(x=part_dep['feature_values'], | ||
xaxis={'title': f'{feature_name}'}, | ||
yaxis={'title': 'Partial Dependence'}, | ||
showlegend=False) | ||
if isinstance(pipeline, evalml.pipelines.MulticlassClassificationPipeline): | ||
class_labels = [class_label] if class_label is not None else pipeline.classes_ | ||
_subplots = import_or_raise("plotly.subplots", error_msg="Cannot find dependency plotly.graph_objects") | ||
|
||
# If the user passes in a value for class_label, we want to create a 1 x 1 subplot or else there would | ||
# be an empty column in the plot and it would look awkward | ||
rows, cols = ((len(class_labels) + 1) // 2, 2) if len(class_labels) > 1 else (1, len(class_labels)) | ||
|
||
# Don't specify share_xaxis and share_yaxis so that we get tickmarks in each subplot | ||
fig = _subplots.make_subplots(rows=rows, cols=cols, subplot_titles=class_labels) | ||
for i, label in enumerate(class_labels): | ||
|
||
# Plotly trace indexing begins at 1 so we add 1 to i | ||
fig.add_trace(_go.Scatter(x=part_dep.loc[part_dep.class_label == label, 'feature_values'], | ||
y=part_dep.loc[part_dep.class_label == label, 'partial_dependence'], | ||
line=dict(width=3), | ||
name=label), | ||
row=(i + 2) // 2, col=(i % 2) + 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! |
||
fig.update_layout(layout) | ||
fig.update_xaxes(title=f'{feature_name}', range=_calculate_axis_range(part_dep['feature_values'])) | ||
fig.update_yaxes(range=_calculate_axis_range(part_dep['partial_dependence'])) | ||
else: | ||
trace = _go.Scatter(x=part_dep['feature_values'], | ||
y=part_dep['partial_dependence'], | ||
name='Partial Dependence', | ||
line=dict(width=3))) | ||
return _go.Figure(layout=layout, data=data) | ||
line=dict(width=3)) | ||
fig = _go.Figure(layout=layout, data=[trace]) | ||
|
||
return fig | ||
|
||
|
||
def _calculate_axis_range(arr): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neat!