From b2d4c7ee0c2d6b6d90b93b2f7c30d1ade3d8bdda Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 3 Sep 2021 03:45:49 +0200 Subject: [PATCH] Fix manual optimization on AMP and skipping backward --- pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/plugins/precision/apex_amp.py | 9 ++++++--- .../plugins/precision/deepspeed_precision.py | 11 +++++++---- pytorch_lightning/plugins/precision/native_amp.py | 11 +++++------ 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 5954af1c671e6..2ce0301a190ad 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -632,7 +632,7 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT: - ``None`` - Training will skip to the next batch Note: - Returning ``None`` is currently not supported for multi-GPU or TPU, or with 16-bit precision enabled. + Returning ``None`` is currently not supported for multi-GPU or TPU. In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific. diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 7c2c6e9bbcabe..b90682923851c 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -97,10 +97,13 @@ def pre_optimizer_step( **kwargs: Any, ) -> bool: """Hook to do something before each optimizer step.""" - lambda_closure() # APEX amp does not support closures + result = lambda_closure() # APEX amp does not support closures super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) - # the following should be in a `optimizer_step` hook but we don't have one in the precision plugin. - optimizer.step(**kwargs) + skipped_backward = result is None + # in manual optimization, the closure does not return a value + if not model.automatic_optimization or not skipped_backward: + # the following should be in a `optimizer_step` hook but we don't have one in the precision plugin. + optimizer.step(**kwargs) return False def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index fdb47c686464c..32f842bf28eed 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -42,11 +42,14 @@ def pre_optimizer_step( **kwargs: Any, ) -> bool: """Hook to do something before each optimizer step.""" - lambda_closure() # DeepSpeed does not support closures + result = lambda_closure() # DeepSpeed does not support closures super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) - # the following should be in a `optimizer_step` hook but we don't have one in the precision plugin. - deepspeed_engine = model.trainer.model - deepspeed_engine.step() + skipped_backward = result is None + # in manual optimization, the closure does not return a value + if not model.automatic_optimization or not skipped_backward: + # the following should be in a `optimizer_step` hook but we don't have one in the precision plugin. + deepspeed_engine = model.trainer.model + deepspeed_engine.step() return False def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any, **kwargs: Any) -> None: diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 78b0df7cb3697..4f52d52a1d693 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -96,14 +96,13 @@ def pre_optimizer_step( f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})." " To request, please file a Github issue in PyTorch and tag @mcarilli" ) - result = True - # FIXME: is this correct for manual? - if model.automatic_optimization: - result = lambda_closure() + result = lambda_closure() self.scaler.unscale_(optimizer) super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) - # lambda_closure returning None indicates that backward has been skipped - if result is not None: + skipped_backward = result is None + # in manual optimization, the closure does not return a value + if not model.automatic_optimization or not skipped_backward: + # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found self.scaler.step(optimizer) self.scaler.update() return False