Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WAITING FOR PL CORE MAINTAINER OPINION] Bugfix/17958 multi optimizer step count behaviour #19589

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ test: clean

# run tests with coverage
python -m coverage run --source src/lightning/pytorch -m pytest src/lightning/pytorch tests/tests_pytorch -v
# DO NOT SUBMIT
python -m coverage run --source lightning/pytorch -m pytest ../tests/tests_pytorch/trainer ../tests/tests_pytorch/loops -v --ignore ../tests/tests_pytorch/models/test_onnx.py
python -m coverage run --source lightning/pytorch -m pytest tests/tests_pytorch/loops/optimization/test_manual_loop.py::test_multiple_optimizers -v -s
python -m coverage run --source src/lightning/app -m pytest tests/tests/app -v
python -m coverage run --source src/lightning/fabric -m pytest src/lightning/fabric tests/tests_fabric -v
python -m coverage report
Expand Down
1 change: 1 addition & 0 deletions docs/source-pytorch/model/manual_optimization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Manual Optimization
*******************

.. DO NOT SUBMIT update + include example.
For advanced research topics like reinforcement learning, sparse coding, or GAN research, it may be desirable to
manually manage the optimization process, especially when dealing with multiple optimizers at the same time.

Expand Down
13 changes: 9 additions & 4 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,10 +951,15 @@ def configure_optimizers(self) -> OptimizerLRScheduler:

- **Single optimizer**.
- **List or Tuple** of optimizers.
- **Two lists** - The first list has multiple optimizers, and the second has multiple LR schedulers
(or multiple ``lr_scheduler_config``).
- **Dictionary**, with an ``"optimizer"`` key, and (optionally) a ``"lr_scheduler"``
key whose value is a single LR scheduler or ``lr_scheduler_config``.
- **Two lists** - The first list has one or more optimizers, and the second has one or more LR schedulers
(or ``lr_scheduler_config``s).
- **Dictionary**, with:
- an ``"optimizer"`` key
- an (optional) ``"lr_scheduler"`` key, whose value is an LR scheduler or ``lr_scheduler_config`` (one
or a list of more)
- an (optional) ``"should_increment"`` key, which is only relevant for the manual optimization mode, see
:ref:`manual optimization<common/optimization:Manual optimization>`

- **None** - Fit will run without any optimizer.

The ``lr_scheduler_config`` is a dictionary which contains the scheduler and its associated configuration.
Expand Down
107 changes: 81 additions & 26 deletions src/lightning/pytorch/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
from lightning.pytorch.utilities.types import LRSchedulerConfig, LRSchedulerTypeTuple
from lightning.pytorch.utilities.types import LRSchedulerConfig, LRSchedulerTypeTuple, OptimizerLRScheduler


def do_nothing_closure() -> None:
Expand All @@ -50,6 +50,8 @@ def __init__(self, optimizer: Optimizer):
# to inject logic around the optimizer step, particularly useful with manual optimization
self._on_before_step = do_nothing_closure
self._on_after_step = do_nothing_closure
# Only used on manual optimization to decide which optimizers count towards increasing the global step counter.
self._should_increment: Optional[bool] = None
# imitate the class of the wrapped object to make isinstance checks work
self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})

Expand Down Expand Up @@ -157,7 +159,9 @@ def closure_dis():

@classmethod
def _to_lightning_optimizer(
cls, optimizer: Union[Optimizer, "LightningOptimizer"], strategy: "pl.strategies.Strategy"
cls,
optimizer: Union[Optimizer, "LightningOptimizer"],
strategy: "pl.strategies.Strategy",
) -> "LightningOptimizer":
# the user could return a `LightningOptimizer` from `configure_optimizers`, see test:
# tests/core/test_lightning_optimizer.py::test_lightning_optimizer[False]
Expand All @@ -171,7 +175,7 @@ def __getattr__(self, item: Any) -> Any:

def _init_optimizers_and_lr_schedulers(
model: "pl.LightningModule",
) -> Tuple[List[Optimizer], List[LRSchedulerConfig]]:
) -> Tuple[List[Optimizer], List[LRSchedulerConfig], Optional[List[bool]]]:
"""Calls `LightningModule.configure_optimizers` and parses and validates the output."""
from lightning.pytorch.trainer import call

Expand All @@ -183,7 +187,7 @@ def _init_optimizers_and_lr_schedulers(
)
optim_conf = _MockOptimizer()

