Skip to content

Commit

Permalink
Laying out the partial dependence plots in multiple rows for multiclass.
Browse files Browse the repository at this point in the history
  • Loading branch information
freddyaboulton committed Dec 15, 2020
1 parent c07a548 commit 35be5f4
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
19 changes: 13 additions & 6 deletions evalml/model_understanding/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,22 +508,29 @@ def graph_partial_dependence(pipeline, X, feature, class_label=None, grid_resolu
feature_name = str(feature)
title = f"Partial Dependence of '{feature_name}'"
layout = _go.Layout(title={'text': title},
xaxis={'title': f'{feature_name}', 'range': _calculate_axis_range(part_dep['feature_values'])},
yaxis={'title': 'Partial Dependence', 'range': _calculate_axis_range(part_dep['partial_dependence'])},
xaxis={'title': f'{feature_name}'},
yaxis={'title': 'Partial Dependence'},
showlegend=False)
if isinstance(pipeline, evalml.pipelines.MulticlassClassificationPipeline):
class_labels = [class_label] if class_label is not None else pipeline.classes_
_subplots = import_or_raise("plotly.subplots", error_msg="Cannot find dependency plotly.graph_objects")
fig = _subplots.make_subplots(rows=1, cols=len(class_labels), subplot_titles=class_labels,
shared_xaxes=True, shared_yaxes=True)

# If the user passes in a value for class_label, we want to create a 1 x 1 subplot or else there would
# be an empty column in the plot and it would look awkward
rows, cols = ((len(class_labels) + 1) // 2, 2) if len(class_labels) > 1 else (1, len(class_labels))

# Don't specify share_xaxis and share_yaxis so that we get tickmarks in each subplot
fig = _subplots.make_subplots(rows=rows, cols=cols, subplot_titles=class_labels)
for i, label in enumerate(class_labels):

# Plotly trace indexing begins at 1 so we add 1 to i
fig.add_trace(_go.Scatter(x=part_dep.loc[part_dep.class_label == label, 'feature_values'],
y=part_dep.loc[part_dep.class_label == label, 'partial_dependence'],
line=dict(width=3)),
row=1, col=i + 1)
row=(i+2)//2, col=(i % 2) + 1)
fig.update_layout(layout)
fig.update_xaxes(title=f'{feature_name}')
fig.update_xaxes(title=f'{feature_name}', range=_calculate_axis_range(part_dep['feature_values']))
fig.update_yaxes(range=_calculate_axis_range(part_dep['partial_dependence']))
else:
trace = _go.Scatter(x=part_dep['feature_values'],
y=part_dep['partial_dependence'],
Expand Down
7 changes: 6 additions & 1 deletion evalml/tests/model_understanding_tests/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
partial_dependence,
precision_recall_curve,
roc_curve,
visualize_decision_tree
visualize_decision_tree,
)
from evalml.objectives import CostBenefitMatrix
from evalml.pipelines import (
Expand Down Expand Up @@ -923,6 +923,11 @@ def test_graph_partial_dependence_multiclass(logistic_regression_multiclass_pipe
assert len(data['x']) == 20
assert len(data['y']) == 20

# Check that all the subplots axes have the same range
for suplot_1_axis, suplot_2_axis in [('axis2', 'axis3'), ('axis2', 'axis4'), ('axis3', 'axis4')]:
for axis_type in ['x', 'y']:
assert fig_dict['layout'][axis_type + suplot_1_axis]['range'] == fig_dict['layout'][axis_type + suplot_2_axis]['range']

fig = graph_partial_dependence(pipeline, X, feature='magnesium', class_label='class_1', grid_resolution=20)

assert isinstance(fig, go.Figure)
Expand Down

0 comments on commit 35be5f4

Please sign in to comment.