From 4e95713b4535ab9f1b890a8c24c06c41d95173d7 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 24 Jan 2022 07:27:54 -0500 Subject: [PATCH 1/4] reduce only loss with dp --- pytorch_lightning/strategies/dp.py | 17 +++++++---------- tests/accelerators/test_dp.py | 26 ++++++++++++++++++++------ 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/strategies/dp.py b/pytorch_lightning/strategies/dp.py index 71d0090e2c8f8..7a2cbb30ba582 100644 --- a/pytorch_lightning/strategies/dp.py +++ b/pytorch_lightning/strategies/dp.py @@ -137,18 +137,15 @@ def predict_step(self, *args, **kwargs) -> STEP_OUTPUT: return self.model(*args, **kwargs) def training_step_end(self, output): - if not is_overridden("training_step_end", self.lightning_module): - return self.reduce(output) - return output + if is_overridden("training_step_end", self.lightning_module): + return output - def validation_step_end(self, output): - if not is_overridden("validation_step_end", self.lightning_module): - return self.reduce(output) - return output + if isinstance(output, dict) and "loss" in output: + output["loss"] = self.reduce(output["loss"]) + + elif isinstance(output, torch.Tensor): + output = self.reduce(output) - def test_step_end(self, output): - if not is_overridden("test_step_end", self.lightning_module): - return self.reduce(output) return output def teardown(self) -> None: diff --git a/tests/accelerators/test_dp.py b/tests/accelerators/test_dp.py index 7313728256b4e..eb72e7a731d09 100644 --- a/tests/accelerators/test_dp.py +++ b/tests/accelerators/test_dp.py @@ -134,8 +134,24 @@ def test_step(self, batch, batch_idx): def training_epoch_end(self, outputs): assert outputs[0]["loss"].shape == torch.Size([]) - assert outputs[0]["reduce_int"].item() == 0 # mean([0, 1]) = 0 - assert outputs[0]["reduce_float"].item() == 0.5 # mean([0., 1.]) = 0.5 + self._assert_extra_outputs(outputs) + + def validation_epoch_end(self, outputs): + assert outputs[0]["x"].shape == torch.Size([2]) + self._assert_extra_outputs(outputs) + + def test_epoch_end(self, outputs): + assert outputs[0]["y"].shape == torch.Size([2]) + self._assert_extra_outputs(outputs) + + def _assert_extra_outputs(self, outputs): + out = outputs[0]["reduce_int"] + assert torch.eq(out, torch.tensor([0, 1], device="cuda:0")).all() + assert out.dtype is torch.int + + out = outputs[0]["reduce_float"] + assert torch.eq(out, torch.tensor([0.0, 1.0], device="cuda:0")).all() + assert out.dtype is torch.float def test_dp_raise_exception_with_batch_transfer_hooks(tmpdir, monkeypatch): @@ -188,11 +204,9 @@ def test_dp_training_step_dict(tmpdir): trainer = pl.Trainer( default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=1, - limit_val_batches=1, - limit_test_batches=1, + fast_dev_run=True, gpus=2, strategy="dp", ) trainer.fit(model) + trainer.test(model) From 41c20f092cf951c6bd6986ce9dc458870c398fb7 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 24 Jan 2022 07:36:05 -0500 Subject: [PATCH 2/4] chlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index aa7c4f9b056bc..9b930586916e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -222,6 +222,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `MisconfigurationException` to `ModuleNotFoundError` when `rich` isn't available ([#11360](https://github.com/PyTorchLightning/pytorch-lightning/pull/11360)) +- Enable reducing only loss when using DP ([#11594](https://github.com/PyTorchLightning/pytorch-lightning/pull/11594)) + + ### Deprecated - Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/pull/10103)) From 8d5093cb2c104180d0844aded2be4f9788c66ad8 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 25 Jan 2022 05:27:36 +0530 Subject: [PATCH 3/4] Update CHANGELOG.md Co-authored-by: Aki Nitta --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b930586916e4..52725083087f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -222,7 +222,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `MisconfigurationException` to `ModuleNotFoundError` when `rich` isn't available ([#11360](https://github.com/PyTorchLightning/pytorch-lightning/pull/11360)) -- Enable reducing only loss when using DP ([#11594](https://github.com/PyTorchLightning/pytorch-lightning/pull/11594)) +- Disabled reducing training_step output other than loss when using DP ([#11594](https://github.com/PyTorchLightning/pytorch-lightning/pull/11594)) ### Deprecated From 3f23334126c94f579b76e8541ffb9a5e58fd132d Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Fri, 28 Jan 2022 00:58:16 +0530 Subject: [PATCH 4/4] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 52725083087f9..fff3976bc2434 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -222,9 +222,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `MisconfigurationException` to `ModuleNotFoundError` when `rich` isn't available ([#11360](https://github.com/PyTorchLightning/pytorch-lightning/pull/11360)) -- Disabled reducing training_step output other than loss when using DP ([#11594](https://github.com/PyTorchLightning/pytorch-lightning/pull/11594)) - +- When using DP (data-parallel), Lightning will no longer automatically reduce all tensors returned in training_step; it will only reduce the loss unless `training_step_end` is overridden ([#11594](https://github.com/PyTorchLightning/pytorch-lightning/pull/11594)) +- When using DP (data-parallel), the `training_epoch_end` hook will no longer receive reduced outputs from `training_step` and instead get the full tensor of results from all GPUs ([#11594](https://github.com/PyTorchLightning/pytorch-lightning/pull/11594)) ### Deprecated - Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/pull/10103))