optimizers, lr_schedulers, monitor = _configure_optimizers(optim_conf)
optimizers, lr_schedulers, monitor, should_increment = _configure_optimizers(optim_conf)
lr_scheduler_configs = (
_configure_schedulers_automatic_opt(lr_schedulers, monitor)
if model.automatic_optimization
Expand All @@ -192,19 +196,36 @@ def _init_optimizers_and_lr_schedulers(
_validate_multiple_optimizers_support(optimizers, model)
_validate_optimizers_attached(optimizers, lr_scheduler_configs)
_validate_scheduler_api(lr_scheduler_configs, model)
return optimizers, lr_scheduler_configs
return optimizers, lr_scheduler_configs, should_increment


def _configure_optimizers(
optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple],
) -> Tuple[List, List, Optional[str]]:
def _configure_optimizers(optim_conf: OptimizerLRScheduler) -> Tuple[List, List, Optional[str], Optional[List]]:
optimizers, lr_schedulers = [], []
monitor = None
should_increment = None
_as_list = lambda values: list(values) if isinstance(values, (list, tuple)) else [values]

def _handle_single_dict(optim_conf: dict) -> Tuple[List, List, Optional[str], Optional[List]]:
_validate_optim_conf_dict(optim_conf)
optimizers = _as_list(optim_conf["optimizer"])
monitor = optim_conf.get("monitor")
lr_schedulers = _as_list(optim_conf.get("lr_scheduler", []))
should_increment = optim_conf.get("should_increment")
if should_increment and len(optimizers) > len(_as_list(should_increment)):
# `_validate_optim_conf_dict` checks `should_increment` to have length 1 if list
single_val = should_increment[0] if isinstance(should_increment, (list, tuple)) else should_increment
should_increment = [single_val for _ in optimizers]
return optimizers, monitor, lr_schedulers, should_increment

# single output, single optimizer
if isinstance(optim_conf, Optimizable):
optimizers = [optim_conf]
# two lists, optimizer + lr schedulers

# single list or tuple of one or more optimizers
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizable) for opt in optim_conf):
optimizers = list(optim_conf)

# two lists, optimizer(s) + lr scheduler(s)
elif (
isinstance(optim_conf, (list, tuple))
and len(optim_conf) == 2
Expand All @@ -214,24 +235,41 @@ def _configure_optimizers(
opt, sch = optim_conf
optimizers = opt
lr_schedulers = sch if isinstance(sch, list) else [sch]

# single dictionary
elif isinstance(optim_conf, dict):
_validate_optim_conf(optim_conf)
optimizers = [optim_conf["optimizer"]]
monitor = optim_conf.get("monitor", None)
lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else []
optimizers, lr_schedulers, monitor, should_increment = _handle_single_dict(optim_conf)

# multiple dictionaries
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf):
for opt_dict in optim_conf:
_validate_optim_conf(opt_dict)
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
scheduler_dict = lambda scheduler: dict(scheduler) if isinstance(scheduler, dict) else {"scheduler": scheduler}
lr_schedulers = [
scheduler_dict(opt_dict["lr_scheduler"]) for opt_dict in optim_conf if "lr_scheduler" in opt_dict
]
# single list or tuple, multiple optimizer
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizable) for opt in optim_conf):
optimizers = list(optim_conf)
optimizers, lr_schedulers, monitor, should_increment = [], [], [], []

# DO NOT SUBMIT add a test for this case first that breaks
# If the user populated some `should_increment` but not all, the rest is assumed as `False`
# if (
# any("should_increment" in optim_dict for optim_dict in optim_conf) and not
# all("should_increment" in optim_dict for optim_dict in optim_conf)
# ):
# for optim_dict in optim_conf:
# optim_dict["should_increment"] = optim_dict.get("should_increment", False)

for optim_dict in optim_conf:
opt, lr_sch, mon, incr = _handle_single_dict(optim_dict)
optimizers.extend(opt)
if lr_sch:
lr_schedulers.extend(lr_sch)
# DO NOT SUBMIT update downstream code according to when multiple monitor values are optained
if mon:
monitor.extend(mon)
if incr:
should_increment.extend(incr)

# reset empty lists to None
if not should_increment:
should_increment = None
if not monitor:
monitor = None

