Skip to content

Commit

Permalink
Change data format in WandbLogger (tinkoff-ai#309)
Browse files Browse the repository at this point in the history
* Update data format in logger

* Upd CHANGELOG
  • Loading branch information
julia-shenshina committed Nov 24, 2021
1 parent 644b33a commit 50988f4
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 28 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed
- Rename confidence interval to prediction interval, start working with quantiles instead of interval_width ([#285](https://github.com/tinkoff-ai/etna/pull/285))
- Changed format of forecast and test dataframes in WandbLogger ([#309](https://github.com/tinkoff-ai/etna/pull/309))

### Fixed

Expand Down
34 changes: 6 additions & 28 deletions etna/loggers/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,31 +117,6 @@ def log(self, msg: Union[str, Dict[str, Any]], **kwargs):
"""
pass

@staticmethod
def _prepare_table(df_raw: pd.DataFrame) -> pd.DataFrame:
"""
Prepare dataframe to be sent to wandb.
Parameters
----------
df_raw:
Dataframe to change
Returns
-------
result: pd.DataFrame
"""
df = df_raw.copy()
squashed_columns = [
f"{segment_column}/{feature_column}"
for segment_column, feature_column in zip(
df.columns.get_level_values("segment"), df.columns.get_level_values("feature")
)
]
df.columns = squashed_columns
df.reset_index(inplace=True)
return df

def log_backtest_metrics(
self, ts: "TSDataset", metrics_df: pd.DataFrame, forecast_df: pd.DataFrame, fold_info_df: pd.DataFrame
):
Expand All @@ -158,10 +133,11 @@ def log_backtest_metrics(
Fold information from backtest
"""
from etna.analysis import plot_backtest_interactive
from etna.datasets import TSDataset

if self.table:
self.experiment.summary["metrics"] = wandb.Table(data=metrics_df)
self.experiment.summary["forecast"] = wandb.Table(data=self._prepare_table(forecast_df))
self.experiment.summary["forecast"] = wandb.Table(data=TSDataset.to_flatten(forecast_df))
self.experiment.summary["fold_info"] = wandb.Table(data=fold_info_df)

if self.plot:
Expand Down Expand Up @@ -202,13 +178,15 @@ def log_backtest_run(self, metrics: pd.DataFrame, forecast: pd.DataFrame, test:
test:
Dataframe with ground truth
"""
from etna.datasets import TSDataset

columns_name = list(metrics.columns)
metrics.reset_index(inplace=True)
metrics.columns = ["segment"] + columns_name
if self.table:
self.experiment.summary["metrics"] = wandb.Table(data=metrics)
self.experiment.summary["forecast"] = wandb.Table(data=self._prepare_table(forecast))
self.experiment.summary["test"] = wandb.Table(data=self._prepare_table(test))
self.experiment.summary["forecast"] = wandb.Table(data=TSDataset.to_flatten(forecast))
self.experiment.summary["test"] = wandb.Table(data=TSDataset.to_flatten(test))

metrics_dict = (
metrics.drop(["segment"], axis=1)
Expand Down

0 comments on commit 50988f4

Please sign in to comment.