Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1/2] Collaborative Strategy #12842

Merged
merged 51 commits into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
f24cbe8
Add code portion of collaborative
Apr 21, 2022
01793ce
Add CHANGELOG.md
Apr 21, 2022
e471ed9
Update pytorch_lightning/strategies/collaborative.py
Apr 22, 2022
9752173
Apply suggestions from code review
Apr 22, 2022
9aa4a3f
Address reviews
Apr 22, 2022
6af3af0
Protect variables
Apr 22, 2022
1733471
Add test, clean up a bit
Apr 25, 2022
176e98c
Merge branch 'master' into feat/collab_training_1n
Apr 25, 2022
e4d404a
Use requests exception
Apr 25, 2022
920d6b2
Test raise exception
Apr 25, 2022
cda9a14
Update pytorch_lightning/strategies/collaborative.py
Apr 25, 2022
93ff15f
fix test
Apr 25, 2022
831d911
Merge remote-tracking branch 'origin/feat/collab_training_1n' into fe…
Apr 25, 2022
db0a411
Apply suggestions from code review
Apr 25, 2022
dae542d
Add comments
Apr 25, 2022
0a6f44d
Address reviews
Apr 25, 2022
8291521
Update tests/strategies/test_collaborative.py
Apr 25, 2022
ea17d5b
Address reviews
Apr 25, 2022
3c8b276
Merge remote-tracking branch 'origin/feat/collab_training_1n' into fe…
Apr 25, 2022
c15aa95
Add test
Apr 26, 2022
7fc1ee0
Try to fix tests
Apr 26, 2022
341c03d
Merge branch 'master' into feat/collab_training_1n
Apr 26, 2022
0626046
Merge branch 'master' into feat/collab_training_1n
Apr 27, 2022
b29ddb3
Attempt to use sys_platform
Apr 27, 2022
d62f600
Add test!
Apr 27, 2022
49deba2
Merge branch 'master' into feat/collab_training_1n
Apr 27, 2022
6861031
Apply suggestions from code review
Apr 28, 2022
46d95a7
code review
Apr 28, 2022
185d5c3
Fix condition
Apr 28, 2022
77a787a
Fix mypy. Raise error with ReduceLROnPlateau
carmocca Apr 28, 2022
e43c7b5
Apply some of my comments
carmocca Apr 28, 2022
4e9e6c9
Apply suggestions from code review
Apr 28, 2022
08448d6
Add todo and message
Apr 28, 2022
bc0f503
Merge branch 'master' into feat/collab_training_1n
kaushikb11 Apr 29, 2022
59e9834
Address reviews
May 4, 2022
af30886
Removing variables to see what breaks
May 4, 2022
d7b3c2e
Add global rank property
May 4, 2022
4ee41b0
Add commas
May 4, 2022
f7252b5
Attempt to fix override
May 4, 2022
c13fec0
Fix spelling error
May 4, 2022
3e43572
Fix mypy
May 4, 2022
047400b
More types
May 4, 2022
340916b
fix check
May 4, 2022
adac127
Add var
May 4, 2022
37bfa2c
Merge branch 'master' into feat/collab_training_1n
May 4, 2022
6597526
Move install
May 4, 2022
514d356
Check linux, address reviews
May 5, 2022
66fc203
Update pytorch_lightning/strategies/collaborative.py
May 5, 2022
779d3d1
Address review
May 5, 2022
ad83816
Update pytorch_lightning/strategies/collaborative.py
May 5, 2022
a77cf16
Fix condition
May 5, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `Trainer(deterministic="warn")` to warn instead of fail when a non-deterministic operation is encountered ([#12588](https://github.com/PyTorchLightning/pytorch-lightning/pull/12588))


- Added `CollaborativeStrategy` ([#12842](https://github.com/PyTorchLightning/pytorch-lightning/pull/12842))


- Include a version suffix for new "last" checkpoints of later runs in the same directory ([#12902](https://github.com/PyTorchLightning/pytorch-lightning/pull/12902))


Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.strategies.bagua import BaguaStrategy # noqa: F401
from pytorch_lightning.strategies.collaborative import CollaborativeStrategy # noqa: F401
from pytorch_lightning.strategies.ddp import DDPStrategy # noqa: F401
from pytorch_lightning.strategies.ddp2 import DDP2Strategy # noqa: F401
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy # noqa: F401
Expand Down
529 changes: 529 additions & 0 deletions pytorch_lightning/strategies/collaborative.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_FAIRSCALE_FULLY_SHARDED_AVAILABLE,
_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE,
_GROUP_AVAILABLE,
_HIVEMIND_AVAILABLE,
_HOROVOD_AVAILABLE,
_HPU_AVAILABLE,
_HYDRA_AVAILABLE,
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.3")
_FAIRSCALE_FULLY_SHARDED_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.4")
_GROUP_AVAILABLE = not _IS_WINDOWS and _module_available("torch.distributed.group")
_HIVEMIND_AVAILABLE = _package_available("hivemind")
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
_HYDRA_AVAILABLE = _package_available("hydra")
_HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental")
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class _LRScheduler(_Stateful, Protocol):
def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None:
...

def step(self, epoch: Optional[int] = None) -> None:
...


# Inferred from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
Expand All @@ -91,6 +94,9 @@ def __init__(
) -> None:
...

def step(self, metrics: Union[float, int, torch.Tensor], epoch: Optional[int] = None) -> None:
...


# todo: improve LRSchedulerType naming/typing
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
Expand Down
1 change: 1 addition & 0 deletions requirements/strategies.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
fairscale>=0.4.5
deepspeed<0.6.0
horovod>=0.21.2,!=0.24.0 # no need to install with [pytorch] as pytorch is already installed
hivemind>=1.0.1; sys_platform == 'linux'
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
7 changes: 7 additions & 0 deletions tests/helpers/runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_DEEPSPEED_AVAILABLE,
_FAIRSCALE_AVAILABLE,
_FAIRSCALE_FULLY_SHARDED_AVAILABLE,
_HIVEMIND_AVAILABLE,
_HOROVOD_AVAILABLE,
_HPU_AVAILABLE,
_IPU_AVAILABLE,
Expand Down Expand Up @@ -84,6 +85,7 @@ def __new__(
omegaconf: bool = False,
slow: bool = False,
bagua: bool = False,
hivemind: bool = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -111,6 +113,7 @@ def __new__(
omegaconf: Require that omry/omegaconf is installed.
slow: Mark the test as slow, our CI will run it in a separate job.
bagua: Require that BaguaSys/bagua is installed.
hivemind: Require that Hivemind is installed.
**kwargs: Any :class:`pytest.mark.skipif` keyword arguments.
"""
conditions = []
Expand Down Expand Up @@ -231,6 +234,10 @@ def __new__(
conditions.append(not _BAGUA_AVAILABLE or sys.platform in ("win32", "darwin"))
reasons.append("Bagua")

if hivemind:
conditions.append(not _HIVEMIND_AVAILABLE or sys.platform in ("win32", "darwin"))
reasons.append("Hivemind")

reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
return pytest.mark.skipif(
*args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs
Expand Down
Loading