Skip to content

Commit

Permalink
1/n move precision plugin into strategy - update reference
Browse files Browse the repository at this point in the history
  • Loading branch information
four4fish committed Nov 16, 2021
1 parent 01cf7a2 commit 4d910b4
Show file tree
Hide file tree
Showing 27 changed files with 121 additions and 66 deletions.
3 changes: 1 addition & 2 deletions docs/source/extensions/accelerators.rst
Expand Up @@ -26,8 +26,7 @@ One to handle differences from the training routine and one to handle different
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin, DDPPlugin

accelerator = GPUAccelerator(
precision_plugin=NativeMixedPrecisionPlugin(16, "cuda"),
training_type_plugin=DDPPlugin(),
training_type_plugin=DDPPlugin(precision_plugin=NativeMixedPrecisionPlugin(16, "cuda")),
)
trainer = Trainer(accelerator=accelerator)

Expand Down
3 changes: 1 addition & 2 deletions docs/source/extensions/plugins.rst
Expand Up @@ -81,8 +81,7 @@ can then be passed into the Trainer directly or via a (custom) accelerator:
# fully custom accelerator and plugins
accelerator = MyAccelerator(
precision_plugin=CustomPrecisionPlugin(),
training_type_plugin=CustomDDPPlugin(),
training_type_plugin=CustomDDPPlugin(precision_plugin=CustomPrecisionPlugin()),
)
trainer = Trainer(accelerator=accelerator)
Expand Down
44 changes: 26 additions & 18 deletions pytorch_lightning/accelerators/accelerator.py
Expand Up @@ -44,15 +44,20 @@ class Accelerator:
One to handle differences from the training routine and one to handle different precisions.
"""

def __init__(self, precision_plugin: PrecisionPlugin, training_type_plugin: TrainingTypePlugin) -> None:
def __init__(
self, training_type_plugin: TrainingTypePlugin, precision_plugin: Optional[PrecisionPlugin] = None
) -> None:
"""
Args:
precision_plugin: the plugin to handle precision-specific parts
training_type_plugin: the plugin to handle different training routines
"""
self.precision_plugin = precision_plugin

self.training_type_plugin = training_type_plugin

if precision_plugin:
self.training_type_plugin._precision_plugin = precision_plugin

self.optimizers: List = []
self.lr_schedulers: List = []
self.optimizer_frequencies: List = []
Expand Down Expand Up @@ -84,7 +89,7 @@ def pre_dispatch(self, trainer: "pl.Trainer") -> None:
if self.training_type_plugin.setup_optimizers_in_pre_dispatch:
self.setup_optimizers(trainer)

self.precision_plugin.pre_dispatch()
self.training_type_plugin.precision_plugin.pre_dispatch()

def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
"""Moves the state of the optimizers to the GPU if needed."""
Expand All @@ -96,12 +101,12 @@ def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
def dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin.dispatch(trainer)
self.precision_plugin.dispatch(trainer)
self.training_type_plugin.precision_plugin.dispatch(trainer)

def post_dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something after the training/evaluation/prediction starts."""
self.training_type_plugin.post_dispatch(trainer)
self.precision_plugin.post_dispatch()
self.training_type_plugin.precision_plugin.post_dispatch()

@property
def model(self) -> Module:
Expand Down Expand Up @@ -159,31 +164,31 @@ def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details
"""
with self.precision_plugin.train_step_context():
with self.training_type_plugin.precision_plugin.train_step_context():
return self.training_type_plugin.training_step(*step_kwargs.values())

def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:
"""The actual validation step.
See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details
"""
with self.precision_plugin.val_step_context():
with self.training_type_plugin.precision_plugin.val_step_context():
return self.training_type_plugin.validation_step(*step_kwargs.values())

def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:
"""The actual test step.
See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details
"""
with self.precision_plugin.test_step_context():
with self.training_type_plugin.precision_plugin.test_step_context():
return self.training_type_plugin.test_step(*step_kwargs.values())

def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
"""The actual predict step.
See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details
"""
with self.precision_plugin.predict_step_context():
with self.training_type_plugin.precision_plugin.predict_step_context():
return self.training_type_plugin.predict_step(*step_kwargs.values())

def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
Expand All @@ -193,11 +198,11 @@ def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
closure_loss: a tensor holding the loss value to backpropagate
"""
self.training_type_plugin.pre_backward(closure_loss)
closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss)
closure_loss = self.training_type_plugin.precision_plugin.pre_backward(self.lightning_module, closure_loss)

