diff --git a/src/fev/adapters.py b/src/fev/adapters.py index 6a64028..e551156 100644 --- a/src/fev/adapters.py +++ b/src/fev/adapters.py @@ -100,7 +100,12 @@ class GluonTSAdapter(PandasAdapter): """Converts dataset to format required by GluonTS.""" @staticmethod - def _convert_dtypes(df: pd.DataFrame, float_dtype: str = "float32") -> pd.DataFrame: + def _convert_dtypes( + df: pd.DataFrame, + id_column: str, + category_as_ordinal: bool = False, + float_dtype: str = "float32", + ) -> pd.DataFrame: """Convert numeric dtypes to float32 and object dtypes to category""" astype_dict = {} for col in df.columns: @@ -108,7 +113,11 @@ def _convert_dtypes(df: pd.DataFrame, float_dtype: str = "float32") -> pd.DataFr astype_dict[col] = "category" elif pd.api.types.is_numeric_dtype(df[col]): astype_dict[col] = float_dtype - return df.astype(astype_dict) + df = df.astype(astype_dict) + if category_as_ordinal: + cat_cols = [col for col in df.select_dtypes(include="category").columns if col != id_column] + df = df.assign(**{col: df[col].cat.codes for col in cat_cols}) + return df @classmethod def convert_input_data( @@ -135,10 +144,10 @@ def convert_input_data( static_columns=static_columns, ) - past_df = cls._convert_dtypes(past_df) - future_df = cls._convert_dtypes(future_df) + past_df = cls._convert_dtypes(past_df, id_column=id_column, category_as_ordinal=True) + future_df = cls._convert_dtypes(future_df, id_column=id_column, category_as_ordinal=True) if static_df is not None: - static_df = cls._convert_dtypes(static_df.set_index(id_column)) + static_df = cls._convert_dtypes(static_df.set_index(id_column), id_column=id_column) else: static_df = pd.DataFrame()