Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

1/n Move precision plugin into strategy - update reference #10570

Merged
merged 27 commits into from Nov 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
dc50966
1/n move precision plugin into strategy - update reference
four4fish Nov 16, 2021
8fd976e
update precision plugin reference in tpu_spawn
awaelchli Nov 16, 2021
ecadcbd
add missing reference in error message
awaelchli Nov 16, 2021
846b595
add back removed license line
awaelchli Nov 16, 2021
42b0325
update references in tests
awaelchli Nov 16, 2021
85f058b
update reference in trainer
awaelchli Nov 16, 2021
cff894e
update return annotation for precision_plugin property on TTP
awaelchli Nov 16, 2021
7e6d635
simplify access to precision plugin reference in sharded plug
awaelchli Nov 16, 2021
7c0e651
add changelog
four4fish Nov 16, 2021
ae6d6c5
remove precision property from ttp and add deprecation message
four4fish Nov 16, 2021
9936c51
fix make doc and update precision reference
four4fish Nov 17, 2021
6120d05
simplify a reference to precision
four4fish Nov 17, 2021
d86e212
Update CHANGELOG.md
four4fish Nov 17, 2021
e4e9384
Update accelerator precision
four4fish Nov 17, 2021
abd3ac8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2021
2711145
Add none check for precision plugin
four4fish Nov 17, 2021
173c4df
Update ipu.py
four4fish Nov 17, 2021
39ee314
update precision_plugin param deprecation message
four4fish Nov 17, 2021
9cd599b
Update accelerator.py
four4fish Nov 17, 2021
5de3120
Remove deprecated warning
four4fish Nov 17, 2021
c3ca785
keep accelerator api
four4fish Nov 18, 2021
9f06093
udpate deprecation message and docs
four4fish Nov 18, 2021
8eccdc0
fix comments format
four4fish Nov 18, 2021
19dde86
remove string comment
awaelchli Nov 18, 2021
06e33ad
fix duplicated deprecation message
awaelchli Nov 18, 2021
9c6fd3e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2021
97ba08b
Apply suggestions from code review
four4fish Nov 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Expand Up @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520))


- Moved ownership of the `PrecisionPlugin` into `TrainingTypePlugin` and updated all references ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570))


-


Expand All @@ -50,7 +53,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `DistributedType` in favor of `_StrategyType` ([#10505](https://github.com/PyTorchLightning/pytorch-lightning/pull/10505))


-
- Deprecated the `precision_plugin` constructor argument from `Accelerator` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570))


-
Expand Down Expand Up @@ -139,6 +142,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated `reload_dataloaders_every_epoch` from `Trainer` in favour of `reload_dataloaders_every_n_epochs` ([#10481](https://github.com/PyTorchLightning/pytorch-lightning/pull/10481))


- Removed the `precision_plugin` attribute from `Accelerator` in favor of its equivalent attribute `precision_plugin` in the `TrainingTypePlugin` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570))

### Fixed

Expand Down
59 changes: 40 additions & 19 deletions pytorch_lightning/accelerators/accelerator.py
Expand Up @@ -25,6 +25,7 @@
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.plugins.training_type import DataParallelPlugin, TrainingTypePlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand All @@ -44,15 +45,23 @@ class Accelerator:
One to handle differences from the training routine and one to handle different precisions.
"""

def __init__(self, precision_plugin: PrecisionPlugin, training_type_plugin: TrainingTypePlugin) -> None:
def __init__(self, precision_plugin: Optional[PrecisionPlugin], training_type_plugin: TrainingTypePlugin) -> None:
"""
Args:
precision_plugin: the plugin to handle precision-specific parts
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

.. deprecated::
The ``precision_plugin`` parameter has been deprecated and will be removed soon.
Pass the precision plugin as a parameter to the ``TrainingTypePlugin`` instead.

training_type_plugin: the plugin to handle different training routines
"""
self.precision_plugin = precision_plugin

self.training_type_plugin = training_type_plugin

if precision_plugin is not None:
self.training_type_plugin._precision_plugin = precision_plugin

self.optimizers: List = []
self.lr_schedulers: List = []
self.optimizer_frequencies: List = []
Expand Down Expand Up @@ -84,7 +93,7 @@ def pre_dispatch(self, trainer: "pl.Trainer") -> None:
if self.training_type_plugin.setup_optimizers_in_pre_dispatch:
self.setup_optimizers(trainer)

self.precision_plugin.pre_dispatch()
self.training_type_plugin.precision_plugin.pre_dispatch()
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
"""Moves the state of the optimizers to the GPU if needed."""
Expand All @@ -96,12 +105,12 @@ def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
def dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin.dispatch(trainer)
self.precision_plugin.dispatch(trainer)
self.training_type_plugin.precision_plugin.dispatch(trainer)

def post_dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something after the training/evaluation/prediction starts."""
self.training_type_plugin.post_dispatch(trainer)
self.precision_plugin.post_dispatch()
self.training_type_plugin.precision_plugin.post_dispatch()

