Skip to content

Commit

Permalink
Run plugin closure before on_before_optimizer_step [1/2] (#9288)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
carmocca and awaelchli committed Sep 7, 2021
1 parent d49709e commit 6892d53
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 55 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Expand Up @@ -285,6 +285,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333))


- Fixed the Apex and DeepSpeed plugin closure running after the `on_before_optimizer_step` hook ([#9288](https://github.com/PyTorchLightning/pytorch-lightning/issues/9288))


- Fixed the Native AMP plugin closure not running with manual optimization ([#9288](https://github.com/PyTorchLightning/pytorch-lightning/issues/9288))


- Fixed bug where data-loading functions where not getting the correct running stage passed ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858))


Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/core/lightning.py
Expand Up @@ -630,10 +630,8 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
- :class:`~torch.Tensor` - The loss tensor
- ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``
- ``None`` - Training will skip to the next batch
Note:
Returning ``None`` is currently not supported for multi-GPU or TPU, or with 16-bit precision enabled.
- ``None`` - Training will skip to the next batch. This is only for automatic optimization.
This is not supported for multi-GPU or TPU, or using ``DeepSpeed``.
In this step you'd normally do the forward pass and calculate the loss for a batch.
You can also do fancier things like multiple forward passes or something model specific.
Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/plugins/precision/apex_amp.py
Expand Up @@ -97,10 +97,13 @@ def pre_optimizer_step(
**kwargs: Any,
) -> bool:
"""Hook to do something before each optimizer step."""
result = lambda_closure() # APEX amp does not support closures
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
lambda_closure() # APEX amp does not support closures
optimizer.step(**kwargs)
skipped_backward = result is None
# in manual optimization, the closure does not return a value
if not model.automatic_optimization or not skipped_backward:
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
optimizer.step(**kwargs)
return False

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/plugins/precision/deepspeed_precision.py
Expand Up @@ -20,6 +20,7 @@
import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.warnings import WarningCache

Expand All @@ -42,9 +43,14 @@ def pre_optimizer_step(
**kwargs: Any,
) -> bool:
"""Hook to do something before each optimizer step."""
result = lambda_closure() # DeepSpeed does not support closures
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
# in manual optimization, the closure does not return a value
if model.automatic_optimization and result is None:
raise MisconfigurationException(
"Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`"
)
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
lambda_closure() # DeepSpeed does not support closures
deepspeed_engine = model.trainer.model
deepspeed_engine.step()
return False
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/plugins/precision/native_amp.py
Expand Up @@ -95,13 +95,13 @@ def pre_optimizer_step(
f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})."
" To request, please file a Github issue in PyTorch and tag @mcarilli"
)
result = True
if model.automatic_optimization:
result = lambda_closure()
result = lambda_closure() # native amp does not support closures
self.scaler.unscale_(optimizer)
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
# lambda_closure returning None indicates that backward has been skipped
if result is not None:
skipped_backward = result is None
# in manual optimization, the closure does not return a value
if not model.automatic_optimization or not skipped_backward:
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
self.scaler.step(optimizer)
self.scaler.update()
return False
Expand Down
15 changes: 9 additions & 6 deletions tests/models/test_hooks.py
Expand Up @@ -275,6 +275,7 @@ def _train_batch(self, *args, **kwargs):
def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), current_epoch=0, **kwargs):
using_native_amp = kwargs.get("amp_backend") == "native"
using_deepspeed = kwargs.get("plugins") == "deepspeed"
using_plugin = kwargs.get("amp_backend") or kwargs.get("plugins")
out = []
on_before_optimizer_step = [
dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)),
Expand All @@ -290,10 +291,8 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
dict(name="Callback.on_batch_start", args=(trainer, model)),
dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i, 0)),
dict(name="on_train_batch_start", args=(ANY, i, 0)),
# these are before the training step because
# they are not part of the `training_step_and_backward` closure, however,
# with native amp, the closure is run first and then the optimizer step.
*(on_before_optimizer_step if not using_native_amp else []),
# without a precision plugin, we execute the closure inside the `optimizer.step`
*([] if using_plugin else on_before_optimizer_step),
dict(name="forward", args=(ANY,)),
dict(name="training_step", args=(ANY, i)),
dict(name="training_step_end", args=(dict(loss=ANY),)),
Expand All @@ -306,7 +305,7 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
*([dict(name="backward", args=(ANY, ANY, 0))] if not using_deepspeed else []),
dict(name="Callback.on_after_backward", args=(trainer, model)),
dict(name="on_after_backward"),
*(on_before_optimizer_step if using_native_amp else []),
*(on_before_optimizer_step if using_plugin else []),
dict(
name="optimizer_step",
args=(current_epoch, i, ANY, 0, ANY),
Expand All @@ -322,6 +321,7 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
@staticmethod
def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **kwargs):
using_deepspeed = kwargs.get("plugins") == "deepspeed"
using_plugin = kwargs.get("amp_backend") or kwargs.get("plugins")
out = []
for i in range(batches):
out.extend(
Expand All @@ -342,8 +342,11 @@ def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **k
dict(name="on_after_backward"),
# `manual_backward` calls the previous 3
dict(name="manual_backward", args=(ANY,)),
*([dict(name="closure")] if using_plugin else []),
dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)),
dict(name="on_before_optimizer_step", args=(ANY, 0)),
# without a precision plugin, we execute the closure inside the `optimizer.step`
*([] if using_plugin else [dict(name="closure")]),
dict(name="training_step", args=(ANY, i)),
dict(name="training_step_end", args=(dict(loss=ANY),)),
dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i, 0)),
Expand Down Expand Up @@ -439,7 +442,7 @@ def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
self.manual_backward(loss)
opt.step()
opt.step(lambda: called.append({"name": "closure"}))
return {"loss": loss}

