Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce base collective and main subclasses #15016

Merged
merged 58 commits into from
Oct 7, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
cd0287a
Introduce base collective and main subclasses
carmocca Oct 6, 2022
dbd2a20
Chery-pick tag fix
carmocca Oct 6, 2022
e25e7ff
Sort methods
carmocca Oct 6, 2022
34aaf00
Fix import
carmocca Oct 6, 2022
515e54d
Support passing ReduceOp to TorchCollective
carmocca Oct 6, 2022
a2187a6
test
Oct 6, 2022
63f0399
working full test
Oct 6, 2022
6a6af88
Refactor impl
carmocca Oct 6, 2022
5298d9f
Remove extra argument
carmocca Oct 6, 2022
d40a7d1
I think we can assume that whoever runs our tests has torch.distributed?
carmocca Oct 6, 2022
266069b
Fix 1.9 compat
carmocca Oct 6, 2022
5c63779
convert_ops test
Oct 6, 2022
021604e
mark
Oct 6, 2022
b6f9356
1.9 compatibility
Oct 6, 2022
37fc617
niceties
Oct 6, 2022
299e5ef
teardown fixture
Oct 6, 2022
6109b90
teardown fixture
Oct 6, 2022
f3875f9
removing param
Oct 6, 2022
fc7ff69
create and destroy tests
Oct 6, 2022
5e832c6
Push current fixes before Ota gives me conflicts
carmocca Oct 6, 2022
76d92c9
All gather true test
carmocca Oct 6, 2022
5f5eab4
Fixing tests
carmocca Oct 6, 2022
814719f
Single device tests
carmocca Oct 6, 2022
96fee6b
Simplify tests
carmocca Oct 6, 2022
1354c08
Fix mypy
carmocca Oct 6, 2022
438b4fd
Reduce true test
carmocca Oct 6, 2022
9fafbfd
singledevice strategy test
Oct 6, 2022
293d0f2
remove unfinished comment
Oct 6, 2022
661abb8
Merge into 1 test to ammortize launch
carmocca Oct 6, 2022
f368b89
Tiny docstring change
carmocca Oct 6, 2022
eccbd96
One fixture is enough
carmocca Oct 6, 2022
60b8741
Do we even need launch?
carmocca Oct 6, 2022
bcb4534
Assert not initialized in fixture
carmocca Oct 6, 2022
7c5b92a
add test, that recreate is possible
Oct 6, 2022
308b27d
Revert to launch
carmocca Oct 6, 2022
2b40b96
distributed tests are passing now
Oct 6, 2022
7c01765
fix other tests
Oct 7, 2022
4c51bff
test two groups
Oct 7, 2022
4ecafd1
Wrapper
carmocca Oct 7, 2022
d8c31ab
Move method
carmocca Oct 7, 2022
19fb7c3
merge
Oct 7, 2022
9766deb
is_initialized/available
carmocca Oct 7, 2022
092dd0b
remove init_kwargs
Oct 7, 2022
f1fc251
Replace RunIf
carmocca Oct 7, 2022
28d0d34
finalize tests
Oct 7, 2022
1df314e
Docstring
carmocca Oct 7, 2022
31ae5fa
Cleanup logic
carmocca Oct 7, 2022
00e4df7
we are passing now
Oct 7, 2022
7fb617d
Typing
carmocca Oct 7, 2022
76c29ae
Environ tests
carmocca Oct 7, 2022
ba313b9
Drop instantiate_group
carmocca Oct 7, 2022
f03cae8
Docstring
carmocca Oct 7, 2022
97b80fc
warning in docstring
carmocca Oct 7, 2022
08858c3
Cleanup env right after
carmocca Oct 7, 2022
545e911
unify types
Oct 7, 2022
e556e5a
Fix mypy
carmocca Oct 7, 2022
67a5ee6
remove debug file
Oct 7, 2022
d1e9209
test_two_groups is hanging in a job. try barrier
carmocca Oct 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/lightning_lite/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO
from lightning_lite.plugins.io.torch_plugin import TorchCheckpointIO
Expand All @@ -28,10 +27,10 @@
"CheckpointIO",
"TorchCheckpointIO",
"XLACheckpointIO",
"Precision",
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"DeepSpeedPrecision",
"DoublePrecision",
"NativeMixedPrecision",
"Precision",
"TPUPrecision",
"TPUBf16Precision",
]
9 changes: 9 additions & 0 deletions src/lightning_lite/plugins/collectives/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from lightning_lite.plugins.collectives.collective import Collective
from lightning_lite.plugins.collectives.single_device_collective import SingleDeviceCollective
from lightning_lite.plugins.collectives.torch_collective import TorchCollective

__all__ = [
"Collective",
"TorchCollective",
"SingleDeviceCollective",
]
carmocca marked this conversation as resolved.
Show resolved Hide resolved
146 changes: 146 additions & 0 deletions src/lightning_lite/plugins/collectives/collective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional

import torch
from typing_extensions import Self

from lightning_lite.utilities.types import CollectibleGroup


class Collective(ABC):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, instantiate_group: bool = False, **group_kwargs: Any) -> None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self._group_kwargs = group_kwargs
self._group: Optional[CollectibleGroup] = None
if instantiate_group:
self.create_group()

