Skip to content

Commit

Permalink
Return the output of the optimizer step (#11711)
Browse files Browse the repository at this point in the history
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
  • Loading branch information
carmocca and rohitgr7 committed Feb 9, 2022
1 parent 9e63281 commit 8822117
Show file tree
Hide file tree
Showing 14 changed files with 71 additions and 67 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247))


- Return the output of the `optimizer.step`. This can be useful for `LightningLite` users, manual optimization users, or users overriding `LightningModule.optimizer_step` ([#11711](https://github.com/PyTorchLightning/pytorch-lightning/pull/11711))


- Teardown the active loop and strategy on exception ([#11620](https://github.com/PyTorchLightning/pytorch-lightning/pull/11620))


Expand Down
13 changes: 7 additions & 6 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1541,18 +1541,19 @@ def optimizer_step(
using_lbfgs: bool = False,
) -> None:
r"""
Override this method to adjust the default way the
:class:`~pytorch_lightning.trainer.trainer.Trainer` calls each optimizer.
By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example
once per optimizer. This method (and ``zero_grad()``) won't be called during the
accumulation phase when ``Trainer(accumulate_grad_batches != 1)``.
Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls
each optimizer.
By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example once per optimizer.
This method (and ``zero_grad()``) won't be called during the accumulation phase when
``Trainer(accumulate_grad_batches != 1)``. Overriding this hook has no benefit with manual optimization.
Args:
epoch: Current epoch
batch_idx: Index of current batch
optimizer: A PyTorch optimizer
optimizer_idx: If you used multiple optimizers, this indexes into that list.
optimizer_closure: Closure for all optimizers. This closure must be executed as it includes the
optimizer_closure: The optimizer closure. This closure must be executed as it includes the
calls to ``training_step()``, ``optimizer.zero_grad()``, and ``backward()``.
on_tpu: ``True`` if TPU backward is required
using_native_amp: ``True`` if using native amp
Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,16 @@ def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]:
yield
lightning_module.untoggle_optimizer(self._optimizer_idx)

def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> None:
def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> Any:
"""Performs a single optimization step (parameter update).
Args:
closure: An optional optimizer_closure.
closure: An optional optimizer closure.
kwargs: Any additional arguments to the ``optimizer.step()`` call.
Returns:
The output from the step call, which is generally the output of the closure execution.
Example::
# Scenario for a GAN using manual optimization
Expand Down Expand Up @@ -163,7 +166,7 @@ def closure_dis():
assert self._strategy is not None
assert self._strategy.lightning_module is not None
with self._strategy.lightning_module.trainer.profiler.profile(profiler_action):
self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
return self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)


def _init_optimizers_and_lr_schedulers(
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/lite/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def optimizer(self) -> Optimizer:
def state_dict(self) -> Dict[str, Tensor]:
return self._strategy.optimizer_state(self.optimizer)

def step(self, closure: Optional[Callable] = None) -> None:
def step(self, closure: Optional[Callable] = None) -> Any:
closure = closure or _do_nothing_closure
self._strategy.optimizer_step(
return self._strategy.optimizer_step(
self.optimizer,
opt_idx=0,
closure=closure,
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def optimizer_step(
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
) -> None:
) -> Any:
if isinstance(optimizer, LBFGS):
raise MisconfigurationException(
f"apex AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
Expand All @@ -90,7 +90,8 @@ def optimizer_step(
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:
optimizer.step(**kwargs)
return optimizer.step(**kwargs)
return closure_result

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if "amp_scaling_state" in checkpoint:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def optimizer_step(
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
) -> None:
) -> Any:
if isinstance(optimizer, LBFGS):
raise MisconfigurationException(
f"DeepSpeed and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
Expand All @@ -76,7 +76,7 @@ def optimizer_step(
)
# DeepSpeed handles the optimizer step internally
deepspeed_engine = model.trainer.model if isinstance(model, pl.LightningModule) else model
deepspeed_engine.step(**kwargs)
return deepspeed_engine.step(**kwargs)

def clip_gradients(
self,
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/precision/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def optimizer_step(
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
) -> None:
) -> Any:
"""IPUs handle the optimizer step internally."""
if isinstance(optimizer, LBFGS):
raise MisconfigurationException(
Expand All @@ -64,6 +64,7 @@ def optimizer_step(
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`"
" requesting this feature."
)
return closure_result

def clip_gradients(
self,
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def optimizer_step(
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
) -> None:
) -> Any:
if self.scaler is None:
# skip scaler logic, as bfloat16 does not require scaler
return super().optimizer_step(model, optimizer, optimizer_idx, closure, **kwargs)
Expand All @@ -90,8 +90,10 @@ def optimizer_step(
# in manual optimization, the closure does not return a value
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
self.scaler.step(optimizer, **kwargs)
step_output = self.scaler.step(optimizer, **kwargs)
self.scaler.update()
return step_output
return closure_result

def autocast_context_manager(self) -> Union["old_autocast", "new_autocast"]:
if _TORCH_GREATER_EQUAL_1_10:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,11 @@ def optimizer_step(
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
) -> None:
) -> Any:
"""Hook to run the optimizer step."""
if isinstance(model, pl.LightningModule):
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
optimizer.step(closure=closure, **kwargs)
return optimizer.step(closure=closure, **kwargs)

def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
if trainer.track_grad_norm == -1:
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/precision/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def optimizer_step(
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any
) -> None:
) -> Any:
if isinstance(model, pl.LightningModule):
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})
Expand All @@ -49,3 +49,4 @@ def optimizer_step(
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`"
" requesting this feature."
)
return closure_result
4 changes: 2 additions & 2 deletions pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def optimizer_step(
closure: Callable[[], Any],
model: Optional[Union["pl.LightningModule", Module]] = None,
**kwargs: Any,
) -> None:
) -> Any:
"""performs the actual optimizer step.
Args:
Expand All @@ -189,7 +189,7 @@ def optimizer_step(
**kwargs: Any extra arguments to ``optimizer.step``
"""
model = model or self.lightning_module
self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)

def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Setup a model and multiple optimizers together.
Expand Down
41 changes: 0 additions & 41 deletions tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,47 +76,6 @@ def test_property_logger(tmpdir):
assert model.logger == logger


def test_params_groups_and_state_are_accessible(tmpdir):
class TestModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"loss": loss}

def configure_optimizers(self):
optimizer = SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = Adam(self.layer.parameters(), lr=0.1)
return [optimizer, optimizer_2]

def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_idx,
optimizer_closure,
on_tpu=False,
using_native_amp=False,
using_lbfgs=False,
):
# warm up lr
if self.trainer.global_step < 500:
lr_scale = min(1.0, float(self.trainer.global_step + 1) / 500.0)
for pg in optimizer.param_groups:
pg["lr"] = lr_scale * 0.01

optimizer.step(closure=optimizer_closure)

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
max_epochs=1, default_root_dir=tmpdir, limit_train_batches=8, limit_val_batches=1, accumulate_grad_batches=1
)

trainer.fit(model)


def test_toggle_untoggle_2_optimizers_no_shared_parameters(tmpdir):
class TestModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx=None):
Expand Down
35 changes: 33 additions & 2 deletions tests/core/test_lightning_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import DEFAULT, patch
from unittest.mock import DEFAULT, Mock, patch

import pytest
import torch
Expand Down Expand Up @@ -95,7 +95,10 @@ def closure(opt):
opt_1.step()

closure(opt_2)
opt_2.step()
step_output = opt_2.step()
# check that the step output is returned with manual optimization
# since the optimizer is mocked, the step output is a Mock
assert isinstance(step_output, Mock)

def configure_optimizers(self):
optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
Expand Down Expand Up @@ -314,3 +317,31 @@ def test_lightning_optimizer_keeps_hooks(tmpdir):
assert len(optimizer._fwd_handles) == 1
del lightning_optimizer
assert len(optimizer._fwd_handles) == 1


def test_params_groups_and_state_are_accessible(tmpdir):
class TestModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.__loss = loss
return loss

def configure_optimizers(self):
optimizer = SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = Adam(self.layer.parameters(), lr=0.1)
return [optimizer, optimizer_2]

def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **__):
# check attributes are accessible
assert all("lr" in pg for pg in optimizer.param_groups)
assert optimizer.state is optimizer._optimizer.state
assert optimizer.defaults is optimizer._optimizer.defaults

loss = optimizer.step(closure=optimizer_closure)
# the optimizer step still returns the loss
assert loss == self.__loss

model = TestModel()
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir, limit_train_batches=1, limit_val_batches=0)
trainer.fit(model)
4 changes: 3 additions & 1 deletion tests/lite/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ def test_lite_optimizer_steps():
"""Test that the LiteOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer."""
optimizer = Mock()
strategy = Mock()
strategy.optimizer_step.return_value = 123
lite_optimizer = _LiteOptimizer(optimizer=optimizer, strategy=strategy)
lite_optimizer.step()
step_output = lite_optimizer.step()
assert step_output == 123
strategy.optimizer_step.assert_called_once()
strategy.optimizer_step.assert_called_with(optimizer, opt_idx=0, closure=ANY, model=strategy.model)

0 comments on commit 8822117

Please sign in to comment.