Skip to content

Commit

Permalink
test static_df with polars (#863)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez committed Jan 22, 2024
1 parent 4af1f67 commit fb0a2d6
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2206,23 +2206,33 @@
"#| polars\n",
"models = [LSTM(h=12, input_size=24, max_steps=5, hist_exog_list=['zeros'], scaler_type='robust')]\n",
"nf = NeuralForecast(models=models, freq='M')\n",
"nf.fit(AirPassengersPanel_train)\n",
"nf.fit(AirPassengersPanel_train, static_df=AirPassengersStatic)\n",
"insample_preds = nf.predict_insample()\n",
"preds = nf.predict()\n",
"cv_res = nf.cross_validation(df=AirPassengersPanel_train)\n",
"cv_res = nf.cross_validation(df=AirPassengersPanel_train, static_df=AirPassengersStatic)\n",
"\n",
"renamer = {'unique_id': 'uid', 'ds': 'time', 'y': 'target'}\n",
"inverse_renamer = {v: k for k, v in renamer.items()}\n",
"AirPassengers_pl = polars.from_pandas(AirPassengersPanel_train)\n",
"AirPassengers_pl = AirPassengers_pl.rename(renamer)\n",
"AirPassengersStatic_pl = polars.from_pandas(AirPassengersStatic)\n",
"AirPassengersStatic_pl = AirPassengersStatic_pl.rename({'unique_id': 'uid'})\n",
"nf = NeuralForecast(models=models, freq='1mo')\n",
"nf.fit(\n",
" AirPassengers_pl, id_col='uid', time_col='time', target_col='target'\n",
" AirPassengers_pl,\n",
" static_df=AirPassengersStatic_pl,\n",
" id_col='uid',\n",
" time_col='time',\n",
" target_col='target',\n",
")\n",
"insample_preds_pl = nf.predict_insample()\n",
"preds_pl = nf.predict()\n",
"cv_res_pl = nf.cross_validation(\n",
" df=AirPassengers_pl, id_col='uid', time_col='time', target_col='target'\n",
" df=AirPassengers_pl,\n",
" static_df=AirPassengersStatic_pl,\n",
" id_col='uid',\n",
" time_col='time',\n",
" target_col='target',\n",
")\n",
"\n",
"def assert_equal_dfs(pandas_df, polars_df):\n",
Expand Down

0 comments on commit fb0a2d6

Please sign in to comment.