@property
def model(self) -> Module:
Expand Down Expand Up @@ -159,31 +168,31 @@ def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:

See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details
"""
with self.precision_plugin.train_step_context():
with self.training_type_plugin.precision_plugin.train_step_context():
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
return self.training_type_plugin.training_step(*step_kwargs.values())

def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:
"""The actual validation step.

See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details
"""
with self.precision_plugin.val_step_context():
with self.training_type_plugin.precision_plugin.val_step_context():
return self.training_type_plugin.validation_step(*step_kwargs.values())

def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:
"""The actual test step.

See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details
"""
with self.precision_plugin.test_step_context():
with self.training_type_plugin.precision_plugin.test_step_context():
return self.training_type_plugin.test_step(*step_kwargs.values())

def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
"""The actual predict step.

See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details
"""
with self.precision_plugin.predict_step_context():
with self.training_type_plugin.precision_plugin.predict_step_context():
return self.training_type_plugin.predict_step(*step_kwargs.values())

def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
Expand All @@ -193,11 +202,11 @@ def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
closure_loss: a tensor holding the loss value to backpropagate
"""
self.training_type_plugin.pre_backward(closure_loss)
closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss)
closure_loss = self.training_type_plugin.precision_plugin.pre_backward(self.lightning_module, closure_loss)

self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
self.training_type_plugin.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)

closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss)
closure_loss = self.training_type_plugin.precision_plugin.post_backward(self.lightning_module, closure_loss)
self.training_type_plugin.post_backward(closure_loss)

return closure_loss
Expand All @@ -208,7 +217,7 @@ def optimizer_step(
opt_idx: int,
closure: Callable[[], Any],
model: Optional[Union["pl.LightningModule", Module]] = None,
**kwargs: Any
**kwargs: Any,
) -> None:
"""performs the actual optimizer step.

Expand All @@ -220,7 +229,7 @@ def optimizer_step(
**kwargs: Any extra arguments to ``optimizer.step``
"""
model = model or self.lightning_module
self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
self.training_type_plugin.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)

def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
"""Zeros all model parameter's gradients."""
Expand Down Expand Up @@ -248,26 +257,38 @@ def setup_training_type_plugin(self) -> None:

def setup_precision_plugin(self) -> None:
"""Attaches the precision plugin to the accelerator."""
model, optimizers, schedulers = self.precision_plugin.connect(self.model, self.optimizers, self.lr_schedulers)
model, optimizers, schedulers = self.training_type_plugin.precision_plugin.connect(
self.model, self.optimizers, self.lr_schedulers
)
self.model = model
self.optimizers = optimizers
self.lr_schedulers = schedulers

@property
def amp_backend(self) -> Optional[LightningEnum]:
if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin):
if isinstance(self.training_type_plugin.precision_plugin, ApexMixedPrecisionPlugin):
return AMPType.APEX
if isinstance(self.precision_plugin, NativeMixedPrecisionPlugin):
if isinstance(self.training_type_plugin.precision_plugin, NativeMixedPrecisionPlugin):
return AMPType.NATIVE
return None

@property
def precision(self) -> Union[str, int]:
return self.precision_plugin.precision
"""The type of precision being used with this accelerator.

.. deprecated::
This property been deprecated and will be removed soon.
Use ``training_type_plugin.precision_plugin.precision`` instead.
"""
rank_zero_deprecation(
f"`{self.__class__.__name__}.precision` has been deprecated and will be removed soon"
f" Use `training_type_plugin.precision_plugin.precision` instead."
)
return self.training_type_plugin.precision_plugin.precision

@property
def scaler(self) -> Optional["GradScaler"]:
return getattr(self.precision_plugin, "scaler", None)
return getattr(self.training_type_plugin.precision_plugin, "scaler", None)

