Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
dd80bfa
improve test
tchaton Aug 26, 2021
2d7981f
resolve bug
tchaton Aug 26, 2021
ba50ead
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2021
97ee3f8
update changelog
tchaton Aug 26, 2021
c8cd17e
remove test
tchaton Aug 26, 2021
a1602ae
Merge branch 'logging' of https://github.com/PyTorchLightning/pytorch…
tchaton Aug 26, 2021
226f9a4
update
tchaton Aug 26, 2021
7dc0246
Merge branch 'logging' of https://github.com/PyTorchLightning/pytorch…
tchaton Aug 26, 2021
54df136
improvement
tchaton Aug 26, 2021
6918724
update
tchaton Aug 26, 2021
cb51f34
Merge branch 'logging' of https://github.com/PyTorchLightning/pytorch…
tchaton Aug 26, 2021
ac2a13e
resolve tests
tchaton Aug 26, 2021
4ed928e
update on comments
tchaton Aug 27, 2021
fdbd065
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2021
55238de
resolve test
tchaton Aug 27, 2021
a7c596b
Merge branch 'logging' of https://github.com/PyTorchLightning/pytorch…
tchaton Aug 27, 2021
c87691f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2021
18f148f
resolve typo
tchaton Aug 27, 2021
f7912ba
update
tchaton Aug 27, 2021
662f720
update
tchaton Aug 27, 2021
7f77ba0
Update tests/core/test_metric_result_integration.py
rohitgr7 Aug 27, 2021
b72571d
Refactor and simplify
carmocca Aug 27, 2021
9d2a785
update
tchaton Aug 27, 2021
5c2367c
Merge branch 'logging' of https://github.com/PyTorchLightning/pytorch…
tchaton Aug 27, 2021
95b6a33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2021
5a4eab2
update
tchaton Aug 27, 2021
266da25
Merge branch 'logging' of https://github.com/PyTorchLightning/pytorch…
tchaton Aug 27, 2021
d2c5bc6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2021
eb791d6
update
tchaton Aug 27, 2021
622881b
update
tchaton Aug 27, 2021
cdf6438
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2021
0d47ec2
Push back changes :)
carmocca Aug 27, 2021
c5ef444
Fix test
carmocca Aug 27, 2021
3696300
Cache callback metrics
carmocca Aug 27, 2021
4afa246
Fix test
carmocca Aug 27, 2021
d4177cb
Merge branch 'master' into logging
carmocca Aug 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug in the binary search mode of auto batch size scaling where exception was thrown if the first trainer run resulted in OOM ([#8954](https://github.com/PyTorchLightning/pytorch-lightning/pull/8954))


- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))


