-
Notifications
You must be signed in to change notification settings - Fork 84
/
api_forecasting.py
93 lines (76 loc) · 3.58 KB
/
api_forecasting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import logging
import random
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from fedot import Fedot
from fedot.core.data.data import InputData
from fedot.core.data.data_split import train_test_data_setup
from fedot.core.repository.dataset_types import DataTypesEnum
from fedot.core.repository.tasks import TsForecastingParams, Task, TaskTypesEnum
from fedot.core.utils import fedot_project_root
logging.raiseExceptions = False
_TS_EXAMPLES_DATA_PATH = fedot_project_root().joinpath('examples/data/ts')
TS_DATASETS = {
'm4_daily': _TS_EXAMPLES_DATA_PATH.joinpath('M4Daily.csv'),
'm4_monthly': _TS_EXAMPLES_DATA_PATH.joinpath('M4Monthly.csv'),
'm4_quarterly': _TS_EXAMPLES_DATA_PATH.joinpath('M4Quarterly.csv'),
'm4_weekly': _TS_EXAMPLES_DATA_PATH.joinpath('M4Weekly.csv'),
'm4_yearly': _TS_EXAMPLES_DATA_PATH.joinpath('M4Yearly.csv'),
'australia': _TS_EXAMPLES_DATA_PATH.joinpath('australia.csv'),
'beer': _TS_EXAMPLES_DATA_PATH.joinpath('beer.csv'),
'salaries': _TS_EXAMPLES_DATA_PATH.joinpath('salaries.csv'),
'stackoverflow': _TS_EXAMPLES_DATA_PATH.joinpath('stackoverflow.csv'),
'test_sea': fedot_project_root().joinpath('test', 'data', 'simple_sea_level.csv')
}
def get_ts_data(dataset='m4_monthly', horizon: int = 30, m4_id=None, validation_blocks=None):
time_series = pd.read_csv(TS_DATASETS[dataset])
task = Task(TaskTypesEnum.ts_forecasting,
TsForecastingParams(forecast_length=horizon))
if 'm4' in dataset:
if not m4_id:
label = random.choice(np.unique(time_series['label']))
else:
label = m4_id
print(label)
time_series = time_series[time_series['label'] == label]
idx = time_series['datetime'].values
else:
label = dataset
if dataset not in ['australia']:
idx = pd.to_datetime(time_series['idx'].values)
else:
# non datetime indexes
idx = time_series['idx'].values
time_series = time_series['value'].values
train_input = InputData(idx=idx,
features=time_series,
target=time_series,
task=task,
data_type=DataTypesEnum.ts)
train_data, test_data = train_test_data_setup(train_input, validation_blocks=validation_blocks)
return train_data, test_data, label
def run_ts_forecasting_example(dataset='australia', horizon: int = 30, timeout: float = None,
visualization=False, validation_blocks=2, with_tuning=True):
train_data, test_data, label = get_ts_data(dataset, horizon, validation_blocks=validation_blocks)
# init model for the time series forecasting
model = Fedot(problem='ts_forecasting',
task_params=Task(TaskTypesEnum.ts_forecasting,
TsForecastingParams(forecast_length=horizon)).task_params,
timeout=timeout,
n_jobs=-1,
metric='mae',
with_tuning=with_tuning)
model.fit(train_data)
pred_fedot = model.forecast(test_data)
if visualization:
model.current_pipeline.show()
plt.plot(train_data.idx, train_data.features, label='features')
plt.plot(test_data.idx, test_data.target, label='target')
plt.plot(test_data.idx, pred_fedot, label='fedot')
plt.grid()
plt.legend()
plt.show()
return pred_fedot
if __name__ == '__main__':
run_ts_forecasting_example(dataset='m4_monthly', horizon=14, timeout=2., validation_blocks=None, visualization=True)