Skip to content

Commit

Permalink
Fix plugin closure execution order
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Sep 2, 2021
1 parent 7535093 commit 5840b7b
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def pre_optimizer_step(
**kwargs: Any,
) -> bool:
"""Hook to do something before each optimizer step."""
lambda_closure() # APEX amp does not support closures
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
lambda_closure() # APEX amp does not support closures
optimizer.step(**kwargs)
return False

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/precision/deepspeed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def pre_optimizer_step(
**kwargs: Any,
) -> bool:
"""Hook to do something before each optimizer step."""
lambda_closure() # DeepSpeed does not support closures
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
lambda_closure() # DeepSpeed does not support closures
deepspeed_engine = model.trainer.model
deepspeed_engine.step()
return False
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def pre_optimizer_step(
" To request, please file a Github issue in PyTorch and tag @mcarilli"
)
result = True
# FIXME: is this correct for manual?
if model.automatic_optimization:
result = lambda_closure()
self.scaler.unscale_(optimizer)
Expand Down
9 changes: 4 additions & 5 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def _train_batch(self, *args, **kwargs):
def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), current_epoch=0, **kwargs):
using_native_amp = kwargs.get("amp_backend") == "native"
using_deepspeed = kwargs.get("plugins") == "deepspeed"
using_plugin = kwargs.get("amp_backend") or kwargs.get("plugins")
out = []
on_before_optimizer_step = [
dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)),
Expand All @@ -292,10 +293,8 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
dict(name="Callback.on_batch_start", args=(trainer, model)),
dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i, 0)),
dict(name="on_train_batch_start", args=(ANY, i, 0)),
# these are before the training step because
# they are not part of the `training_step_and_backward` closure, however,
# with native amp, the closure is run first and then the optimizer step.
*(on_before_optimizer_step if not using_native_amp else []),
# without a precision plugin, we execute the closure inside the `optimizer.step`
*([] if using_plugin else on_before_optimizer_step),
dict(name="forward", args=(ANY,)),
dict(name="training_step", args=(ANY, i)),
dict(name="training_step_end", args=(dict(loss=ANY),)),
Expand All @@ -308,7 +307,7 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
*([dict(name="backward", args=(ANY, ANY, 0))] if not using_deepspeed else []),
dict(name="Callback.on_after_backward", args=(trainer, model)),
dict(name="on_after_backward"),
*(on_before_optimizer_step if using_native_amp else []),
*(on_before_optimizer_step if using_plugin else []),
dict(
name="optimizer_step",
args=(current_epoch, i, ANY, 0, ANY),
Expand Down

0 comments on commit 5840b7b

Please sign in to comment.