From 5d7c1ac118a2a245334e20a75d665570f2eccb6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Mon, 21 Aug 2023 15:28:49 -0600 Subject: [PATCH] check for level when prediction_intervals are set (#615) --- nbs/src/core/core.ipynb | 33 ++++++++++++++++++++++++++++++--- statsforecast/core.py | 25 +++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/nbs/src/core/core.ipynb b/nbs/src/core/core.ipynb index df029edc6..76a829f7c 100644 --- a/nbs/src/core/core.ipynb +++ b/nbs/src/core/core.ipynb @@ -57,6 +57,7 @@ "#| hide\n", "import warnings\n", "warnings.filterwarnings('ignore', category=FutureWarning)\n", + "warnings.filterwarnings('always', category=UserWarning)\n", "\n", "from nbdev.showdoc import add_docs, show_doc\n", "from statsforecast.models import Naive" @@ -75,6 +76,7 @@ "import random\n", "import re\n", "import reprlib\n", + "import warnings\n", "from itertools import product\n", "from os import cpu_count\n", "from typing import Any, List, Optional, Union, Dict\n", @@ -128,7 +130,7 @@ "outputs": [], "source": [ "#| hide\n", - "from fastcore.test import test_eq, test_fail\n", + "from fastcore.test import test_eq, test_fail, test_warns\n", "from statsforecast.utils import generate_series" ] }, @@ -1410,6 +1412,12 @@ " DataFrame with `models` columns for point predictions and probabilistic\n", " predictions for all fitted `models`.\n", " \"\"\"\n", + " \n", + " if any(getattr(m, 'prediction_intervals', None) is not None for m in self.models) and level is None:\n", + " warnings.warn(\n", + " \"Prediction intervals are set but `level` was not provided. \"\n", + " \"Predictions won't have intervals.\"\n", + " )\n", " X, level = self._parse_X_level(h=h, X=X_df, level=level)\n", " if self.n_jobs == 1:\n", " fcsts, cols = self.ga.predict(fm=self.fitted_, h=h, X=X, level=level)\n", @@ -1459,6 +1467,8 @@ " DataFrame with `models` columns for point predictions and probabilistic\n", " predictions for all fitted `models`.\n", " \"\"\"\n", + " if prediction_intervals is not None and level is None:\n", + " raise ValueError('You must specify `level` when using `prediction_intervals`') \n", " self._set_prediction_intervals(prediction_intervals=prediction_intervals)\n", " self._prepare_fit(df, sort_df)\n", " X, level = self._parse_X_level(h=h, X=X_df, level=level)\n", @@ -1625,6 +1635,8 @@ " raise Exception('you must define `n_windows` or `test_size`')\n", " else:\n", " raise Exception('you must define `n_windows` or `test_size` but not both')\n", + " if prediction_intervals is not None and level is None:\n", + " raise ValueError('You must specify `level` when using `prediction_intervals`') \n", " self._set_prediction_intervals(prediction_intervals=prediction_intervals)\n", " self._prepare_fit(df, sort_df)\n", " series_sizes = np.diff(self.ga.indptr)\n", @@ -2190,6 +2202,8 @@ " sort_df: bool = True,\n", " prediction_intervals: Optional[ConformalIntervals] = None,\n", " ):\n", + " if prediction_intervals is not None and level is None:\n", + " raise ValueError('You must specify `level` when using `prediction_intervals`')\n", " if self._is_native(df=df):\n", " return super().forecast(\n", " h=h,\n", @@ -3225,7 +3239,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6a36a749-811f-4bc2-bcfe-40b57ee2deb9", + "id": "028f1935-4305-4e0a-a0b9-0bb4afea900f", "metadata": {}, "outputs": [], "source": [ @@ -3246,7 +3260,20 @@ "test_eq(\n", " sf.predict(h=12, level=[80, 90]),\n", " sf.forecast(df=series_subset, h=12, level=[80, 90], prediction_intervals=ConformalIntervals(h=12)),\n", - ")" + ")\n", + "\n", + "# test errors/warnings are raised when level is not specified\n", + "intervals = ConformalIntervals(h=12)\n", + "sf2 = StatsForecast(\n", + " models=[ADIDA()],\n", + " freq='D', \n", + " n_jobs=1,\n", + ")\n", + "sf2.fit(df=series_subset, prediction_intervals=intervals)\n", + "test_warns(lambda: sf2.predict(h=12))\n", + "test_fail(lambda: sf2.forecast(df=series_subset, h=12, prediction_intervals=intervals))\n", + "test_fail(lambda: sf2.fit_predict(df=series_subset, h=12, prediction_intervals=intervals))\n", + "test_fail(lambda: sf2.cross_validation(df=series_subset, h=12, prediction_intervals=intervals))" ] }, { diff --git a/statsforecast/core.py b/statsforecast/core.py index be94593b4..4140983bd 100644 --- a/statsforecast/core.py +++ b/statsforecast/core.py @@ -9,6 +9,7 @@ import random import re import reprlib +import warnings from itertools import product from os import cpu_count from typing import Any, List, Optional, Union, Dict @@ -956,6 +957,18 @@ def predict( DataFrame with `models` columns for point predictions and probabilistic predictions for all fitted `models`. """ + + if ( + any( + getattr(m, "prediction_intervals", None) is not None + for m in self.models + ) + and level is None + ): + warnings.warn( + "Prediction intervals are set but `level` was not provided. " + "Predictions won't have intervals." + ) X, level = self._parse_X_level(h=h, X=X_df, level=level) if self.n_jobs == 1: fcsts, cols = self.ga.predict(fm=self.fitted_, h=h, X=X, level=level) @@ -1005,6 +1018,10 @@ def fit_predict( DataFrame with `models` columns for point predictions and probabilistic predictions for all fitted `models`. """ + if prediction_intervals is not None and level is None: + raise ValueError( + "You must specify `level` when using `prediction_intervals`" + ) self._set_prediction_intervals(prediction_intervals=prediction_intervals) self._prepare_fit(df, sort_df) X, level = self._parse_X_level(h=h, X=X_df, level=level) @@ -1180,6 +1197,10 @@ def cross_validation( raise Exception("you must define `n_windows` or `test_size`") else: raise Exception("you must define `n_windows` or `test_size` but not both") + if prediction_intervals is not None and level is None: + raise ValueError( + "You must specify `level` when using `prediction_intervals`" + ) self._set_prediction_intervals(prediction_intervals=prediction_intervals) self._prepare_fit(df, sort_df) series_sizes = np.diff(self.ga.indptr) @@ -1857,6 +1878,10 @@ def forecast( sort_df: bool = True, prediction_intervals: Optional[ConformalIntervals] = None, ): + if prediction_intervals is not None and level is None: + raise ValueError( + "You must specify `level` when using `prediction_intervals`" + ) if self._is_native(df=df): return super().forecast( h=h,