def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
"""Returns state of an optimizer.
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerators/tpu.py
Expand Up @@ -36,10 +36,11 @@ def setup(self, trainer: "pl.Trainer") -> None:
ValueError:
If the precision or training type plugin are unsupported.
"""
if not isinstance(self.precision_plugin, TPUPrecisionPlugin):
if not isinstance(self.training_type_plugin.precision_plugin, TPUPrecisionPlugin):
# this configuration should have been avoided in the accelerator connector
raise ValueError(
f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`, found: {self.precision_plugin}."
f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`,"
f" found: {self.training_type_plugin.precision_plugin}."
)
if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/lite/lite.py
Expand Up @@ -108,7 +108,7 @@ def __init__(
)
self._accelerator = self._accelerator_connector.accelerator
self._strategy = self._accelerator.training_type_plugin
self._precision_plugin = self._accelerator.precision_plugin
self._precision_plugin = self._strategy.precision_plugin
self._models_setup: int = 0

# wrap the run method so we can inject setup logic or spawn processes for the user
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp.py
Expand Up @@ -36,6 +36,7 @@
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import (
Expand Down Expand Up @@ -86,6 +87,7 @@ def __init__(
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
Expand All @@ -96,6 +98,7 @@ def __init__(
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
self.interactive_ddp_procs = []
self._num_nodes = 1
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Expand Up @@ -29,6 +29,7 @@
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
Expand All @@ -74,6 +76,7 @@ def __init__(
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
self._num_nodes = 1
self.sync_batchnorm = False
Expand Down
8 changes: 5 additions & 3 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Expand Up @@ -30,6 +30,7 @@
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.trainer.states import TrainerFn
Expand Down Expand Up @@ -129,6 +130,7 @@ def __init__(
synchronize_checkpoint_boundary: bool = False,
load_full_weights: bool = False,
partition_module: bool = True,
precision_plugin: Optional[PrecisionPlugin] = None,
) -> None:
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
billion parameter models. `For more information: https://pytorch-
Expand Down Expand Up @@ -273,6 +275,7 @@ def __init__(
super().__init__(
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
precision_plugin=precision_plugin,
)

self.config = self._load_config(config)
Expand Down Expand Up @@ -331,7 +334,7 @@ def __init__(

@property
def precision(self) -> Union[str, int]:
return self._precision or self.lightning_module.trainer.precision
return self._precision or self.precision_plugin.precision

@property
def amp_level(self) -> Optional[str]:
Expand Down Expand Up @@ -456,8 +459,7 @@ def init_deepspeed(self):
"DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs."
)

precision = self.lightning_module.trainer.accelerator.precision
four4fish marked this conversation as resolved.
Show resolved Hide resolved
model = LightningDeepSpeedModule(pl_module=self.model, precision=precision)
model = LightningDeepSpeedModule(pl_module=self.model, precision=self.precision)

if self.zero_stage_3 and self.partition_module:
# Ensure the entire model has been moved to the appropriate device
Expand Down
9 changes: 8 additions & 1 deletion pytorch_lightning/plugins/training_type/dp.py
Expand Up @@ -18,6 +18,7 @@

from pytorch_lightning.overrides.data_parallel import LightningParallelModule
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.enums import _StrategyType
Expand All @@ -35,8 +36,14 @@ def __init__(
self,
parallel_devices: Optional[List[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
):
super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io)
super().__init__(
parallel_devices=parallel_devices,
cluster_environment=None,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)

@property
def global_rank(self) -> int:
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/plugins/training_type/fully_sharded.py
Expand Up @@ -18,6 +18,7 @@

from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
from pytorch_lightning.utilities.enums import _StrategyType
Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
):
"""Plugin for Fully Sharded Data Parallel provided by FairScale.

Expand Down Expand Up @@ -97,6 +99,7 @@ def __init__(
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
self.cpu_offload = cpu_offload
self.move_grads_to_cpu = move_grads_to_cpu
Expand Down Expand Up @@ -124,7 +127,7 @@ def setup_distributed(self) -> None:

@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
precision = self.lightning_module.trainer.precision
precision = self.precision_plugin.precision

def wrap_policy(*args, **kwargs):
return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params)
Expand Down
9 changes: 8 additions & 1 deletion pytorch_lightning/plugins/training_type/horovod.py
Expand Up @@ -21,6 +21,7 @@

from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities.distributed import distributed_available
Expand All @@ -41,8 +42,14 @@ def __init__(
self,
parallel_devices: Optional[List[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
):
super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io)
super().__init__(
parallel_devices=parallel_devices,
cluster_environment=None,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
rank_zero_only.rank = self.global_rank

@property
Expand Down