diff --git a/nbs/tsdataset.ipynb b/nbs/tsdataset.ipynb index 92afc3eb3..9781ef7a7 100644 --- a/nbs/tsdataset.ipynb +++ b/nbs/tsdataset.ipynb @@ -61,7 +61,7 @@ "import torch\n", "import utilsforecast.processing as ufp\n", "from torch.utils.data import Dataset, DataLoader\n", - "from utilsforecast.compat import DataFrame, pl_Series" + "from utilsforecast.compat import DataFrame, pl_Series, pl_DataFrame" ] }, { @@ -354,7 +354,10 @@ "\n", " # Static features\n", " if static_df is not None:\n", - " static_cols = static_df.columns.drop(id_col)\n", + " if isinstance(static_df, pd.DataFrame):\n", + " static_cols = static_df.columns.drop(id_col)\n", + " elif isinstance(static_df, pl_DataFrame):\n", + " static_cols = static_df.columns.remove(id_col)\n", " static = ufp.to_numpy(static_df[static_cols])\n", " else:\n", " static = None\n", diff --git a/neuralforecast/tsdataset.py b/neuralforecast/tsdataset.py index e251bba47..597bbfc28 100644 --- a/neuralforecast/tsdataset.py +++ b/neuralforecast/tsdataset.py @@ -13,7 +13,7 @@ import torch import utilsforecast.processing as ufp from torch.utils.data import Dataset, DataLoader -from utilsforecast.compat import DataFrame, pl_Series +from utilsforecast.compat import DataFrame, pl_Series, pl_DataFrame # %% ../nbs/tsdataset.ipynb 5 class TimeSeriesLoader(DataLoader): @@ -317,7 +317,10 @@ def from_df( # Static features if static_df is not None: - static_cols = static_df.columns.drop(id_col) + if isinstance(static_df, pd.DataFrame): + static_cols = static_df.columns.drop(id_col) + elif isinstance(static_df, pl_DataFrame): + static_cols = static_df.columns.remove(id_col) static = ufp.to_numpy(static_df[static_cols]) else: static = None