# unknown configuration
else:
raise MisconfigurationException(
Expand All @@ -242,7 +280,7 @@ def _configure_optimizers(
" * ([`Optimizer`], [`LRScheduler`])\n"
' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `LRScheduler`}\n'
)
return optimizers, lr_schedulers, monitor
return optimizers, lr_schedulers, monitor, should_increment


def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]:
Expand Down Expand Up @@ -369,13 +407,30 @@ def _validate_optimizers_attached(optimizers: List[Optimizer], lr_scheduler_conf
)


def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None:
valid_keys = {"optimizer", "lr_scheduler", "monitor"}
def _validate_optim_conf_dict(optim_conf: Dict[str, Any]) -> None:
# DO NOT SUBMIT there are prob some tests for this, add increments_step there
valid_keys = {"optimizer", "lr_scheduler", "monitor", "should_increment"}
extra_keys = optim_conf.keys() - valid_keys
if extra_keys:
rank_zero_warn(
f"Found unsupported keys in the optimizer configuration: {set(extra_keys)}", category=RuntimeWarning
)
if (
"should_increment" in optim_conf
and isinstance(optim_conf["should_increment"], (list, tuple))
and len(optim_conf["should_increment"]) > 1
):
# length needs to match optimizers length
if not isinstance(optim_conf["optimizer"], (list, tuple)) or len(optim_conf["should_increment"]) != len(
optim_conf["optimizer"]
):
num_opt = len(optim_conf["optimizer"]) if isinstance(optim_conf["optimizer"], (list, tuple)) else 1
rank_zero_warn(
f"`should_increment` values should equal number of optimizers it is passed along with,"
f" but found {optim_conf['should_increment']} (len={len(optim_conf['should_increment'])}) and"
f" {num_opt} optimizers",
category=RuntimeWarning,
)


class _MockOptimizer(Optimizer):
Expand Down
2 changes: 0 additions & 2 deletions src/lightning/pytorch/loops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,4 @@
from lightning.pytorch.loops.loop import _Loop # noqa: F401 isort: skip (avoids circular imports)
from lightning.pytorch.loops.evaluation_loop import _EvaluationLoop # noqa: F401
from lightning.pytorch.loops.fit_loop import _FitLoop # noqa: F401
from lightning.pytorch.loops.optimization import _AutomaticOptimization, _ManualOptimization # noqa: F401
from lightning.pytorch.loops.prediction_loop import _PredictionLoop # noqa: F401
from lightning.pytorch.loops.training_epoch_loop import _TrainingEpochLoop # noqa: F401
15 changes: 12 additions & 3 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,15 @@ class _FitLoop(_Loop):
...

Args:
min_epochs: The minimum number of epochs
max_epochs: The maximum number of epochs, can be set -1 to turn this limit off
min_epochs: The minimum number of epochs, default 0
max_epochs: The maximum number of epochs, disabled by default (``None``).
If both ``max_epochs`` and ``max_steps`` are not specified,
:class:`~lightning.pytorch.trainer.trainer.Trainer` will default to ``max_epochs = 1000``.
To enable infinite training, set ``max_epochs = -1``.
min_steps: Force training for at least these number of steps. Disabled by default (``None``)
max_steps: Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
and ``max_epochs = None``, :class:`~lightning.pytorch.trainer.trainer.Trainer` will default to
``max_epochs = 1000``. To enable infinite training, set ``max_epochs`` to ``-1``

"""

Expand All @@ -78,6 +85,8 @@ def __init__(
trainer: "pl.Trainer",
min_epochs: Optional[int] = 0,
max_epochs: Optional[int] = None,
min_steps: Optional[int] = None,
max_steps: int = -1,
) -> None:
super().__init__(trainer)
if isinstance(max_epochs, int) and max_epochs < -1:
Expand All @@ -88,7 +97,7 @@ def __init__(

self.max_epochs = max_epochs
self.min_epochs = min_epochs
self.epoch_loop = _TrainingEpochLoop(trainer)
self.epoch_loop = _TrainingEpochLoop(trainer, min_steps, max_steps)
self.epoch_progress = _Progress()
self.max_batches: Union[int, float] = float("inf")

Expand Down
32 changes: 19 additions & 13 deletions src/lightning/pytorch/loops/optimization/manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing_extensions import override

import lightning.pytorch as pl
from lightning.pytorch.core.optimizer import do_nothing_closure
from lightning.pytorch.loops import _Loop
from lightning.pytorch.loops.optimization.closure import OutputResult
from lightning.pytorch.loops.progress import _Progress, _ReadyCompletedTracker
Expand Down Expand Up @@ -97,9 +96,12 @@ def run(self, kwargs: OrderedDict) -> _OUTPUTS_TYPE:

def on_run_start(self) -> None:
# inject logic around the optimizer step
# DO NOT SUBMIT
# This is no longer reset (as that seems not needed) - so instead do this once and not again and again:
for lightning_optimizer in self.trainer.strategy._lightning_optimizers:
lightning_optimizer._on_before_step = self._on_before_step
lightning_optimizer._on_after_step = self._on_after_step
incr = lightning_optimizer._should_increment
lightning_optimizer._on_before_step = self._get_on_before_optim_step_func(increment=incr)
lightning_optimizer._on_after_step = self._get_on_after_optim_step_func(increment=incr)

def advance(self, kwargs: OrderedDict) -> None:
"""Performs the training step for manual optimization.
Expand All @@ -121,16 +123,20 @@ def advance(self, kwargs: OrderedDict) -> None:
def on_run_end(self) -> _OUTPUTS_TYPE:
"""Returns the result of this loop, i.e., the post-processed outputs from the training step."""
output, self._output = self._output, {} # free memory
# reset logic around the optimizer step
for lightning_optimizer in self.trainer.strategy._lightning_optimizers:
lightning_optimizer._on_before_step = do_nothing_closure
lightning_optimizer._on_after_step = do_nothing_closure
return output