- Fixed not setting a default value for `max_epochs` if `max_time` was specified on the `Trainer` constructor ([#9072](https://github.com/PyTorchLightning/pytorch-lightning/pull/9072))


Expand Down
59 changes: 34 additions & 25 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,33 @@ class MetricSource(LightningEnum):
@dataclass
class _Sync:
fn: Optional[Callable] = None
should: bool = False
_should: bool = False
rank_zero_only: bool = False
op: Optional[str] = None
group: Optional[Any] = None

def __post_init__(self) -> None:
if self.fn is None:
self.fn = self.no_op
self._generate_sync_fn()

@property
def should(self) -> bool:
return self._should

@should.setter
def should(self, should: bool) -> None:
self._should = should
# `self._fn` needs to be re-generated.
self._generate_sync_fn()

def _generate_sync_fn(self) -> None:
"""Used to compute the syncing function and cache it."""
fn = self.no_op if self.fn is None or not self.should or self.rank_zero_only else self.fn
# save the function as `_fn` as the meta are being re-created and the object references need to match.
self._fn = partial(fn, reduce_op=self.op, group=self.group)

@property
def __call__(self) -> Any:
return (
partial(self.fn, reduce_op=self.op, group=self.group)
if self.should and not self.rank_zero_only
else self.no_op
)
return self._fn

@staticmethod
def no_op(value: Any, *_, **__) -> Any:
Expand All @@ -75,31 +86,28 @@ class _Metadata:
logger: bool = True
on_step: bool = False
on_epoch: bool = True
_reduce_fx: Callable = torch.mean
reduce_fx: Callable = torch.mean
enable_graph: bool = False
dataloader_idx: Optional[int] = None
metric_attribute: Optional[str] = None
_sync: Optional[_Sync] = None

@property
def reduce_fx(self) -> Callable:
return self._reduce_fx
def __post_init__(self) -> None:
self._parse_reduce_fx()

@reduce_fx.setter
def reduce_fx(self, reduce_fx: Union[str, Callable]) -> None:
def _parse_reduce_fx(self) -> None:
error = (
"Only `self.log(..., reduce_fx={min,max,mean,sum})` are currently supported."
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`."
f" Found: {reduce_fx}"
f" Found: {self.reduce_fx}"
)
self._reduce_fx = reduce_fx
if isinstance(reduce_fx, str):
reduce_fx = reduce_fx.lower()
if isinstance(self.reduce_fx, str):
reduce_fx = self.reduce_fx.lower()
if reduce_fx == "avg":
reduce_fx = "mean"
if reduce_fx not in ("min", "max", "mean", "sum"):
raise MisconfigurationException(error)
self._reduce_fx = getattr(torch, reduce_fx)
self.reduce_fx = getattr(torch, reduce_fx)
elif self.is_custom_reduction:
raise MisconfigurationException(error)

Expand Down Expand Up @@ -178,11 +186,11 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
def update(self, value: _METRIC, batch_size: torch.Tensor) -> None:
if self.is_tensor:
value = value.float()
self._forward_cache = value
# performance: no need to accumulate on values only logged on_step
if self.meta.on_step and not self.meta.on_epoch:
self.value = self.meta.sync(value)
self._forward_cache = self.value = self.meta.sync(value)
return
self._forward_cache = value
# perform accumulation with reduction
if self.meta.is_mean_reduction:
self.value += value.mean() * batch_size
Expand All @@ -201,8 +209,7 @@ def compute(self) -> torch.Tensor:
if self.meta.is_mean_reduction:
cumulated_batch_size = self.meta.sync(self.cumulated_batch_size)
return value / cumulated_batch_size
elif self.meta.is_max_reduction or self.meta.is_min_reduction or self.meta.is_sum_reduction:
return value
return value
return self.value.compute()

def reset(self) -> None:
Expand Down Expand Up @@ -449,12 +456,12 @@ def log(
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
dataloader_idx=dataloader_idx,
metric_attribute=metric_attribute,
)
meta.reduce_fx = reduce_fx
meta.sync = _Sync(should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only)
meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only)

# register logged value if it doesn't exist
if key not in self:
Expand Down Expand Up @@ -680,6 +687,8 @@ def load_state_dict(

if not metrics:
return

# iterate through result metrics and re-attached Metric references on reload.
result_metrics = self.result_metrics
for metric_attribute, metric in metrics.items():
for result_metric in result_metrics:
Expand Down
12 changes: 8 additions & 4 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,14 @@ def sync_ddp(
if group is None:
group = torch.distributed.group.WORLD

op = reduce_op if isinstance(reduce_op, ReduceOp) else ReduceOp.SUM

if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
divide_by_world_size = True
if isinstance(reduce_op, str):
if reduce_op.lower() in ("avg", "mean"):
op = ReduceOp.SUM
divide_by_world_size = True
else:
op = getattr(ReduceOp, reduce_op.upper())
else:
op = reduce_op

# sync all processes before reduction
torch.distributed.barrier(group=group)
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection
from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource, ResultCollection
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_7
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -336,7 +336,7 @@ def on_save_checkpoint(self, checkpoint) -> None:
# default sync fn
new_results = ResultCollection(False, device)
new_results.load_state_dict(state_dict, map_location="cpu")
assert new_results["validation_step.v"].meta.sync.fn == _Sync.no_op
assert new_results["validation_step.v"].meta.sync.fn is None

# check map location
assert new_results["validation_step.v"].value.device.type == "cpu"
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _setup_ddp(rank, worldsize):
def _ddp_test_fn(rank, worldsize):
_setup_ddp(rank, worldsize)
tensor = torch.tensor([1.0])
sync = _Sync(sync_ddp_if_available, should=True, op="SUM")
sync = _Sync(sync_ddp_if_available, _should=True, op="SUM")
actual = sync(tensor)
assert actual.item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors"

Expand Down
75 changes: 56 additions & 19 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,37 +359,74 @@ def get_expected(on_epoch, values):
assert is_included if should_include else not is_included


@pytest.mark.parametrize("gpus", [None, pytest.param(1, marks=RunIf(min_gpus=1))])
class LoggingSyncDistModel(BoringModel):
def __init__(self, fake_result):
super().__init__()
self.fake_result = fake_result

@property
def rank(self) -> int:
return self.trainer.global_rank

def training_step(self, batch, batch_idx):
value = self.fake_result + self.rank
self.log("foo", value, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="sum")
self.log("foo_2", 2, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="sum")
self.log("foo_3", 2, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="mean")
self.log("foo_4", value, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="mean")
self.log("foo_5", batch_idx + self.rank, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="max")

self.log("foo_6", value, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
self.log("foo_7", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
self.log("foo_8", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean")
self.log("foo_9", value, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean")
self.log("foo_10", batch_idx, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="max")
return super().training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
self.log("bar", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
self.log("bar_2", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean")
self.log("bar_3", batch_idx + self.rank, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="max")
return super().validation_step(batch, batch_idx)


@pytest.mark.parametrize(
"gpus", [None, pytest.param(1, marks=RunIf(min_gpus=1)), pytest.param(2, marks=RunIf(min_gpus=2))]
)
def test_logging_sync_dist_true(tmpdir, gpus):
"""
Tests to ensure that the sync_dist flag works (should just return the original value)
"""
fake_result = 1

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
self.log("foo", fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
self.log("foo_2", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
return super().training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
self.log("bar", fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
return super().validation_step(batch, batch_idx)

model = TestModel()
model = LoggingSyncDistModel(fake_result)
trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=1,
max_epochs=2,
limit_train_batches=3,
limit_val_batches=3,
weights_summary=None,
gpus=gpus,
)
trainer.fit(model)

assert trainer.logged_metrics["foo"] == fake_result
assert trainer.logged_metrics["foo_2"] == 2
assert trainer.logged_metrics["bar"] == fake_result
num_devices = 1 if gpus is None else gpus
use_multiple_devices = num_devices > 1
total = fake_result * num_devices + 1

metrics = trainer.callback_metrics
assert metrics["foo"] == total if use_multiple_devices else fake_result
assert metrics["foo_2"] == 2 * num_devices
assert metrics["foo_3"] == 2
assert metrics["foo_4"] == total / num_devices if use_multiple_devices else 1
assert metrics["foo_5"] == fake_result * 2 + 1 if use_multiple_devices else fake_result * 2
assert metrics["foo_6"] == fake_result * 3 * 2 + 3 if use_multiple_devices else fake_result * 3 * 2
assert metrics["foo_7"] == 2 * num_devices * 3
assert metrics["foo_8"] == 2
assert metrics["foo_9"] == (fake_result * 2 + 1) / num_devices if use_multiple_devices else fake_result
assert metrics["foo_10"] == 2
assert metrics["bar"] == fake_result * 3 * num_devices
assert metrics["bar_2"] == fake_result
assert metrics["bar_3"] == 2 + int(use_multiple_devices)


@RunIf(min_gpus=2, special=True)
Expand Down