Skip to content

Commit

Permalink
Do not force sync_dist=True on epoch end (#13364)
Browse files Browse the repository at this point in the history
Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
3 people committed Jul 22, 2022
1 parent 9596fab commit 238c991
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 4 deletions.
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Expand Up @@ -144,6 +144,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enabled using any Sampler in distributed environment in Lite ([#13646](https://github.com/PyTorchLightning/pytorch-lightning/pull/13646))


- Raised a warning instead of forcing `sync_dist=True` on epoch end ([13364](https://github.com/Lightning-AI/lightning/pull/13364))


- Updated `val_check_interval`(int) to consider total train batches processed instead of `_batches_that_stepped` for validation check during training ([#12832](https://github.com/Lightning-AI/lightning/pull/12832)


Expand Down
Expand Up @@ -24,11 +24,13 @@
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device
from pytorch_lightning.utilities.data import extract_batch_size
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.warnings import WarningCache
from pytorch_lightning.utilities.warnings import PossibleUserWarning, WarningCache

_IN_METRIC = Union[Metric, Tensor] # Do not include scalars as they were converted to tensors
_OUT_METRIC = Union[Tensor, Dict[str, Tensor]]
Expand Down Expand Up @@ -522,12 +524,26 @@ def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]:
cache = result_metric._forward_cache
elif not on_step and result_metric.meta.on_epoch:
if result_metric._computed is None:
# always reduce on epoch end
should = result_metric.meta.sync.should
result_metric.meta.sync.should = True
if not result_metric.meta.sync.should and distributed_available():
# ensure sync happens for FT since during a failure, the metrics are synced and saved to the
# checkpoint, so during restart, metrics on rank 0 are from the accumulated ones from the previous
# run, and on other ranks, they are 0. So we need to make sure they are synced in further training
# to ensure correct calculation.
if _fault_tolerant_training():
result_metric.meta.sync.should = True
else:
warning_cache.warn(
f"It is recommended to use `self.log({result_metric.meta.name!r}, ..., sync_dist=True)`"
" when logging on epoch level in distributed setting to accumulate the metric across"
" devices.",
category=PossibleUserWarning,
)
result_metric.compute()
result_metric.meta.sync.should = should

cache = result_metric._computed

if cache is not None:
if not isinstance(cache, Tensor):
raise ValueError(
Expand All @@ -536,6 +552,7 @@ def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]:
)
if not result_metric.meta.enable_graph:
return cache.detach()

return cache

def valid_items(self) -> Generator:
Expand Down
25 changes: 24 additions & 1 deletion tests/tests_pytorch/core/test_metric_result_integration.py
Expand Up @@ -34,7 +34,9 @@
_ResultMetric,
_Sync,
)
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.utils import no_warning_call


class DummyMetric(Metric):
Expand Down Expand Up @@ -456,6 +458,8 @@ def on_train_epoch_end(self) -> None:
"limit_val_batches": 0,
"accelerator": accelerator,
"devices": devices,
"enable_progress_bar": False,
"enable_model_summary": False,
}
trainer_kwargs.update(kwargs)
trainer = Trainer(**trainer_kwargs)
Expand All @@ -471,7 +475,7 @@ def on_train_epoch_end(self) -> None:
)
ckpt_path = os.path.join(tmpdir, ".pl_auto_save.ckpt")

trainer = Trainer(**trainer_kwargs, enable_progress_bar=False, enable_model_summary=False)
trainer = Trainer(**trainer_kwargs)
trainer.fit(model, ckpt_path=ckpt_path)
assert model.has_validated_sum

Expand Down Expand Up @@ -659,3 +663,22 @@ def on_train_start(self):
)
with pytest.raises(ValueError, match=r"compute\(\)` return of.*foo' must be a tensor"):
trainer.fit(model)


@pytest.mark.parametrize("distributed_env", [True, False])
def test_logger_sync_dist(distributed_env):
# self.log('bar', 7, ..., sync_dist=False)
meta = _Metadata("foo", "bar")
meta.sync = _Sync(_should=False)
result_metric = _ResultMetric(metadata=meta, is_tensor=True)
result_metric.update(torch.tensor(7.0), 10)

warning_ctx = pytest.warns if distributed_env else no_warning_call

with mock.patch(
"pytorch_lightning.trainer.connectors.logger_connector.result.distributed_available",
return_value=distributed_env,
):
with warning_ctx(PossibleUserWarning, match=r"recommended to use `self.log\('bar', ..., sync_dist=True\)`"):
value = _ResultCollection._get_cache(result_metric, on_step=False)
assert value == 7.0

0 comments on commit 238c991

Please sign in to comment.