self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
self.training_type_plugin.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)

closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss)
closure_loss = self.training_type_plugin.precision_plugin.post_backward(self.lightning_module, closure_loss)
self.training_type_plugin.post_backward(closure_loss)

return closure_loss
Expand All @@ -220,7 +225,7 @@ def optimizer_step(
**kwargs: Any extra arguments to ``optimizer.step``
"""
model = model or self.lightning_module
self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
self.training_type_plugin.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)

def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
"""Zeros all model parameter's gradients."""
Expand Down Expand Up @@ -248,26 +253,29 @@ def setup_training_type_plugin(self) -> None:

def setup_precision_plugin(self) -> None:
"""Attaches the precision plugin to the accelerator."""
model, optimizers, schedulers = self.precision_plugin.connect(self.model, self.optimizers, self.lr_schedulers)
model, optimizers, schedulers = self.training_type_plugin.precision_plugin.connect(
self.model, self.optimizers, self.lr_schedulers
)
self.model = model
self.optimizers = optimizers
self.lr_schedulers = schedulers

@property
def amp_backend(self) -> Optional[LightningEnum]:
if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin):
if isinstance(self.training_type_plugin.precision_plugin, ApexMixedPrecisionPlugin):
return AMPType.APEX
if isinstance(self.precision_plugin, NativeMixedPrecisionPlugin):
if isinstance(self.training_type_plugin.precision_plugin, NativeMixedPrecisionPlugin):
return AMPType.NATIVE
return None

@property
def precision(self) -> Union[str, int]:
return self.precision_plugin.precision
"""deprecated."""
return self.training_type_plugin.precision

@property
def scaler(self) -> Optional["GradScaler"]:
return getattr(self.precision_plugin, "scaler", None)
return getattr(self.training_type_plugin.precision_plugin, "scaler", None)

def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
"""Returns state of an optimizer.
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerators/tpu.py
Expand Up @@ -36,10 +36,11 @@ def setup(self, trainer: "pl.Trainer") -> None:
ValueError:
If the precision or training type plugin are unsupported.
"""
if not isinstance(self.precision_plugin, TPUPrecisionPlugin):
if not isinstance(self.training_type_plugin.precision_plugin, TPUPrecisionPlugin):
# this configuration should have been avoided in the accelerator connector
raise ValueError(
f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`, found: {self.precision_plugin}."
f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`,"
f" found: {self.training_type_plugin}."
)
if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/lite/lite.py
Expand Up @@ -108,7 +108,7 @@ def __init__(
)
self._accelerator = self._accelerator_connector.accelerator
self._strategy = self._accelerator.training_type_plugin
self._precision_plugin = self._accelerator.precision_plugin
self._precision_plugin = self._strategy.precision_plugin
self._models_setup: int = 0

