Skip to content

Commit

Permalink
Plot for time series regression (#1483)
Browse files Browse the repository at this point in the history
* Adding graph_prediction_vs_target_over_time function to model_understanding.

* Adding PR 1483 to release notes.

* graph_prediction_vs_target_over_time only supports time series regression pipelines

* Skip test_graph_prediction_vs_target_over_time_value_error if plotly not installed.

* Adding graph_prediction_vs_target_over_time to api ref.

* Adding graph_prediction_vs_target_over_time to model_understanding __init__

* Changing y-axis label in graph_prediction_vs_target_over_time.

* Using actual instead of target in function name. Adding get_prediction_vs_actual_over_time_data

* Adding PR 1483 to future releases section of the release notes.

* Changing typo in error_msg for graph_prediction_vs_actual_over_time
  • Loading branch information
freddyaboulton committed Dec 2, 2020
1 parent e3acef6 commit 3d69967
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/source/api_reference.rst
Expand Up @@ -268,6 +268,8 @@ Graph Utils
binary_objective_vs_threshold
graph_binary_objective_vs_threshold
graph_prediction_vs_actual
get_prediction_vs_actual_over_time_data
graph_prediction_vs_actual_over_time

.. currentmodule:: evalml.model_understanding.prediction_explanations

Expand Down
1 change: 1 addition & 0 deletions docs/source/release_notes.rst
Expand Up @@ -3,6 +3,7 @@ Release Notes

**Future Releases**
* Enhancements
* Added ``graph_prediction_vs_actual_over_time`` and ``get_prediction_vs_actual_over_time_data`` to the model understanding module for time series problems :pr:`1483`
* Fixes
* Fix Windows CI jobs: install ``numba`` via conda, required for ``shap`` :pr:`1490`
* Changes
Expand Down
4 changes: 3 additions & 1 deletion evalml/model_understanding/__init__.py
Expand Up @@ -12,6 +12,8 @@
graph_binary_objective_vs_threshold,
partial_dependence,
graph_partial_dependence,
graph_prediction_vs_actual
graph_prediction_vs_actual,
get_prediction_vs_actual_over_time_data,
graph_prediction_vs_actual_over_time,
)
from .prediction_explanations import explain_prediction, explain_predictions_best_worst, explain_predictions
47 changes: 47 additions & 0 deletions evalml/model_understanding/graphs.py
Expand Up @@ -555,3 +555,50 @@ def graph_prediction_vs_actual(y_true, y_pred, outlier_threshold=None):
marker=_go.scatter.Marker(color=color),
name=name))
return _go.Figure(layout=layout, data=data)


def get_prediction_vs_actual_over_time_data(pipeline, X, y, dates):
"""Get the data needed for the prediction_vs_actual_over_time plot.
Arguments:
pipeline (PipelineBase): Fitted time series regression pipeline.
X (pd.DataFrame): Features used to generate new predictions.
y (pd.Series): Target values to compare predictions against.
Returns:
pd.DataFrame
"""
return pd.DataFrame({"dates": dates,
"target": y,
"prediction": pipeline.predict(X, y)})


def graph_prediction_vs_actual_over_time(pipeline, X, y, dates):
"""Plot the target values and predictions against time on the x-axis.
Arguments:
pipeline (PipelineBase): Fitted time series regression pipeline.
X (pd.DataFrame): Features used to generate new predictions.
y (pd.Series): Target values to compare predictions against.
Returns:
plotly.Figure showing the prediction vs actual over time.
"""
_go = import_or_raise("plotly.graph_objects", error_msg="Cannot find dependency plotly.graph_objects")

if pipeline.problem_type != ProblemTypes.TIME_SERIES_REGRESSION:
raise ValueError("graph_prediction_vs_actual_over_time only supports time series regression pipelines! "
f"Received {str(pipeline.problem_type)}.")

data = get_prediction_vs_actual_over_time_data(pipeline, X, y, dates)

data = [_go.Scatter(x=data["dates"], y=data["target"], mode='lines+markers', name="Target",
line=dict(color='#1f77b4')),
_go.Scatter(x=data["dates"], y=data["prediction"], mode='lines+markers', name='Prediction',
line=dict(color='#d62728'))]
# Let plotly pick the best date format.
layout = _go.Layout(title={'text': "Prediction vs Target over time"},
xaxis={'title': 'Time'},
yaxis={'title': 'Target Values and Predictions'})

return _go.Figure(data=data, layout=layout)
42 changes: 42 additions & 0 deletions evalml/tests/model_understanding_tests/test_graphs.py
Expand Up @@ -21,6 +21,7 @@
graph_permutation_importance,
graph_precision_recall_curve,
graph_prediction_vs_actual,
graph_prediction_vs_actual_over_time,
graph_roc_curve,
normalize_confusion_matrix,
partial_dependence,
Expand Down Expand Up @@ -954,3 +955,44 @@ def test_graph_prediction_vs_actual():
assert len(fig_dict['data'][2]['x']) == 2
assert len(fig_dict['data'][2]['y']) == 2
assert fig_dict['data'][2]['name'] == ">= outlier_threshold"


def test_graph_prediction_vs_actual_over_time():
go = pytest.importorskip('plotly.graph_objects', reason='Skipping plotting test because plotly not installed')

class MockPipeline:
problem_type = ProblemTypes.TIME_SERIES_REGRESSION

def predict(self, X, y):
return y + 10

y = np.arange(61)
dates = pd.date_range("2020-03-01", "2020-04-30")
pipeline = MockPipeline()

# For this test it doesn't matter what the features are
fig = graph_prediction_vs_actual_over_time(pipeline, X=pd.DataFrame(), y=y, dates=dates)

assert isinstance(fig, go.Figure)
fig_dict = fig.to_dict()
assert fig_dict['layout']['title']['text'] == 'Prediction vs Target over time'
assert fig_dict['layout']['xaxis']['title']['text'] == 'Time'
assert fig_dict['layout']['yaxis']['title']['text'] == 'Target Values and Predictions'
assert len(fig_dict['data']) == 2
assert fig_dict['data'][0]['line']['color'] == '#1f77b4'
assert len(fig_dict['data'][0]['x']) == 61
assert len(fig_dict['data'][0]['y']) == 61
assert fig_dict['data'][1]['line']['color'] == '#d62728'
assert len(fig_dict['data'][1]['x']) == 61
assert len(fig_dict['data'][1]['y']) == 61


def test_graph_prediction_vs_actual_over_time_value_error():
pytest.importorskip('plotly.graph_objects', reason='Skipping plotting test because plotly not installed')

class NotTSPipeline:
problem_type = ProblemTypes.REGRESSION

error_msg = "graph_prediction_vs_actual_over_time only supports time series regression pipelines! Received regression."
with pytest.raises(ValueError, match=error_msg):
graph_prediction_vs_actual_over_time(NotTSPipeline(), None, None, None)

0 comments on commit 3d69967

Please sign in to comment.