Skip to content

Commit

Permalink
fix cross_validation results with uneven windows (#989)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez committed May 4, 2024
1 parent 0c1a760 commit b85b07d
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
os: [ubuntu-latest, macos-13, windows-latest]
python-version: ['3.8', '3.9', '3.10', '3.11']
exclude:
- os: windows-latest
Expand Down
55 changes: 52 additions & 3 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -932,9 +932,29 @@
" output_length = len(model.loss.output_names)\n",
" fcsts[:,col_idx:(col_idx + output_length)] = model_fcsts\n",
" col_idx += output_length\n",
" if self.scalers_: \n",
" indptr = np.append(0, np.full(self.dataset.n_groups, self.h * n_windows).cumsum())\n",
" fcsts = self._scalers_target_inverse_transform(fcsts, indptr)\n",
" # we may have allocated more space than needed\n",
" # each serie can produce at most (serie.size - 1) // self.h CV windows\n",
" effective_sizes = ufp.counts_by_id(fcsts_df, id_col)['counts'].to_numpy()\n",
" needs_trim = effective_sizes.sum() != fcsts.shape[0]\n",
" if self.scalers_ or needs_trim:\n",
" indptr = np.arange(\n",
" 0,\n",
" n_windows * self.h * (self.dataset.n_groups + 1),\n",
" n_windows * self.h,\n",
" dtype=np.int32,\n",
" )\n",
" if self.scalers_:\n",
" fcsts = self._scalers_target_inverse_transform(fcsts, indptr)\n",
" if needs_trim:\n",
" # we keep only the effective samples of each serie from the cv results\n",
" trimmed = np.empty_like(\n",
" fcsts, shape=(effective_sizes.sum(), fcsts.shape[1])\n",
" )\n",
" cv_indptr = np.append(0, effective_sizes).cumsum(dtype=np.int32)\n",
" for i in range(fcsts.shape[1]):\n",
" ga = GroupedArray(fcsts[:, i], indptr)\n",
" trimmed[:, i] = ga._tails(cv_indptr)\n",
" fcsts = trimmed\n",
"\n",
" self._fitted = True\n",
"\n",
Expand Down Expand Up @@ -2204,6 +2224,7 @@
" Y_hat_df[Y_hat_df_cv.columns],\n",
" Y_hat_df_cv,\n",
" check_dtype=False,\n",
" atol=1e-5,\n",
" )"
]
},
Expand All @@ -2218,6 +2239,34 @@
"test_cross_validation(AirPassengersPanel, AirPassengersStatic, h=12, test_size=12)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "03396c73",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test cv with series of different sizes\n",
"series = pd.DataFrame({\n",
" 'unique_id': np.repeat([0, 1], [10, 15]),\n",
" 'ds': np.arange(25),\n",
" 'y': np.random.rand(25),\n",
"})\n",
"nf = NeuralForecast(\n",
" freq=1,\n",
" models=[MLP(input_size=5, h=5, max_steps=0, enable_progress_bar=False)]\n",
")\n",
"cv_df = nf.cross_validation(df=series, n_windows=3, step_size=5)\n",
"expected = pd.DataFrame({\n",
" 'unique_id': np.repeat([0, 1], [5, 10]),\n",
" 'ds': np.hstack([np.arange(5, 10), np.arange(15, 25)]),\n",
" 'cutoff': np.repeat([4, 14, 19], 5)\n",
"})\n",
"expected = expected.merge(series, on=['unique_id', 'ds'])\n",
"pd.testing.assert_frame_equal(expected, cv_df.drop(columns='MLP'))"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
4 changes: 2 additions & 2 deletions nbs/models.bitcn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@
"Y_test_df = Y_df[Y_df.ds>'1959-12-31'] # 12 test\n",
"\n",
"dataset, *_ = TimeSeriesDataset.from_df(Y_train_df)\n",
"model = BiTCN(h=12, input_size=24, max_steps=500, scaler_type='standard')\n",
"model = BiTCN(h=12, input_size=24, max_steps=5, scaler_type='standard')\n",
"model.fit(dataset=dataset)\n",
"y_hat = model.predict(dataset=dataset)\n",
"Y_test_df['BiTCN'] = y_hat\n",
Expand Down Expand Up @@ -449,7 +449,7 @@
" BiTCN(h=12,\n",
" input_size=24,\n",
" loss=GMM(n_components=7, return_params=True, level=[80,90]),\n",
" max_steps=500,\n",
" max_steps=5,\n",
" scaler_type='standard',\n",
" futr_exog_list=['y_[lag12]'],\n",
" hist_exog_list=None,\n",
Expand Down
2 changes: 1 addition & 1 deletion nbs/models.hint.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@
" for parent_idx, children_list in parent_children_dict.items():\n",
" parent_value = hint_mean[parent_idx]\n",
" children_sum = hint_mean[children_list].sum()\n",
" np.testing.assert_allclose(children_sum, parent_value)"
" np.testing.assert_allclose(children_sum, parent_value, rtol=1e-6)"
]
},
{
Expand Down
26 changes: 22 additions & 4 deletions neuralforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,11 +874,29 @@ def _no_refit_cross_validation(
output_length = len(model.loss.output_names)
fcsts[:, col_idx : (col_idx + output_length)] = model_fcsts
col_idx += output_length
if self.scalers_:
indptr = np.append(
0, np.full(self.dataset.n_groups, self.h * n_windows).cumsum()
# we may have allocated more space than needed
# each serie can produce at most (serie.size - 1) // self.h CV windows
effective_sizes = ufp.counts_by_id(fcsts_df, id_col)["counts"].to_numpy()
needs_trim = effective_sizes.sum() != fcsts.shape[0]
if self.scalers_ or needs_trim:
indptr = np.arange(
0,
n_windows * self.h * (self.dataset.n_groups + 1),
n_windows * self.h,
dtype=np.int32,
)
fcsts = self._scalers_target_inverse_transform(fcsts, indptr)
if self.scalers_:
fcsts = self._scalers_target_inverse_transform(fcsts, indptr)
if needs_trim:
# we keep only the effective samples of each serie from the cv results
trimmed = np.empty_like(
fcsts, shape=(effective_sizes.sum(), fcsts.shape[1])
)
cv_indptr = np.append(0, effective_sizes).cumsum(dtype=np.int32)
for i in range(fcsts.shape[1]):
ga = GroupedArray(fcsts[:, i], indptr)
trimmed[:, i] = ga._tails(cv_indptr)
fcsts = trimmed

self._fitted = True

Expand Down

0 comments on commit b85b07d

Please sign in to comment.