Skip to content

Commit

Permalink
Fix mypy errors in pytorch_lightning/strategies/sharded.py (#14184)
Browse files Browse the repository at this point in the history
Co-authored-by: otaj <ota@lightning.ai>
  • Loading branch information
lijm1358 and otaj committed Aug 27, 2022
1 parent af688de commit 03f2f32
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Expand Up @@ -52,7 +52,6 @@ module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.profilers.base",
"pytorch_lightning.profilers.pytorch",
"pytorch_lightning.strategies.sharded",
"pytorch_lightning.trainer.callback_hook",
"pytorch_lightning.trainer.supporters",
"pytorch_lightning.trainer.trainer",
Expand Down
19 changes: 12 additions & 7 deletions src/pytorch_lightning/strategies/sharded.py
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Dict, Generator, List, Tuple, Union
from typing import Dict, Generator, List, Tuple

from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.enums import PrecisionType
Expand Down Expand Up @@ -51,10 +51,11 @@ def connect(self, model: "pl.LightningModule") -> None:

def setup(self, trainer: "pl.Trainer") -> None:
# share ddp pids to all processes
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
self._rank_0_will_call_children_scripts: bool = self.broadcast(self._rank_0_will_call_children_scripts)
if self._should_run_deadlock_detection():
self._share_information_to_prevent_deadlock()

assert self.accelerator is not None
self.accelerator.setup(trainer)

# move the model to the correct device
Expand All @@ -64,6 +65,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
trainer_fn = trainer.state.fn
if trainer_fn == TrainerFn.FITTING:
if self._layer_sync:
assert self.model is not None
self.model = self._layer_sync.apply(self.model)

self.setup_precision_plugin()
Expand All @@ -73,7 +75,9 @@ def setup(self, trainer: "pl.Trainer") -> None:

def configure_ddp(self) -> None:
self._set_ddp_kwargs()
self.setup_optimizers(self.model.trainer)
assert self.lightning_module is not None
self.setup_optimizers(self.lightning_module.trainer)
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
self.model, self.optimizers = self._setup_model_and_optimizers(
model=_LightningModuleWrapperBase(self.model),
optimizers=self.optimizers,
Expand All @@ -97,12 +101,13 @@ def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]
return model, optimizers

def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
assert self.lightning_module is not None
if self.model is not None and self.lightning_module.trainer.state.fn != TrainerFn.FITTING:
return optimizers

return self._reinit_optimizers_with_oss(optimizers)

def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]:
def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]:
for x, optimizer in enumerate(optimizers):
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
Expand Down Expand Up @@ -135,7 +140,7 @@ def block_backward_sync(self) -> Generator:
else:
yield None

def post_training_step(self):
def post_training_step(self) -> None:
pass

@classmethod
Expand Down

0 comments on commit 03f2f32

Please sign in to comment.