Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug where `truncated_bptt_steps` would throw an AttributeError when the target RNN has multiple hidden states ([#8145](https://github.com/PyTorchLightning/pytorch-lightning/pull/8145))


- Fixed `self.optimizers()` not returning a single optimizer if it had been wrapped ([#8326](https://github.com/PyTorchLightning/pytorch-lightning/pull/8326))


- Fixed moving batch to device before sending it to the `on_*_batch_start`/`on_*_batch_end` callbacks and model hooks ([#7378](https://github.com/PyTorchLightning/pytorch-lightning/pull/7378))


Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
# deprecated, will be removed in 1.6
self._loaded_optimizer_states_dict = {}

def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]:
def optimizers(
self,
use_pl_optimizer: bool = True
) -> Union[Optimizer, LightningOptimizer, List[Optimizer], List[LightningOptimizer]]:
"""
Returns the optimizer(s) that are being used during training. Useful for manual optimization.

Expand All @@ -134,7 +137,7 @@ def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Opt
opts = self.trainer.optimizers

# single optimizer
if isinstance(opts, list) and len(opts) == 1 and isinstance(opts[0], Optimizer):
if isinstance(opts, list) and len(opts) == 1 and isinstance(opts[0], (Optimizer, LightningOptimizer)):
return opts[0]
# multiple opts
return opts
Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException


def is_lightning_optimizer(optimizer):
return isinstance(optimizer, LightningOptimizer)


def do_nothing_closure():
return

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import is_lightning_optimizer
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
Expand Down Expand Up @@ -48,7 +48,7 @@ def configure_ddp(self):
def _reinit_optimizers_with_oss(self):
optimizers = self.lightning_module.trainer.optimizers
for x, optimizer in enumerate(optimizers):
if is_lightning_optimizer(optimizer):
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
if not isinstance(optimizer, OSS):
optim_class = type(optimizer)
Expand All @@ -72,7 +72,7 @@ def _wrap_optimizers(self):
self._reinit_optimizers_with_oss()

def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
if is_lightning_optimizer(optimizer):
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)
Expand Down
4 changes: 2 additions & 2 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self):
self.layer = None

def training_step(self, batch, batch_idx):
opt = self.optimizers()[0]
opt = self.optimizers()
output = self(batch)
loss = self.loss(batch, output)
opt.zero_grad()
Expand Down Expand Up @@ -518,7 +518,7 @@ def training_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = F.cross_entropy(logits, y)
opt = self.optimizers()[0]
opt = self.optimizers()
self.log('train_loss', loss, prog_bar=True)
self.log('train_acc', self.train_acc(logits, y), prog_bar=True, sync_dist=True)
opt.zero_grad()
Expand Down