diff --git a/CHANGELOG.md b/CHANGELOG.md index ddaf4288a0202..6ca4a599978d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -297,6 +297,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed training loop total batch counter when accumulate grad batches was enabled ([#7692](https://github.com/PyTorchLightning/pytorch-lightning/pull/7692)) +- Fixed a bug where skipping an optimizer while using amp causes amp to trigger an assertion error ([#7975](https://github.com/PyTorchLightning/pytorch-lightning/pull/7975)) + + ## [1.3.2] - 2021-05-18 ### Changed diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 9f2b63f014da1..e25f46d9ec239 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -83,19 +83,21 @@ def pre_optimizer_step( f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})." " To request, please file a Github issue in PyTorch and tag @mcarilli" ) - lambda_closure() if not pl_module.automatic_optimization: self.scaler.unscale_(optimizer) pl_module.trainer.call_hook("on_after_backward") + self.scaler.step(optimizer) + self.scaler.update() + else: + result = lambda_closure() + # lambda_closure returning None indicates that backward has been skipped + if result is not None: + self.scaler.step(optimizer) + self.scaler.update() return False - def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: - """Updates the GradScaler""" - self.scaler.step(optimizer) - self.scaler.update() - @contextmanager def train_step_context(self) -> Generator[None, None, None]: """Enable autocast context""" diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py index 6d0dbed2cf88b..cf58427b071ce 100644 --- a/tests/plugins/test_amp_plugins.py +++ b/tests/plugins/test_amp_plugins.py @@ -99,6 +99,47 @@ def test_amp_gradient_unscale(tmpdir, accum: int): trainer.fit(model) +@RunIf(min_gpus=1, amp_native=True) +def test_amp_skip_optimizer(tmpdir): + """ + Test that optimizers can be skipped when using amp + """ + + class CustomBoringModel(BoringModel): + + def __init__(self): + super().__init__() + self.layer1 = torch.nn.Linear(32, 32) + self.layer2 = torch.nn.Linear(32, 2) + + def forward(self, x: torch.Tensor): + x = self.layer1(x) + x = self.layer2(x) + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + if optimizer_idx == 1: + return None + output = self(batch) + return self.loss(batch, output) + + def configure_optimizers(self): + return [ + torch.optim.SGD(self.layer1.parameters(), lr=0.1), + torch.optim.SGD(self.layer2.parameters(), lr=0.1), + ] + + trainer = Trainer( + default_root_dir=tmpdir, + gpus=1, + fast_dev_run=1, + amp_backend='native', + precision=16, + ) + model = CustomBoringModel() + trainer.fit(model) + + @RunIf(min_gpus=2, amp_apex=True, special=True) @pytest.mark.parametrize("amp_level", ['O2']) def test_amp_apex_ddp_fit(amp_level, tmpdir):