# wrap the run method so we can inject setup logic or spawn processes for the user
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp.py
Expand Up @@ -36,6 +36,7 @@
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import (
Expand Down Expand Up @@ -86,6 +87,7 @@ def __init__(
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
Expand All @@ -96,6 +98,7 @@ def __init__(
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
self.interactive_ddp_procs = []
self._num_nodes = 1
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Expand Up @@ -29,6 +29,7 @@
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
Expand All @@ -74,6 +76,7 @@ def __init__(
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
self._num_nodes = 1
self.sync_batchnorm = False
Expand Down
8 changes: 5 additions & 3 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Expand Up @@ -30,6 +30,7 @@
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.trainer.states import TrainerFn
Expand Down Expand Up @@ -129,6 +130,7 @@ def __init__(
synchronize_checkpoint_boundary: bool = False,
load_full_weights: bool = False,
partition_module: bool = True,
precision_plugin: Optional[PrecisionPlugin] = None,
) -> None:
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
billion parameter models. `For more information: https://pytorch-
Expand Down Expand Up @@ -273,6 +275,7 @@ def __init__(
super().__init__(
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
precision_plugin=precision_plugin,
)

self.config = self._load_config(config)
Expand Down Expand Up @@ -331,7 +334,7 @@ def __init__(

@property
def precision(self) -> Union[str, int]:
return self._precision or self.lightning_module.trainer.precision
return self._precision or self.precision_plugin.precision

@property
def amp_level(self) -> Optional[str]:
Expand Down Expand Up @@ -456,8 +459,7 @@ def init_deepspeed(self):
"DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs."
)

precision = self.lightning_module.trainer.accelerator.precision
model = LightningDeepSpeedModule(pl_module=self.model, precision=precision)
model = LightningDeepSpeedModule(pl_module=self.model, precision=self.precision)

if self.zero_stage_3 and self.partition_module:
# Ensure the entire model has been moved to the appropriate device
Expand Down
9 changes: 8 additions & 1 deletion pytorch_lightning/plugins/training_type/dp.py
Expand Up @@ -18,6 +18,7 @@

from pytorch_lightning.overrides.data_parallel import LightningParallelModule
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.enums import _StrategyType
Expand All @@ -35,8 +36,14 @@ def __init__(
self,
parallel_devices: Optional[List[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
):
super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io)
super().__init__(
parallel_devices=parallel_devices,
cluster_environment=None,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)

@property
def global_rank(self) -> int:
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/plugins/training_type/fully_sharded.py
Expand Up @@ -18,6 +18,7 @@

from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
from pytorch_lightning.utilities.enums import _StrategyType
Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
):
"""Plugin for Fully Sharded Data Parallel provided by FairScale.
Expand Down Expand Up @@ -97,6 +99,7 @@ def __init__(
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
self.cpu_offload = cpu_offload
self.move_grads_to_cpu = move_grads_to_cpu
Expand Down
9 changes: 8 additions & 1 deletion pytorch_lightning/plugins/training_type/horovod.py
Expand Up @@ -21,6 +21,7 @@

from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities.distributed import distributed_available
Expand All @@ -41,8 +42,14 @@ def __init__(
self,
parallel_devices: Optional[List[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
):
super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io)
super().__init__(
parallel_devices=parallel_devices,
cluster_environment=None,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
rank_zero_only.rank = self.global_rank

@property
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/plugins/training_type/ipu.py
Expand Up @@ -22,6 +22,7 @@
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
training_opts: Optional["poptorch.Options"] = None,
inference_opts: Optional["poptorch.Options"] = None,
) -> None:
Expand Down Expand Up @@ -116,8 +118,7 @@ def setup(self) -> None:
self.lightning_module.trainer._update_dataloader = self._convert_to_poptorch_loader

def pre_dispatch(self) -> None:
precision = self.lightning_module.trainer.precision
model = LightningIPUModule(self.lightning_module, precision)
model = LightningIPUModule(self.lightning_module, self.precision)
self.model = model

# reset the backup
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/plugins/training_type/parallel.py
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.utilities import _XLA_AVAILABLE
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp
Expand All @@ -36,8 +37,9 @@ def __init__(
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
):
super().__init__(checkpoint_io)
super().__init__(checkpoint_io=checkpoint_io, precision_plugin=precision_plugin)
self.parallel_devices = parallel_devices
self.cluster_environment = cluster_environment

Expand Down

0 comments on commit 4d910b4

Please sign in to comment.