Skip to content

Commit

Permalink
Fix manual optimization on AMP and skipping backward
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Sep 3, 2021
1 parent 066ae70 commit b2d4c7e
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
- ``None`` - Training will skip to the next batch
Note:
Returning ``None`` is currently not supported for multi-GPU or TPU, or with 16-bit precision enabled.
Returning ``None`` is currently not supported for multi-GPU or TPU.
In this step you'd normally do the forward pass and calculate the loss for a batch.
You can also do fancier things like multiple forward passes or something model specific.
Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,13 @@ def pre_optimizer_step(
**kwargs: Any,
) -> bool:
"""Hook to do something before each optimizer step."""
lambda_closure() # APEX amp does not support closures
result = lambda_closure() # APEX amp does not support closures
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
optimizer.step(**kwargs)
skipped_backward = result is None
# in manual optimization, the closure does not return a value
if not model.automatic_optimization or not skipped_backward:
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
optimizer.step(**kwargs)
return False

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
Expand Down
11 changes: 7 additions & 4 deletions pytorch_lightning/plugins/precision/deepspeed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,14 @@ def pre_optimizer_step(
**kwargs: Any,
) -> bool:
"""Hook to do something before each optimizer step."""
lambda_closure() # DeepSpeed does not support closures
result = lambda_closure() # DeepSpeed does not support closures
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
deepspeed_engine = model.trainer.model
deepspeed_engine.step()
skipped_backward = result is None
# in manual optimization, the closure does not return a value
if not model.automatic_optimization or not skipped_backward:
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
deepspeed_engine = model.trainer.model
deepspeed_engine.step()
return False

def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any, **kwargs: Any) -> None:
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,13 @@ def pre_optimizer_step(
f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})."
" To request, please file a Github issue in PyTorch and tag @mcarilli"
)
result = True
# FIXME: is this correct for manual?
if model.automatic_optimization:
result = lambda_closure()
result = lambda_closure()
self.scaler.unscale_(optimizer)
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
# lambda_closure returning None indicates that backward has been skipped
if result is not None:
skipped_backward = result is None
# in manual optimization, the closure does not return a value
if not model.automatic_optimization or not skipped_backward:
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
self.scaler.step(optimizer)
self.scaler.update()
return False
Expand Down

0 comments on commit b2d4c7e

Please sign in to comment.