model = TestModel(called)
Expand Down
12 changes: 12 additions & 0 deletions tests/plugins/test_deepspeed_plugin.py
Expand Up @@ -796,3 +796,15 @@ def test_deepspeed_multigpu_no_schedulers(tmpdir):
trainer.fit(model)

_assert_save_model_is_equal(model, tmpdir, trainer)


@RunIf(min_gpus=1, deepspeed=True)
def test_deepspeed_skip_backward_raises(tmpdir):
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
return None

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, plugins=[DeepSpeedPlugin()], gpus=1, fast_dev_run=True, precision=16)
with pytest.raises(MisconfigurationException, match="returning `None` .* is not supported"):
trainer.fit(model)
85 changes: 49 additions & 36 deletions tests/trainer/optimization/test_manual_optimization.py
Expand Up @@ -64,17 +64,50 @@ def configure_optimizers(self):
return optimizer, optimizer_2


def test_multiple_optimizers_manual_no_return(tmpdir):
@pytest.mark.parametrize(
"kwargs",
[
{},
pytest.param({"gpus": 1, "precision": 16, "amp_backend": "native"}, marks=RunIf(amp_native=True, min_gpus=1)),
pytest.param(
{"gpus": 1, "precision": 16, "amp_backend": "apex", "amp_level": "O2"},
marks=RunIf(amp_apex=True, min_gpus=1),
),
],
)
def test_multiple_optimizers_manual_no_return(tmpdir, kwargs):
apex_optimizer_patches = []
apex_optimizer_steps = []

class TestModel(ManualOptModel):
def training_step(self, batch, batch_idx):
# avoid returning a value
super().training_step(batch, batch_idx)

def training_epoch_end(self, outputs) -> None:
def training_epoch_end(self, outputs):
# outputs is empty as training_step does not return
# and it is not automatic optimization
assert not outputs

def on_train_start(self):
if kwargs.get("amp_backend") != "apex":
return
# extremely ugly. APEX patches all the native torch optimizers on `_initialize` which we call on
# `ApexMixedPrecisionPlugin.dispatch`. Additionally, their replacement `new_step` functions are locally
# defined so can't even patch those, thus we need to create the mock after APEX has been initialized
nonlocal apex_optimizer_patches, apex_optimizer_steps
for opt in self.trainer.optimizers:
# `amp.scale_loss` will also patch the step to avoid it when gradient overflow happens. avoid it
opt._amp_stash.already_patched = True
patch = mock.patch.object(opt, "step")
apex_optimizer_patches.append(patch)
apex_optimizer_steps.append(patch.start())

def on_train_end(self):
if kwargs.get("amp_backend") == "apex":
for p in apex_optimizer_patches:
p.stop()

model = TestModel()
model.val_dataloader = None

Expand All @@ -86,12 +119,26 @@ def training_epoch_end(self, outputs) -> None:
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
**kwargs,
)

if kwargs.get("amp_backend") == "native":
# mock the scaler instead of the optimizer step because it can be skipped with NaNs
scaler_step_patch = mock.patch.object(
trainer.precision_plugin.scaler, "step", wraps=trainer.precision_plugin.scaler.step
)
scaler_step = scaler_step_patch.start()

with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock:
trainer.fit(model)
assert bwd_mock.call_count == limit_train_batches * 3

if kwargs.get("amp_backend") == "native":
scaler_step_patch.stop()
assert scaler_step.call_count == len(model.optimizers()) * limit_train_batches
if kwargs.get("amp_backend") == "apex":
assert [s.call_count for s in apex_optimizer_steps] == [len(model.optimizers())] * limit_train_batches


def test_multiple_optimizers_manual_return(tmpdir):
class TestModel(ManualOptModel):
Expand Down Expand Up @@ -171,40 +218,6 @@ def test_multiple_optimizers_manual_native_amp(tmpdir):
assert bwd_mock.call_count == limit_train_batches * 3


@RunIf(min_gpus=1, amp_apex=True)
def test_multiple_optimizers_manual_apex_no_return(tmpdir):
class TestModel(ManualOptModel):
def training_step(self, batch, batch_idx):
# avoid returning a value
super().training_step(batch, batch_idx)

def training_epoch_end(self, outputs) -> None:
# outputs is empty as training_step does not return
# and it is not automatic optimization
assert len(outputs) == 0

model = TestModel()
model.val_dataloader = None

limit_train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
precision=16,
amp_level="O2",
amp_backend="apex",
gpus=1,
)

with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock:
trainer.fit(model)
assert bwd_mock.call_count == limit_train_batches * 3


class ManualOptimizationExtendedModel(BoringModel):

count = 0
Expand Down

0 comments on commit 6892d53

Please sign in to comment.