def create_group(self, **kwargs: Any) -> Self: # type: ignore[valid-type]
if self._group is not None:
raise RuntimeError(f"{type(self).__name__} already owns a group.")
self._group_kwargs.update(kwargs)
self._group = self.init_group(**self._group_kwargs)
return self

@property
def group(self) -> CollectibleGroup:
if self._group is None:
raise RuntimeError(
f"{type(self).__name__} does not own a group. HINT: try `collective.create_group().group`"
)
return self._group

@property
@abstractmethod
def rank(self) -> int:
pass

@property
@abstractmethod
def world_size(self) -> int:
pass

@staticmethod
@abstractmethod
def init_group(
**kwargs: Any,
) -> CollectibleGroup:
pass

def teardown(self) -> None:
if self._group is None:
raise RuntimeError(f"{type(self).__name__} does not own a group to destroy.")
self.destroy_group(self._group)
self._group = None

@staticmethod
@abstractmethod
def destroy_group(group: CollectibleGroup) -> None:
pass

@staticmethod
@abstractmethod
def _convert_to_native_op(op: str) -> Any:
pass

@abstractmethod
def send(self, tensor: torch.Tensor, dst: int, tag: Optional[int] = 0) -> None:
pass

@abstractmethod
def recv(self, tensor: torch.Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> torch.Tensor:
pass

@abstractmethod
def broadcast(
self,
tensor: torch.Tensor,
src: int,
) -> torch.Tensor:
pass

@abstractmethod
def all_reduce(
self,
tensor: torch.Tensor,
op: str,
) -> torch.Tensor:
pass

@abstractmethod
def reduce(
self,
tensor: torch.Tensor,
dst: int,
op: str,
) -> torch.Tensor:
pass

@abstractmethod
def all_gather(
self,
tensor_list: List[torch.Tensor],
tensor: torch.Tensor,
) -> List[torch.Tensor]:
pass

@abstractmethod
def gather(
self,
tensor: torch.Tensor,
gather_list: Optional[List[torch.Tensor]] = None,
dst: int = 0,
) -> Optional[List[torch.Tensor]]:
pass

@abstractmethod
def scatter(
self,
tensor: torch.Tensor,
scatter_list: Optional[List[torch.Tensor]] = None,
src: int = 0,
) -> torch.Tensor:
pass

@abstractmethod
def reduce_scatter(
self,
output: torch.Tensor,
input_list: List[torch.Tensor],
op: str,
) -> torch.Tensor:
pass

@abstractmethod
def all_to_all(
self,
output_tensor_list: List[torch.Tensor],
input_tensor_list: List[torch.Tensor],
) -> List[torch.Tensor]:
pass

@abstractmethod
def barrier(
self,
device_ids: Optional[List[int]] = None,
) -> None:
pass
110 changes: 110 additions & 0 deletions src/lightning_lite/plugins/collectives/single_device_collective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from typing import Any, List, Optional

import torch

from lightning_lite.plugins.collectives.collective import Collective
from lightning_lite.utilities.types import CollectibleGroup


class SingleDeviceCollective(Collective):
@property
def rank(self) -> int:
return 0

@property
def world_size(self) -> int:
return 1

@staticmethod
def init_group(
**kwargs: Any,
) -> CollectibleGroup:
return object() # type: ignore[return-value]

@staticmethod
def _convert_to_native_op(op: str) -> str:
return op

@staticmethod
def destroy_group(group: CollectibleGroup) -> None:
pass

def send(self, *_: Any, **__: Any) -> None:
pass

def recv(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor:
return tensor

def broadcast(
self,
tensor: torch.Tensor,
*_: Any,
**__: Any,
) -> torch.Tensor:
return tensor

def all_reduce(
self,
tensor: torch.Tensor,
*_: Any,
**__: Any,
) -> torch.Tensor:
return tensor

def reduce(
self,
tensor: torch.Tensor,
*_: Any,
**__: Any,
) -> torch.Tensor:
return tensor

def all_gather(
self,
tensor_list: List[torch.Tensor],
tensor: torch.Tensor,
**__: Any,
) -> List[torch.Tensor]:
return [tensor]

def gather(
self,
tensor: torch.Tensor,
*_: Any,
**__: Any,
) -> Optional[List[torch.Tensor]]:
return [tensor]

def scatter( # type: ignore[override]
self,
tensor: torch.Tensor,
scatter_list: List[torch.Tensor], # it doesn't make sense to have a None here for a single device
otaj marked this conversation as resolved.
Show resolved Hide resolved
*_: Any,
**__: Any,
) -> torch.Tensor:
return scatter_list[0]

def reduce_scatter(
self,
output: torch.Tensor,
input_list: List[torch.Tensor],
*_: Any,
**__: Any,
) -> torch.Tensor:
return input_list[0]

def all_to_all(
self,
output_tensor_list: List[torch.Tensor],
input_tensor_list: List[torch.Tensor],
*_: Any,
**__: Any,
) -> List[torch.Tensor]:
return input_tensor_list

def barrier(
self,
*_: Any,
**__: Any,
) -> None:
pass