Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename the TrainingTypePlugin base to Strategy #11120

Merged
merged 9 commits into from Dec 20, 2021
4 changes: 4 additions & 0 deletions CHANGELOG.md
Expand Up @@ -122,6 +122,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Moved ownership of the `Accelerator` instance to the `TrainingTypePlugin`; all training-type plugins now take an optional parameter `accelerator` ([#11022](https://github.com/PyTorchLightning/pytorch-lightning/pull/11022))


- Renamed the `TrainingTypePlugin` to `Strategy` ([#11120](https://github.com/PyTorchLightning/pytorch-lightning/pull/11120))


### Deprecated

- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103))
Expand Down
2 changes: 1 addition & 1 deletion docs/source/_static/images/accelerator/overview.svg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/source/api_references.rst
Expand Up @@ -146,7 +146,7 @@ Training Type Plugins
:nosignatures:
:template: classtemplate.rst

TrainingTypePlugin
Strategy
SingleDevicePlugin
ParallelPlugin
DataParallelPlugin
Expand Down
6 changes: 3 additions & 3 deletions docs/source/common/checkpointing.rst
Expand Up @@ -318,7 +318,7 @@ Customize Checkpointing


Lightning supports modifying the checkpointing save/load functionality through the ``CheckpointIO``. This encapsulates the save/load logic
that is managed by the ``TrainingTypePlugin``. ``CheckpointIO`` is different from :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint`
that is managed by the ``Strategy``. ``CheckpointIO`` is different from :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint`
and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` methods as it determines how the checkpoint is saved/loaded to storage rather than
what's saved in the checkpoint.

Expand All @@ -342,7 +342,7 @@ Built-in Checkpoint IO Plugins
Custom Checkpoint IO Plugin
===========================

``CheckpointIO`` can be extended to include your custom save/load functionality to and from a path. The ``CheckpointIO`` object can be passed to either a ``Trainer`` directly or a ``TrainingTypePlugin`` as shown below:
``CheckpointIO`` can be extended to include your custom save/load functionality to and from a path. The ``CheckpointIO`` object can be passed to either a ``Trainer`` directly or a ``Strategy`` as shown below:

.. code-block:: python

Expand Down Expand Up @@ -372,7 +372,7 @@ Custom Checkpoint IO Plugin
)
trainer.fit(model)

# or pass into TrainingTypePlugin
# or pass into Strategy
model = MyModel()
device = torch.device("cpu")
trainer = Trainer(
Expand Down
8 changes: 4 additions & 4 deletions docs/source/extensions/plugins.rst
Expand Up @@ -29,8 +29,8 @@ We expose Accelerators and Plugins mainly for expert users that want to extend L

There are two types of Plugins in Lightning with different responsibilities:

TrainingTypePlugin
------------------
Strategy
--------

- Launching and teardown of training processes (if applicable)
- Setup communication between processes (NCCL, GLOO, MPI, ...)
Expand Down Expand Up @@ -70,7 +70,7 @@ Expert users may choose to extend an existing plugin by overriding its methods .
device_ids=...,
)

or by subclassing the base classes :class:`~pytorch_lightning.plugins.training_type.TrainingTypePlugin` or
or by subclassing the base classes :class:`~pytorch_lightning.plugins.training_type.Strategy` or
:class:`~pytorch_lightning.plugins.precision.PrecisionPlugin` to create new ones. These custom plugins
can then be passed into the Trainer directly or via a (custom) accelerator:

Expand Down Expand Up @@ -105,7 +105,7 @@ Training Type Plugins
:nosignatures:
:template: classtemplate.rst

TrainingTypePlugin
Strategy
SingleDevicePlugin
ParallelPlugin
DataParallelPlugin
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/distributed/dist.py
Expand Up @@ -29,7 +29,7 @@ class LightningDistributed:
def __init__(self, rank=None, device=None):
rank_zero_deprecation(
"LightningDistributed is deprecated in v1.5 and will be removed in v1.7."
"Broadcast logic is implemented directly in the :class:`TrainingTypePlugin` implementations."
"Broadcast logic is implemented directly in the :class:`Strategy` implementations."
)
self.rank = rank
self.device = device
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/lite/lite.py
Expand Up @@ -26,7 +26,7 @@

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from pytorch_lightning.plugins import DDPSpawnPlugin, DeepSpeedPlugin, PLUGIN_INPUT, TPUSpawnPlugin, TrainingTypePlugin
from pytorch_lightning.plugins import DDPSpawnPlugin, DeepSpeedPlugin, PLUGIN_INPUT, Strategy, TPUSpawnPlugin
from pytorch_lightning.plugins.training_type.training_type_plugin import TBroadcast
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.utilities import _AcceleratorType, _StrategyType, move_data_to_device
Expand Down Expand Up @@ -69,7 +69,7 @@ class LightningLite(ABC):
def __init__(
self,
accelerator: Optional[Union[str, Accelerator]] = None,
strategy: Optional[Union[str, TrainingTypePlugin]] = None,
strategy: Optional[Union[str, Strategy]] = None,
devices: Optional[Union[List[int], str, int]] = None,
num_nodes: int = 1,
precision: Union[int, str] = 32,
Expand Down Expand Up @@ -451,13 +451,13 @@ def _check_accelerator_support(self, accelerator: Optional[Union[str, Accelerato
f" Choose one of {supported} or pass in a `Accelerator` instance."
)

def _check_strategy_support(self, strategy: Optional[Union[str, TrainingTypePlugin]]) -> None:
def _check_strategy_support(self, strategy: Optional[Union[str, Strategy]]) -> None:
supported = [t.lower() for t in self._supported_strategy_types()]
valid = strategy is None or isinstance(strategy, TrainingTypePlugin) or strategy in supported
valid = strategy is None or isinstance(strategy, Strategy) or strategy in supported
if not valid:
raise MisconfigurationException(
f"`strategy={repr(strategy)}` is not a valid choice."
f" Choose one of {supported} or pass in a `TrainingTypePlugin` instance."
f" Choose one of {supported} or pass in a `Strategy` instance."
)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/lite/wrappers.py
Expand Up @@ -20,7 +20,7 @@
from torch.utils.data import DataLoader

from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
from pytorch_lightning.plugins import PrecisionPlugin, TrainingTypePlugin
from pytorch_lightning.plugins import PrecisionPlugin, Strategy
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device


Expand All @@ -29,7 +29,7 @@ def _do_nothing_closure() -> None:


class _LiteOptimizer:
def __init__(self, optimizer: Optimizer, strategy: TrainingTypePlugin) -> None:
def __init__(self, optimizer: Optimizer, strategy: Strategy) -> None:
"""LiteOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer
step calls to the strategy plugin.

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/__init__.py
Expand Up @@ -33,9 +33,9 @@
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.plugins.training_type.training_type_plugin import Strategy

PLUGIN = Union[TrainingTypePlugin, PrecisionPlugin, ClusterEnvironment, CheckpointIO]
PLUGIN = Union[Strategy, PrecisionPlugin, ClusterEnvironment, CheckpointIO]
PLUGIN_INPUT = Union[PLUGIN, str]

__all__ = [
Expand Down Expand Up @@ -63,7 +63,7 @@
"TPUPrecisionPlugin",
"TPUBf16PrecisionPlugin",
"TPUSpawnPlugin",
"TrainingTypePlugin",
"Strategy",
"ParallelPlugin",
"DDPShardedPlugin",
"DDPSpawnShardedPlugin",
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/io/checkpoint_plugin.py
Expand Up @@ -18,7 +18,7 @@


class CheckpointIO(ABC):
"""Interface to save/load checkpoints as they are saved through the ``TrainingTypePlugin``.
"""Interface to save/load checkpoints as they are saved through the ``Strategy``.

Typically most plugins either use the Torch based IO Plugin; ``TorchCheckpointIO`` but may
require particular handling depending on the plugin.
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/plugins_registry.py
Expand Up @@ -17,7 +17,7 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.plugins.training_type.training_type_plugin import Strategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand Down Expand Up @@ -122,7 +122,7 @@ def is_register_plugins_overridden(plugin: type) -> bool:
plugin_attr = getattr(plugin, method_name)
previous_super_cls = inspect.getmro(plugin)[1]

if issubclass(previous_super_cls, TrainingTypePlugin):
if issubclass(previous_super_cls, Strategy):
super_attr = getattr(previous_super_cls, method_name)
else:
return False
Expand All @@ -133,5 +133,5 @@ def is_register_plugins_overridden(plugin: type) -> bool:
def call_training_type_register_plugins(root: Path, base_module: str) -> None:
module = importlib.import_module(base_module)
for _, mod in getmembers(module, isclass):
if issubclass(mod, TrainingTypePlugin) and is_register_plugins_overridden(mod):
if issubclass(mod, Strategy) and is_register_plugins_overridden(mod):
mod.register_plugins(TrainingTypePluginsRegistry)
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/precision/precision_plugin.py
Expand Up @@ -205,7 +205,7 @@ def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -
torch.nn.utils.clip_grad_norm_(parameters, clip_val)

def dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something when ``TrainingTypePlugin.dispatch()`` gets called."""
"""Hook to do something when ``Strategy.dispatch()`` gets called."""

@contextlib.contextmanager
def forward_context(self) -> Generator[None, None, None]:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/__init__.py
Expand Up @@ -11,4 +11,4 @@
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401
from pytorch_lightning.plugins.training_type.training_type_plugin import Strategy # noqa: F401
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/parallel.py
Expand Up @@ -24,12 +24,12 @@
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.plugins.training_type.training_type_plugin import Strategy
from pytorch_lightning.utilities import _XLA_AVAILABLE
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp


class ParallelPlugin(TrainingTypePlugin, ABC):
class ParallelPlugin(Strategy, ABC):
"""Plugin for training with multiple processes in parallel."""

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/single_device.py
Expand Up @@ -18,11 +18,11 @@
import pytorch_lightning as pl
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.plugins.training_type.training_type_plugin import Strategy
from pytorch_lightning.utilities import _XLA_AVAILABLE


class SingleDevicePlugin(TrainingTypePlugin):
class SingleDevicePlugin(Strategy):
"""Plugin that handles communication on a single device."""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/tpu_spawn.py
Expand Up @@ -227,7 +227,7 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[
_invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
if _invalid_reduce_op or _invalid_reduce_op_str:
raise MisconfigurationException(
"Currently, TPUSpawn TrainingTypePlugin only support `sum`, `mean`, `avg` reduce operation."
"Currently, TPUSpawn Strategy only support `sum`, `mean`, `avg` reduce operation."
)

output = xm.mesh_reduce("reduce", output, sum)
Expand Down
Expand Up @@ -37,7 +37,7 @@
TBroadcast = TypeVar("TBroadcast")


class TrainingTypePlugin(ABC):
class Strategy(ABC):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""Base class for all training type plugins that change the behaviour of the training, validation and test-
loop."""

Expand All @@ -55,7 +55,7 @@ def __init__(
self.optimizers: List[Optimizer] = []
self.lr_schedulers: List[_LRScheduler] = []
self.optimizer_frequencies: List[int] = []
if is_overridden("post_dispatch", self, parent=TrainingTypePlugin):
if is_overridden("post_dispatch", self, parent=Strategy):
rank_zero_deprecation(
f"`{self.__class__.__name__}.post_dispatch()` has been deprecated in v1.6 and will be removed in v1.7."
f" Move your implementation to `{self.__class__.__name__}.teardown()` instead."
Expand Down