/
main.py
411 lines (353 loc) · 17.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
# Copyright WillianFuks
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Main class definition for running Causal Impact analysis.
"""
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import tensorflow_probability as tfp
import causalimpact.data as cidata
import causalimpact.inferences as inferrer
import causalimpact.model as cimodel
import causalimpact.plot as plotter
import causalimpact.summary as summarizer
from causalimpact.misc import maybe_unstandardize
class CausalImpact():
"""
Main class used to run the Causal Impact algorithm implemented by Google as
described in the offical
[paper](https://google.github.io/CausalImpact/CausalImpact.html).
The algorithm basically fits a structural state space model to observed data `y` and
uses Bayesian inferencing to find the posterior P(z|y) where `z` represents for the
chosen model parameters (such as level, trend, season, and so on).
In this package, the fitting method can be either 'Hamitonian Monte Carlo' or 'hmc'
for short (more accurate algorithm but slower) or 'Variational Inference' or 'vi'
(faster but less accurate), both available on Tensorflow Probability.
Args
----
data: Union[np.array, pd.DataFrame]
First column must contain the `y` value whose future values will be forecasted
while the remaining data contains the covariates `X` that are used in the
linear regression component of the model (supposing that there's a linear
regression otherwise `X` is not specified).
If `data` it's a pandas DataFrame, its index can be defined either as a
`RangeIndex`, `Index` or `DateTimeIndex`.
In case of the second, then a conversion to `DateTime` type is automatically
performed; in case of failure, the original index is kept untouched.
pre_period: Union[List[int], List[str], List[pd.Timestamp]]
A list of size two containing either `int`, `str` or `pd.Timestamp` values
that references the range from beginning to end to be used in the
pre-intervention data.
As an example, valid inputs are:
- [0, 30]
- ['20200101', '20200130']
- [pd.to_datetime('20200101'), pd.to_datetime('20200130')]
- [pd.Timestamp('20200101'), pd.Timestamp('20200130')]
The latter can be used only if the input `data` is a pandas DataFrame whose
index is based on datetime values.
post_period: Union[List[int], List[str], List[pd.Timestamp]]
The same as `pre_period` but references where the post-intervention
data begins and ends. This is the data that will be compared against the
counter-factual forecasts.
model: Optional[tfp.sts.StructuralTimeSeries]
If `None` then a default `tfp.sts.LocalLevel` model is internally built
otherwise use the input `model` for fitting and forecasting.
model_args: Dict[str, Any]
Sets general variables for building and running the state space model. Possible
values are:
standardize: bool
If `True`, standardizes data to have zero mean and unitary standard
deviation.
prior_level_sd: Optional[float]
Prior value for the local level standard deviation. If `None` then an
automatic optimization of the local level is performed. This is
recommended when there's uncertainty about what prior value is
appropriate for the data.
In general, if the covariates are expected to be good descriptors of the
observed response then this value can be low (such as the default of
0.01). In cases when the linear regression is not quite expected to fully
explain the observed data, the value 0.1 can be used.
fit_method: str
Which method to use for the Bayesian algorithm. Can be either 'vi'
(default) or 'hmc' (more precision but much slower).
nseasons: int
Specifies the duration of the period of the seasonal component; if input
data is specified in terms of days, then choosing nseasons=7 adds a weekly
seasonal effect.
season_duration: int
Specifies how many data points each value in season spans over. A good
example to understand this argument is to consider a hourly data as input.
For modeling a weekly season on this data, one can specify `nseasons=7` and
season_duration=24 which means each value that builds the season component
is repeated for 24 data points. Default value is 1 which means the season
component spans over just 1 point (this in practice doesn't change
anything). If this value is specified and bigger than 1 then `nseasons`
must be specified and bigger than 1 as well.
alpha: float
A float that ranges between 0 and 1 indicating the significance level that
will be used when statistically testing for signal presencen in the post-
intervention period.
Returns
-------
Causal Impact object with inferences, summary and plotting functionalities.
Examples
--------
Imput data can be a `numpy.array`:
```python
import numpy as np
data = np.random.rand(100, 2)
pre_period = [0, 69]
post_period = [70, 99]
ci = CausalImpact(data, pre_period, post_period)
print(ci.summary())
print(ci.summary('report'))
ci.plot()
```
Using pandas DataFrames:
```python
df = pd.DataFrame('tests/fixtures/arma_data.csv')
pre_period = [0, 69]
post_period = [70, 99]
ci = CausalImpact(df, pre_period, post_period)
print(ci.summary())
```
Using pandas DataFrames with pandas timestamps:
```python
df = pd.DataFrame('tests/fixtures/arma_data.csv')
df = df.set_index(pd.date_range(start='20200101', periods=len(data)))
pre_period = [pd.to_datetime('20200101'), pd.to_datetime('20200311')]
post_period = [pd.to_datetime('20200312'), pd.to_datetime('20200410')]
ci = CausalImpact(df, pre_period, post_period)
print(ci.summary())
```
Using a weekly seasonal component on daily data:
```python
df = pd.DataFrame('tests/fixtures/arma_data.csv')
df = df.set_index(pd.date_range(start='20200101', periods=len(data)))
pre_period = ['20200101', '20200311']
post_period = ['20200312', '20200410']
ci = CausalImpact(df, pre_period, post_period, model_args={'nseasons': 7})
print(ci.summary())
```
Using a weekly seasonal component on hourly data:
```python
df = pd.DataFrame('tests/fixtures/arma_data.csv')
df = df.set_index(pd.date_range(start='20200101', periods=len(data), freq='H'))
pre_period = ['20200101 00:00:00', '20200311 23:00:00']
post_period = ['20200312 00:00:00', '20200410 23:00:00']
ci = CausalImpact(df, pre_period, post_period, model_args={'nseasons': 7,
'season_duration': 24})
print(ci.summary())
```
Using a customized model:
```python
import tensorflow_probability as tfp
data = tfp.sts.regularize_series(data).astype('float32')
pre_period = ['20200101', '20200401']
post_period = ['20200402', '20200501']
obs_series = data.loc[:pre_period[1], 0]
local_linear = tfp.sts.LocalLinearTrend(observed_time_series=obs_series)
seasonal = tfp.sts.Seasonal(num_seasons=7, observed_time_series=obs_series)
model = tfp.sts.Sum([local_linear, seasonal], observed_time_series=obs_series)
ci = CausalImpact(data, pre_period, post_period, model=model,
model_args={'standardize': False})
print(ci.summary())
```
Notice that for custom models no assumptions are made about the input data used
to build the model. This can incur errors if `standardize` is set to True because
the model was built with the regular data and internally tfcausalimpact will
standardize it which removes the reference relatively to the model data. To avoid
that, all data processing must be held before calling causal impact. For instance:
```python
import tensorflow_probability as tfp
from causalimpact.misc import standardize
data = tfp.sts.regularize_series(data).astype('float32')
normed_data = standardize(data)[0]
pre_period = ['20200101', '20200401']
post_period = ['20200402', '20200501']
obs_series = normed_data.loc[:pre_period[1], 0]
local_linear = tfp.sts.LocalLinearTrend(observed_time_series=obs_series)
seasonal = tfp.sts.Seasonal(num_seasons=7, observed_time_series=obs_series)
model = tfp.sts.Sum([local_linear, seasonal], observed_time_series=obs_series)
ci = CausalImpact(data, pre_period, post_period, model=model,
model_args={'standardize': True})
print(ci.summary())
```
Custom models also requires that the data won't have any "holes" on its estimated
frequency otherwise it won't work as well. This can be a problem when working
with linear regression that requires the whole data to be already valid for running
inference. To avoid any issue, one way to handle that is to apply
`tfp.sts.regularize_series` to the input data and fill with zeros remaining
covariates that end up being null. For instance:
```python
from causalimpact.misc import standardize
pre_period = ['20200101', '20200311']
post_period = ['20200312', '20200409']
reg_data = tfp.sts.regularize_series(data)
normed_data = standardize(reg_data.astype(np.float32))[0]
obs_data = normed_data.loc[pre_period[0]: pre_period[1].iloc[:, 0]
design_matrix_data = normed_data.iloc[:, 1:].fillna(0).values.reshape(
-1, normed_data.shape[1] -1)
linear_level = tfp.sts.LocalLinearTrend(observed_time_series=obs_data)
linear_reg = tfp.sts.LinearRegression(design_matrix=design_matrix_data)
model = tfp.sts.Sum([linear_level, linear_reg], observed_time_series=obs_data)
ci = CausalImpact(data, pre_period, post_period, model=model)
```
"""
def __init__(
self,
data: Union[np.array, pd.DataFrame],
pre_period: Union[List[int], List[str], List[pd.Timestamp]],
post_period: Union[List[int], List[str], List[pd.Timestamp]],
model: Optional[tfp.sts.StructuralTimeSeries] = None,
model_args: Dict[str, Any] = {},
alpha: float = 0.05
):
processed_input = cidata.process_input_data(data, pre_period, post_period,
model, model_args, alpha)
self.data = data
self.processed_data_index = processed_input['data'].index
self.pre_period = processed_input['pre_period']
self.post_period = processed_input['post_period']
self.pre_data = processed_input['pre_data']
self.post_data = processed_input['post_data']
self.alpha = processed_input['alpha']
self.model_args = processed_input['model_args']
self.model = processed_input['model']
self.normed_pre_data = processed_input['normed_pre_data']
self.normed_post_data = processed_input['normed_post_data']
self.observed_time_series = processed_input['observed_time_series']
self.mu_sig = processed_input['mu_sig']
self._mask = processed_input['mask']
self._fit_model()
self._process_posterior_inferences()
self._summarize_inferences()
def plot(
self,
panels: List[str] = ['original', 'pointwise', 'cumulative'],
figsize: Tuple[int] = (10, 7),
show: bool = True
) -> None:
"""
Plots the graphic with results associated to Causal Impact.
Args
----
panels: List[str]
Which graphics to plot. 'original' plots the original data, forecasts means
and credible intervals related to the fitted model.
'pointwise' plots the point wise differences between observed data and
predictions. Finally, 'cumulative' is a cumulative summation over real
data and its forecasts.
figsize: Tuple[int]
Sets the width and height of the figure to plot.
show: bool
If `True` then plots the figure by running `plt.plot()`.
If `False` then nothing will be plotted which allows for accessing and
manipulating the figure and axis of the plot, i.e., the figure can be saved
and the styling can be modified. To get the axis, just run:
`import matplotlib.pyplot as plt; ax = plt.gca()` or the figure:
`fig = plt.gcf()`. Defaults to `True`.
"""
plotter.plot(self.inferences, self.pre_data, self.post_data[self._mask],
panels=panels, figsize=figsize, show=show)
def summary(self, output: str = 'summary', digits: int = 2) -> str:
"""
Builds and prints the summary report.
Args
----
output: str
Can be either "summary" or "report". The first is a simpler output just
informing general metrics such as expected absolute or relative effect.
digits: int
Defines the number of digits after the decimal point to round. For
`digits=2`, value 1.566 becomes 1.57.
Returns
-------
summary: str
Contains results of the causal impact analysis.
Raises
------
ValueError: If input `output` is not either 'summary' or 'report'.
If input `digits` is not of type integer.
"""
if not isinstance(digits, int):
raise ValueError(
f'Input value for digits must be integer. Received "{type(digits)}" '
'instead.'
)
result = summarizer.summary(self.summary_data, self.p_value, self.alpha,
output, digits)
return result
def _fit_model(self) -> None:
"""
Use observed data `Y` to find the posterior `P(Z|Y)` where `Z` represents the
structural components that were used for building the model (such as local level
factor or seasonal components).
"""
model_samples, model_kernel_results = cimodel.fit_model(
self.model,
self.observed_time_series,
self.model_args['fit_method'],
)
self.model_samples = model_samples
self.model_kernel_results = model_kernel_results
def _summarize_inferences(self) -> None:
"""
After processing predictions and forecasts, uses these values to build the
summary data required for reporting and plotting.
As the addition of the frequency step when processing input data can add `NaN`
values in data, a boolean mask identifying those potential holes is created and a
filter is applied for its removal.
Finishes by estimating the p-value for determining if the impact is statistically
significant or not.
"""
post_preds_means = self.inferences['post_preds_means']
post_data_sum = self.post_data.iloc[:, 0].sum()
niter = self.model_args['niter']
simulated_ys = maybe_unstandardize(
np.squeeze(self.posterior_dist.sample(niter).numpy()),
self.mu_sig
)[:, self._mask]
self.summary_data = inferrer.summarize_posterior_inferences(
post_preds_means,
self.post_data[self._mask],
simulated_ys,
self.alpha
)
self.p_value = inferrer.compute_p_value(simulated_ys, post_data_sum)
def _process_posterior_inferences(self) -> None:
"""
Run `inferrer` to process data forecasts and predictions. Results feeds the
summary table as well as the plotting functionalities.
"""
num_steps_forecast = len(self.post_data)
self.one_step_dist = cimodel.build_one_step_dist(self.model,
self.observed_time_series,
self.model_samples)
self.posterior_dist = cimodel.build_posterior_dist(self.model,
self.observed_time_series,
self.model_samples,
num_steps_forecast)
self.inferences = inferrer.compile_posterior_inferences(
self.processed_data_index,
self._mask,
self.pre_data,
self.post_data,
self.one_step_dist,
self.posterior_dist,
self.mu_sig,
self.alpha,
self.model_args['niter']
)