-
Notifications
You must be signed in to change notification settings - Fork 445
Update README for Chronos-2 #324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
e1d5863 to
c8e03ce
Compare
README.md
Outdated
| timeseries_id = "DE" # Specific time series to visualize | ||
| history_length = 256 # The number of historical values to plot | ||
|
|
||
| ts_context = context_df.query(f"{id_column} == @timeseries_id").set_index(timestamp_column)[target] | ||
| ts_pred = pred_df.query(f"{id_column} == @timeseries_id and target_name == @target").set_index(timestamp_column)[ | ||
| ["0.1", "predictions", "0.9"] | ||
| ] | ||
| ts_ground_truth = test_df.query(f"{id_column} == @timeseries_id").set_index(timestamp_column)[target] | ||
|
|
||
| start_idx = max(0, len(ts_context) - history_length) | ||
| plot_cutoff = ts_context.index[start_idx] | ||
| ts_context = ts_context[ts_context.index >= plot_cutoff] | ||
| ts_ground_truth = ts_ground_truth[ts_ground_truth.index >= plot_cutoff] | ||
|
|
||
| fig = plt.figure(figsize=(12, 3)) | ||
| ax = fig.gca() | ||
| ts_context.plot(ax=ax, label=f"historical {target}", color="xkcd:azure") | ||
| ts_ground_truth.plot(ax=ax, label=f"future {target} (ground truth)", color="xkcd:grass green") | ||
| ts_pred["predictions"].plot(ax=ax, label="forecast", color="xkcd:violet") | ||
| ax.fill_between( | ||
| ts_pred.index, | ||
| ts_pred["0.1"], | ||
| ts_pred["0.9"], | ||
| alpha=0.7, | ||
| label="prediction interval", | ||
| color="xkcd:light lavender", | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be more concise:
import matplotlib.pyplot as plt
ts_context = context_df.set_index(timestamp_column)[target].tail(256)
ts_pred = pred_df.set_index(timestamp_column)
ts_ground_truth = test_df.set_index(timestamp_column)[target]
ts_context.plot(label=f"historical data", color="xkcd:azure", figsize=(12, 3))
ts_ground_truth.plot(label=f"future data (ground truth)", color="xkcd:grass green")
ts_pred["predictions"].plot(label="forecast", color="xkcd:violet")
plt.fill_between(
ts_pred.index,
ts_pred["0.1"],
ts_pred["0.9"],
alpha=0.7,
label="prediction interval",
color="xkcd:light lavender",
)
plt.legend()There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right. I was going for "more general" but I guess for README example this is good enough.
Co-authored-by: Oleksandr Shchur <oleks.shchur@gmail.com>
Co-authored-by: Oleksandr Shchur <oleks.shchur@gmail.com>
Co-authored-by: Oleksandr Shchur <oleks.shchur@gmail.com>
Co-authored-by: Oleksandr Shchur <oleks.shchur@gmail.com>
README.md
Outdated
| target = "target" # Column name containing the values to forecast | ||
| prediction_length = 24 # Number of steps to forecast ahead | ||
| id_column = "id" # Column identifying different time series | ||
| timestamp_column = "timestamp" # Column containing datetime information |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we remove these variables completely and make the example for Chronos-2 usage shorter?
import pandas as pd # requires: pip install pandas
from chronos import Chronos2Pipeline
pipeline = Chronos2Pipeline.from_pretrained("s3://autogluon/chronos-2", device_map="cuda")
# Load historical energy prices and past values of covariates
context_df = pd.read_parquet("https://autogluon.s3.amazonaws.com/datasets/timeseries/electricity_price/train.parquet")
# (Optional) Load future values of covariates
test_df = pd.read_parquet("https://autogluon.s3.amazonaws.com/datasets/timeseries/electricity_price/test.parquet")
future_df = test_df.drop(columns="target")
# Generate predictions with covariates
pred_df = pipeline.predict_df(
context_df,
future_df=future_df,
prediction_length = 24 # Number of steps to forecast
quantile_levels=[0.1, 0.5, 0.9], # Quantiles for probabilistic forecast
id_column = "id" # Column identifying different time series
timestamp_column = "timestamp" # Column with datetime information
target = "target" # Column(s) with time series values to predict
)dae1ba4 to
f233300
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Issue #, if available:
Description of changes:
TODOs:
Move pretraining related stuff to Chronos notebook.By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.