From 4f687e2631c6142fafc4a0f9caf3020cb8ae0f86 Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Fri, 11 Jun 2021 16:18:47 -0700 Subject: [PATCH 1/3] Make optimizers skippable when using amp --- .../plugins/precision/native_amp.py | 14 +++--- tests/plugins/test_amp_plugins.py | 50 ++++++++++++++++++- 2 files changed, 56 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 9f2b63f014da1..e25f46d9ec239 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -83,19 +83,21 @@ def pre_optimizer_step( f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})." " To request, please file a Github issue in PyTorch and tag @mcarilli" ) - lambda_closure() if not pl_module.automatic_optimization: self.scaler.unscale_(optimizer) pl_module.trainer.call_hook("on_after_backward") + self.scaler.step(optimizer) + self.scaler.update() + else: + result = lambda_closure() + # lambda_closure returning None indicates that backward has been skipped + if result is not None: + self.scaler.step(optimizer) + self.scaler.update() return False - def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: - """Updates the GradScaler""" - self.scaler.step(optimizer) - self.scaler.update() - @contextmanager def train_step_context(self) -> Generator[None, None, None]: """Enable autocast context""" diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py index 6d0dbed2cf88b..55d66aec97595 100644 --- a/tests/plugins/test_amp_plugins.py +++ b/tests/plugins/test_amp_plugins.py @@ -18,10 +18,10 @@ import pytest import torch -from pytorch_lightning import Trainer +from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin from pytorch_lightning.plugins.precision import MixedPrecisionPlugin -from tests.helpers import BoringModel +from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf @@ -99,6 +99,52 @@ def test_amp_gradient_unscale(tmpdir, accum: int): trainer.fit(model) +@RunIf(min_gpus=1, amp_native=True) +def test_amp_skip_optimizer(tmpdir): + """ + Test that optimizers can be skipped when using amp + """ + + class CustomBoringModel(LightningModule): + + def __init__(self): + super().__init__() + self.layer1 = torch.nn.Linear(32, 32) + self.layer2 = torch.nn.Linear(32, 2) + + def forward(self, x: torch.Tensor): + x = self.layer1(x) + x = self.layer2(x) + return x + + def loss(self, batch, prediction): + # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls + return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) + + def training_step(self, batch, batch_idx, optimizer_idx): + if optimizer_idx == 1: + return None + output = self(batch) + return self.loss(batch, output) + + def configure_optimizers(self): + return [ + torch.optim.SGD(self.layer1.parameters(), lr=0.1), + torch.optim.SGD(self.layer2.parameters(), lr=0.1), + ] + + trainer = Trainer( + default_root_dir=tmpdir, + gpus=1, + max_epochs=1, + limit_train_batches=1, + amp_backend='native', + precision=16, + ) + model = CustomBoringModel() + trainer.fit(model, datamodule=BoringDataModule()) + + @RunIf(min_gpus=2, amp_apex=True, special=True) @pytest.mark.parametrize("amp_level", ['O2']) def test_amp_apex_ddp_fit(amp_level, tmpdir): From 60a90ffae29eec8e4d93b61b66aaeead5fdeefa8 Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Tue, 15 Jun 2021 09:50:54 -0700 Subject: [PATCH 2/3] Address comments --- tests/plugins/test_amp_plugins.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py index 55d66aec97595..cf58427b071ce 100644 --- a/tests/plugins/test_amp_plugins.py +++ b/tests/plugins/test_amp_plugins.py @@ -18,10 +18,10 @@ import pytest import torch -from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning import Trainer from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin from pytorch_lightning.plugins.precision import MixedPrecisionPlugin -from tests.helpers import BoringDataModule, BoringModel +from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -105,7 +105,7 @@ def test_amp_skip_optimizer(tmpdir): Test that optimizers can be skipped when using amp """ - class CustomBoringModel(LightningModule): + class CustomBoringModel(BoringModel): def __init__(self): super().__init__() @@ -117,10 +117,6 @@ def forward(self, x: torch.Tensor): x = self.layer2(x) return x - def loss(self, batch, prediction): - # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls - return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) - def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 1: return None @@ -136,13 +132,12 @@ def configure_optimizers(self): trainer = Trainer( default_root_dir=tmpdir, gpus=1, - max_epochs=1, - limit_train_batches=1, + fast_dev_run=1, amp_backend='native', precision=16, ) model = CustomBoringModel() - trainer.fit(model, datamodule=BoringDataModule()) + trainer.fit(model) @RunIf(min_gpus=2, amp_apex=True, special=True) From 9f8569fe3fb752006e40a9c9ba0196aa07d72d33 Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Tue, 15 Jun 2021 15:42:37 -0700 Subject: [PATCH 3/3] CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e63484ca8612..d2bdcd294698d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -254,6 +254,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `LearningRateMonitor` keys not properly setup when running with `BackboneFinetuning` Callback ([#7835](https://github.com/PyTorchLightning/pytorch-lightning/pull/7835)) +- Fixed a bug where skipping an optimizer while using amp causes amp to trigger an assertion error ([#7975](https://github.com/PyTorchLightning/pytorch-lightning/pull/7975)) + + ## [1.3.2] - 2021-05-18 ### Changed