def _on_before_step(self) -> None:
self.optim_step_progress.increment_ready()
self.trainer.profiler.start("optimizer_step")
def _get_on_before_optim_step_func(self, increment: bool) -> callable:
def _on_before_step() -> None:
if increment:
self.optim_step_progress.increment_ready()
self.trainer.profiler.start("optimizer_step")

return _on_before_step

def _get_on_after_optim_step_func(self, increment: bool) -> callable:
def _on_before_step() -> None:
self.trainer.profiler.stop("optimizer_step")
if increment:
self.optim_step_progress.increment_completed()

def _on_after_step(self) -> None:
self.trainer.profiler.stop("optimizer_step")
self.optim_step_progress.increment_completed()
return _on_before_step
2 changes: 1 addition & 1 deletion src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def init_deepspeed(self) -> None:

def _init_optimizers(self) -> Tuple[Optimizer, Optional[LRSchedulerConfig]]:
assert self.lightning_module is not None
optimizers, lr_schedulers = _init_optimizers_and_lr_schedulers(self.lightning_module)
optimizers, lr_schedulers, _ = _init_optimizers_and_lr_schedulers(self.lightning_module)
if len(optimizers) > 1 or len(lr_schedulers) > 1:
raise MisconfigurationException(
"DeepSpeed currently only supports single optimizer, single optional scheduler."
Expand Down
21 changes: 20 additions & 1 deletion src/lightning/pytorch/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@ def optimizers(self) -> List[Optimizer]:
def optimizers(self, optimizers: List[Optimizer]) -> None:
self._optimizers = optimizers
self._lightning_optimizers = [LightningOptimizer._to_lightning_optimizer(opt, self) for opt in optimizers]
# The below is only relevant for the manual optimization loop, when for each optimizer it needs back to DO NOT SUBMIT
# DEFAULT UPDATE having only the last of all optimizers count towards the global step counter, so that if all optimizers are called during
"""Relevant only for manual optimization loop, in which case the user can manually set for each optimizer
whether it counts towards incrementing the global step counter."""
if self._which_optimizers_should_increment:
for opt, incr in zip(self._lightning_optimizers, self._which_optimizers_should_increment):
opt._should_increment = incr

# Default case
else:
for opt in self._lightning_optimizers:
opt._should_increment = False
if len(self._lightning_optimizers) >= 1:
self._lightning_optimizers[-1]._should_increment = True

def connect(self, model: "pl.LightningModule") -> None:
"""Called by the Trainer to connect the strategy with the model."""
Expand Down Expand Up @@ -136,7 +150,12 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:

"""
assert self.lightning_module is not None
self.optimizers, self.lr_scheduler_configs = _init_optimizers_and_lr_schedulers(self.lightning_module)
optimizers, self.lr_scheduler_configs, should_increment = _init_optimizers_and_lr_schedulers(
self.lightning_module
)
# DO NOT SUBMIT - bug: assert optimizers and should_increment equal size
self._which_optimizers_should_increment = should_increment
self.optimizers = optimizers

def setup(self, trainer: "pl.Trainer") -> None:
"""Sets up the accelerator, plugins and initializes the optimizers (if needed).
Expand Down
Loading
Loading