Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions src/chronos/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
if TYPE_CHECKING:
import datasets
import fev
import pandas as pd
from transformers import PreTrainedModel


from .utils import left_pad_and_stack_1D


Expand Down Expand Up @@ -53,6 +55,14 @@ def __init__(self, inner_model: "PreTrainedModel"):
# for easy access to the inner HF-style model
self.inner_model = inner_model

@property
def model_context_length(self) -> int:
raise NotImplementedError()

@property
def model_prediction_length(self) -> int:
raise NotImplementedError()

def _prepare_and_validate_context(self, context: Union[torch.Tensor, List[torch.Tensor]]):
if isinstance(context, list):
context = left_pad_and_stack_1D(context)
Expand Down Expand Up @@ -122,6 +132,106 @@ def predict_quantiles(
"""
raise NotImplementedError()

def predict_df(
self,
df: "pd.DataFrame",
*,
id_column: str = "item_id",
timestamp_column: str = "timestamp",
target: str = "target",
prediction_length: int | None = None,
quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
**predict_kwargs,
) -> "pd.DataFrame":
"""
Perform forecasting on time series data in a long-format pandas DataFrame.

Parameters
----------
df
Time series data in long format with an id column, a timestamp, and one target column.
Any other columns, if present, will be ignored
id_column
The name of the column which contains the unique time series identifiers, by default "item_id"
timestamp_column
The name of the column which contains timestamps, by default "timestamp"
All time series in the dataframe must have regular timestamps with the same frequency (no gaps)
target
The name of the column which contains the target variables to be forecasted, by default "target"
prediction_length
Number of steps to predict for each time series
quantile_levels
Quantile levels to compute
**predict_kwargs
Additional arguments passed to predict_quantiles

Returns
-------
The forecasts dataframe generated by the model with the following columns
- `id_column`: The time series ID
- `timestamp_column`: Future timestamps
- "target_name": The name of the target column
- "predictions": The point predictions generated by the model
- One column for predictions at each quantile level in `quantile_levels`
"""
try:
import pandas as pd

from .df_utils import convert_df_input_to_list_of_dicts_input
except ImportError:
raise ImportError("pandas is required for predict_df. Please install it with `pip install pandas`.")

if not isinstance(target, str):
raise ValueError(
f"Expected `target` to be str, but found {type(target)}. {self.__class__.__name__} only supports univariate forecasting."
)

if prediction_length is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, if there is any way at all of reducing the repeat in Chronos2 I think we should consider it.

prediction_length = self.model_prediction_length

inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
df=df,
future_df=None,
id_column=id_column,
timestamp_column=timestamp_column,
target_columns=[target],
prediction_length=prediction_length,
)

# NOTE: any covariates, if present, are ignored here
context = [torch.tensor(item["target"]).squeeze(0) for item in inputs] # squeeze the extra variate dim

# Generate forecasts
quantiles, mean = self.predict_quantiles(
inputs=context,
prediction_length=prediction_length,
quantile_levels=quantile_levels,
limit_prediction_length=False,
**predict_kwargs,
)

quantiles_np = quantiles.numpy() # [n_series, horizon, num_quantiles]
mean_np = mean.numpy() # [n_series, horizon]

results_dfs = []
for i, (series_id, future_ts) in enumerate(prediction_timestamps.items()):
q_pred = quantiles_np[i] # (horizon, num_quantiles)
point_pred = mean_np[i] # (horizon)

series_forecast_data = {id_column: series_id, timestamp_column: future_ts, "target_name": target}
series_forecast_data["predictions"] = point_pred
for q_idx, q_level in enumerate(quantile_levels):
series_forecast_data[str(q_level)] = q_pred[:, q_idx]

results_dfs.append(pd.DataFrame(series_forecast_data))

predictions_df = pd.concat(results_dfs, ignore_index=True)
predictions_df.set_index(id_column, inplace=True)
predictions_df = predictions_df.loc[original_order]
predictions_df.reset_index(inplace=True)

return predictions_df

def predict_fev(
self, task: "fev.Task", batch_size: int = 32, **kwargs
) -> tuple[list["datasets.DatasetDict"], float]:
Expand Down
8 changes: 8 additions & 0 deletions src/chronos/chronos.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,14 @@ def __init__(self, tokenizer, model):
self.tokenizer = tokenizer
self.model = model

@property
def model_context_length(self) -> int:
return self.model.config.context_length

@property
def model_prediction_length(self) -> int:
return self.model.config.prediction_length

def _prepare_and_validate_context(self, context: Union[torch.Tensor, List[torch.Tensor]]):
if isinstance(context, list):
context = left_pad_and_stack_1D(context)
Expand Down
Loading