diff --git a/notebooks/deploy-chronos-bolt-to-amazon-sagemaker.ipynb b/notebooks/deploy-chronos-to-amazon-sagemaker.ipynb
similarity index 65%
rename from notebooks/deploy-chronos-bolt-to-amazon-sagemaker.ipynb
rename to notebooks/deploy-chronos-to-amazon-sagemaker.ipynb
index 486f0ae0..c7e3f3c0 100644
--- a/notebooks/deploy-chronos-bolt-to-amazon-sagemaker.ipynb
+++ b/notebooks/deploy-chronos-to-amazon-sagemaker.ipynb
@@ -2,16 +2,18 @@
"cells": [
{
"cell_type": "markdown",
+ "id": "2da16609",
"metadata": {},
"source": [
- "# SageMaker JumpStart - Deploy Chronos endpoints to AWS for production use"
+ "# SageMaker JumpStart - Deploy Chronos-2 endpoints to AWS for production use"
]
},
{
"cell_type": "markdown",
+ "id": "9f776b6f",
"metadata": {},
"source": [
- "In this demo notebook, we will walk through the process of using the **SageMaker Python SDK** to deploy a **Chronos** model to a cloud endpoint on AWS. To simplify deployment, we will leverage **SageMaker JumpStart**.\n",
+ "In this demo notebook, we will walk through the process of using the **SageMaker Python SDK** to deploy a **Chronos-2** model to a cloud endpoint on AWS. To simplify deployment, we will leverage **SageMaker JumpStart**.\n",
"\n",
"### Why Deploy to an Endpoint?\n",
"So far, we’ve seen how to run models locally, which is useful for experimentation. However, in a production setting, a forecasting model is typically just one component of a larger system. Running models locally doesn’t scale well and lacks the reliability needed for real-world applications.\n",
@@ -21,6 +23,40 @@
},
{
"cell_type": "markdown",
+ "id": "3c77da85",
+ "metadata": {},
+ "source": [
+ "
\n",
+ "
⚠️ Looking for Chronos-Bolt or original Chronos?\n",
+ "This notebook covers
Chronos-2, our latest and recommended model. For documentation on older models (Chronos-Bolt and original Chronos), see the
legacy deployment walkthrough.\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "59912f1f",
+ "metadata": {},
+ "source": [
+ "### Chronos-2 vs. Previous Models\n",
+ "\n",
+ "**Chronos-2** is a foundation model for time series forecasting that builds on [Chronos](https://arxiv.org/abs/2403.07815) and [Chronos-Bolt](https://aws.amazon.com/blogs/machine-learning/fast-and-accurate-zero-shot-forecasting-with-chronos-bolt-and-autogluon/). It offers significant improvements in capabilities, better accuracy, and can handle diverse forecasting scenarios not supported by earlier models.\n",
+ "\n",
+ "| Capability | Chronos-2 | Chronos-Bolt | Chronos |\n",
+ "|------------|-----------|--------------|----------|\n",
+ "| Univariate Forecasting | ✅ | ✅ | ✅ |\n",
+ "| Cross-learning across items | ✅ | ❌ | ❌ |\n",
+ "| Multivariate Forecasting | ✅ | ❌ | ❌ |\n",
+ "| Past-only (real/categorical) covariates | ✅ | ❌ | ❌ |\n",
+ "| Known future (real/categorical) covariates | ✅ | 🧩 | ❌ |\n",
+ "| Max. Context Length | 8192 | 2048 | 512 |\n",
+ "| Max. Prediction Length | 1024 | 64 | 64 |\n",
+ "\n",
+ "🧩 Chronos-Bolt does not natively support future covariates, but they can be combined with external covariate regressors (see [AutoGluon tutorial](https://auto.gluon.ai/stable/tutorials/timeseries/forecasting-chronos.html#incorporating-the-covariates)). This only models per-timestep effects, not effects across time. In contrast, Chronos-2 supports all covariate types natively."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a0583a66",
"metadata": {},
"source": [
"## Deploy the model"
@@ -28,6 +64,7 @@
},
{
"cell_type": "markdown",
+ "id": "690d9093",
"metadata": {},
"source": [
"First, update the SageMaker SDK to access the latest models:"
@@ -35,7 +72,8 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
+ "id": "a4ed0fb5",
"metadata": {},
"outputs": [],
"source": [
@@ -44,36 +82,38 @@
},
{
"cell_type": "markdown",
+ "id": "a07054a7",
"metadata": {},
"source": [
"We create a `JumpStartModel` with the necessary configuration based on the model ID. The key parameters are:\n",
- "- `model_id`: Specifies the model to use. Here, we choose the [Chronos-Bolt (Base)](https://huggingface.co/amazon/chronos-bolt-base) model. Currently, the following model IDs are supported:\n",
- " - `autogluon-forecasting-chronos-bolt-base` - [Chronos-Bolt (Base)](https://huggingface.co/amazon/chronos-bolt-base).\n",
- " - `autogluon-forecasting-chronos-bolt-small` - [Chronos-Bolt (Small)](https://huggingface.co/amazon/chronos-bolt-small).\n",
- " - [Original Chronos models](https://huggingface.co/amazon/chronos-t5-small) in sizes `small`, `base` and `large` can be accessed, e.g., as `autogluon-forecasting-chronos-t5-small`. Note that these models require a GPU to run, are much slower and don't support covariates. Therefore, for most practical purposes we recommend using Chronos-Bolt models instead.\n",
- "- `instance_type`: Defines the AWS instance for serving the endpoint. We use `ml.c5.2xlarge` to run the model on CPU. To use a GPU, select an instance like `ml.g5.2xlarge`, or choose other CPU options such as `ml.m5.xlarge` or `ml.m5.4xlarge`. You can check the pricing for different SageMaker instance types for real-time inference [here](https://aws.amazon.com/sagemaker-ai/pricing/).\n",
+ "- `model_id`: Specifies the model to use. We use `pytorch-forecasting-chronos-2` for the [Chronos-2](https://github.com/amazon-science/chronos-forecasting) model.\n",
+ "- `instance_type`: Defines the AWS instance for serving the endpoint. Chronos-2 currently requires a **GPU instance** from the `ml.g5`, `ml.g6`, `ml.g6e`, or `ml.g4dn` families with a single GPU. The model does not benefit from multi-GPU instances. **CPU support is coming soon**.\n",
+ "\n",
+ " You can check the pricing for different SageMaker instance types for real-time inference [here](https://aws.amazon.com/sagemaker-ai/pricing/).\n",
"\n",
"The `JumpStartModel` will automatically set the necessary attributes such as `image_uri` based on the chosen `model_id` and `instance_type`."
]
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
+ "id": "ffbae4f0",
"metadata": {},
"outputs": [],
"source": [
"from sagemaker.jumpstart.model import JumpStartModel\n",
"\n",
- "model_id = \"autogluon-forecasting-chronos-bolt-base\"\n",
- "\n",
"model = JumpStartModel(\n",
- " model_id=model_id,\n",
- " instance_type=\"ml.c5.2xlarge\",\n",
+ " model_id=\"pytorch-forecasting-chronos-2\",\n",
+ " instance_type=\"ml.g5.2xlarge\",\n",
+ " # You might need to provide the SageMaker execution role to ensure necessary AWS resources are accessible\n",
+ " # role=\"arn:aws:iam::123456789012:role/service-role/AmazonSageMaker-ExecutionRole-XXXXXXXXXXXXXXX\",\n",
")"
]
},
{
"cell_type": "markdown",
+ "id": "eb864ee1",
"metadata": {},
"source": [
"Next, we deploy the model and create an endpoint. Deployment typically takes a few minutes, as SageMaker provisions the instance, loads the model, and sets up the endpoint for inference.\n"
@@ -82,6 +122,7 @@
{
"cell_type": "code",
"execution_count": null,
+ "id": "7fd0068b",
"metadata": {},
"outputs": [],
"source": [
@@ -90,44 +131,41 @@
},
{
"cell_type": "markdown",
+ "id": "f4dd66e7",
"metadata": {},
"source": [
- "> **Note:** Once the endpoint is deployed, it remains active and incurs charges on your AWS account until it is deleted. The cost depends on factors such as the instance type, the region where the endpoint is hosted, and the duration it remains running. To avoid unnecessary charges, make sure to delete the endpoint when it is no longer needed. For detailed pricing information, refer to the [SageMaker AI pricing page](https://aws.amazon.com/sagemaker-ai/pricing/).\n",
- "\n",
- "\n",
- "If the previous step results in an error, you may need to update the model configuration. For example, specifying a `role` when creating the `JumpStartModel` ensures the necessary AWS resources are accessible."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# model = JumpStartModel(role=\"your-sagemaker-execution-role\", model_id=model_id, instance_type=\"ml.c5.2xlarge\")"
+ "> **Note:** Once the endpoint is deployed, it remains active and incurs charges on your AWS account until it is deleted. The cost depends on factors such as the instance type, the region where the endpoint is hosted, and the duration it remains running. To avoid unnecessary charges, make sure to delete the endpoint when it is no longer needed. For detailed pricing information, refer to the [SageMaker AI pricing page](https://aws.amazon.com/sagemaker-ai/pricing/)."
]
},
{
"cell_type": "markdown",
+ "id": "48ce52ef",
"metadata": {},
"source": [
- "Alternatively, you can create a predictor for an existing endpoint."
+ "Alternatively, you can connect to an existing endpoint."
]
},
{
"cell_type": "code",
"execution_count": null,
+ "id": "a09367fc",
"metadata": {},
"outputs": [],
"source": [
- "# from sagemaker.predictor import retrieve_default\n",
+ "# from sagemaker.predictor import Predictor\n",
+ "# from sagemaker.serializers import JSONSerializer\n",
+ "# from sagemaker.deserializers import JSONDeserializer\n",
"\n",
- "# endpoint_name = \"NAME-OF-EXISTING-ENDPOINT\"\n",
- "# predictor = retrieve_default(endpoint_name)"
+ "# predictor = Predictor(\n",
+ "# \"NAME_OF_EXISTING_ENDPOINT\",\n",
+ "# serializer=JSONSerializer(),\n",
+ "# deserializer=JSONDeserializer(),\n",
+ "# )"
]
},
{
"cell_type": "markdown",
+ "id": "f3def973",
"metadata": {},
"source": [
"## Querying the endpoint"
@@ -135,6 +173,7 @@
},
{
"cell_type": "markdown",
+ "id": "6fbe39f9",
"metadata": {},
"source": [
"We can now invoke the endpoint to make a forecast. We send a **payload** to the endpoint, which includes historical time series values and configuration parameters, such as the prediction length. The endpoint processes this input and returns a **response** containing the forecasted values based on the provided data."
@@ -143,7 +182,10 @@
{
"cell_type": "code",
"execution_count": 2,
- "metadata": {},
+ "id": "1ae7b33e",
+ "metadata": {
+ "lines_to_next_cell": 1
+ },
"outputs": [],
"source": [
"# Define a utility function to print the response in a pretty format\n",
@@ -166,20 +208,28 @@
" return pformat(nested_round(data), width=150, sort_dicts=False)"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "cb1629c7",
+ "metadata": {},
+ "source": [
+ "### Univariate forecasting"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 3,
- "id": "07605824",
+ "id": "320a9c49",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "{'predictions': [{'mean': [-1.58, 0.52, 1.88, 1.39, -1.03, -3.34, -2.67, -0.64, 0.96, 1.59],\n",
- " '0.1': [-4.17, -2.71, -1.7, -2.35, -4.79, -6.98, -6.59, -4.87, -3.45, -2.89],\n",
- " '0.5': [-1.58, 0.52, 1.88, 1.39, -1.03, -3.34, -2.67, -0.64, 0.96, 1.59],\n",
- " '0.9': [1.47, 4.47, 6.27, 5.98, 3.5, 1.11, 2.06, 4.47, 6.41, 7.17]}]}\n"
+ "{'predictions': [{'mean': [-0.36, 4.03, 5.31, 2.44, -2.47, -5.09, -4.31, 0.07, 4.41, 5.16],\n",
+ " '0.1': [-1.69, 2.84, 4.0, 0.97, -3.77, -6.19, -5.34, -1.77, 2.55, 3.61],\n",
+ " '0.5': [-0.36, 4.03, 5.31, 2.44, -2.47, -5.09, -4.31, 0.07, 4.41, 5.16],\n",
+ " '0.9': [1.03, 5.0, 6.31, 3.81, -0.85, -3.89, -2.89, 1.84, 5.59, 6.44]}]}\n"
]
}
],
@@ -198,6 +248,7 @@
},
{
"cell_type": "markdown",
+ "id": "0f37d392",
"metadata": {},
"source": [
"A payload may also contain **multiple time series**, potentially including `start` and `item_id` fields."
@@ -206,23 +257,23 @@
{
"cell_type": "code",
"execution_count": 4,
- "id": "d476c397",
+ "id": "14c62c74",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "{'predictions': [{'mean': [1.41, 1.5, 1.49, 1.45, 1.51],\n",
- " '0.1': [0.12, -0.08, -0.25, -0.41, -0.45],\n",
- " '0.5': [1.41, 1.5, 1.49, 1.45, 1.51],\n",
- " '0.9': [3.29, 3.82, 4.09, 4.3, 4.56],\n",
+ "{'predictions': [{'mean': [1.7, 1.95, 1.66, 1.55, 1.84],\n",
+ " '0.1': [0.28, 0.32, -0.08, -0.35, -0.18],\n",
+ " '0.5': [1.7, 1.95, 1.66, 1.55, 1.84],\n",
+ " '0.9': [3.09, 3.77, 3.62, 3.58, 4.22],\n",
" 'item_id': 'product_A',\n",
" 'start': '2024-01-01T10:00:00'},\n",
- " {'mean': [-1.22, -1.3, -1.3, -1.14, -1.13],\n",
- " '0.1': [-4.51, -5.48, -6.12, -6.5, -7.1],\n",
- " '0.5': [-1.22, -1.3, -1.3, -1.14, -1.13],\n",
- " '0.9': [2.84, 4.02, 4.92, 5.99, 6.79],\n",
+ " {'mean': [-1.21, -1.4, -1.27, -1.34, -1.27],\n",
+ " '0.1': [-4.19, -5.84, -6.38, -7.53, -8.0],\n",
+ " '0.5': [-1.21, -1.4, -1.27, -1.34, -1.27],\n",
+ " '0.9': [2.02, 2.92, 3.55, 4.62, 5.66],\n",
" 'item_id': 'product_B',\n",
" 'start': '2024-02-02T10:00:00'}]}\n"
]
@@ -255,22 +306,31 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "id": "6ae41cdc",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
"source": [
- "Chronos-Bolt models also support forecasting with covariates (a.k.a. exogenous features or related time series). These can be provided using the `past_covariates` and `future_covariates` keys."
+ "### Forecasting with covariates\n",
+ "\n",
+ "Chronos-2 models also support forecasting with **covariates** (a.k.a. exogenous features or related time series). These can be provided using the `past_covariates` and `future_covariates` keys.\n",
+ "\n",
+ "**Note:** If you only provide `past_covariates` without matching keys in `future_covariates`, the model will treat them as past-only covariates (features that are only available historically but not in the future).\n",
+ "If future values of covariates are available, it is recommended to provide them in `future_covariates` as this typically results in more accurate forecasts."
]
},
{
"cell_type": "code",
"execution_count": 5,
+ "id": "e57f1541",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "{'predictions': [{'mean': [1.41, 1.5, 1.49], '0.1': [0.12, -0.08, -0.25], '0.5': [1.41, 1.5, 1.49], '0.9': [3.29, 3.82, 4.09]},\n",
- " {'mean': [-1.22, -1.3, -1.3], '0.1': [-4.51, -5.48, -6.12], '0.5': [-1.22, -1.3, -1.3], '0.9': [2.84, 4.02, 4.92]}]}\n"
+ "{'predictions': [{'mean': [1.73, 2.09, 1.73], '0.1': [0.36, 0.6, 0.17], '0.5': [1.73, 2.09, 1.73], '0.9': [3.11, 3.8, 3.52]},\n",
+ " {'mean': [-0.61, -0.41, -1.43], '0.1': [-4.16, -5.59, -7.53], '0.5': [-0.61, -0.41, -1.43], '0.9': [3.12, 4.56, 3.91]}]}\n"
]
}
],
@@ -282,7 +342,10 @@
" # past_covariates must have the same length as \"target\"\n",
" \"past_covariates\": {\n",
" \"feat_1\": [3.0, 6.0, 9.0, 6.0, 1.5, 6.0, 9.0, 6.0, 3.0],\n",
+ " # Categorical covariates should be provided as strings\n",
" \"feat_2\": [\"A\", \"B\", \"B\", \"B\", \"A\", \"A\", \"A\", \"A\", \"B\"],\n",
+ " # feat_3 is a past-only covariate (not present in future_covariates)\n",
+ " \"feat_3\": [10.0, 20.0, 30.0, 20.0, 5.0, 20.0, 30.0, 20.0, 10.0],\n",
" },\n",
" # future_covariates must have length equal to \"prediction_length\"\n",
" \"future_covariates\": {\n",
@@ -295,6 +358,7 @@
" \"past_covariates\": {\n",
" \"feat_1\": [0.6, 1.2, 1.8, 1.2, 0.3, 1.2, 1.8],\n",
" \"feat_2\": [\"A\", \"B\", \"B\", \"B\", \"A\", \"A\", \"A\"],\n",
+ " \"feat_3\": [5.4, 3.0, 3.0, 2.0, 1.5, 2.0, -1.0],\n",
" },\n",
" \"future_covariates\": {\n",
" \"feat_1\": [1.2, 0.3, 4.4],\n",
@@ -313,37 +377,88 @@
},
{
"cell_type": "markdown",
+ "id": "76c88a22",
+ "metadata": {},
+ "source": [
+ "### Multivariate forecasting\n",
+ "\n",
+ "Chronos-2 also supports **multivariate forecasting**, where multiple related time series are forecasted jointly. For multivariate forecasting, provide the target as a list of lists, where each inner list represents one dimension of the multivariate series."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "b73609be",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'predictions': [{'mean': [[3.66, 3.55, 3.5, 3.42], [2.0, 2.05, 2.19, 2.23], [3.33, 3.27, 3.25, 3.22]],\n",
+ " '0.1': [[1.98, 1.52, 1.17, 0.88], [0.84, 0.18, 0.0, -0.25], [2.5, 2.27, 2.08, 1.94]],\n",
+ " '0.5': [[3.66, 3.55, 3.5, 3.42], [2.0, 2.05, 2.19, 2.23], [3.33, 3.27, 3.25, 3.22]],\n",
+ " '0.9': [[5.75, 6.25, 6.59, 7.0], [3.8, 4.47, 4.88, 5.31], [4.38, 4.62, 4.78, 5.0]]}]}\n"
+ ]
+ }
+ ],
+ "source": [
+ "payload = {\n",
+ " \"inputs\": [\n",
+ " {\n",
+ " # For multivariate forecasting, target is a list of lists\n",
+ " # Each inner list represents one dimension with the same length\n",
+ " # np.array(target) would have shape [num_dimensions, length]\n",
+ " \"target\": [\n",
+ " [1.0, 2.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0], # Dimension 1\n",
+ " [5.0, 4.0, 3.0, 4.0, 5.0, 4.0, 3.0, 2.0], # Dimension 2\n",
+ " [2.0, 2.5, 3.0, 2.5, 2.0, 2.5, 3.0, 3.5], # Dimension 3\n",
+ " ],\n",
+ " },\n",
+ " ],\n",
+ " \"parameters\": {\n",
+ " \"prediction_length\": 4,\n",
+ " \"quantile_levels\": [0.1, 0.5, 0.9],\n",
+ " },\n",
+ "}\n",
+ "response = predictor.predict(payload)\n",
+ "print(pretty_format(response))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9479e7a3",
"metadata": {},
"source": [
"## Endpoint API\n",
"So far, we have explored several examples of querying the endpoint with different payload structures. Below is a comprehensive API specification detailing all supported parameters, their meanings, and how they affect the model’s predictions.\n",
"\n",
"* **inputs** (required): List with at most 1000 time series that need to be forecasted. Each time series is represented by a dictionary with the following keys:\n",
- " * **target** (required): List of observed numeric time series values. \n",
+ " * **target** (required): Observed time series values.\n",
+ " - For univariate forecasting: List of numeric values.\n",
+ " - For multivariate forecasting: List of lists, where each inner list represents one dimension. All dimensions must have the same length. If converted to a numpy array via `np.array(target)`, the shape would be `[num_dimensions, length]`.\n",
" - It is recommended that each time series contains at least 30 observations.\n",
" - If any time series contains fewer than 5 observations, an error will be raised.\n",
- " * **item_id**: String that uniquely identifies each time series. \n",
+ " * **item_id**: String that uniquely identifies each time series.\n",
" - If provided, the ID must be unique for each time series.\n",
" - If provided, then the endpoint response will also include the **item_id** field for each forecast.\n",
- " * **start**: Timestamp of the first time series observation in ISO format (`YYYY-MM-DD` or `YYYY-MM-DDThh:mm:ss`). \n",
+ " * **start**: Timestamp of the first time series observation in ISO format (`YYYY-MM-DD` or `YYYY-MM-DDThh:mm:ss`).\n",
" - If **start** field is provided, then **freq** must also be provided as part of **parameters**.\n",
" - If provided, then the endpoint response will also include the **start** field indicating the first timestamp of each forecast.\n",
" * **past_covariates**: Dictionary containing the past values of the covariates for this time series.\n",
- " - If **past_covariates** field is provided, then **future_covariates** must be provided as well with the same keys.\n",
" - Each key in **past_covariates** correspond to the name of the covariate. Each value must be an array consisting of all-numeric or all-string values, with the length equal to the length of the **target**.\n",
+ " - Covariates that appear only in **past_covariates** (and not in **future_covariates**) are treated as past-only covariates.\n",
" * **future_covariates**: Dictionary containing the future values of the covariates for this time series (values during the forecast horizon).\n",
- " - If **future_covariates** field is provided, then **past_covariates** must be provided as well with the same keys.\n",
" - Each key in **future_covariates** correspond to the name of the covariate. Each value must be an array consisting of all-numeric or all-string values, with the length equal to **prediction_length**.\n",
- " - If both **past_covariates** and **future_covariates** are provided, a regression model specified by **covariate_model** will be used to incorporate the covariate information into the forecast.\n",
+ " - Covariates that appear in both **past_covariates** and **future_covariates** are treated as known future covariates.\n",
"* **parameters**: Optional parameters to configure the model.\n",
- " * **prediction_length**: Integer corresponding to the number of future time series values that need to be predicted. Defaults to `1`.\n",
- " - Recommended to keep prediction_length <= 64 since larger values will result in inaccurate quantile forecasts. Values above 1000 will raise an error.\n",
- " * **quantile_levels**: List of floats in range (0, 1) specifying which quantiles should should be included in the probabilistic forecast. Defaults to `[0.1, 0.5, 0.9]`. \n",
- " - Note that Chronos-Bolt cannot produce quantiles outside the [0.1, 0.9] range (predictions outside the range will be clipped).\n",
- " * **freq**: Frequency of the time series observations in [pandas-compatible format](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases). For example, `1h` for hourly data or `2W` for bi-weekly data. \n",
+ " * **prediction_length**: Integer corresponding to the number of future time series values that need to be predicted. Defaults to `1`. Values up to `1024` are supported.\n",
+ " * **quantile_levels**: List of floats in range (0, 1) specifying which quantiles should should be included in the probabilistic forecast. Defaults to `[0.1, 0.5, 0.9]`.\n",
+ " - Chronos-2 natively supports quantile levels in range `[0.01, 0.99]`. Predictions outside the range will be clipped.\n",
+ " * **freq**: Frequency of the time series observations in [pandas-compatible format](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases). For example, `1h` for hourly data or `2W` for bi-weekly data.\n",
" - If **freq** is provided, then **start** must also be provided for each time series in **inputs**.\n",
" * **batch_size**: Number of time series processed in parallel by the model. Larger values speed up inference but may lead to out of memory errors. Defaults to `256`.\n",
- " * **covariate_model**: Name of the tabular regression model applied to the covariates. Possible options: `GBM` (LightGBM), `LR` (linear regression), `RF` (random forest), `CAT` (CatBoost), `XGB` (XGBoost). Defaults to `GBM`.\n",
+ " * **predict_batches_jointly**: If `True`, the model will apply group attention to all items in the batch, instead of processing each item separately (described as \"full cross-learning mode\" in the [technical report](https://www.arxiv.org/abs/2510.15821)). This may produce more accurate forecasts at the cost of lower inference speed. Defaults to `False`.\n",
"\n",
"All keys not marked with (required) are optional.\n",
"\n",
@@ -352,6 +467,7 @@
},
{
"cell_type": "markdown",
+ "id": "5ee1e161",
"metadata": {},
"source": [
"## Working with long-format data frames"
@@ -359,6 +475,7 @@
},
{
"cell_type": "markdown",
+ "id": "a744884c",
"metadata": {},
"source": [
"The endpoint communicates using JSON format for both input and output. However, in practice, time series data is often stored in a **long-format data frame** (where each row represents a timestamp for a specific item).\n",
@@ -374,7 +491,8 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 7,
+ "id": "6ecba0ca",
"metadata": {},
"outputs": [
{
@@ -472,7 +590,7 @@
"4 661.0 "
]
},
- "execution_count": 6,
+ "execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@@ -489,6 +607,7 @@
},
{
"cell_type": "markdown",
+ "id": "4288470c",
"metadata": {},
"source": [
"We split the data into two parts:\n",
@@ -498,7 +617,8 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 8,
+ "id": "d95eb6d7",
"metadata": {},
"outputs": [],
"source": [
@@ -512,7 +632,8 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 9,
+ "id": "c2482ffd",
"metadata": {},
"outputs": [
{
@@ -610,7 +731,7 @@
"4 661.0 "
]
},
- "execution_count": 8,
+ "execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@@ -621,8 +742,11 @@
},
{
"cell_type": "code",
- "execution_count": 9,
- "metadata": {},
+ "execution_count": 10,
+ "id": "c47905f8",
+ "metadata": {
+ "lines_to_next_cell": 1
+ },
"outputs": [
{
"data": {
@@ -706,7 +830,7 @@
"27 1062_101 2018-07-09 1.000000 0.0 0.0"
]
},
- "execution_count": 9,
+ "execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@@ -717,6 +841,7 @@
},
{
"cell_type": "markdown",
+ "id": "e9f54de7",
"metadata": {},
"source": [
"We can now convert this data into a JSON payload."
@@ -724,8 +849,11 @@
},
{
"cell_type": "code",
- "execution_count": 10,
- "metadata": {},
+ "execution_count": 11,
+ "id": "55bbad68",
+ "metadata": {
+ "lines_to_next_cell": 1
+ },
"outputs": [],
"source": [
"def convert_df_to_payload(\n",
@@ -733,59 +861,67 @@
" future_df=None,\n",
" prediction_length=1,\n",
" freq=\"D\",\n",
- " target_col=\"target\",\n",
- " id_col=\"item_id\",\n",
- " timestamp_col=\"timestamp\",\n",
+ " target=\"target\",\n",
+ " id_column=\"item_id\",\n",
+ " timestamp_column=\"timestamp\",\n",
"):\n",
" \"\"\"\n",
" Converts past and future DataFrames into JSON payload format for the Chronos endpoint.\n",
"\n",
" Args:\n",
- " past_df (pd.DataFrame): Historical data with `target_col`, `timestamp_col`, and `id_col`.\n",
- " future_df (pd.DataFrame, optional): Future covariates with `timestamp_col` and `id_col`.\n",
+ " past_df (pd.DataFrame): Historical data with `target`, `timestamp_column`, and `id_column`.\n",
+ " future_df (pd.DataFrame, optional): Future covariates with `timestamp_column` and `id_column`.\n",
+ " Covariates in past_df but not in future_df are treated as past-only covariates.\n",
" prediction_length (int): Number of future time steps to predict.\n",
" freq (str): Pandas-compatible frequency of the time series.\n",
- " target_col (str): Column name for target values.\n",
- " id_col (str): Column name for item IDs.\n",
- " timestamp_col (str): Column name for timestamps.\n",
+ " target (str or list[str]): Column name(s) for target values.\n",
+ " Use a string for univariate forecasting or a list of strings for multivariate forecasting.\n",
+ " id_column (str): Column name for item IDs.\n",
+ " timestamp_column (str): Column name for timestamps.\n",
"\n",
" Returns:\n",
" dict: JSON payload formatted for the Chronos endpoint.\n",
" \"\"\"\n",
- " past_df = past_df.sort_values([id_col, timestamp_col])\n",
+ " past_df = past_df.sort_values([id_column, timestamp_column])\n",
" if future_df is not None:\n",
- " future_df = future_df.sort_values([id_col, timestamp_col])\n",
+ " future_df = future_df.sort_values([id_column, timestamp_column])\n",
"\n",
- " covariate_cols = list(past_df.columns.drop([target_col, id_col, timestamp_col]))\n",
- " if covariate_cols and (future_df is None or not set(covariate_cols).issubset(future_df.columns)):\n",
- " raise ValueError(f\"If past_df contains covariates {covariate_cols}, they should also be present in future_df\")\n",
+ " target_cols = [target] if isinstance(target, str) else target\n",
+ " past_covariate_cols = list(past_df.columns.drop([*target_cols, id_column, timestamp_column]))\n",
+ " future_covariate_cols = [] if future_df is None else [col for col in past_covariate_cols if col in future_df.columns]\n",
"\n",
" inputs = []\n",
- " for item_id, past_group in past_df.groupby(id_col):\n",
- " target_values = past_group[target_col].tolist()\n",
+ " for item_id, past_group in past_df.groupby(id_column):\n",
+ " if len(target_cols) > 1:\n",
+ " target_values = [past_group[col].tolist() for col in target_cols]\n",
+ " series_length = len(target_values[0])\n",
+ " else:\n",
+ " target_values = past_group[target_cols[0]].tolist()\n",
+ " series_length = len(target_values)\n",
"\n",
- " if len(target_values) < 5:\n",
+ " if series_length < 5:\n",
" raise ValueError(f\"Time series '{item_id}' has fewer than 5 observations.\")\n",
"\n",
" series_dict = {\n",
" \"target\": target_values,\n",
" \"item_id\": str(item_id),\n",
- " \"start\": past_group[timestamp_col].iloc[0].isoformat(),\n",
+ " \"start\": past_group[timestamp_column].iloc[0].isoformat(),\n",
" }\n",
"\n",
- " if covariate_cols:\n",
- " series_dict[\"past_covariates\"] = past_group[covariate_cols].to_dict(orient=\"list\")\n",
- " future_group = future_df[future_df[id_col] == item_id]\n",
+ " if past_covariate_cols:\n",
+ " series_dict[\"past_covariates\"] = past_group[past_covariate_cols].to_dict(orient=\"list\")\n",
+ "\n",
+ " if future_covariate_cols:\n",
+ " future_group = future_df[future_df[id_column] == item_id]\n",
" if len(future_group) != prediction_length:\n",
" raise ValueError(\n",
" f\"future_df must contain exactly {prediction_length=} values for each item_id from past_df \"\n",
" f\"(got {len(future_group)=}) for {item_id=}\"\n",
" )\n",
- " series_dict[\"future_covariates\"] = future_group[covariate_cols].to_dict(orient=\"list\")\n",
+ " series_dict[\"future_covariates\"] = future_group[future_covariate_cols].to_dict(orient=\"list\")\n",
"\n",
" inputs.append(series_dict)\n",
"\n",
- "\n",
" return {\n",
" \"inputs\": inputs,\n",
" \"parameters\": {\"prediction_length\": prediction_length, \"freq\": freq},\n",
@@ -794,7 +930,8 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 12,
+ "id": "d5226957",
"metadata": {},
"outputs": [],
"source": [
@@ -803,12 +940,13 @@
" future_df,\n",
" prediction_length=prediction_length,\n",
" freq=freq,\n",
- " target_col=\"unit_sales\",\n",
+ " target=\"unit_sales\",\n",
")"
]
},
{
"cell_type": "markdown",
+ "id": "4611c3e6",
"metadata": {},
"source": [
"We can now send the payload to the endpoint."
@@ -816,8 +954,11 @@
},
{
"cell_type": "code",
- "execution_count": 12,
- "metadata": {},
+ "execution_count": 13,
+ "id": "504a731e",
+ "metadata": {
+ "lines_to_next_cell": 1
+ },
"outputs": [],
"source": [
"response = predictor.predict(payload)"
@@ -825,17 +966,21 @@
},
{
"cell_type": "markdown",
+ "id": "742be985",
"metadata": {},
"source": [
- "Note how Chronos-Bolt generated predictions for >300 time series in the dataset (with covariates!) in less than 2 seconds, even when running on a small CPU instance.\n",
+ "Note how Chronos-2 generated predictions for >300 time series in the dataset (with covariates!) in less than 2 seconds.\n",
"\n",
"Finally, we can convert the response back to a long-format data frame."
]
},
{
"cell_type": "code",
- "execution_count": 13,
- "metadata": {},
+ "execution_count": 14,
+ "id": "a48807f6",
+ "metadata": {
+ "lines_to_next_cell": 1
+ },
"outputs": [],
"source": [
"def convert_response_to_df(response, freq=\"D\"):\n",
@@ -848,18 +993,35 @@
"\n",
" Returns:\n",
" pd.DataFrame: Long-format DataFrame with timestamps, item_id, and forecasted values.\n",
+ " For multivariate forecasts, creates separate rows for each target dimension (target_1, target_2, etc.).\n",
" \"\"\"\n",
" dfs = []\n",
" for forecast in response[\"predictions\"]:\n",
- " forecast_df = pd.DataFrame(forecast).drop(columns=[\"start\"])\n",
- " forecast_df[\"timestamp\"] = pd.date_range(forecast[\"start\"], freq=freq, periods=len(forecast_df))\n",
- " dfs.append(forecast_df)\n",
- " return pd.concat(dfs)"
+ " if isinstance(forecast[\"mean\"], list) and isinstance(forecast[\"mean\"][0], list):\n",
+ " # Multivariate forecast\n",
+ " timestamps = pd.date_range(forecast[\"start\"], freq=freq, periods=len(forecast[\"mean\"][0]))\n",
+ " for dim_idx in range(len(forecast[\"mean\"])):\n",
+ " dim_data = {\"item_id\": forecast.get(\"item_id\"), \"timestamp\": timestamps, \"target\": f\"target_{dim_idx + 1}\"}\n",
+ " for key, value in forecast.items():\n",
+ " if key not in [\"item_id\", \"start\"]:\n",
+ " dim_data[key] = value[dim_idx]\n",
+ " dfs.append(pd.DataFrame(dim_data))\n",
+ " else:\n",
+ " # Univariate forecast\n",
+ " forecast_df = pd.DataFrame(forecast).drop(columns=[\"start\"])\n",
+ " forecast_df[\"timestamp\"] = pd.date_range(forecast[\"start\"], freq=freq, periods=len(forecast_df))\n",
+ " # Reorder columns to have item_id and timestamp first\n",
+ " cols = [\"item_id\", \"timestamp\"] + [c for c in forecast_df.columns if c not in [\"item_id\", \"timestamp\"]]\n",
+ " forecast_df = forecast_df[cols]\n",
+ " dfs.append(forecast_df)\n",
+ "\n",
+ " return pd.concat(dfs, ignore_index=True)"
]
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 15,
+ "id": "ce0cf954",
"metadata": {},
"outputs": [
{
@@ -883,74 +1045,74 @@
" \n",
" \n",
" | \n",
+ " item_id | \n",
+ " timestamp | \n",
" mean | \n",
" 0.1 | \n",
" 0.5 | \n",
" 0.9 | \n",
- " item_id | \n",
- " timestamp | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
- " 315.504037 | \n",
- " 210.074945 | \n",
- " 315.504037 | \n",
- " 487.484408 | \n",
" 1062_101 | \n",
" 2018-06-11 | \n",
+ " 320.0 | \n",
+ " 186.0 | \n",
+ " 320.0 | \n",
+ " 488.0 | \n",
"
\n",
" \n",
" | 1 | \n",
- " 315.364478 | \n",
- " 200.272695 | \n",
- " 315.364478 | \n",
- " 508.145850 | \n",
" 1062_101 | \n",
" 2018-06-18 | \n",
+ " 318.0 | \n",
+ " 175.0 | \n",
+ " 318.0 | \n",
+ " 496.0 | \n",
"
\n",
" \n",
" | 2 | \n",
- " 310.507265 | \n",
- " 193.902630 | \n",
- " 310.507265 | \n",
- " 511.559740 | \n",
" 1062_101 | \n",
" 2018-06-25 | \n",
+ " 316.0 | \n",
+ " 169.0 | \n",
+ " 316.0 | \n",
+ " 508.0 | \n",
"
\n",
" \n",
" | 3 | \n",
- " 317.322873 | \n",
- " 200.051215 | \n",
- " 317.322873 | \n",
- " 525.013830 | \n",
" 1062_101 | \n",
" 2018-07-02 | \n",
+ " 316.0 | \n",
+ " 171.0 | \n",
+ " 316.0 | \n",
+ " 506.0 | \n",
"
\n",
" \n",
" | 4 | \n",
- " 319.089405 | \n",
- " 199.634549 | \n",
- " 319.089405 | \n",
- " 534.102518 | \n",
" 1062_101 | \n",
" 2018-07-09 | \n",
+ " 310.0 | \n",
+ " 165.0 | \n",
+ " 310.0 | \n",
+ " 506.0 | \n",
"
\n",
" \n",
"\n",
""
],
"text/plain": [
- " mean 0.1 0.5 0.9 item_id timestamp\n",
- "0 315.504037 210.074945 315.504037 487.484408 1062_101 2018-06-11\n",
- "1 315.364478 200.272695 315.364478 508.145850 1062_101 2018-06-18\n",
- "2 310.507265 193.902630 310.507265 511.559740 1062_101 2018-06-25\n",
- "3 317.322873 200.051215 317.322873 525.013830 1062_101 2018-07-02\n",
- "4 319.089405 199.634549 319.089405 534.102518 1062_101 2018-07-09"
+ " item_id timestamp mean 0.1 0.5 0.9\n",
+ "0 1062_101 2018-06-11 320.0 186.0 320.0 488.0\n",
+ "1 1062_101 2018-06-18 318.0 175.0 318.0 496.0\n",
+ "2 1062_101 2018-06-25 316.0 169.0 316.0 508.0\n",
+ "3 1062_101 2018-07-02 316.0 171.0 316.0 506.0\n",
+ "4 1062_101 2018-07-09 310.0 165.0 310.0 506.0"
]
},
- "execution_count": 14,
+ "execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
@@ -962,6 +1124,7 @@
},
{
"cell_type": "markdown",
+ "id": "e89cbc36",
"metadata": {},
"source": [
"## Clean up the endpoint\n",
@@ -970,8 +1133,11 @@
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {},
+ "execution_count": 16,
+ "id": "9602a0a5",
+ "metadata": {
+ "lines_to_next_cell": 3
+ },
"outputs": [],
"source": [
"predictor.delete_predictor()"
@@ -979,9 +1145,13 @@
}
],
"metadata": {
- "instance_type": "ml.t3.medium",
+ "jupytext": {
+ "cell_metadata_filter": "-all",
+ "main_language": "python",
+ "notebook_metadata_filter": "-all"
+ },
"kernelspec": {
- "display_name": "base",
+ "display_name": "sm",
"language": "python",
"name": "python3"
},
@@ -995,9 +1165,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.10"
+ "version": "3.11.11"
}
},
"nbformat": 4,
- "nbformat_minor": 4
+ "nbformat_minor": 5
}