diff --git a/CHANGELOG.md b/CHANGELOG.md index ac3ba286f..5de407a45 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Rename confidence interval to prediction interval, start working with quantiles instead of interval_width ([#285](https://github.com/tinkoff-ai/etna/pull/285)) +- Changed format of forecast and test dataframes in WandbLogger ([#309](https://github.com/tinkoff-ai/etna/pull/309)) ### Fixed diff --git a/etna/loggers/wandb_logger.py b/etna/loggers/wandb_logger.py index b3d70d6a7..15ba060d7 100644 --- a/etna/loggers/wandb_logger.py +++ b/etna/loggers/wandb_logger.py @@ -117,31 +117,6 @@ def log(self, msg: Union[str, Dict[str, Any]], **kwargs): """ pass - @staticmethod - def _prepare_table(df_raw: pd.DataFrame) -> pd.DataFrame: - """ - Prepare dataframe to be sent to wandb. - - Parameters - ---------- - df_raw: - Dataframe to change - - Returns - ------- - result: pd.DataFrame - """ - df = df_raw.copy() - squashed_columns = [ - f"{segment_column}/{feature_column}" - for segment_column, feature_column in zip( - df.columns.get_level_values("segment"), df.columns.get_level_values("feature") - ) - ] - df.columns = squashed_columns - df.reset_index(inplace=True) - return df - def log_backtest_metrics( self, ts: "TSDataset", metrics_df: pd.DataFrame, forecast_df: pd.DataFrame, fold_info_df: pd.DataFrame ): @@ -158,10 +133,11 @@ def log_backtest_metrics( Fold information from backtest """ from etna.analysis import plot_backtest_interactive + from etna.datasets import TSDataset if self.table: self.experiment.summary["metrics"] = wandb.Table(data=metrics_df) - self.experiment.summary["forecast"] = wandb.Table(data=self._prepare_table(forecast_df)) + self.experiment.summary["forecast"] = wandb.Table(data=TSDataset.to_flatten(forecast_df)) self.experiment.summary["fold_info"] = wandb.Table(data=fold_info_df) if self.plot: @@ -202,13 +178,15 @@ def log_backtest_run(self, metrics: pd.DataFrame, forecast: pd.DataFrame, test: test: Dataframe with ground truth """ + from etna.datasets import TSDataset + columns_name = list(metrics.columns) metrics.reset_index(inplace=True) metrics.columns = ["segment"] + columns_name if self.table: self.experiment.summary["metrics"] = wandb.Table(data=metrics) - self.experiment.summary["forecast"] = wandb.Table(data=self._prepare_table(forecast)) - self.experiment.summary["test"] = wandb.Table(data=self._prepare_table(test)) + self.experiment.summary["forecast"] = wandb.Table(data=TSDataset.to_flatten(forecast)) + self.experiment.summary["test"] = wandb.Table(data=TSDataset.to_flatten(test)) metrics_dict = ( metrics.drop(["segment"], axis=1)