Skip to content

Commit

Permalink
Introduce base collective and main subclasses (#15016)
Browse files Browse the repository at this point in the history
Co-authored-by: otaj <ota@lightning.ai>
  • Loading branch information
carmocca and otaj committed Oct 7, 2022
1 parent 7e518ca commit 62ca073
Show file tree
Hide file tree
Showing 26 changed files with 713 additions and 42 deletions.
3 changes: 1 addition & 2 deletions src/lightning_lite/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO
from lightning_lite.plugins.io.torch_plugin import TorchCheckpointIO
Expand All @@ -28,10 +27,10 @@
"CheckpointIO",
"TorchCheckpointIO",
"XLACheckpointIO",
"Precision",
"DeepSpeedPrecision",
"DoublePrecision",
"NativeMixedPrecision",
"Precision",
"TPUPrecision",
"TPUBf16Precision",
]
9 changes: 9 additions & 0 deletions src/lightning_lite/plugins/collectives/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from lightning_lite.plugins.collectives.collective import Collective
from lightning_lite.plugins.collectives.single_device_collective import SingleDeviceCollective
from lightning_lite.plugins.collectives.torch_collective import TorchCollective

__all__ = [
"Collective",
"TorchCollective",
"SingleDeviceCollective",
]
137 changes: 137 additions & 0 deletions src/lightning_lite/plugins/collectives/collective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional

import torch
from typing_extensions import Self

from lightning_lite.utilities.types import CollectibleGroup


class Collective(ABC):
"""Interface for collective operations.
Supports communications between multiple processes and multiple nodes. A collective owns a group.
.. warning::
This API is experimental and subject to change
"""

def __init__(self) -> None:
self._group: Optional[CollectibleGroup] = None

@property
@abstractmethod
def rank(self) -> int:
...

@property
@abstractmethod
def world_size(self) -> int:
...

@property
def group(self) -> CollectibleGroup:
if self._group is None:
raise RuntimeError(
f"`{type(self).__name__}` does not own a group. HINT: try `collective.create_group().group`"
)
return self._group

@abstractmethod
def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
...

@abstractmethod
def all_reduce(self, tensor: torch.Tensor, op: str) -> torch.Tensor:
...

@abstractmethod
def reduce(self, tensor: torch.Tensor, dst: int, op: str) -> torch.Tensor:
...

@abstractmethod
def all_gather(self, tensor_list: List[torch.Tensor], tensor: torch.Tensor) -> List[torch.Tensor]:
...

@abstractmethod
def gather(self, tensor: torch.Tensor, gather_list: List[torch.Tensor], dst: int = 0) -> List[torch.Tensor]:
...

@abstractmethod
def scatter(self, tensor: torch.Tensor, scatter_list: List[torch.Tensor], src: int = 0) -> torch.Tensor:
...

@abstractmethod
def reduce_scatter(self, output: torch.Tensor, input_list: List[torch.Tensor], op: str) -> torch.Tensor:
...

@abstractmethod
def all_to_all(
self, output_tensor_list: List[torch.Tensor], input_tensor_list: List[torch.Tensor]
) -> List[torch.Tensor]:
...

@abstractmethod
def send(self, tensor: torch.Tensor, dst: int, tag: Optional[int] = 0) -> None:
...

@abstractmethod
def recv(self, tensor: torch.Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> torch.Tensor:
...

@abstractmethod
def barrier(self, device_ids: Optional[List[int]] = None) -> None:
...

@classmethod
@abstractmethod
def is_available(cls) -> bool:
...

@classmethod
@abstractmethod
def is_initialized(cls) -> bool:
...

@classmethod
@abstractmethod
def init_group(cls, **kwargs: Any) -> None:
...

@classmethod
@abstractmethod
def new_group(cls, **kwargs: Any) -> CollectibleGroup:
...

@classmethod
@abstractmethod
def destroy_group(cls, group: CollectibleGroup) -> None:
...

@classmethod
@abstractmethod
def _convert_to_native_op(cls, op: str) -> Any:
...

def setup(self, **kwargs: Any) -> Self: # type: ignore[valid-type]
if not self.is_initialized():
self.init_group(**kwargs)
return self

def create_group(self, **kwargs: Any) -> Self: # type: ignore[valid-type]
"""Create a group.
This assumes that :meth:`~lightning_lite.plugins.collectives.Collective.init_group` has been
called already by the user.
"""
if self._group is not None:
raise RuntimeError(f"`{type(self).__name__}` already owns a group.")
self._group = self.new_group(**kwargs)
return self

def teardown(self) -> Self: # type: ignore[valid-type]
if self._group is None:
raise RuntimeError(f"`{type(self).__name__}` does not own a group to destroy.")
self.destroy_group(self._group)
self._group = None
return self
81 changes: 81 additions & 0 deletions src/lightning_lite/plugins/collectives/single_device_collective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Any, List

import torch

from lightning_lite.plugins.collectives.collective import Collective
from lightning_lite.utilities.types import CollectibleGroup


class SingleDeviceCollective(Collective):
@property
def rank(self) -> int:
return 0

@property
def world_size(self) -> int:
return 1

def broadcast(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor:
return tensor

def all_reduce(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor:
return tensor

def reduce(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor:
return tensor

def all_gather(self, tensor_list: List[torch.Tensor], tensor: torch.Tensor, **__: Any) -> List[torch.Tensor]:
return [tensor]

def gather(self, tensor: torch.Tensor, *_: Any, **__: Any) -> List[torch.Tensor]:
return [tensor]

def scatter(
self,
tensor: torch.Tensor,
scatter_list: List[torch.Tensor],
*_: Any,
**__: Any,
) -> torch.Tensor:
return scatter_list[0]

def reduce_scatter(self, output: torch.Tensor, input_list: List[torch.Tensor], *_: Any, **__: Any) -> torch.Tensor:
return input_list[0]

def all_to_all(
self, output_tensor_list: List[torch.Tensor], input_tensor_list: List[torch.Tensor], *_: Any, **__: Any
) -> List[torch.Tensor]:
return input_tensor_list

def send(self, *_: Any, **__: Any) -> None:
pass

def recv(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor:
return tensor

def barrier(self, *_: Any, **__: Any) -> None:
pass

@classmethod
def is_available(cls) -> bool:
return True # vacuous truth

@classmethod
def is_initialized(cls) -> bool:
return True # vacuous truth

@classmethod
def init_group(cls, **_: Any) -> None:
pass

@classmethod
def new_group(cls, **_: Any) -> CollectibleGroup:
return object() # type: ignore[return-value]

@classmethod
def destroy_group(cls, group: CollectibleGroup) -> None:
pass

@classmethod
def _convert_to_native_op(cls, op: str) -> str:
return op

0 comments on commit 62ca073

Please sign in to comment.