Skip to content

Commit

Permalink
check for level when prediction_intervals are set (#615)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez committed Aug 21, 2023
1 parent c5702cf commit 5d7c1ac
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 3 deletions.
33 changes: 30 additions & 3 deletions nbs/src/core/core.ipynb
Expand Up @@ -57,6 +57,7 @@
"#| hide\n",
"import warnings\n",
"warnings.filterwarnings('ignore', category=FutureWarning)\n",
"warnings.filterwarnings('always', category=UserWarning)\n",
"\n",
"from nbdev.showdoc import add_docs, show_doc\n",
"from statsforecast.models import Naive"
Expand All @@ -75,6 +76,7 @@
"import random\n",
"import re\n",
"import reprlib\n",
"import warnings\n",
"from itertools import product\n",
"from os import cpu_count\n",
"from typing import Any, List, Optional, Union, Dict\n",
Expand Down Expand Up @@ -128,7 +130,7 @@
"outputs": [],
"source": [
"#| hide\n",
"from fastcore.test import test_eq, test_fail\n",
"from fastcore.test import test_eq, test_fail, test_warns\n",
"from statsforecast.utils import generate_series"
]
},
Expand Down Expand Up @@ -1410,6 +1412,12 @@
" DataFrame with `models` columns for point predictions and probabilistic\n",
" predictions for all fitted `models`.\n",
" \"\"\"\n",
" \n",
" if any(getattr(m, 'prediction_intervals', None) is not None for m in self.models) and level is None:\n",
" warnings.warn(\n",
" \"Prediction intervals are set but `level` was not provided. \"\n",
" \"Predictions won't have intervals.\"\n",
" )\n",
" X, level = self._parse_X_level(h=h, X=X_df, level=level)\n",
" if self.n_jobs == 1:\n",
" fcsts, cols = self.ga.predict(fm=self.fitted_, h=h, X=X, level=level)\n",
Expand Down Expand Up @@ -1459,6 +1467,8 @@
" DataFrame with `models` columns for point predictions and probabilistic\n",
" predictions for all fitted `models`.\n",
" \"\"\"\n",
" if prediction_intervals is not None and level is None:\n",
" raise ValueError('You must specify `level` when using `prediction_intervals`') \n",
" self._set_prediction_intervals(prediction_intervals=prediction_intervals)\n",
" self._prepare_fit(df, sort_df)\n",
" X, level = self._parse_X_level(h=h, X=X_df, level=level)\n",
Expand Down Expand Up @@ -1625,6 +1635,8 @@
" raise Exception('you must define `n_windows` or `test_size`')\n",
" else:\n",
" raise Exception('you must define `n_windows` or `test_size` but not both')\n",
" if prediction_intervals is not None and level is None:\n",
" raise ValueError('You must specify `level` when using `prediction_intervals`') \n",
" self._set_prediction_intervals(prediction_intervals=prediction_intervals)\n",
" self._prepare_fit(df, sort_df)\n",
" series_sizes = np.diff(self.ga.indptr)\n",
Expand Down Expand Up @@ -2190,6 +2202,8 @@
" sort_df: bool = True,\n",
" prediction_intervals: Optional[ConformalIntervals] = None,\n",
" ):\n",
" if prediction_intervals is not None and level is None:\n",
" raise ValueError('You must specify `level` when using `prediction_intervals`')\n",
" if self._is_native(df=df):\n",
" return super().forecast(\n",
" h=h,\n",
Expand Down Expand Up @@ -3225,7 +3239,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "6a36a749-811f-4bc2-bcfe-40b57ee2deb9",
"id": "028f1935-4305-4e0a-a0b9-0bb4afea900f",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -3246,7 +3260,20 @@
"test_eq(\n",
" sf.predict(h=12, level=[80, 90]),\n",
" sf.forecast(df=series_subset, h=12, level=[80, 90], prediction_intervals=ConformalIntervals(h=12)),\n",
")"
")\n",
"\n",
"# test errors/warnings are raised when level is not specified\n",
"intervals = ConformalIntervals(h=12)\n",
"sf2 = StatsForecast(\n",
" models=[ADIDA()],\n",
" freq='D', \n",
" n_jobs=1,\n",
")\n",
"sf2.fit(df=series_subset, prediction_intervals=intervals)\n",
"test_warns(lambda: sf2.predict(h=12))\n",
"test_fail(lambda: sf2.forecast(df=series_subset, h=12, prediction_intervals=intervals))\n",
"test_fail(lambda: sf2.fit_predict(df=series_subset, h=12, prediction_intervals=intervals))\n",
"test_fail(lambda: sf2.cross_validation(df=series_subset, h=12, prediction_intervals=intervals))"
]
},
{
Expand Down
25 changes: 25 additions & 0 deletions statsforecast/core.py
Expand Up @@ -9,6 +9,7 @@
import random
import re
import reprlib
import warnings
from itertools import product
from os import cpu_count
from typing import Any, List, Optional, Union, Dict
Expand Down Expand Up @@ -956,6 +957,18 @@ def predict(
DataFrame with `models` columns for point predictions and probabilistic
predictions for all fitted `models`.
"""

if (
any(
getattr(m, "prediction_intervals", None) is not None
for m in self.models
)
and level is None
):
warnings.warn(
"Prediction intervals are set but `level` was not provided. "
"Predictions won't have intervals."
)
X, level = self._parse_X_level(h=h, X=X_df, level=level)
if self.n_jobs == 1:
fcsts, cols = self.ga.predict(fm=self.fitted_, h=h, X=X, level=level)
Expand Down Expand Up @@ -1005,6 +1018,10 @@ def fit_predict(
DataFrame with `models` columns for point predictions and probabilistic
predictions for all fitted `models`.
"""
if prediction_intervals is not None and level is None:
raise ValueError(
"You must specify `level` when using `prediction_intervals`"
)
self._set_prediction_intervals(prediction_intervals=prediction_intervals)
self._prepare_fit(df, sort_df)
X, level = self._parse_X_level(h=h, X=X_df, level=level)
Expand Down Expand Up @@ -1180,6 +1197,10 @@ def cross_validation(
raise Exception("you must define `n_windows` or `test_size`")
else:
raise Exception("you must define `n_windows` or `test_size` but not both")
if prediction_intervals is not None and level is None:
raise ValueError(
"You must specify `level` when using `prediction_intervals`"
)
self._set_prediction_intervals(prediction_intervals=prediction_intervals)
self._prepare_fit(df, sort_df)
series_sizes = np.diff(self.ga.indptr)
Expand Down Expand Up @@ -1857,6 +1878,10 @@ def forecast(
sort_df: bool = True,
prediction_intervals: Optional[ConformalIntervals] = None,
):
if prediction_intervals is not None and level is None:
raise ValueError(
"You must specify `level` when using `prediction_intervals`"
)
if self._is_native(df=df):
return super().forecast(
h=h,
Expand Down

0 comments on commit 5d7c1ac

Please sign in to comment.