Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Strategy.init_optimizers #11236

Merged
merged 12 commits into from Dec 23, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -343,6 +343,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed support for Python 3.6 ([#11117](https://github.com/PyTorchLightning/pytorch-lightning/pull/11117))


- Removed `Strategy.init_optimizers` in favor of `Strategy.setup_optimizers` ([#11236](https://github.com/PyTorchLightning/pytorch-lightning/pull/11236))

### Fixed

- Fixed security vulnerabilities CVE-2020-1747 and CVE-2020-14343 caused by the `PyYAML` dependency ([#11099](https://github.com/PyTorchLightning/pytorch-lightning/pull/11099))
Expand Down
14 changes: 12 additions & 2 deletions pytorch_lightning/strategies/deepspeed.py
Expand Up @@ -562,11 +562,21 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
return distributed_sampler_kwargs

def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> Tuple[List, List, List]:
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
"""Creates optimizers and schedulers.

Args:
trainer: the Trainer, these optimizers should be connected to
"""
if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING):
return
# Skip initializing optimizers here as DeepSpeed handles optimizers via config.
# User may have specified config options instead in configure_optimizers, but this is handled
# via `_initialize_deepspeed_train`
return [], [], [] # empty optimizers, schedulers and frequencies
# empty optimizers, schedulers and frequencies
self.optimizers = []
self.lr_schedulers = []
self.optimizer_frequencies = []

@property
def handles_gradient_accumulation(self) -> bool:
Expand Down
10 changes: 2 additions & 8 deletions pytorch_lightning/strategies/strategy.py
Expand Up @@ -104,12 +104,9 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
"""
if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING):
return
optimizers, lr_schedulers, optimizer_frequencies = self.init_optimizers(
trainer=trainer, model=self.lightning_module
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = _init_optimizers_and_lr_schedulers(
self.lightning_module
)
self.optimizers = optimizers
self.lr_schedulers = lr_schedulers
self.optimizer_frequencies = optimizer_frequencies

def setup(self, trainer: "pl.Trainer") -> None:
"""Setup plugins for the trainer fit and creates optimizers.
Expand Down Expand Up @@ -377,9 +374,6 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader:
"""
return dataloader

def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule"):
return _init_optimizers_and_lr_schedulers(model)

@property
def restore_checkpoint_after_setup(self) -> bool:
"""Override to delay restoring from checkpoint till after pre-dispatch. This is useful when the plugin
Expand Down
22 changes: 12 additions & 10 deletions pytorch_lightning/tuner/lr_finder.py
Expand Up @@ -24,7 +24,7 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.optimizer import _get_default_scheduler_config
from pytorch_lightning.core.optimizer import _get_default_scheduler_config, _init_optimizers_and_lr_schedulers
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
Expand Down Expand Up @@ -99,14 +99,14 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
self._total_batch_idx = 0 # for debug purpose

def _exchange_scheduler(self, trainer: "pl.Trainer", model: "pl.LightningModule"):
"""Decorate `trainer.strategy.init_optimizers` method such that it returns the user's originally specified
optimizer together with a new scheduler that that takes care of the learning rate search."""
init_optimizers = trainer.strategy.init_optimizers
"""Decorate `trainer.strategy.setup_optimizers` method such that it sets the user's originally specified
optimizer together with a new scheduler that takes care of the learning rate search."""
setup_optimizers = trainer.strategy.setup_optimizers

@wraps(init_optimizers)
def func(trainer, model):
# Decide the structure of the output from trainer.strategy.init_optimizers
optimizers, _, _ = init_optimizers(trainer, model)
@wraps(setup_optimizers)
def func(trainer):
# Decide the structure of the output from _init_optimizers_and_lr_schedulers
optimizers, _, _ = _init_optimizers_and_lr_schedulers(trainer.lightning_module)

if len(optimizers) != 1:
raise MisconfigurationException(
Expand All @@ -126,7 +126,9 @@ def func(trainer, model):
sched_config = _get_default_scheduler_config()
sched_config.update({"scheduler": scheduler, "interval": "step"})

return [optimizer], [sched_config], []
trainer.strategy.optimizers = [optimizer]
trainer.strategy.lr_schedulers = [sched_config]
trainer.strategy.optimizer_frequencies = []

return func

Expand Down Expand Up @@ -232,7 +234,7 @@ def lr_find(
trainer.save_checkpoint(str(save_path))

# Configure optimizer and scheduler
trainer.strategy.init_optimizers = lr_finder._exchange_scheduler(trainer, model)
trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model)

# Fit, lr & loss logged in callback
trainer.tuner._run(model)
Expand Down