Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[timeseries] Speed up prediction for GluonTS models #2593

Merged
merged 8 commits into from
Dec 22, 2022

Conversation

shchur
Copy link
Collaborator

@shchur shchur commented Dec 21, 2022

Description of changes:

  • Significantly speed up prediction for GluonTS models. The two main sources are:
    • Use of pd.Series instead of TimeSeriesDataFrame inside SimpleGluonTSDataset to avoid copying the static features.
    • Speed up get_forecast_horizon_index_ts_dataframe with a more efficient groupby operation
  • Clean up redundant logic in AbstractGluonTSModel. predict.

Benchmarking results on the M5 dataset

(only training for 1 epoch / 1 batch, so the training time/performance are meaningless here)

With this PR

5K items / 7M rows

Training timeseries model DeepAR.
        -0.9107       = Validation score (-mean_wQuantileLoss)
        5.99    s     = Training runtime
        47.57   s     = Validation (prediction) runtime

30K items / 47M rows

Training timeseries model DeepAR.
        -0.7019       = Validation score (-mean_wQuantileLoss)
        11.92   s     = Training runtime
        298.31  s     = Validation (prediction) runtime

Current master branch
5K items / 7M rows

Training timeseries model DeepAR.
        -0.8331       = Validation score (-mean_wQuantileLoss)
        6.61    s     = Training runtime
        81.16   s     = Validation (prediction) runtime

30K items / 47M rows

Training timeseries model DeepAR.
        -0.8672       = Validation score (-mean_wQuantileLoss)
        15.64   s     = Training runtime
        1107.97 s     = Validation (prediction) runtime
Code to reproduce the results
import pandas as pd
from autogluon.timeseries import TimeSeriesDataFrame, TimeSeriesPredictor


prediction_length = 28
raw_data = pd.read_parquet("../m5/data/target.parquet")
static = pd.read_parquet("../m5/data/static.parquet")

raw_data["item_id"] = raw_data["item_id"].astype("str")
raw_data["timestamp"] = pd.to_datetime(raw_data["timestamp"])
static["item_id"] = static["item_id"].astype("str")
static.set_index("item_id", inplace=True)

print(f"Loaded dataset with {len(raw_data)} rows and {raw_data['item_id'].nunique()} items.")
df = TimeSeriesDataFrame(raw_data, static_features=static)

predictor = TimeSeriesPredictor(prediction_length=prediction_length, target="demand")
predictor.fit(
    train_data=df,
    hyperparameters={"DeepAR": {"epochs": 1, "num_batches_per_epoch": 1}},
)

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@github-actions
Copy link

Job PR-2593-4386fe1 is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-2593/4386fe1/index.html

@shchur shchur added this to the 0.6.2 Release milestone Dec 21, 2022
@canerturkmen canerturkmen added the module: timeseries related to the timeseries module label Dec 22, 2022
Copy link
Contributor

@canerturkmen canerturkmen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, and very impressive speedups!! Some minor comments.

Copy link
Contributor

@canerturkmen canerturkmen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks!

@github-actions
Copy link

Job PR-2593-6bb80b7 is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-2593/6bb80b7/index.html

@shchur shchur merged commit f6311cf into autogluon:master Dec 22, 2022
@shchur shchur deleted the faster-gluonts-pred branch December 22, 2022 17:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: timeseries related to the timeseries module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants