Skip to content

Commit

Permalink
Allow to skip plot_show() to be able to get the current axis (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
kjappelbaum committed Mar 18, 2021
1 parent 2bf3a26 commit a2b13d5
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 6 deletions.
12 changes: 10 additions & 2 deletions causalimpact/main.py
Expand Up @@ -223,7 +223,8 @@ def __init__(
def plot(
self,
panels: List[str] = ['original', 'pointwise', 'cumulative'],
figsize: Tuple[int] = (10, 7)
figsize: Tuple[int] = (10, 7),
show: bool = True
) -> None:
"""
Plots the graphic with results associated to Causal Impact.
Expand All @@ -238,9 +239,16 @@ def plot(
data and its forecasts.
figsize: Tuple[int]
Sets the width and height of the figure to plot.
show: bool
If `True` then plots the figure by running `plt.plot()`.
If `False` then nothing will be plotted which allows for accessing and
manipulating the figure and axis of the plot, i.e., the figure can be saved
and the styling can be modified. To get the axis, just run:
`import matplotlib.pyplot as plt; ax = plt.gca()` or the figure:
`fig = plt.gcf()`. Defaults to `True`.
"""
plotter.plot(self.inferences, self.pre_data, self.post_data, panels=panels,
figsize=figsize)
figsize=figsize, show=show)

def summary(self, output: str = 'summary', digits: int = 2) -> str:
"""
Expand Down
13 changes: 10 additions & 3 deletions causalimpact/plot.py
Expand Up @@ -26,7 +26,8 @@ def plot(
pre_data: pd.DataFrame,
post_data: pd.DataFrame,
panels=['original', 'pointwise', 'cumulative'],
figsize=(10, 7)
figsize=(10, 7),
show=True
) -> None:
"""Plots inferences results related to causal impact analysis.
Expand All @@ -36,7 +37,12 @@ def plot(
Indicates which plot should be considered in the graphics.
figsize: tuple.
Changes the size of the graphics plotted.
show: bool.
If true, runs plt.show(), i.e., displays the figure.
If false, it gives acess to the axis, i.e., the figure can be saved
and the style of the plot can be modified by getting the axis with
`ax = plt.gca()` or the figure with `fig = plt.gcf()`.
Defaults to True.
Raises
------
RuntimeError: if inferences were not computed yet.
Expand Down Expand Up @@ -132,7 +138,8 @@ def plot(
ax.axhline(y=0, color='gray', linestyle='--')
ax.legend()
ax.grid(True, color='gainsboro')
plt.show()
if show:
plt.show()


def get_plotter(): # pragma: no cover
Expand Down
2 changes: 1 addition & 1 deletion tests/test_main.py
Expand Up @@ -199,7 +199,7 @@ def test_plotter(monkeypatch, rand_data, pre_int_period, post_int_period):
ci.plot()
plotter_mock.plot.assert_called_with('inferences', 'pre_data', 'post_data',
panels=['original', 'pointwise', 'cumulative'],
figsize=(10, 7))
figsize=(10, 7), show=True)


def test_summarizer(monkeypatch, rand_data, pre_int_period, post_int_period):
Expand Down
40 changes: 40 additions & 0 deletions tests/test_plot.py
Expand Up @@ -681,3 +681,43 @@ def test_plot_raises_wrong_input_panel(rand_data, pre_int_period, post_int_perio
'"test" is not a valid panel. Valid panels are: '
'"original", "pointwise", "cumulative".'
)


def test_plot_original_panel_gap_data_show_is_false(
rand_data, pre_int_gap_period, post_int_gap_period, inferences, monkeypatch
):
plot_mock = mock.Mock()
pre_data = rand_data.loc[pre_int_gap_period[0]: pre_int_gap_period[1]]
post_data = rand_data.loc[post_int_gap_period[0]: post_int_gap_period[1]]
pre_post_index = pre_data.index.union(post_data.index)
monkeypatch.setattr("causalimpact.plot.get_plotter", plot_mock)
plotter.plot(inferences, pre_data, post_data, panels=["original"], show=False)

plot_mock.assert_called_once()
plot_mock.return_value.figure.assert_called_with(figsize=(10, 7))
plot_mock.return_value.subplot.assert_any_call(1, 1, 1)
ax_mock = plot_mock.return_value.subplot.return_value
ax_args = ax_mock.plot.call_args_list

assert_array_equal(pre_post_index, ax_args[0][0][0])
assert_array_equal(
pd.concat([pre_data.iloc[:, 0], post_data.iloc[:, 0]]), ax_args[0][0][1]
)
assert ax_args[0][0][2] == "k"
assert ax_args[0][1] == {"label": "y"}
assert_array_equal(pre_post_index[1:], ax_args[1][0][0])
assert_array_equal(inferences["complete_preds_means"].iloc[1:], ax_args[1][0][1])
assert ax_args[1][1] == {"color": "orangered", "ls": "dashed", "label": "Predicted"}

ax_mock.axvline.assert_called_with(pre_int_gap_period[1], c="gray", linestyle="--")

ax_args = ax_mock.fill_between.call_args_list[0]
assert_array_equal(ax_args[0][0], pre_post_index[1:])
assert_array_equal(ax_args[0][1], inferences["complete_preds_lower"].iloc[1:])
assert_array_equal(ax_args[0][2], inferences["complete_preds_upper"].iloc[1:])
assert ax_args[1] == {"color": (1.0, 0.4981, 0.0549), "alpha": 0.4}

ax_mock.grid.assert_called_with(True, color="gainsboro")
ax_mock.legend.assert_called()
# If `show == False` then `plt.show()` should not have been called
plot_mock.return_value.show.assert_not_called()

0 comments on commit a2b13d5

Please sign in to comment.