Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed unnecessary _move_optimizer_state method overrides #10849

Merged
merged 10 commits into from Dec 2, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -188,6 +188,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed argument `return_result` from the `DDPSpawnPlugin.spawn()` method ([#10867](https://github.com/PyTorchLightning/pytorch-lightning/pull/10867))


- Removed unnessesary `_move_optimizer_state` method overrides from `TPUSpawnPlugin` and `SingleTPUPlugin` ([#10849](https://github.com/PyTorchLightning/pytorch-lightning/pull/10849))


### Fixed

- Fixed an issue with `SignalConnector` not restoring the default signal handlers on teardown when running on SLURM or with fault-tolerant training enabled ([#10611](https://github.com/PyTorchLightning/pytorch-lightning/pull/10611))
Expand Down
11 changes: 0 additions & 11 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Expand Up @@ -14,15 +14,12 @@
import os
from typing import Any, Dict, Optional

import torch

import pytorch_lightning as pl
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _PATH
Expand Down Expand Up @@ -66,14 +63,6 @@ def setup(self, trainer: "pl.Trainer") -> None:
self.setup_optimizers(trainer)
self.setup_precision_plugin()

def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
"""Moves the state of the optimizers to the TPU if needed."""
# TODO: `self.root_device` would raise error if called outside the spawn process
# while training on 8 and more cores.
for opt in self.optimizers:
for p, v in opt.state.items():
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)

def model_to_device(self) -> None:
self.model.to(self.root_device)

Expand Down
12 changes: 2 additions & 10 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Expand Up @@ -33,7 +33,7 @@
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -60,7 +60,7 @@ def __init__(
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
debug: bool = False,
**_: Any
**_: Any,
) -> None:
checkpoint_io = checkpoint_io or XLACheckpointIO()
super().__init__(
Expand Down Expand Up @@ -128,14 +128,6 @@ def setup(self, trainer: "pl.Trainer") -> None:
self.setup_optimizers(trainer)
self.setup_precision_plugin()

def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
"""Moves the state of the optimizers to the TPU if needed."""
# TODO: `self.root_device` would raise error if called outside the spawn process
# while training on 8 and more cores.
for opt in self.optimizers:
for p, v in opt.state.items():
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)

def _setup_model(self, model: Module) -> Module:
return model

Expand Down
Expand Up @@ -108,10 +108,9 @@ def setup_precision_plugin(self) -> None:

def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
"""Moves the state of the optimizers to the GPU if needed."""
four4fish marked this conversation as resolved.
Show resolved Hide resolved
device = device or self.root_device
for opt in self.optimizers:
for p, v in opt.state.items():
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device)
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device or self.root_device)
four4fish marked this conversation as resolved.
Show resolved Hide resolved

def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
"""Returns state of an optimizer.
Expand Down