Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Latest commit

 

History

History
47 lines (34 loc) · 2.55 KB

tabular_forecasting.rst

File metadata and controls

47 lines (34 loc) · 2.55 KB

Tabular Forecasting

The Task

Tabular (or timeseries) forecasting is the task of using historical data to predict future trends in a time varying quantity such as: stock prices, temperature, etc. The ~flash.tabular.forecasting.model.TabularForecaster and ~flash.tabular.forecasting.data.TabularForecastingData enable timeseries forecasting in Flash using PyTorch Forecasting.


Example

Let's look at training the NBeats model on some synthetic data with seasonal changes. The data could represent many naturally occurring timeseries such as energy demand which fluctuates throughout the day but is also expected to change with the season. This example is a reimplementation of the NBeats tutorial from the PyTorch Forecasting docs in Flash. The NBeats model takes no additional inputs unlike other more complex models such as the Temporal Fusion Transformer.

Once we've created, we can create the ~flash.tabular.classification.data.TabularData from our DataFrame using the ~flash.tabular.forecasting.data.TabularForecastingData.from_data_frame method. To this method, we provide any configuration arguments that should be used when internally constructing the TimeSeriesDataSet.

Next, we create the ~flash.tabular.forecasting.model.TabularForecaster and train on the data. We then use the trained ~flash.tabular.forecasting.model.TabularForecaster for inference. Finally, we save the model. Here's the full example:

../../../flash_examples/tabular_forecasting.py

To learn how to view the available backbones / heads for this task, see backbones_heads.

Note

Read more about our integration with PyTorch Forecasting <pytorch_forecasting> to see how to use your Flash model with their built-in plotting capabilities.