Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further clean up aggregation logic #12053

Merged
merged 5 commits into from Feb 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -580,6 +580,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `get_mp_spawn_kwargs` from `DDPSpawnStrategy` and `TPUSpawnStrategy` in favor of configuration in the `_SpawnLauncher` ([#11966](https://github.com/PyTorchLightning/pytorch-lightning/pull/11966))


- Removed `_aggregate_metrics`, `_reduce_agg_metrics`, and `_finalize_agg_metrics` from `LightningLoggerBase` ([#12053](https://github.com/PyTorchLightning/pytorch-lightning/pull/12053))


### Fixed

- Fixed an issue where `HorovodStrategy.teardown()` did not complete gracefully if an exception was thrown during callback setup [#11752](https://github.com/PyTorchLightning/pytorch-lightning/pull/11752)
Expand Down
70 changes: 3 additions & 67 deletions pytorch_lightning/loggers/base.py
Expand Up @@ -19,7 +19,7 @@
from abc import ABC, abstractmethod
from argparse import Namespace
from functools import wraps
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Union
from weakref import ReferenceType

import numpy as np
Expand Down Expand Up @@ -123,66 +123,6 @@ def update_agg_funcs(
"`LightningLoggerBase.update_agg_funcs` was deprecated in v1.6 and will be removed in v1.8."
)

def _aggregate_metrics(
self, metrics: Dict[str, float], step: Optional[int] = None
) -> Tuple[int, Optional[Dict[str, float]]]:
"""Aggregates metrics.

.. deprecated:: v1.6
This method is deprecated in v1.6 and will be removed in v1.8.

Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded

Returns:
Step and aggregated metrics. The return value could be ``None``. In such case, metrics
are added to the aggregation list, but not aggregated yet.
"""
# if you still receiving metric from the same step, just accumulate it
if step == self._prev_step:
self._metrics_to_agg.append(metrics)
return step, None

# compute the metrics
agg_step, agg_mets = self._reduce_agg_metrics()

# as new step received reset accumulator
self._metrics_to_agg = [metrics]
self._prev_step = step
return agg_step, agg_mets

def _reduce_agg_metrics(self):
"""Aggregate accumulated metrics.

See deprecation warning below.

.. deprecated:: v1.6
This method is deprecated in v1.6 and will be removed in v1.8.
"""
# compute the metrics
if not self._metrics_to_agg:
agg_mets = None
elif len(self._metrics_to_agg) == 1:
agg_mets = self._metrics_to_agg[0]
else:
agg_mets = merge_dicts(self._metrics_to_agg, self._agg_key_funcs, self._agg_default_func)
return self._prev_step, agg_mets

def _finalize_agg_metrics(self):
"""This shall be called before save/close.

See deprecation warning below.

.. deprecated:: v1.6
This method is deprecated in v1.6 and will be removed in v1.8.
"""
agg_step, metrics_to_log = self._reduce_agg_metrics()
self._metrics_to_agg = []

if metrics_to_log is not None:
self.log_metrics(metrics=metrics_to_log, step=agg_step)

def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
"""Aggregates and records metrics. This method doesn't log the passed metrics instantaneously, but instead
it aggregates them and logs only if metrics are ready to be logged.
Expand All @@ -195,10 +135,7 @@ def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = N
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
"""
agg_step, metrics_to_log = self._aggregate_metrics(metrics=metrics, step=step)
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved

if metrics_to_log:
self.log_metrics(metrics=metrics_to_log, step=agg_step)
self.log_metrics(metrics=metrics, step=step)

@abstractmethod
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
Expand All @@ -221,7 +158,7 @@ def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs):
Args:
params: :class:`~argparse.Namespace` containing the hyperparameters
args: Optional positional arguments, depends on the specific logger being used
kwargs: Optional keywoard arguments, depends on the specific logger being used
kwargs: Optional keyword arguments, depends on the specific logger being used
"""

def log_graph(self, model: "pl.LightningModule", input_array=None) -> None:
Expand All @@ -235,7 +172,6 @@ def log_graph(self, model: "pl.LightningModule", input_array=None) -> None:

def save(self) -> None:
"""Save log data."""
self._finalize_agg_metrics()
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

def finalize(self, status: str) -> None:
"""Do any processing that is necessary to finalize an experiment.
Expand Down