From be4c24c484e7969045adafb5f695561ec9190005 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 19 Nov 2020 10:43:16 +0000 Subject: [PATCH 01/71] Encapsulate extracting reference model within the plugin to allow custom wrapper logic to live within the plugin/accelerators --- pytorch_lightning/accelerators/accelerator.py | 17 ++++++++++++++ .../accelerators/ddp2_accelerator.py | 3 +++ .../accelerators/ddp_accelerator.py | 3 +++ .../accelerators/ddp_cpu_spawn_accelerator.py | 3 +++ .../accelerators/ddp_hpc_accelerator.py | 3 +++ .../accelerators/ddp_spawn_accelerator.py | 3 +++ .../accelerators/dp_accelerator.py | 3 +++ pytorch_lightning/plugins/ddp_plugin.py | 17 ++++++++++++++ .../trainer/connectors/model_connector.py | 22 ++++++++----------- 9 files changed, 61 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 5b9c3abc4709a..5e5bc9a8e126f 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -208,6 +208,23 @@ def optimizer_state(self, optimizer: Optimizer) -> dict: return self.ddp_plugin.optimizer_state(optimizer) return optimizer.state_dict() + def reference_model(self, model): + """ + Override to modify returning base :class:`LightningModule` + when accessing variable and functions if the accelerator has wrapped the model. + + Example:: + ref_model = accelerator.reference_model(model) + ref_model.training_step(...) + + Args: + model: Accelerator model. + + Returns: Reference :class:`LightningModule`. + + """ + return model + def __getstate__(self): return { 'trainer': self.trainer, diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 9fcbdd4668ee9..5efb0d0a469ce 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -218,3 +218,6 @@ def sync_tensor(self, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) + + def reference_model(self, model): + return self.ddp_plugin.reference_model_in_plugin_wrapper(model) diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 69d41cd024646..8c82f1a43623b 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -319,3 +319,6 @@ def sync_tensor(self, """ return sync_ddp_if_available(tensor, group, reduce_op) + + def reference_model(self, model): + return self.ddp_plugin.reference_model_in_plugin_wrapper(model) \ No newline at end of file diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 2a090a72e2b5a..95ba7bfb18762 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -246,3 +246,6 @@ def sync_tensor(self, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) + + def reference_model(self, model): + return self.ddp_plugin.reference_model_in_plugin_wrapper(model) \ No newline at end of file diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index 2ff9c2b7ddaae..6522ae38dc1d4 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -213,3 +213,6 @@ def sync_tensor(self, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) + + def reference_model(self, model): + return self.ddp_plugin.reference_model_in_plugin_wrapper(model) \ No newline at end of file diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index eac51393a5f2e..b0a9933c130a3 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -272,3 +272,6 @@ def sync_tensor(self, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) + + def reference_model(self, model): + return self.ddp_plugin.reference_model_in_plugin_wrapper(model) \ No newline at end of file diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index 2f6c5dce97c46..db79aaa09820a 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -172,3 +172,6 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list): scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer) if state is not None: scheduler.load_state_dict(state) + + def reference_model(self, model: LightningDataParallel): + return model.module \ No newline at end of file diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 6a38da2e0c2bc..d4fc97d4e5a45 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -108,3 +108,20 @@ def on_before_forward(self, model, *args): def optimizer_state(self, optimizer: Optimizer) -> dict: return optimizer.state_dict() + + def reference_model_in_plugin_wrapper(self, model: LightningDistributedDataParallel): + """ + Override to modify returning base :class:`LightningModule` + when accessing variable and functions outside of the parallel wrapper. + + Example:: + ref_model = ddp_plugin.reference_model(model) + ref_model.training_step(...) + + Args: + model: Model with parallel wrapper. + + Returns: Reference :class:`LightningModule` within parallel wrapper. + + """ + return model.module diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index dbdceb1532288..0c44e76319e71 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -17,10 +17,6 @@ Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU. """ -from pytorch_lightning.overrides.data_parallel import ( - LightningDistributedDataParallel, - LightningDataParallel, -) class ModelConnector: @@ -28,12 +24,7 @@ def __init__(self, trainer): self.trainer = trainer def copy_trainer_model_properties(self, model): - if isinstance(model, LightningDataParallel): - ref_model = model.module - elif isinstance(model, LightningDistributedDataParallel): - ref_model = model.module - else: - ref_model = model + ref_model = self._reference_model(model) automatic_optimization = ref_model.automatic_optimization and self.trainer.train_loop.automatic_optimization self.trainer.train_loop.automatic_optimization = automatic_optimization @@ -55,6 +46,11 @@ def copy_trainer_model_properties(self, model): m.local_rank = self.trainer.local_rank def get_model(self): - is_dp_module = isinstance(self.trainer.model, (LightningDistributedDataParallel, LightningDataParallel)) - model = self.trainer.model.module if is_dp_module else self.trainer.model - return model + return self._reference_model(self.trainer.model) + + def _reference_model(self, model): + if self.trainer.accelerator.ddp_plugin: + ref_model = self.trainer.accelerator.reference_model(model) + else: + ref_model = model + return ref_model From 5101696b63d26db7d0a700d5bac6f264e8a765b8 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 19 Nov 2020 10:46:57 +0000 Subject: [PATCH 02/71] Add missing new lines --- pytorch_lightning/accelerators/ddp2_accelerator.py | 1 + pytorch_lightning/accelerators/ddp_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_hpc_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_spawn_accelerator.py | 2 +- pytorch_lightning/accelerators/dp_accelerator.py | 2 +- 6 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 5efb0d0a469ce..6043c7b85fda2 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -221,3 +221,4 @@ def sync_tensor(self, def reference_model(self, model): return self.ddp_plugin.reference_model_in_plugin_wrapper(model) + diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 8c82f1a43623b..825a6ccb137d3 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -321,4 +321,4 @@ def sync_tensor(self, return sync_ddp_if_available(tensor, group, reduce_op) def reference_model(self, model): - return self.ddp_plugin.reference_model_in_plugin_wrapper(model) \ No newline at end of file + return self.ddp_plugin.reference_model_in_plugin_wrapper(model) diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 95ba7bfb18762..1dc735443e0a7 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -248,4 +248,4 @@ def sync_tensor(self, return sync_ddp_if_available(tensor, group, reduce_op) def reference_model(self, model): - return self.ddp_plugin.reference_model_in_plugin_wrapper(model) \ No newline at end of file + return self.ddp_plugin.reference_model_in_plugin_wrapper(model) diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index 6522ae38dc1d4..ff571d111764c 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -215,4 +215,4 @@ def sync_tensor(self, return sync_ddp_if_available(tensor, group, reduce_op) def reference_model(self, model): - return self.ddp_plugin.reference_model_in_plugin_wrapper(model) \ No newline at end of file + return self.ddp_plugin.reference_model_in_plugin_wrapper(model) diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index b0a9933c130a3..47866ed4656a8 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -274,4 +274,4 @@ def sync_tensor(self, return sync_ddp_if_available(tensor, group, reduce_op) def reference_model(self, model): - return self.ddp_plugin.reference_model_in_plugin_wrapper(model) \ No newline at end of file + return self.ddp_plugin.reference_model_in_plugin_wrapper(model) diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index db79aaa09820a..36e4da4df6f22 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -174,4 +174,4 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list): scheduler.load_state_dict(state) def reference_model(self, model: LightningDataParallel): - return model.module \ No newline at end of file + return model.module From 078a829834e44579c936bdd465ff7e80d433190b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 19 Nov 2020 10:48:27 +0000 Subject: [PATCH 03/71] Fix call to accelerator --- pytorch_lightning/trainer/connectors/model_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index 0c44e76319e71..72e8ce417ecdd 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -49,7 +49,7 @@ def get_model(self): return self._reference_model(self.trainer.model) def _reference_model(self, model): - if self.trainer.accelerator.ddp_plugin: + if self.trainer.accelerator: ref_model = self.trainer.accelerator.reference_model(model) else: ref_model = model From aeab93c6b307ea26412841eba46d937e21dd7f33 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 19 Nov 2020 10:52:25 +0000 Subject: [PATCH 04/71] Removed double blank --- pytorch_lightning/accelerators/ddp2_accelerator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 6043c7b85fda2..5efb0d0a469ce 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -221,4 +221,3 @@ def sync_tensor(self, def reference_model(self, model): return self.ddp_plugin.reference_model_in_plugin_wrapper(model) - From 95a1f1985125a5778dd7ea0c3734b761fabb447c Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 19 Nov 2020 10:59:17 +0000 Subject: [PATCH 05/71] Use accelerator backend --- pytorch_lightning/trainer/connectors/model_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index 72e8ce417ecdd..efd22fee1d711 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -49,8 +49,8 @@ def get_model(self): return self._reference_model(self.trainer.model) def _reference_model(self, model): - if self.trainer.accelerator: - ref_model = self.trainer.accelerator.reference_model(model) + if self.trainer.accelerator_backend: + ref_model = self.trainer.accelerator_backend.reference_model(model) else: ref_model = model return ref_model From 84ccdbf8862805244121f425fe99ab3d91b9621a Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 19 Nov 2020 11:59:30 +0000 Subject: [PATCH 06/71] Handle case where wrapper has not been initialized within the plugin --- pytorch_lightning/accelerators/dp_accelerator.py | 5 ++++- pytorch_lightning/plugins/ddp_plugin.py | 7 +++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index 36e4da4df6f22..57cced6422e44 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -174,4 +174,7 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list): scheduler.load_state_dict(state) def reference_model(self, model: LightningDataParallel): - return model.module + if isinstance(model, LightningDataParallel): + return model.module + else: + return model diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index d4fc97d4e5a45..8da7079a45695 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -109,7 +109,7 @@ def on_before_forward(self, model, *args): def optimizer_state(self, optimizer: Optimizer) -> dict: return optimizer.state_dict() - def reference_model_in_plugin_wrapper(self, model: LightningDistributedDataParallel): + def reference_model_in_plugin_wrapper(self, model): """ Override to modify returning base :class:`LightningModule` when accessing variable and functions outside of the parallel wrapper. @@ -124,4 +124,7 @@ def reference_model_in_plugin_wrapper(self, model: LightningDistributedDataParal Returns: Reference :class:`LightningModule` within parallel wrapper. """ - return model.module + if isinstance(model, LightningDistributedDataParallel): + return model.module + else: + return model From 0864b1c8934fa4c93be27c75dcc4a8cf604f1502 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 19 Nov 2020 12:36:23 +0000 Subject: [PATCH 07/71] Added basic get model tests, add better typing --- .../accelerators/dp_accelerator.py | 4 +- pytorch_lightning/plugins/ddp_plugin.py | 4 +- tests/trainer/properties/test_get_model.py | 84 +++++++++++++++++++ 3 files changed, 89 insertions(+), 3 deletions(-) create mode 100644 tests/trainer/properties/test_get_model.py diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index 57cced6422e44..da7e0c8893f00 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union import torch from torch import optim +from pytorch_lightning import LightningModule from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.core.step_result import Result @@ -173,7 +175,7 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list): if state is not None: scheduler.load_state_dict(state) - def reference_model(self, model: LightningDataParallel): + def reference_model(self, model: Union[LightningDataParallel, LightningModule]): if isinstance(model, LightningDataParallel): return model.module else: diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 8da7079a45695..7fb02c7aee427 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List +from typing import Any, Dict, List, Union import torch.distributed as torch_distrib from pytorch_lightning import _logger as log @@ -109,7 +109,7 @@ def on_before_forward(self, model, *args): def optimizer_state(self, optimizer: Optimizer) -> dict: return optimizer.state_dict() - def reference_model_in_plugin_wrapper(self, model): + def reference_model_in_plugin_wrapper(self, model: Union[LightningDistributedDataParallel, LightningModule]): """ Override to modify returning base :class:`LightningModule` when accessing variable and functions outside of the parallel wrapper. diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py new file mode 100644 index 0000000000000..84726093300e2 --- /dev/null +++ b/tests/trainer/properties/test_get_model.py @@ -0,0 +1,84 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from pytorch_lightning import Trainer +from tests.base.boring_model import BoringModel + + +class TestGetModel(BoringModel): + def on_fit_start(self): + assert self == self.trainer.get_model() + + def on_fit_end(self): + assert self == self.trainer.get_model() + + +def test_get_model(tmpdir): + """ + Tests that trainer.get_model() extracts the model correctly + """ + + model = TestGetModel() + + limit_train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + ) + trainer.fit(model) + + +def test_get_model_ddp_cpu(tmpdir): + """ + Tests that trainer.get_model() extracts the model correctly when using ddp on cpu + """ + + model = TestGetModel() + + limit_train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + accelerator='ddp_cpu', + num_processes=2 + ) + trainer.fit(model) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@pytest.mark.parametrize("ddp_backend", [None, 'ddp_spawn']) +def test_get_model_ddp_gpu(tmpdir, ddp_backend): + """ + Tests that trainer.get_model() extracts the model correctly when using GPU + ddp accelerators + """ + + model = TestGetModel() + + limit_train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + gpus=1, + accelerator=ddp_backend + ) + trainer.fit(model) From 142a2d39694ebf1d371f909d4f92a04c26e0f55b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 19 Nov 2020 12:37:41 +0000 Subject: [PATCH 08/71] Change model name --- tests/trainer/properties/test_get_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 84726093300e2..5a9fdd7c5a40a 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -19,7 +19,7 @@ from tests.base.boring_model import BoringModel -class TestGetModel(BoringModel): +class TrainerGetModel(BoringModel): def on_fit_start(self): assert self == self.trainer.get_model() @@ -32,7 +32,7 @@ def test_get_model(tmpdir): Tests that trainer.get_model() extracts the model correctly """ - model = TestGetModel() + model = TrainerGetModel() limit_train_batches = 2 trainer = Trainer( @@ -49,7 +49,7 @@ def test_get_model_ddp_cpu(tmpdir): Tests that trainer.get_model() extracts the model correctly when using ddp on cpu """ - model = TestGetModel() + model = TrainerGetModel() limit_train_batches = 2 trainer = Trainer( @@ -70,7 +70,7 @@ def test_get_model_ddp_gpu(tmpdir, ddp_backend): Tests that trainer.get_model() extracts the model correctly when using GPU + ddp accelerators """ - model = TestGetModel() + model = TrainerGetModel() limit_train_batches = 2 trainer = Trainer( From 6e548df4b359f475a325a2d79e860ee69be85400 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 19 Nov 2020 13:10:51 +0000 Subject: [PATCH 09/71] Split GPU/DDP test --- tests/trainer/properties/test_get_model.py | 32 ++++++++++++++++++---- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 5a9fdd7c5a40a..8bb103c62edf8 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys import pytest import torch @@ -29,7 +30,7 @@ def on_fit_end(self): def test_get_model(tmpdir): """ - Tests that trainer.get_model() extracts the model correctly + Tests that :meth:`trainer.get_model` extracts the model correctly """ model = TrainerGetModel() @@ -46,7 +47,7 @@ def test_get_model(tmpdir): def test_get_model_ddp_cpu(tmpdir): """ - Tests that trainer.get_model() extracts the model correctly when using ddp on cpu + Tests that :meth:`trainer.get_model` extracts the model correctly when using ddp on cpu """ model = TrainerGetModel() @@ -64,10 +65,29 @@ def test_get_model_ddp_cpu(tmpdir): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") -@pytest.mark.parametrize("ddp_backend", [None, 'ddp_spawn']) -def test_get_model_ddp_gpu(tmpdir, ddp_backend): +def test_get_model_gpu(tmpdir): """ - Tests that trainer.get_model() extracts the model correctly when using GPU + ddp accelerators + Tests that :meth:`trainer.get_model` extracts the model correctly when using GPU + """ + + model = TrainerGetModel() + + limit_train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + gpus=1 + ) + trainer.fit(model) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +def test_get_model_ddp_gpu(tmpdir): + """ + Tests that :meth:`trainer.get_model` extracts the model correctly when using GPU + ddp accelerators """ model = TrainerGetModel() @@ -79,6 +99,6 @@ def test_get_model_ddp_gpu(tmpdir, ddp_backend): limit_val_batches=2, max_epochs=1, gpus=1, - accelerator=ddp_backend + accelerator='ddp_spawn' ) trainer.fit(model) From aebb1a30ff04aaeeaa3c16e39ef0f1ae35bc201d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 19 Nov 2020 13:42:27 +0000 Subject: [PATCH 10/71] Add stronger typing, skip ddp test on windows --- pytorch_lightning/accelerators/accelerator.py | 8 ++++---- pytorch_lightning/accelerators/ddp2_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_accelerator.py | 2 +- .../accelerators/ddp_cpu_spawn_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_hpc_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_spawn_accelerator.py | 2 +- pytorch_lightning/accelerators/dp_accelerator.py | 2 +- pytorch_lightning/plugins/ddp_plugin.py | 3 ++- tests/trainer/properties/test_get_model.py | 1 + 9 files changed, 13 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 5e5bc9a8e126f..e7b26ef646334 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os + from enum import Enum -from typing import Any, Optional, Union, List +from typing import Any, Optional, Union import torch from torch.optim import Optimizer @@ -23,7 +23,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict import torch.distributed as torch_distrib -from pytorch_lightning import _logger as log +from pytorch_lightning import LightningModule if torch.distributed.is_available(): from torch.distributed import ReduceOp @@ -208,7 +208,7 @@ def optimizer_state(self, optimizer: Optimizer) -> dict: return self.ddp_plugin.optimizer_state(optimizer) return optimizer.state_dict() - def reference_model(self, model): + def reference_model(self, model) -> LightningModule: """ Override to modify returning base :class:`LightningModule` when accessing variable and functions if the accelerator has wrapped the model. diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 5efb0d0a469ce..5f1d004122fe0 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -219,5 +219,5 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) - def reference_model(self, model): + def reference_model(self, model) -> LightningModule: return self.ddp_plugin.reference_model_in_plugin_wrapper(model) diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 825a6ccb137d3..d60fe6d3d61b1 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -320,5 +320,5 @@ def sync_tensor(self, """ return sync_ddp_if_available(tensor, group, reduce_op) - def reference_model(self, model): + def reference_model(self, model) -> LightningModule: return self.ddp_plugin.reference_model_in_plugin_wrapper(model) diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 1dc735443e0a7..001f9259bd419 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -247,5 +247,5 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) - def reference_model(self, model): + def reference_model(self, model) -> LightningModule: return self.ddp_plugin.reference_model_in_plugin_wrapper(model) diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index ff571d111764c..76e212035ce13 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -214,5 +214,5 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) - def reference_model(self, model): + def reference_model(self, model) -> LightningModule: return self.ddp_plugin.reference_model_in_plugin_wrapper(model) diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index 47866ed4656a8..34b3b4de056e1 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -273,5 +273,5 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) - def reference_model(self, model): + def reference_model(self, model) -> LightningModule: return self.ddp_plugin.reference_model_in_plugin_wrapper(model) diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index da7e0c8893f00..4ecb2dea4bd91 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -175,7 +175,7 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list): if state is not None: scheduler.load_state_dict(state) - def reference_model(self, model: Union[LightningDataParallel, LightningModule]): + def reference_model(self, model: Union[LightningDataParallel, LightningModule]) -> LightningModule: if isinstance(model, LightningDataParallel): return model.module else: diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 7fb02c7aee427..e99d95aafb495 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -109,7 +109,8 @@ def on_before_forward(self, model, *args): def optimizer_state(self, optimizer: Optimizer) -> dict: return optimizer.state_dict() - def reference_model_in_plugin_wrapper(self, model: Union[LightningDistributedDataParallel, LightningModule]): + def reference_model_in_plugin_wrapper(self, model: Union[ + LightningDistributedDataParallel, LightningModule]) -> LightningModule: """ Override to modify returning base :class:`LightningModule` when accessing variable and functions outside of the parallel wrapper. diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 8bb103c62edf8..2ce20e05e64d6 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -45,6 +45,7 @@ def test_get_model(tmpdir): trainer.fit(model) +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") def test_get_model_ddp_cpu(tmpdir): """ Tests that :meth:`trainer.get_model` extracts the model correctly when using ddp on cpu From fa0480768d1d2eb85c0c6baa61a2a57ae395d1d9 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 19 Nov 2020 14:04:49 +0000 Subject: [PATCH 11/71] Fix import --- pytorch_lightning/accelerators/accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index e7b26ef646334..9ea43a9185d06 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -22,8 +22,8 @@ from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict +from pytorch_lightning.core.lightning import LightningModule import torch.distributed as torch_distrib -from pytorch_lightning import LightningModule if torch.distributed.is_available(): from torch.distributed import ReduceOp From 47e562e835eb2a2345615d4efd4266cc0ea61c65 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 19 Nov 2020 14:06:16 +0000 Subject: [PATCH 12/71] Fix import in dp --- pytorch_lightning/accelerators/dp_accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index 4ecb2dea4bd91..b2f355bb36fc2 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -16,7 +16,7 @@ import torch from torch import optim -from pytorch_lightning import LightningModule +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.core.step_result import Result From 15734e9dc95f1cb3f40d8ef42caed4e4fc57d8c7 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 19 Nov 2020 14:18:52 +0000 Subject: [PATCH 13/71] Fixed PEP8 definition --- pytorch_lightning/plugins/ddp_plugin.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index e99d95aafb495..87abbef67581d 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -109,8 +109,10 @@ def on_before_forward(self, model, *args): def optimizer_state(self, optimizer: Optimizer) -> dict: return optimizer.state_dict() - def reference_model_in_plugin_wrapper(self, model: Union[ - LightningDistributedDataParallel, LightningModule]) -> LightningModule: + def reference_model_in_plugin_wrapper( + self, + model: Union[LightningDistributedDataParallel, LightningModule] + ) -> LightningModule: """ Override to modify returning base :class:`LightningModule` when accessing variable and functions outside of the parallel wrapper. From b44dd7507d4f6813e7e3b20b6143d9ec2619961e Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 20 Nov 2020 19:20:37 +0000 Subject: [PATCH 14/71] Add ddp launcher for ddp testing --- tests/trainer/properties/test_get_model.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 2ce20e05e64d6..36bed99498e68 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -17,6 +17,7 @@ import torch from pytorch_lightning import Trainer +from tests.backends.launcher import DDPLauncher from tests.base.boring_model import BoringModel @@ -86,7 +87,10 @@ def test_get_model_gpu(tmpdir): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") -def test_get_model_ddp_gpu(tmpdir): +@DDPLauncher.run("--accelerator [accelerator]", + max_epochs=["1"], + accelerator=["ddp", "ddp_spawn"]) +def test_get_model_ddp_gpu(tmpdir, args=None): """ Tests that :meth:`trainer.get_model` extracts the model correctly when using GPU + ddp accelerators """ @@ -100,6 +104,7 @@ def test_get_model_ddp_gpu(tmpdir): limit_val_batches=2, max_epochs=1, gpus=1, - accelerator='ddp_spawn' + accelerator=args.accelerator ) trainer.fit(model) + return 1 From 358f503848abf93dcd159a5219e0933b6b732373 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sun, 22 Nov 2020 11:39:00 +0000 Subject: [PATCH 15/71] Modify accelerator reference model to property, change name to reflect func --- pytorch_lightning/accelerators/accelerator.py | 10 ++++------ pytorch_lightning/accelerators/ddp2_accelerator.py | 5 +++-- pytorch_lightning/accelerators/ddp_accelerator.py | 5 +++-- .../accelerators/ddp_cpu_spawn_accelerator.py | 5 +++-- pytorch_lightning/accelerators/ddp_hpc_accelerator.py | 5 +++-- .../accelerators/ddp_spawn_accelerator.py | 5 +++-- pytorch_lightning/accelerators/dp_accelerator.py | 7 ++++--- pytorch_lightning/plugins/ddp_plugin.py | 7 +++---- .../trainer/connectors/model_connector.py | 11 ++++++----- 9 files changed, 32 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 9ea43a9185d06..58b8bf2a80ebe 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -208,22 +208,20 @@ def optimizer_state(self, optimizer: Optimizer) -> dict: return self.ddp_plugin.optimizer_state(optimizer) return optimizer.state_dict() - def reference_model(self, model) -> LightningModule: + @property + def reference_model(self) -> LightningModule: """ Override to modify returning base :class:`LightningModule` when accessing variable and functions if the accelerator has wrapped the model. Example:: - ref_model = accelerator.reference_model(model) + ref_model = accelerator.reference_model ref_model.training_step(...) - Args: - model: Accelerator model. - Returns: Reference :class:`LightningModule`. """ - return model + return self.trainer.model def __getstate__(self): return { diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 5f1d004122fe0..97bfc39051010 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -219,5 +219,6 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) - def reference_model(self, model) -> LightningModule: - return self.ddp_plugin.reference_model_in_plugin_wrapper(model) + @property + def reference_model(self) -> LightningModule: + return self.ddp_plugin.module_from_plugin(self.trainer.model) diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index d60fe6d3d61b1..b86b767306566 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -320,5 +320,6 @@ def sync_tensor(self, """ return sync_ddp_if_available(tensor, group, reduce_op) - def reference_model(self, model) -> LightningModule: - return self.ddp_plugin.reference_model_in_plugin_wrapper(model) + @property + def reference_model(self) -> LightningModule: + return self.ddp_plugin.module_from_plugin(self.trainer.model) diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 001f9259bd419..b80821df1a303 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -247,5 +247,6 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) - def reference_model(self, model) -> LightningModule: - return self.ddp_plugin.reference_model_in_plugin_wrapper(model) + @property + def reference_model(self) -> LightningModule: + return self.ddp_plugin.module_from_plugin(self.trainer.model) diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index 76e212035ce13..7ef2feb79b2b3 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -214,5 +214,6 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) - def reference_model(self, model) -> LightningModule: - return self.ddp_plugin.reference_model_in_plugin_wrapper(model) + @property + def reference_model(self) -> LightningModule: + return self.ddp_plugin.module_from_plugin(self.trainer.model) diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index 34b3b4de056e1..ea4b218091374 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -273,5 +273,6 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) - def reference_model(self, model) -> LightningModule: - return self.ddp_plugin.reference_model_in_plugin_wrapper(model) + @property + def reference_model(self) -> LightningModule: + return self.ddp_plugin.module_from_plugin(self.trainer.model) diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index b2f355bb36fc2..53e4d0e3fc1eb 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -175,8 +175,9 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list): if state is not None: scheduler.load_state_dict(state) - def reference_model(self, model: Union[LightningDataParallel, LightningModule]) -> LightningModule: + @property + def reference_model(self) -> LightningModule: + model = self.trainer.model if isinstance(model, LightningDataParallel): return model.module - else: - return model + return model diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 87abbef67581d..3f5840744e905 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -109,7 +109,7 @@ def on_before_forward(self, model, *args): def optimizer_state(self, optimizer: Optimizer) -> dict: return optimizer.state_dict() - def reference_model_in_plugin_wrapper( + def module_from_plugin( self, model: Union[LightningDistributedDataParallel, LightningModule] ) -> LightningModule: @@ -118,7 +118,7 @@ def reference_model_in_plugin_wrapper( when accessing variable and functions outside of the parallel wrapper. Example:: - ref_model = ddp_plugin.reference_model(model) + ref_model = ddp_plugin.module_from_plugin(model) ref_model.training_step(...) Args: @@ -129,5 +129,4 @@ def reference_model_in_plugin_wrapper( """ if isinstance(model, LightningDistributedDataParallel): return model.module - else: - return model + return model diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index efd22fee1d711..eac8b2785d870 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -24,7 +24,7 @@ def __init__(self, trainer): self.trainer = trainer def copy_trainer_model_properties(self, model): - ref_model = self._reference_model(model) + ref_model = self._reference_model automatic_optimization = ref_model.automatic_optimization and self.trainer.train_loop.automatic_optimization self.trainer.train_loop.automatic_optimization = automatic_optimization @@ -46,11 +46,12 @@ def copy_trainer_model_properties(self, model): m.local_rank = self.trainer.local_rank def get_model(self): - return self._reference_model(self.trainer.model) + return self._reference_model - def _reference_model(self, model): + @property + def _reference_model(self): if self.trainer.accelerator_backend: - ref_model = self.trainer.accelerator_backend.reference_model(model) + ref_model = self.trainer.accelerator_backend.model else: - ref_model = model + ref_model = self.trainer.model return ref_model From 977625c289f3c9fc24714ce746c86c89af8f8e55 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sun, 22 Nov 2020 14:54:00 +0000 Subject: [PATCH 16/71] Revert property as this is incorrect.= --- .../trainer/connectors/model_connector.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index eac8b2785d870..efd22fee1d711 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -24,7 +24,7 @@ def __init__(self, trainer): self.trainer = trainer def copy_trainer_model_properties(self, model): - ref_model = self._reference_model + ref_model = self._reference_model(model) automatic_optimization = ref_model.automatic_optimization and self.trainer.train_loop.automatic_optimization self.trainer.train_loop.automatic_optimization = automatic_optimization @@ -46,12 +46,11 @@ def copy_trainer_model_properties(self, model): m.local_rank = self.trainer.local_rank def get_model(self): - return self._reference_model + return self._reference_model(self.trainer.model) - @property - def _reference_model(self): + def _reference_model(self, model): if self.trainer.accelerator_backend: - ref_model = self.trainer.accelerator_backend.model + ref_model = self.trainer.accelerator_backend.reference_model(model) else: - ref_model = self.trainer.model + ref_model = model return ref_model From b506a7e46a6ec92994e21f3ee6dae68a198021e8 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sun, 22 Nov 2020 15:00:16 +0000 Subject: [PATCH 17/71] Revert across accelerators --- pytorch_lightning/accelerators/accelerator.py | 10 ++++++---- pytorch_lightning/accelerators/ddp2_accelerator.py | 5 ++--- pytorch_lightning/accelerators/ddp_accelerator.py | 5 ++--- .../accelerators/ddp_cpu_spawn_accelerator.py | 5 ++--- pytorch_lightning/accelerators/ddp_hpc_accelerator.py | 5 ++--- .../accelerators/ddp_spawn_accelerator.py | 5 ++--- .../trainer/connectors/model_connector.py | 6 ++---- 7 files changed, 18 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 58b8bf2a80ebe..9ea43a9185d06 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -208,20 +208,22 @@ def optimizer_state(self, optimizer: Optimizer) -> dict: return self.ddp_plugin.optimizer_state(optimizer) return optimizer.state_dict() - @property - def reference_model(self) -> LightningModule: + def reference_model(self, model) -> LightningModule: """ Override to modify returning base :class:`LightningModule` when accessing variable and functions if the accelerator has wrapped the model. Example:: - ref_model = accelerator.reference_model + ref_model = accelerator.reference_model(model) ref_model.training_step(...) + Args: + model: Accelerator model. + Returns: Reference :class:`LightningModule`. """ - return self.trainer.model + return model def __getstate__(self): return { diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 97bfc39051010..0b4c3208a8255 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -219,6 +219,5 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) - @property - def reference_model(self) -> LightningModule: - return self.ddp_plugin.module_from_plugin(self.trainer.model) + def reference_model(self, model) -> LightningModule: + return self.ddp_plugin.module_from_plugin(model) diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index b86b767306566..05b90cff10976 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -320,6 +320,5 @@ def sync_tensor(self, """ return sync_ddp_if_available(tensor, group, reduce_op) - @property - def reference_model(self) -> LightningModule: - return self.ddp_plugin.module_from_plugin(self.trainer.model) + def reference_model(self, model) -> LightningModule: + return self.ddp_plugin.module_from_plugin(model) diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index b80821df1a303..58e3f36918a70 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -247,6 +247,5 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) - @property - def reference_model(self) -> LightningModule: - return self.ddp_plugin.module_from_plugin(self.trainer.model) + def reference_model(self, model) -> LightningModule: + return self.ddp_plugin.module_from_plugin(model) diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index 7ef2feb79b2b3..4dc1fc743725e 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -214,6 +214,5 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) - @property - def reference_model(self) -> LightningModule: - return self.ddp_plugin.module_from_plugin(self.trainer.model) + def reference_model(self, model) -> LightningModule: + return self.ddp_plugin.module_from_plugin(model) diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index ea4b218091374..b630ebef38ecc 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -273,6 +273,5 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) - @property - def reference_model(self) -> LightningModule: - return self.ddp_plugin.module_from_plugin(self.trainer.model) + def reference_model(self, model) -> LightningModule: + return self.ddp_plugin.module_from_plugin(model) diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index efd22fee1d711..bda33f18a7694 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -50,7 +50,5 @@ def get_model(self): def _reference_model(self, model): if self.trainer.accelerator_backend: - ref_model = self.trainer.accelerator_backend.reference_model(model) - else: - ref_model = model - return ref_model + return self.trainer.accelerator_backend.reference_model(model) + return model From 2e8585f46a08823a2ed568440aa0a410d15bf13d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 19 Nov 2020 10:21:34 +0000 Subject: [PATCH 18/71] Add base code --- pytorch_lightning/overrides/fairscale.py | 47 ++++++++++ .../plugins/sharded_native_amp_plugin.py | 31 +++++++ pytorch_lightning/plugins/sharded_plugin.py | 90 +++++++++++++++++++ 3 files changed, 168 insertions(+) create mode 100644 pytorch_lightning/overrides/fairscale.py create mode 100644 pytorch_lightning/plugins/sharded_native_amp_plugin.py create mode 100644 pytorch_lightning/plugins/sharded_plugin.py diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py new file mode 100644 index 0000000000000..6d23d67965253 --- /dev/null +++ b/pytorch_lightning/overrides/fairscale.py @@ -0,0 +1,47 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Union + +from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel +from fairscale.optim import OSS +from torch import nn + + +class LightningShardedDataParallel(ShardedDataParallel): + def __init__( + self, + base_model: nn.Module, + sharded_optimizer: Union[OSS, List[OSS]], + process_group: Any = None, + broadcast_buffers: bool = True + ): + super().__init__( + base_model=base_model, + sharded_optimizer=sharded_optimizer, + process_group=process_group, + broadcast_buffers=broadcast_buffers + ) + self.module = base_model + + def forward(self, *inputs, **kwargs): + if self.enable_broadcast_buffers: + self.sync_buffers() + + if self.base_model.training: + outputs = self.base_model.training_step(*inputs, **kwargs) + elif self.base_model.testing: + outputs = self.base_model.test_step(*inputs, **kwargs) + else: + outputs = self.base_model.validation_step(*inputs, **kwargs) + return outputs diff --git a/pytorch_lightning/plugins/sharded_native_amp_plugin.py b/pytorch_lightning/plugins/sharded_native_amp_plugin.py new file mode 100644 index 0000000000000..149b49762d802 --- /dev/null +++ b/pytorch_lightning/plugins/sharded_native_amp_plugin.py @@ -0,0 +1,31 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import cast + +from fairscale.optim import OSS +from fairscale.optim.grad_scaler import ShardedGradScaler + +from pytorch_lightning.plugins.native_amp import NativeAMPPlugin + + +class ShardedNativeAMPPlugin(NativeAMPPlugin): + @property + def scaler(self): + return ShardedGradScaler() + + def clip_gradients(self, grad_clip_val, model, optimizer): + max_norm = grad_clip_val + norm_type = float(2.0) + optimizer = cast(OSS, optimizer) + optimizer.clip_grad_norm(max_norm, norm_type=norm_type) diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py new file mode 100644 index 0000000000000..ec99b42f4e1f4 --- /dev/null +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -0,0 +1,90 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Any, Optional + +from fairscale.optim import OSS + +from pytorch_lightning import LightningModule +from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel +from pytorch_lightning.plugins.ddp_plugin import DDPPlugin +from pytorch_lightning.utilities import rank_zero_only + + +class DDPShardedPlugin(DDPPlugin): + + def configure_ddp( + self, model: LightningModule, device_ids: List[int] + ): + self._wrap_optimizers(model) + if model.trainer.testing: # Revert to standard DDP if using testing + super().configure_ddp( + model=model, + device_ids=device_ids + ) + else: + model = LightningShardedDataParallel(model, sharded_optimizer=model.trainer.optimizers) + return model + + def optimizer_state(self, optimizer: OSS) -> Optional[dict]: + optimizer.consolidate_state_dict() + return self._optim_state_dict(optimizer) + + def on_before_forward(self, args: Any, model: LightningModule): + batch = args[0] + batch = model.transfer_batch_to_device(batch, model.trainer.root_gpu) + args[0] = batch + return args + + @rank_zero_only + def _optim_state_dict(self, optimizer): + """ + Ensure we only return the state dict from the optimizer on rank 0. + Other ranks do not have the complete optimizer state. + Args: + optimizer: OSS Optimizer + Returns: + State dict if rank 0 else None. + """ + return optimizer.state_dict() + + def _wrap_optimizers(self, model): + trainer = model.trainer + if trainer.testing is True: + return + + self._reinit_with_fairscale_oss(trainer) + + def _reinit_with_fairscale_oss(self, trainer): + """ + Re-initialise optimizers to use OSS wrapper. We need to re-initialise due to + the parameters being sharded across distributed processes, each optimizing a partition. + Args: + trainer: trainer object to reinit optimizers. + """ + optimizers = trainer.optimizers + lr_schedulers = trainer.lr_schedulers + for x, optimizer in enumerate(optimizers): + if not isinstance(optimizer, OSS): + optim_class = type(optimizer) + zero_optimizer = OSS( + params=optimizer.param_groups, + optim=optim_class, + **optimizer.defaults + ) + optimizers[x] = zero_optimizer + for scheduler in lr_schedulers: + scheduler = scheduler['scheduler'] + if scheduler.optimizer == optimizer: + scheduler.optimizer = zero_optimizer + del optimizer From 9c34589493743c7dd023505223c009730567fb6f Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 20 Nov 2020 14:33:01 +0000 Subject: [PATCH 19/71] Assert availability via imports --- pytorch_lightning/plugins/sharded_plugin.py | 23 ++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index ec99b42f4e1f4..cee59b256bbca 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -13,16 +13,27 @@ # limitations under the License. from typing import List, Any, Optional -from fairscale.optim import OSS - from pytorch_lightning import LightningModule -from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel + from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +try: + from fairscale.optim import OSS + from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel +except (ModuleNotFoundError, ImportError): + FAIRSCALE_AVAILABLE = False +else: + FAIRSCALE_AVAILABLE = True class DDPShardedPlugin(DDPPlugin): + def __init__(self, **kwargs): + self._check_fairscale() + super().__init__(**kwargs) + def configure_ddp( self, model: LightningModule, device_ids: List[int] ): @@ -46,6 +57,12 @@ def on_before_forward(self, args: Any, model: LightningModule): args[0] = batch return args + def _check_fairscale(self): + if not FAIRSCALE_AVAILABLE: + raise MisconfigurationException( + 'Requested Fairscale Feature, but Fairscale is not installed.' + ) + @rank_zero_only def _optim_state_dict(self, optimizer): """ From 1e429bae58b41eacb9d2ae3c9f5dba9696d41c02 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sat, 21 Nov 2020 11:40:38 +0000 Subject: [PATCH 20/71] Unified API upstream with suggestion to ben --- pytorch_lightning/overrides/fairscale.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 6d23d67965253..9691465103d14 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -11,28 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Union from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel -from fairscale.optim import OSS -from torch import nn class LightningShardedDataParallel(ShardedDataParallel): - def __init__( - self, - base_model: nn.Module, - sharded_optimizer: Union[OSS, List[OSS]], - process_group: Any = None, - broadcast_buffers: bool = True - ): - super().__init__( - base_model=base_model, - sharded_optimizer=sharded_optimizer, - process_group=process_group, - broadcast_buffers=broadcast_buffers - ) - self.module = base_model def forward(self, *inputs, **kwargs): if self.enable_broadcast_buffers: From 4ae6f0969a582a2fdcabbc03442832eb99f071e8 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sun, 22 Nov 2020 13:59:02 +0000 Subject: [PATCH 21/71] Fixed reference --- pytorch_lightning/overrides/fairscale.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 9691465103d14..73d9a6e6fb107 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -21,10 +21,10 @@ def forward(self, *inputs, **kwargs): if self.enable_broadcast_buffers: self.sync_buffers() - if self.base_model.training: - outputs = self.base_model.training_step(*inputs, **kwargs) - elif self.base_model.testing: - outputs = self.base_model.test_step(*inputs, **kwargs) + if self.module.training: + outputs = self.module.training_step(*inputs, **kwargs) + elif self.module.testing: + outputs = self.module.test_step(*inputs, **kwargs) else: - outputs = self.base_model.validation_step(*inputs, **kwargs) + outputs = self.module.validation_step(*inputs, **kwargs) return outputs From 50ed083fc7ee6b8771247de112ed64d6532be841 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sun, 22 Nov 2020 14:42:55 +0000 Subject: [PATCH 22/71] Add module wrapper code --- pytorch_lightning/plugins/sharded_plugin.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index cee59b256bbca..cc7220a9063f2 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Any, Optional - -from pytorch_lightning import LightningModule +from typing import List, Any, Optional, Union +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -105,3 +104,11 @@ def _reinit_with_fairscale_oss(self, trainer): if scheduler.optimizer == optimizer: scheduler.optimizer = zero_optimizer del optimizer + + def module_from_plugin( + self, + model: Union[LightningShardedDataParallel, LightningModule] + ) -> LightningModule: + if isinstance(model, LightningShardedDataParallel): + return model.module + return model From df416f6c78cf81518490553d077bc4b4b3146fb8 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sun, 22 Nov 2020 15:06:11 +0000 Subject: [PATCH 23/71] Fix conversion in on_before_forward --- pytorch_lightning/plugins/sharded_plugin.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index cc7220a9063f2..fdb632486c666 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Any, Optional, Union +from typing import List, Optional, Union from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.ddp_plugin import DDPPlugin @@ -50,11 +50,12 @@ def optimizer_state(self, optimizer: OSS) -> Optional[dict]: optimizer.consolidate_state_dict() return self._optim_state_dict(optimizer) - def on_before_forward(self, args: Any, model: LightningModule): + def on_before_forward(self, model: LightningModule, *args): + args = list(args) batch = args[0] batch = model.transfer_batch_to_device(batch, model.trainer.root_gpu) args[0] = batch - return args + return tuple(args) def _check_fairscale(self): if not FAIRSCALE_AVAILABLE: From c590e3a1665153deb9a3a9d46722114a3819c138 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sun, 22 Nov 2020 15:18:50 +0000 Subject: [PATCH 24/71] Ensure we check if we should use sharded amp plugin --- .../trainer/connectors/precision_connector.py | 27 ++++++++++++++----- pytorch_lightning/trainer/trainer.py | 2 +- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/precision_connector.py b/pytorch_lightning/trainer/connectors/precision_connector.py index 09b3c0d6cbd12..aa14fc9600dc8 100644 --- a/pytorch_lightning/trainer/connectors/precision_connector.py +++ b/pytorch_lightning/trainer/connectors/precision_connector.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from pytorch_lightning import _logger as log from pytorch_lightning.plugins.apex import ApexPlugin from pytorch_lightning.plugins.native_amp import NativeAMPPlugin +from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin +from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, AMPType, rank_zero_warn @@ -24,7 +27,7 @@ def __init__(self, trainer): self.trainer = trainer self.backend = None - def on_trainer_init(self, precision, amp_level, amp_backend): + def on_trainer_init(self, precision, amp_level, amp_backend, plugins): # AMP init # These are the only lines needed after v0.8.0 # we wrap the user's forward with autocast and give it back at the end of fit @@ -33,14 +36,14 @@ def on_trainer_init(self, precision, amp_level, amp_backend): self.trainer.scaler = None self.trainer.amp_level = amp_level - self.init_amp(amp_backend) + self.init_amp(amp_backend, plugins) - def init_amp(self, amp_type: str): + def init_amp(self, amp_type: str, plugins: Optional[list]): assert self.trainer.precision in (16, 32), 'only 32 or 16 bit precision supported' self.trainer.amp_backend = None - self._setup_amp_backend(amp_type) + self._setup_amp_backend(amp_type, plugins) - def _setup_amp_backend(self, amp_type: str): + def _setup_amp_backend(self, amp_type: str, plugins: Optional[list]): if self.trainer.precision != 16: # no AMP requested, so we can leave now return @@ -54,9 +57,13 @@ def _setup_amp_backend(self, amp_type: str): ' We will attempt to use NVIDIA Apex for this session.') amp_type = 'apex' else: - log.info('Using native 16bit precision.') self.trainer.amp_backend = AMPType.NATIVE - self.backend = NativeAMPPlugin(self.trainer) + if plugins and self._sharded_in_plugins(plugins): + log.info('Using Sharded 16bit plugin.') + self.backend = ShardedNativeAMPPlugin(self.trainer) + else: + log.info('Using native 16bit precision.') + self.backend = NativeAMPPlugin(self.trainer) if amp_type == 'apex': if not APEX_AVAILABLE: @@ -79,3 +86,9 @@ def connect(self, model): self.trainer.optimizers = optimizers return model + + def _sharded_in_plugins(self, plugins): + for plugin in plugins: + if isinstance(plugin, DDPShardedPlugin): + return True + return False diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 43babc0b34088..7ef3d51db2fcb 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -399,7 +399,7 @@ def __init__( ) # set precision - self.precision_connector.on_trainer_init(precision, amp_level, amp_backend) + self.precision_connector.on_trainer_init(precision, amp_level, amp_backend, plugins) # last thing are the plugins which override whatever the trainer used by default self.plugin_connector.on_trainer_init(plugins) From 08d37d9cd2abc2b82e8c1e25d697040eef7b1284 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 23 Nov 2020 20:20:19 +0000 Subject: [PATCH 25/71] Fixed name ref --- pytorch_lightning/plugins/sharded_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index fdb632486c666..67902a1a60ce9 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -106,7 +106,7 @@ def _reinit_with_fairscale_oss(self, trainer): scheduler.optimizer = zero_optimizer del optimizer - def module_from_plugin( + def get_model_from_plugin( self, model: Union[LightningShardedDataParallel, LightningModule] ) -> LightningModule: From f765364c02081dc3ba8466d346bc82b3fa2bbfa7 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 24 Nov 2020 18:05:00 +0000 Subject: [PATCH 26/71] Fixed configure_ddp, removed lr scheduler modification, added unit tests --- pytorch_lightning/plugins/sharded_plugin.py | 10 +- tests/plugins/test_sharded_plugin.py | 197 ++++++++++++++++++++ 2 files changed, 199 insertions(+), 8 deletions(-) create mode 100644 tests/plugins/test_sharded_plugin.py diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index 67902a1a60ce9..b7806d2c14d8c 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -38,13 +38,12 @@ def configure_ddp( ): self._wrap_optimizers(model) if model.trainer.testing: # Revert to standard DDP if using testing - super().configure_ddp( + return super().configure_ddp( model=model, device_ids=device_ids ) else: - model = LightningShardedDataParallel(model, sharded_optimizer=model.trainer.optimizers) - return model + return LightningShardedDataParallel(model, sharded_optimizer=model.trainer.optimizers) def optimizer_state(self, optimizer: OSS) -> Optional[dict]: optimizer.consolidate_state_dict() @@ -90,7 +89,6 @@ def _reinit_with_fairscale_oss(self, trainer): trainer: trainer object to reinit optimizers. """ optimizers = trainer.optimizers - lr_schedulers = trainer.lr_schedulers for x, optimizer in enumerate(optimizers): if not isinstance(optimizer, OSS): optim_class = type(optimizer) @@ -100,10 +98,6 @@ def _reinit_with_fairscale_oss(self, trainer): **optimizer.defaults ) optimizers[x] = zero_optimizer - for scheduler in lr_schedulers: - scheduler = scheduler['scheduler'] - if scheduler.optimizer == optimizer: - scheduler.optimizer = zero_optimizer del optimizer def get_model_from_plugin( diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py new file mode 100644 index 0000000000000..a4ff2e3466e99 --- /dev/null +++ b/tests/plugins/test_sharded_plugin.py @@ -0,0 +1,197 @@ +import os +import platform +import time +from distutils.version import LooseVersion +from unittest import mock + +import pytest +import torch +from torch.utils.data.distributed import DistributedSampler + +from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin +from tests.base.boring_model import BoringModel, RandomDataset + + +@mock.patch.dict( + os.environ, + { + "CUDA_VISIBLE_DEVICES": "0,1", + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0", + }, +) +@mock.patch("torch.cuda.device_count", return_value=2) +@pytest.mark.parametrize( + ["ddp_backend", "gpus", "num_processes"], + [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], +) +def test_ddp_choice_sharded_cpu(tmpdir, ddp_backend, gpus, num_processes): + class CB(Callback): + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend.ddp_plugin, DDPShardedPlugin) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + gpus=gpus, + num_processes=num_processes, + distributed_backend=ddp_backend, + plugins=[DDPShardedPlugin()], + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +def test_ddp_sharded_plugin_correctness_one_device(): + run_sharded_correctness(accelerator='ddp_cpu') + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +def test_ddp_sharded_plugin_correctness_one_gpu(): + run_sharded_correctness(gpus=1, accelerator='ddp_spawn') + + +@pytest.mark.skipif( + LooseVersion(torch.__version__) < LooseVersion("1.6.0"), + reason="Minimal PT version is set to 1.6") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +def test_ddp_sharded_plugin_correctness_amp_one_gpu(): + run_sharded_correctness(gpus=1, precision=16, accelerator='ddp_spawn') + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +def test_ddp_sharded_plugin_correctness_multi_gpu(): + run_sharded_correctness(gpus=2, accelerator='ddp_spawn') + + +@pytest.mark.skipif( + LooseVersion(torch.__version__) < LooseVersion("1.6.0"), + reason="Minimal PT version is set to 1.6") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): + run_sharded_correctness(gpus=2, precision=16, accelerator='ddp_spawn') + + +class TestModel(BoringModel): + """ + Overrides training loader to ensure we enforce the same seed for all DDP processes. + """ + + def train_dataloader(self): + seed_everything(42) + return torch.utils.data.DataLoader(RandomDataset(32, 64)) + + +def record_ddp_fit_model_stats(trainer, model, gpus): + """ + Helper to calculate wall clock time for fit + max allocated memory. + + Args: + trainer: The trainer object. + model: The LightningModule. + gpus: Number of GPUs in test. + + Returns: + Max Memory if using GPUs, and total wall clock time. + + """ + max_memory = None + if gpus > 0: + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + time_start = time.perf_counter() + trainer.fit(model) + + if gpus > 0: + torch.cuda.synchronize() + max_memory = torch.cuda.max_memory_allocated() / 2 ** 20 + total_time = time.perf_counter() - time_start + + return max_memory, total_time + + +def run_sharded_correctness( + accelerator='ddp_spawn', + gpus=0, + precision=32, + max_percent_speed_regression=0.1): + """ + Ensures that the trained model is identical to the standard DDP implementation. + Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate. + Args: + accelerator: Accelerator type for test. + gpus: Number of GPUS to enable. + precision: Whether to use AMP or normal FP32 training. + max_percent_speed_regression: The maximum speed regression compared to normal DDP training + + """ + + # Train normal DDP + seed_everything(42) + ddp_model = TestModel() + + trainer = Trainer( + limit_val_batches=0.0, + fast_dev_run=False, + max_epochs=1, + gpus=gpus, + precision=precision, + accelerator=accelerator, + ) + + max_ddp_memory, ddp_time = record_ddp_fit_model_stats( + trainer=trainer, + model=ddp_model, + gpus=gpus + ) + + # Reset and train sharded DDP + seed_everything(42) + sharded_model = TestModel() + + trainer = Trainer( + limit_val_batches=0.0, + fast_dev_run=False, + max_epochs=1, + gpus=gpus, + precision=precision, + accelerator=accelerator, + plugins=[DDPShardedPlugin()], + ) + + max_sharded_memory, sharded_time = record_ddp_fit_model_stats( + trainer=trainer, + model=sharded_model, + gpus=gpus + ) + + # Assert model parameters are identical after fit + for ddp_param, shard_param in zip(ddp_model.parameters(), sharded_model.parameters()): + assert torch.equal(ddp_param, shard_param) + + # Assert speed parity + upper_bound_speed = ddp_time * (1 + max_percent_speed_regression) + assert sharded_time <= upper_bound_speed + + if gpus > 0: + # Assert CUDA memory parity + assert max_sharded_memory <= max_ddp_memory From 6b129216d00394fc00a0c088addd67573b5956ff Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 24 Nov 2020 19:23:55 +0000 Subject: [PATCH 27/71] Add catches around fairscale installation --- pytorch_lightning/plugins/sharded_native_amp_plugin.py | 9 +++++++-- pytorch_lightning/plugins/sharded_plugin.py | 2 +- .../trainer/connectors/precision_connector.py | 5 ++++- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/sharded_native_amp_plugin.py b/pytorch_lightning/plugins/sharded_native_amp_plugin.py index 149b49762d802..f3b7fef55bb99 100644 --- a/pytorch_lightning/plugins/sharded_native_amp_plugin.py +++ b/pytorch_lightning/plugins/sharded_native_amp_plugin.py @@ -13,8 +13,13 @@ # limitations under the License. from typing import cast -from fairscale.optim import OSS -from fairscale.optim.grad_scaler import ShardedGradScaler +try: + from fairscale.optim import OSS + from fairscale.optim.grad_scaler import ShardedGradScaler +except (ModuleNotFoundError, ImportError): + FAIRSCALE_AMP_AVAILABLE = False +else: + FAIRSCALE_AMP_AVAILABLE = True from pytorch_lightning.plugins.native_amp import NativeAMPPlugin diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index b7806d2c14d8c..3ed30f4ab1006 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -59,7 +59,7 @@ def on_before_forward(self, model: LightningModule, *args): def _check_fairscale(self): if not FAIRSCALE_AVAILABLE: raise MisconfigurationException( - 'Requested Fairscale Feature, but Fairscale is not installed.' + 'Sharded DDP Plugin requires Fairscale to be installed.' ) @rank_zero_only diff --git a/pytorch_lightning/trainer/connectors/precision_connector.py b/pytorch_lightning/trainer/connectors/precision_connector.py index aa14fc9600dc8..d5c82a396de2a 100644 --- a/pytorch_lightning/trainer/connectors/precision_connector.py +++ b/pytorch_lightning/trainer/connectors/precision_connector.py @@ -16,9 +16,10 @@ from pytorch_lightning import _logger as log from pytorch_lightning.plugins.apex import ApexPlugin from pytorch_lightning.plugins.native_amp import NativeAMPPlugin -from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin +from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin, FAIRSCALE_AMP_AVAILABLE from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, AMPType, rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException class PrecisionConnector: @@ -59,6 +60,8 @@ def _setup_amp_backend(self, amp_type: str, plugins: Optional[list]): else: self.trainer.amp_backend = AMPType.NATIVE if plugins and self._sharded_in_plugins(plugins): + if not FAIRSCALE_AMP_AVAILABLE: + raise MisconfigurationException('Sharded DDP Plugin requires Fairscale to be installed.') log.info('Using Sharded 16bit plugin.') self.backend = ShardedNativeAMPPlugin(self.trainer) else: From 17f23e5e669229ca189f224533461bf24908fd53 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 24 Nov 2020 20:11:12 +0000 Subject: [PATCH 28/71] Ensure imports are not required explicitly for type casting --- pytorch_lightning/plugins/sharded_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index 3ed30f4ab1006..1e801d7decb1d 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -45,7 +45,7 @@ def configure_ddp( else: return LightningShardedDataParallel(model, sharded_optimizer=model.trainer.optimizers) - def optimizer_state(self, optimizer: OSS) -> Optional[dict]: + def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]: optimizer.consolidate_state_dict() return self._optim_state_dict(optimizer) @@ -102,7 +102,7 @@ def _reinit_with_fairscale_oss(self, trainer): def get_model_from_plugin( self, - model: Union[LightningShardedDataParallel, LightningModule] + model: Union['LightningShardedDataParallel', LightningModule] ) -> LightningModule: if isinstance(model, LightningShardedDataParallel): return model.module From a52e6a4a618b13b3b53528225223d91997c02edd Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 24 Nov 2020 21:12:18 +0000 Subject: [PATCH 29/71] Add additional checkpoint tests --- tests/plugins/test_sharded_plugin.py | 64 +++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 2 deletions(-) diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index a4ff2e3466e99..7c06ca080916f 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -1,3 +1,4 @@ +import glob import os import platform import time @@ -9,8 +10,8 @@ from torch.utils.data.distributed import DistributedSampler from pytorch_lightning import Trainer, seed_everything -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin +from pytorch_lightning.callbacks import Callback, ModelCheckpoint +from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin, FAIRSCALE_AVAILABLE from tests.base.boring_model import BoringModel, RandomDataset @@ -30,6 +31,7 @@ ["ddp_backend", "gpus", "num_processes"], [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], ) +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_choice_sharded_cpu(tmpdir, ddp_backend, gpus, num_processes): class CB(Callback): def on_fit_start(self, trainer, pl_module): @@ -52,6 +54,59 @@ def on_fit_start(self, trainer, pl_module): @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_checkpoint_cpu(tmpdir): + model = BoringModel() + trainer = Trainer( + callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], + accelerator='ddp_cpu', + plugins=[DDPShardedPlugin()], + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1 + ) + + trainer.fit(model) + + checkpoint_path = glob.glob(os.path.join(tmpdir, "*.ckpt"))[0] + + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + # Assert model parameters are identical after loading + for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(ddp_param, shard_param) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_checkpoint_multi_gpu(tmpdir): + model = BoringModel() + trainer = Trainer( + callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], + gpus=2, + accelerator='ddp_spawn', + plugins=[DDPShardedPlugin()], + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1 + ) + + trainer.fit(model) + + checkpoint_path = glob.glob(os.path.join(tmpdir, "*.ckpt"))[0] + + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + # Assert model parameters are identical after loading + for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(ddp_param, shard_param) + + +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_one_device(): run_sharded_correctness(accelerator='ddp_cpu') @@ -59,6 +114,7 @@ def test_ddp_sharded_plugin_correctness_one_device(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_one_gpu(): run_sharded_correctness(gpus=1, accelerator='ddp_spawn') @@ -69,6 +125,7 @@ def test_ddp_sharded_plugin_correctness_one_gpu(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_amp_one_gpu(): run_sharded_correctness(gpus=1, precision=16, accelerator='ddp_spawn') @@ -76,6 +133,8 @@ def test_ddp_sharded_plugin_correctness_amp_one_gpu(): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, + reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_multi_gpu(): run_sharded_correctness(gpus=2, accelerator='ddp_spawn') @@ -86,6 +145,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu(): @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): run_sharded_correctness(gpus=2, precision=16, accelerator='ddp_spawn') From bfe754da1202f706680cde8f04a3c5922ad383ec Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 25 Nov 2020 12:55:02 +0000 Subject: [PATCH 30/71] Removed comments, skip test --- pytorch_lightning/plugins/sharded_plugin.py | 16 +- tests/plugins/test_sharded_plugin.py | 179 ++++++++++++++++++-- 2 files changed, 167 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index 1e801d7decb1d..5575baef0c1ab 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -37,7 +37,7 @@ def configure_ddp( self, model: LightningModule, device_ids: List[int] ): self._wrap_optimizers(model) - if model.trainer.testing: # Revert to standard DDP if using testing + if model.trainer.testing: # Revert to standard DDP if testing return super().configure_ddp( model=model, device_ids=device_ids @@ -64,14 +64,6 @@ def _check_fairscale(self): @rank_zero_only def _optim_state_dict(self, optimizer): - """ - Ensure we only return the state dict from the optimizer on rank 0. - Other ranks do not have the complete optimizer state. - Args: - optimizer: OSS Optimizer - Returns: - State dict if rank 0 else None. - """ return optimizer.state_dict() def _wrap_optimizers(self, model): @@ -82,12 +74,6 @@ def _wrap_optimizers(self, model): self._reinit_with_fairscale_oss(trainer) def _reinit_with_fairscale_oss(self, trainer): - """ - Re-initialise optimizers to use OSS wrapper. We need to re-initialise due to - the parameters being sharded across distributed processes, each optimizing a partition. - Args: - trainer: trainer object to reinit optimizers. - """ optimizers = trainer.optimizers for x, optimizer in enumerate(optimizers): if not isinstance(optimizer, OSS): diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 7c06ca080916f..9faa818e12119 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -11,6 +11,7 @@ from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import Callback, ModelCheckpoint +from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin, FAIRSCALE_AVAILABLE from tests.base.boring_model import BoringModel, RandomDataset @@ -32,7 +33,11 @@ [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], ) @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") -def test_ddp_choice_sharded_cpu(tmpdir, ddp_backend, gpus, num_processes): +def test_ddp_choice_sharded(tmpdir, ddp_backend, gpus, num_processes): + """ + Test to ensure that plugin is correctly chosen + """ + class CB(Callback): def on_fit_start(self, trainer, pl_module): assert isinstance(trainer.accelerator_backend.ddp_plugin, DDPShardedPlugin) @@ -52,18 +57,62 @@ def on_fit_start(self, trainer, pl_module): trainer.fit(model) +@mock.patch.dict( + os.environ, + { + "CUDA_VISIBLE_DEVICES": "0,1", + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0", + }, +) +@mock.patch("torch.cuda.device_count", return_value=2) +@pytest.mark.parametrize( + ["ddp_backend", "gpus", "num_processes"], + [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], +) +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_choice_sharded_amp(tmpdir, ddp_backend, gpus, num_processes): + """ + Test to ensure that plugin native amp plugin is correctly chosen when using sharded + """ + + class CB(Callback): + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend.ddp_plugin, DDPShardedPlugin) + assert isinstance(trainer.precision_connector.backend, ShardedNativeAMPPlugin) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + gpus=gpus, + precision=16, + num_processes=num_processes, + distributed_backend=ddp_backend, + plugins=[DDPShardedPlugin()], + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_checkpoint_cpu(tmpdir): + """ + Test to ensure that checkpoint is saved correctly + """ model = BoringModel() trainer = Trainer( callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], accelerator='ddp_cpu', plugins=[DDPShardedPlugin()], - limit_train_batches=2, - limit_val_batches=2, - max_epochs=1 + fast_dev_run=True, ) trainer.fit(model) @@ -82,15 +131,16 @@ def test_ddp_sharded_plugin_checkpoint_cpu(tmpdir): reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_checkpoint_multi_gpu(tmpdir): + """ + Test to ensure that checkpoint is saved correctly when using multiple GPUs + """ model = BoringModel() trainer = Trainer( callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], gpus=2, accelerator='ddp_spawn', plugins=[DDPShardedPlugin()], - limit_train_batches=2, - limit_val_batches=2, - max_epochs=1 + fast_dev_run=True, ) trainer.fit(model) @@ -104,11 +154,117 @@ def test_ddp_sharded_plugin_checkpoint_multi_gpu(tmpdir): assert torch.equal(ddp_param, shard_param) +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir): + """ + Test to ensure that resuming from checkpoint works + """ + model = BoringModel() + trainer = Trainer( + callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], + accelerator='ddp_cpu', + plugins=[DDPShardedPlugin()], + fast_dev_run=True, + ) + + trainer.fit(model) + + checkpoint_path = glob.glob(os.path.join(tmpdir, "*.ckpt"))[0] + + model = BoringModel() + + trainer = Trainer( + callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], + accelerator='ddp_cpu', + plugins=[DDPShardedPlugin()], + fast_dev_run=True, + resume_from_checkpoint=checkpoint_path + ) + + trainer.fit(model) + return 1 + + +@pytest.mark.skip(reason="Currently unsupported restarting training on different number of devices.") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_resume_from_checkpoint_downsize_gpus(tmpdir): + """ + Test to ensure that resuming from checkpoint works when downsizing number of GPUS + """ + model = BoringModel() + trainer = Trainer( + callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], + accelerator='ddp_spawn', + plugins=[DDPShardedPlugin()], + fast_dev_run=True, + gpus=2, + ) + + trainer.fit(model) + + checkpoint_path = glob.glob(os.path.join(tmpdir, "*.ckpt"))[0] + + model = BoringModel() + + trainer = Trainer( + callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], + accelerator='ddp_spawn', + plugins=[DDPShardedPlugin()], + fast_dev_run=True, + gpus=1, + resume_from_checkpoint=checkpoint_path + ) + + trainer.fit(model) + return 1 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): + """ + Test to ensure that resuming from checkpoint works when going from GPUs- > CPU + """ + model = BoringModel() + trainer = Trainer( + callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], + accelerator='ddp_spawn', + plugins=[DDPShardedPlugin()], + gpus=1, + fast_dev_run=True + ) + + trainer.fit(model) + + checkpoint_path = glob.glob(os.path.join(tmpdir, "*.ckpt"))[0] + + model = BoringModel() + + trainer = Trainer( + callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], + plugins=[DDPShardedPlugin()], + accelerator='ddp_cpu', + fast_dev_run=True, + resume_from_checkpoint=checkpoint_path + ) + + trainer.fit(model) + return 1 + + @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_one_device(): - run_sharded_correctness(accelerator='ddp_cpu') + # Allow slightly slower speed due to one CPU machine doing rigorously memory saving calls + run_sharded_correctness(accelerator='ddp_cpu', max_percent_speed_regression=0.3) @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @@ -201,7 +357,8 @@ def run_sharded_correctness( accelerator: Accelerator type for test. gpus: Number of GPUS to enable. precision: Whether to use AMP or normal FP32 training. - max_percent_speed_regression: The maximum speed regression compared to normal DDP training + max_percent_speed_regression: The maximum speed regression compared to normal DDP training. + This is more a safety net for CI which can vary in speed. """ @@ -210,8 +367,6 @@ def run_sharded_correctness( ddp_model = TestModel() trainer = Trainer( - limit_val_batches=0.0, - fast_dev_run=False, max_epochs=1, gpus=gpus, precision=precision, @@ -229,8 +384,6 @@ def run_sharded_correctness( sharded_model = TestModel() trainer = Trainer( - limit_val_batches=0.0, - fast_dev_run=False, max_epochs=1, gpus=gpus, precision=precision, From 99326088bbd58b1fb33a0368e9ecfc8656645b73 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 25 Nov 2020 15:38:54 +0000 Subject: [PATCH 31/71] Add additional test cases --- tests/plugins/test_sharded_plugin.py | 127 ++++++++++++++++++++++++++- 1 file changed, 123 insertions(+), 4 deletions(-) diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 9faa818e12119..b95eec2bd0798 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -154,6 +154,35 @@ def test_ddp_sharded_plugin_checkpoint_multi_gpu(tmpdir): assert torch.equal(ddp_param, shard_param) +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_finetune(tmpdir): + """ + Test to ensure that we can save and restart training (simulate fine-tuning) + """ + model = BoringModel() + trainer = Trainer( + callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], + gpus=2, + accelerator='ddp_spawn', + plugins=[DDPShardedPlugin()], + fast_dev_run=True, + ) + trainer.fit(model) + + checkpoint_path = glob.glob(os.path.join(tmpdir, "*.ckpt"))[0] + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + trainer = Trainer( + callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], + fast_dev_run=True, + ) + trainer.fit(saved_model) + return 1 + + @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") @@ -306,6 +335,42 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): run_sharded_correctness(gpus=2, precision=16, accelerator='ddp_spawn') +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, + reason="Fairscale is not available") +def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim(): + """ + Ensures same results using multiple optimizers across multiple GPUs + """ + run_sharded_correctness( + gpus=2, + accelerator='ddp_spawn', + model_cls=TestMultipleOptimizersModel, + max_percent_speed_regression=0.5 # multiple optimizers sharded across only two GPUs is costly. + ) + + +@pytest.mark.skip(reason="Currently DDP manual optimization is broken due to no reduce within training step.") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, + reason="Fairscale is not available") +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir): + """ + Ensures using multiple optimizers across multiple GPUs with manual optimization + """ + run_sharded_correctness( + gpus=2, + accelerator='ddp_spawn', + model_cls=TestManualModel, + max_percent_speed_regression=0.5 # multiple optimizers sharded across only two GPUs is costly. + ) + + class TestModel(BoringModel): """ Overrides training loader to ensure we enforce the same seed for all DDP processes. @@ -316,6 +381,56 @@ def train_dataloader(self): return torch.utils.data.DataLoader(RandomDataset(32, 64)) +class TestManualModel(TestModel): + def training_step(self, batch, batch_idx, optimizer_idx): + # manual + (opt_a, opt_b) = self.optimizers() + loss_1 = self.step(batch) + + self.manual_backward(loss_1, opt_a) + self.manual_optimizer_step(opt_a) + + # fake discriminator + loss_2 = self.step(batch[0]) + + # ensure we forward the correct params to the optimizer + # without retain_graph we can't do multiple backward passes + self.manual_backward(loss_2, opt_b, retain_graph=True) + self.manual_backward(loss_2, opt_a, retain_graph=True) + self.manual_optimizer_step(opt_b) + + assert self.layer.weight.grad is None or torch.all(self.layer.weight.grad == 0) + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + @property + def automatic_optimization(self) -> bool: + return False + + +class TestMultipleOptimizersModel(TestModel): + def training_step(self, batch, batch_idx, optimizer_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + def record_ddp_fit_model_stats(trainer, model, gpus): """ Helper to calculate wall clock time for fit + max allocated memory. @@ -349,7 +464,8 @@ def run_sharded_correctness( accelerator='ddp_spawn', gpus=0, precision=32, - max_percent_speed_regression=0.1): + max_percent_speed_regression=0.2, + model_cls=TestModel): """ Ensures that the trained model is identical to the standard DDP implementation. Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate. @@ -358,15 +474,17 @@ def run_sharded_correctness( gpus: Number of GPUS to enable. precision: Whether to use AMP or normal FP32 training. max_percent_speed_regression: The maximum speed regression compared to normal DDP training. - This is more a safety net for CI which can vary in speed. + This is more a safety net for variability in CI which can vary in speed, not for benchmarking. + model_cls: Model class to use for test. """ # Train normal DDP seed_everything(42) - ddp_model = TestModel() + ddp_model = model_cls() trainer = Trainer( + fast_dev_run=True, max_epochs=1, gpus=gpus, precision=precision, @@ -381,9 +499,10 @@ def run_sharded_correctness( # Reset and train sharded DDP seed_everything(42) - sharded_model = TestModel() + sharded_model = model_cls() trainer = Trainer( + fast_dev_run=True, max_epochs=1, gpus=gpus, precision=precision, From d8224687b8a54d466e18fa83c6e835e16d2cb7a3 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 25 Nov 2020 16:16:57 +0000 Subject: [PATCH 32/71] Move to percentage diff, increase diff --- tests/plugins/test_sharded_plugin.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index b95eec2bd0798..eb6bd56affaee 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -293,7 +293,7 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_one_device(): # Allow slightly slower speed due to one CPU machine doing rigorously memory saving calls - run_sharded_correctness(accelerator='ddp_cpu', max_percent_speed_regression=0.3) + run_sharded_correctness(accelerator='ddp_cpu') @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @@ -348,7 +348,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim(): gpus=2, accelerator='ddp_spawn', model_cls=TestMultipleOptimizersModel, - max_percent_speed_regression=0.5 # multiple optimizers sharded across only two GPUs is costly. + max_percent_speed_diff=0.3 # Increase speed diff since only 2 GPUs sharding 2 optimizers ) @@ -367,7 +367,6 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir): gpus=2, accelerator='ddp_spawn', model_cls=TestManualModel, - max_percent_speed_regression=0.5 # multiple optimizers sharded across only two GPUs is costly. ) @@ -445,16 +444,18 @@ def record_ddp_fit_model_stats(trainer, model, gpus): """ max_memory = None + + time_start = time.perf_counter() if gpus > 0: torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() - time_start = time.perf_counter() trainer.fit(model) if gpus > 0: torch.cuda.synchronize() max_memory = torch.cuda.max_memory_allocated() / 2 ** 20 + total_time = time.perf_counter() - time_start return max_memory, total_time @@ -464,7 +465,7 @@ def run_sharded_correctness( accelerator='ddp_spawn', gpus=0, precision=32, - max_percent_speed_regression=0.2, + max_percent_speed_diff=0.25, model_cls=TestModel): """ Ensures that the trained model is identical to the standard DDP implementation. @@ -473,7 +474,7 @@ def run_sharded_correctness( accelerator: Accelerator type for test. gpus: Number of GPUS to enable. precision: Whether to use AMP or normal FP32 training. - max_percent_speed_regression: The maximum speed regression compared to normal DDP training. + max_percent_speed_diff: The maximum speed difference compared to normal DDP training. This is more a safety net for variability in CI which can vary in speed, not for benchmarking. model_cls: Model class to use for test. @@ -518,12 +519,15 @@ def run_sharded_correctness( # Assert model parameters are identical after fit for ddp_param, shard_param in zip(ddp_model.parameters(), sharded_model.parameters()): - assert torch.equal(ddp_param, shard_param) + assert torch.equal(ddp_param, shard_param), 'Model parameters are different between DDP and Sharded plugin' - # Assert speed parity - upper_bound_speed = ddp_time * (1 + max_percent_speed_regression) - assert sharded_time <= upper_bound_speed + # Assert speed parity by ensuring percentage difference between sharded/ddp is below threshold + percent_diff = (abs(sharded_time - ddp_time) / sharded_time) + assert percent_diff <= max_percent_speed_diff, \ + f'Sharded plugin was too slow compared to DDP, Sharded Time: {sharded_time}, DDP Time: {ddp_time}' if gpus > 0: # Assert CUDA memory parity - assert max_sharded_memory <= max_ddp_memory + assert max_sharded_memory <= max_ddp_memory, \ + f'Sharded plugin used too much memory compared to DDP,' \ + f'Sharded Mem: {max_sharded_memory}, DDP Mem: {max_ddp_memory}' From a311ee17aba4266a6e63f2dc85cb316e7a9ecd47 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 25 Nov 2020 18:16:36 +0000 Subject: [PATCH 33/71] Add fairscale requirement as zip before release --- requirements/extra.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/extra.txt b/requirements/extra.txt index be21317a1d826..c5b20477a8ad0 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -8,3 +8,4 @@ scikit-learn>=0.22.2 torchtext>=0.3.1, <0.7 # TODO: temporary fix fix for compatibility onnx>=1.7.0 onnxruntime>=1.3.0 +https://github.com/facebookresearch/fairscale/archive/2b121242baf4bc999bfe6c79c75faab4746195f0.zip \ No newline at end of file From ba312473f85f2269c4491a83cb9d0c6edef24e0b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 25 Nov 2020 19:40:58 +0000 Subject: [PATCH 34/71] Add check to ensure 1.6 --- pytorch_lightning/plugins/sharded_plugin.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index 5575baef0c1ab..d6d1186992af2 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -11,16 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from distutils.version import LooseVersion from typing import List, Optional, Union +import torch + from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException try: - from fairscale.optim import OSS - from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel + IS_TORCH_AT_LEAST_1_6 = LooseVersion(torch.__version__) >= LooseVersion("1.6.0") + if IS_TORCH_AT_LEAST_1_6: + from fairscale.optim import OSS + from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel except (ModuleNotFoundError, ImportError): FAIRSCALE_AVAILABLE = False else: @@ -59,7 +64,7 @@ def on_before_forward(self, model: LightningModule, *args): def _check_fairscale(self): if not FAIRSCALE_AVAILABLE: raise MisconfigurationException( - 'Sharded DDP Plugin requires Fairscale to be installed.' + 'Sharded DDP Plugin requires Fairscale to be installed and Pytorch version 1.6 or above.' ) @rank_zero_only From 586f6c62ee578afcbab05e85dd99f7d07e671f2d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 25 Nov 2020 20:16:20 +0000 Subject: [PATCH 35/71] Attempt try catch to prevent errors --- pytorch_lightning/overrides/fairscale.py | 30 ++++++++++++++---------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 73d9a6e6fb107..4ab1933e4ae9c 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -11,20 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +try: + from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel +except (ModuleNotFoundError, ImportError): + FAIRSCALE_SHARDED_AVAILABLE = False +else: + FAIRSCALE_SHARDED_AVAILABLE = True -from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel + class LightningShardedDataParallel(ShardedDataParallel): -class LightningShardedDataParallel(ShardedDataParallel): + def forward(self, *inputs, **kwargs): + if self.enable_broadcast_buffers: + self.sync_buffers() - def forward(self, *inputs, **kwargs): - if self.enable_broadcast_buffers: - self.sync_buffers() - - if self.module.training: - outputs = self.module.training_step(*inputs, **kwargs) - elif self.module.testing: - outputs = self.module.test_step(*inputs, **kwargs) - else: - outputs = self.module.validation_step(*inputs, **kwargs) - return outputs + if self.module.training: + outputs = self.module.training_step(*inputs, **kwargs) + elif self.module.testing: + outputs = self.module.test_step(*inputs, **kwargs) + else: + outputs = self.module.validation_step(*inputs, **kwargs) + return outputs From 888b12bbc955c51ff73e79eacf9cd16a5dfdbc4d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 25 Nov 2020 20:20:07 +0000 Subject: [PATCH 36/71] Add additional else check --- pytorch_lightning/plugins/sharded_plugin.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index d6d1186992af2..45554e8617672 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -26,6 +26,8 @@ if IS_TORCH_AT_LEAST_1_6: from fairscale.optim import OSS from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel + else: + FAIRSCALE_AVAILABLE = False # Requires AMP support except (ModuleNotFoundError, ImportError): FAIRSCALE_AVAILABLE = False else: From cf7a7f7b8d2a81df9b7a3f1f873825cf853f8b5a Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 25 Nov 2020 20:20:38 +0000 Subject: [PATCH 37/71] Add additional else check --- pytorch_lightning/plugins/sharded_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index 45554e8617672..aacde746d1554 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -26,12 +26,12 @@ if IS_TORCH_AT_LEAST_1_6: from fairscale.optim import OSS from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel + + FAIRSCALE_AVAILABLE = True else: FAIRSCALE_AVAILABLE = False # Requires AMP support except (ModuleNotFoundError, ImportError): FAIRSCALE_AVAILABLE = False -else: - FAIRSCALE_AVAILABLE = True class DDPShardedPlugin(DDPPlugin): From 9215908fedea24b4c6ef8be26cb10e8950fd2634 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 25 Nov 2020 20:38:04 +0000 Subject: [PATCH 38/71] Removed line, dont abs --- pytorch_lightning/overrides/fairscale.py | 1 - tests/plugins/test_sharded_plugin.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 4ab1933e4ae9c..8d105fe896211 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -18,7 +18,6 @@ else: FAIRSCALE_SHARDED_AVAILABLE = True - class LightningShardedDataParallel(ShardedDataParallel): def forward(self, *inputs, **kwargs): diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index eb6bd56affaee..8a0156d93c1fd 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -522,7 +522,8 @@ def run_sharded_correctness( assert torch.equal(ddp_param, shard_param), 'Model parameters are different between DDP and Sharded plugin' # Assert speed parity by ensuring percentage difference between sharded/ddp is below threshold - percent_diff = (abs(sharded_time - ddp_time) / sharded_time) + percent_diff = (sharded_time - ddp_time) / sharded_time + assert percent_diff <= max_percent_speed_diff, \ f'Sharded plugin was too slow compared to DDP, Sharded Time: {sharded_time}, DDP Time: {ddp_time}' From 6b93987b31d05f121c213aa6f28702e17c8af059 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 25 Nov 2020 21:01:42 +0000 Subject: [PATCH 39/71] Revert "Add check to ensure 1.6" This reverts commit ba312473 --- pytorch_lightning/plugins/sharded_plugin.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index aacde746d1554..28d28c0a04d0e 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -11,25 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from distutils.version import LooseVersion from typing import List, Optional, Union -import torch - from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException try: - IS_TORCH_AT_LEAST_1_6 = LooseVersion(torch.__version__) >= LooseVersion("1.6.0") - if IS_TORCH_AT_LEAST_1_6: - from fairscale.optim import OSS - from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel - - FAIRSCALE_AVAILABLE = True - else: - FAIRSCALE_AVAILABLE = False # Requires AMP support + from fairscale.optim import OSS + from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel except (ModuleNotFoundError, ImportError): FAIRSCALE_AVAILABLE = False @@ -66,7 +57,7 @@ def on_before_forward(self, model: LightningModule, *args): def _check_fairscale(self): if not FAIRSCALE_AVAILABLE: raise MisconfigurationException( - 'Sharded DDP Plugin requires Fairscale to be installed and Pytorch version 1.6 or above.' + 'Sharded DDP Plugin requires Fairscale to be installed.' ) @rank_zero_only From 321e63ae8b07070b9f50da743cecd432683f0608 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 25 Nov 2020 21:17:21 +0000 Subject: [PATCH 40/71] Fixes to import --- pytorch_lightning/overrides/fairscale.py | 14 +++++++++----- pytorch_lightning/plugins/sharded_plugin.py | 10 ++++++---- tests/plugins/test_sharded_plugin.py | 9 +++------ 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 8d105fe896211..05ec35c1cbfd1 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -11,12 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: +from pytorch_lightning.utilities import _module_available, NATIVE_AMP_AVALAIBLE + +if _module_available('fairscale.nn.data_parallel.sharded_ddp') and NATIVE_AMP_AVALAIBLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel -except (ModuleNotFoundError, ImportError): - FAIRSCALE_SHARDED_AVAILABLE = False -else: - FAIRSCALE_SHARDED_AVAILABLE = True + class LightningShardedDataParallel(ShardedDataParallel): @@ -31,3 +30,8 @@ def forward(self, *inputs, **kwargs): else: outputs = self.module.validation_step(*inputs, **kwargs) return outputs + + + FAIRSCALE_SHARDED_AVAILABLE = True +else: + FAIRSCALE_SHARDED_AVAILABLE = False diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index 28d28c0a04d0e..f1a4946226747 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -15,13 +15,15 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities import rank_zero_only, _module_available, NATIVE_AMP_AVALAIBLE from pytorch_lightning.utilities.exceptions import MisconfigurationException -try: +if _module_available('fairscale.optim') and NATIVE_AMP_AVALAIBLE: from fairscale.optim import OSS from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel -except (ModuleNotFoundError, ImportError): + + FAIRSCALE_AVAILABLE = True +else: FAIRSCALE_AVAILABLE = False @@ -57,7 +59,7 @@ def on_before_forward(self, model: LightningModule, *args): def _check_fairscale(self): if not FAIRSCALE_AVAILABLE: raise MisconfigurationException( - 'Sharded DDP Plugin requires Fairscale to be installed.' + 'Sharded DDP Plugin requires Fairscale to be installed and Pytorch version 1.6 or above.' ) @rank_zero_only diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 8a0156d93c1fd..1b06cf04cb17c 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -13,6 +13,7 @@ from pytorch_lightning.callbacks import Callback, ModelCheckpoint from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin, FAIRSCALE_AVAILABLE +from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE from tests.base.boring_model import BoringModel, RandomDataset @@ -304,9 +305,7 @@ def test_ddp_sharded_plugin_correctness_one_gpu(): run_sharded_correctness(gpus=1, accelerator='ddp_spawn') -@pytest.mark.skipif( - LooseVersion(torch.__version__) < LooseVersion("1.6.0"), - reason="Minimal PT version is set to 1.6") +@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="Requires native AMP") @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @@ -324,9 +323,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu(): run_sharded_correctness(gpus=2, accelerator='ddp_spawn') -@pytest.mark.skipif( - LooseVersion(torch.__version__) < LooseVersion("1.6.0"), - reason="Minimal PT version is set to 1.6") +@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="Requires native AMP") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") From 28afc46be29f6816139858fdab566725b9290a10 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 25 Nov 2020 21:18:11 +0000 Subject: [PATCH 41/71] Removed lines --- pytorch_lightning/overrides/fairscale.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 05ec35c1cbfd1..17766673340b6 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -16,7 +16,6 @@ if _module_available('fairscale.nn.data_parallel.sharded_ddp') and NATIVE_AMP_AVALAIBLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel - class LightningShardedDataParallel(ShardedDataParallel): def forward(self, *inputs, **kwargs): @@ -31,7 +30,6 @@ def forward(self, *inputs, **kwargs): outputs = self.module.validation_step(*inputs, **kwargs) return outputs - FAIRSCALE_SHARDED_AVAILABLE = True else: FAIRSCALE_SHARDED_AVAILABLE = False From 6c8715e7395ed4c4f810a310448ea3b7ae2f46fd Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 25 Nov 2020 21:31:36 +0000 Subject: [PATCH 42/71] Swap ordering of imports --- pytorch_lightning/overrides/fairscale.py | 2 +- pytorch_lightning/plugins/sharded_plugin.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 17766673340b6..446c86e29b001 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -13,7 +13,7 @@ # limitations under the License. from pytorch_lightning.utilities import _module_available, NATIVE_AMP_AVALAIBLE -if _module_available('fairscale.nn.data_parallel.sharded_ddp') and NATIVE_AMP_AVALAIBLE: +if NATIVE_AMP_AVALAIBLE and _module_available('fairscale.nn.data_parallel.sharded_ddp'): from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel class LightningShardedDataParallel(ShardedDataParallel): diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index f1a4946226747..792e3a07dbbad 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -18,7 +18,7 @@ from pytorch_lightning.utilities import rank_zero_only, _module_available, NATIVE_AMP_AVALAIBLE from pytorch_lightning.utilities.exceptions import MisconfigurationException -if _module_available('fairscale.optim') and NATIVE_AMP_AVALAIBLE: +if NATIVE_AMP_AVALAIBLE and _module_available('fairscale.optim'): from fairscale.optim import OSS from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel From 5f2a64b778aaf7a56700fd417d573c5b8aa72a43 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 25 Nov 2020 23:23:08 +0000 Subject: [PATCH 43/71] Add explicit checkpoints for tests --- tests/plugins/test_sharded_plugin.py | 34 +++++++++++----------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 1b06cf04cb17c..14cd7e26bdd0f 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -1,8 +1,6 @@ -import glob import os import platform import time -from distutils.version import LooseVersion from unittest import mock import pytest @@ -10,7 +8,7 @@ from torch.utils.data.distributed import DistributedSampler from pytorch_lightning import Trainer, seed_everything -from pytorch_lightning.callbacks import Callback, ModelCheckpoint +from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin, FAIRSCALE_AVAILABLE from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE @@ -110,7 +108,6 @@ def test_ddp_sharded_plugin_checkpoint_cpu(tmpdir): """ model = BoringModel() trainer = Trainer( - callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], accelerator='ddp_cpu', plugins=[DDPShardedPlugin()], fast_dev_run=True, @@ -118,8 +115,8 @@ def test_ddp_sharded_plugin_checkpoint_cpu(tmpdir): trainer.fit(model) - checkpoint_path = glob.glob(os.path.join(tmpdir, "*.ckpt"))[0] - + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) saved_model = BoringModel.load_from_checkpoint(checkpoint_path) # Assert model parameters are identical after loading @@ -137,7 +134,6 @@ def test_ddp_sharded_plugin_checkpoint_multi_gpu(tmpdir): """ model = BoringModel() trainer = Trainer( - callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], gpus=2, accelerator='ddp_spawn', plugins=[DDPShardedPlugin()], @@ -146,8 +142,8 @@ def test_ddp_sharded_plugin_checkpoint_multi_gpu(tmpdir): trainer.fit(model) - checkpoint_path = glob.glob(os.path.join(tmpdir, "*.ckpt"))[0] - + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) saved_model = BoringModel.load_from_checkpoint(checkpoint_path) # Assert model parameters are identical after loading @@ -165,7 +161,6 @@ def test_ddp_sharded_plugin_finetune(tmpdir): """ model = BoringModel() trainer = Trainer( - callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], gpus=2, accelerator='ddp_spawn', plugins=[DDPShardedPlugin()], @@ -173,11 +168,11 @@ def test_ddp_sharded_plugin_finetune(tmpdir): ) trainer.fit(model) - checkpoint_path = glob.glob(os.path.join(tmpdir, "*.ckpt"))[0] + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) saved_model = BoringModel.load_from_checkpoint(checkpoint_path) trainer = Trainer( - callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], fast_dev_run=True, ) trainer.fit(saved_model) @@ -193,7 +188,6 @@ def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir): """ model = BoringModel() trainer = Trainer( - callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], accelerator='ddp_cpu', plugins=[DDPShardedPlugin()], fast_dev_run=True, @@ -201,12 +195,12 @@ def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir): trainer.fit(model) - checkpoint_path = glob.glob(os.path.join(tmpdir, "*.ckpt"))[0] + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) model = BoringModel() trainer = Trainer( - callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], accelerator='ddp_cpu', plugins=[DDPShardedPlugin()], fast_dev_run=True, @@ -228,7 +222,6 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_downsize_gpus(tmpdir): """ model = BoringModel() trainer = Trainer( - callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], accelerator='ddp_spawn', plugins=[DDPShardedPlugin()], fast_dev_run=True, @@ -237,12 +230,12 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_downsize_gpus(tmpdir): trainer.fit(model) - checkpoint_path = glob.glob(os.path.join(tmpdir, "*.ckpt"))[0] + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) model = BoringModel() trainer = Trainer( - callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], accelerator='ddp_spawn', plugins=[DDPShardedPlugin()], fast_dev_run=True, @@ -264,7 +257,6 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): """ model = BoringModel() trainer = Trainer( - callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], accelerator='ddp_spawn', plugins=[DDPShardedPlugin()], gpus=1, @@ -273,12 +265,12 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): trainer.fit(model) - checkpoint_path = glob.glob(os.path.join(tmpdir, "*.ckpt"))[0] + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) model = BoringModel() trainer = Trainer( - callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)], plugins=[DDPShardedPlugin()], accelerator='ddp_cpu', fast_dev_run=True, From 79527672cb8b254c01a03e2215b151c28a8d6b96 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 26 Nov 2020 10:13:27 +0000 Subject: [PATCH 44/71] Remove amp check as guard now upstream --- pytorch_lightning/overrides/fairscale.py | 4 ++-- pytorch_lightning/plugins/sharded_native_amp_plugin.py | 10 ++++++---- pytorch_lightning/plugins/sharded_plugin.py | 6 +++--- requirements/extra.txt | 2 +- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 446c86e29b001..3e84191b4cb24 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.utilities import _module_available, NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities import _module_available -if NATIVE_AMP_AVALAIBLE and _module_available('fairscale.nn.data_parallel.sharded_ddp'): +if _module_available('fairscale.nn.data_parallel.sharded_ddp'): from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel class LightningShardedDataParallel(ShardedDataParallel): diff --git a/pytorch_lightning/plugins/sharded_native_amp_plugin.py b/pytorch_lightning/plugins/sharded_native_amp_plugin.py index f3b7fef55bb99..d244f68dcca4c 100644 --- a/pytorch_lightning/plugins/sharded_native_amp_plugin.py +++ b/pytorch_lightning/plugins/sharded_native_amp_plugin.py @@ -13,13 +13,15 @@ # limitations under the License. from typing import cast -try: +from pytorch_lightning.utilities import _module_available, NATIVE_AMP_AVALAIBLE + +if NATIVE_AMP_AVALAIBLE and _module_available('fairscale.optim'): from fairscale.optim import OSS from fairscale.optim.grad_scaler import ShardedGradScaler -except (ModuleNotFoundError, ImportError): - FAIRSCALE_AMP_AVAILABLE = False -else: + FAIRSCALE_AMP_AVAILABLE = True +else: + FAIRSCALE_AMP_AVAILABLE = False from pytorch_lightning.plugins.native_amp import NativeAMPPlugin diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index 792e3a07dbbad..f45ce7fdddcec 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -15,10 +15,10 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.utilities import rank_zero_only, _module_available, NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities import rank_zero_only, _module_available from pytorch_lightning.utilities.exceptions import MisconfigurationException -if NATIVE_AMP_AVALAIBLE and _module_available('fairscale.optim'): +if _module_available('fairscale.optim'): from fairscale.optim import OSS from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel @@ -59,7 +59,7 @@ def on_before_forward(self, model: LightningModule, *args): def _check_fairscale(self): if not FAIRSCALE_AVAILABLE: raise MisconfigurationException( - 'Sharded DDP Plugin requires Fairscale to be installed and Pytorch version 1.6 or above.' + 'Sharded DDP Plugin requires Fairscale to be installed.' ) @rank_zero_only diff --git a/requirements/extra.txt b/requirements/extra.txt index c5b20477a8ad0..e1f9bc18f2241 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -8,4 +8,4 @@ scikit-learn>=0.22.2 torchtext>=0.3.1, <0.7 # TODO: temporary fix fix for compatibility onnx>=1.7.0 onnxruntime>=1.3.0 -https://github.com/facebookresearch/fairscale/archive/2b121242baf4bc999bfe6c79c75faab4746195f0.zip \ No newline at end of file +https://github.com/facebookresearch/fairscale/archive/8e85ce8c93569017521d92ceb78dba2c57c955a0.zip # TODO temporary fix till release version \ No newline at end of file From 80e5329c1f7056ae5e293436e031cf1c41e0aa96 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 26 Nov 2020 10:24:25 +0000 Subject: [PATCH 45/71] Add check for windows to plugin --- pytorch_lightning/plugins/sharded_plugin.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index f45ce7fdddcec..35a71faa893bc 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import platform from typing import List, Optional, Union from pytorch_lightning.core.lightning import LightningModule @@ -18,7 +19,7 @@ from pytorch_lightning.utilities import rank_zero_only, _module_available from pytorch_lightning.utilities.exceptions import MisconfigurationException -if _module_available('fairscale.optim'): +if _module_available('fairscale.optim') and not platform.system() == "Windows": # Distributed not supported on windows from fairscale.optim import OSS from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel From fa5934492f7f2584416edcfcb5ad91ff29f313c8 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 26 Nov 2020 10:50:32 +0000 Subject: [PATCH 46/71] Fixes --- tests/plugins/test_sharded_plugin.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 14cd7e26bdd0f..263919358488b 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -73,6 +73,7 @@ def on_fit_start(self, trainer, pl_module): [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], ) @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="Requires native AMP") def test_ddp_choice_sharded_amp(tmpdir, ddp_backend, gpus, num_processes): """ Test to ensure that plugin native amp plugin is correctly chosen when using sharded @@ -285,8 +286,8 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_one_device(): - # Allow slightly slower speed due to one CPU machine doing rigorously memory saving calls - run_sharded_correctness(accelerator='ddp_cpu') + # Allow slightly slower speed due to one CPU doing additional sequential memory saving calls + run_sharded_correctness(accelerator='ddp_cpu', max_percent_speed_diff=0.5) @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") From 8f9763166dbf7dcb1213111351fae09c5ea33ad7 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 26 Nov 2020 10:59:23 +0000 Subject: [PATCH 47/71] Add check to fairscale override --- pytorch_lightning/overrides/fairscale.py | 4 +++- pytorch_lightning/plugins/sharded_plugin.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 3e84191b4cb24..b53b92595dda9 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import platform + from pytorch_lightning.utilities import _module_available -if _module_available('fairscale.nn.data_parallel.sharded_ddp'): +if _module_available('fairscale.nn.data_parallel.sharded_ddp') and platform.system() != "Windows": from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel class LightningShardedDataParallel(ShardedDataParallel): diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index 35a71faa893bc..dcd9b6afbd286 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -19,7 +19,7 @@ from pytorch_lightning.utilities import rank_zero_only, _module_available from pytorch_lightning.utilities.exceptions import MisconfigurationException -if _module_available('fairscale.optim') and not platform.system() == "Windows": # Distributed not supported on windows +if _module_available('fairscale.optim') and platform.system() != "Windows": # Distributed not supported on windows from fairscale.optim import OSS from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel From 091c23639275d8c0888d26d7a995b680428168c9 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 26 Nov 2020 11:07:11 +0000 Subject: [PATCH 48/71] Ensure we do windows check first --- pytorch_lightning/overrides/fairscale.py | 2 +- pytorch_lightning/plugins/sharded_plugin.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index b53b92595dda9..e6bc5b3e255ea 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -15,7 +15,7 @@ from pytorch_lightning.utilities import _module_available -if _module_available('fairscale.nn.data_parallel.sharded_ddp') and platform.system() != "Windows": +if platform.system() != "Windows" and _module_available('fairscale.nn.data_parallel.sharded_ddp'): from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel class LightningShardedDataParallel(ShardedDataParallel): diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index dcd9b6afbd286..2eb58093dc6ee 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -19,7 +19,7 @@ from pytorch_lightning.utilities import rank_zero_only, _module_available from pytorch_lightning.utilities.exceptions import MisconfigurationException -if _module_available('fairscale.optim') and platform.system() != "Windows": # Distributed not supported on windows +if platform.system() != "Windows" and _module_available('fairscale.optim'): # Distributed not supported on windows from fairscale.optim import OSS from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel @@ -43,8 +43,7 @@ def configure_ddp( model=model, device_ids=device_ids ) - else: - return LightningShardedDataParallel(model, sharded_optimizer=model.trainer.optimizers) + return LightningShardedDataParallel(model, sharded_optimizer=model.trainer.optimizers) def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]: optimizer.consolidate_state_dict() From ff34a8fed918a588b3c14f98e420f8d0aa15cbc4 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Thu, 26 Nov 2020 16:37:22 +0000 Subject: [PATCH 49/71] Update tests/plugins/test_sharded_plugin.py Co-authored-by: Jirka Borovec --- tests/plugins/test_sharded_plugin.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 263919358488b..76db75a21f1c8 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -458,8 +458,9 @@ def run_sharded_correctness( max_percent_speed_diff=0.25, model_cls=TestModel): """ - Ensures that the trained model is identical to the standard DDP implementation. - Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate. + Ensures that the trained model is identical to the standard DDP implementation. + Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate. + Args: accelerator: Accelerator type for test. gpus: Number of GPUS to enable. From 47c121ef1adab9aca9571538eb66a2945289e35a Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 26 Nov 2020 16:44:45 +0000 Subject: [PATCH 50/71] Addressed code review points --- benchmarks/test_sharded_correctness.py | 255 ++++++++++++++++++ pytorch_lightning/overrides/fairscale.py | 9 +- .../plugins/sharded_native_amp_plugin.py | 8 +- pytorch_lightning/plugins/sharded_plugin.py | 9 +- .../trainer/connectors/precision_connector.py | 7 +- pytorch_lightning/utilities/__init__.py | 2 + tests/plugins/test_sharded_plugin.py | 249 +---------------- 7 files changed, 270 insertions(+), 269 deletions(-) create mode 100644 benchmarks/test_sharded_correctness.py diff --git a/benchmarks/test_sharded_correctness.py b/benchmarks/test_sharded_correctness.py new file mode 100644 index 0000000000000..0b986f4199d17 --- /dev/null +++ b/benchmarks/test_sharded_correctness.py @@ -0,0 +1,255 @@ +import os +import platform +import time +from unittest import mock + +import pytest +import torch +from torch.utils.data.distributed import DistributedSampler + +from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin, FAIRSCALE_AVAILABLE +from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE +from tests.base.boring_model import BoringModel, RandomDataset + + +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_correctness_one_device(): + # Allow slightly slower speed due to one CPU doing additional sequential memory saving calls + run_sharded_correctness(accelerator='ddp_cpu', max_percent_speed_diff=0.5) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_correctness_one_gpu(): + run_sharded_correctness(gpus=1, accelerator='ddp_spawn') + + +@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="Requires native AMP") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_correctness_amp_one_gpu(): + run_sharded_correctness(gpus=1, precision=16, accelerator='ddp_spawn') + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, + reason="Fairscale is not available") +def test_ddp_sharded_plugin_correctness_multi_gpu(): + run_sharded_correctness(gpus=2, accelerator='ddp_spawn') + + +@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="Requires native AMP") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): + run_sharded_correctness(gpus=2, precision=16, accelerator='ddp_spawn') + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, + reason="Fairscale is not available") +def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim(): + """ + Ensures same results using multiple optimizers across multiple GPUs + """ + run_sharded_correctness( + gpus=2, + accelerator='ddp_spawn', + model_cls=SeedTrainLoaderMultipleOptimizersModel, + max_percent_speed_diff=0.3 # Increase speed diff since only 2 GPUs sharding 2 optimizers + ) + + +@pytest.mark.skip(reason="Currently DDP manual optimization is broken due to no reduce within training step.") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir): + """ + Ensures using multiple optimizers across multiple GPUs with manual optimization + """ + run_sharded_correctness( + gpus=2, + accelerator='ddp_spawn', + model_cls=SeedTrainLoaderManualModel, + ) + + +class SeedTrainLoaderModel(BoringModel): + """ + Overrides training loader to ensure we enforce the same seed for all DDP processes. + """ + + def train_dataloader(self): + seed_everything(42) + return torch.utils.data.DataLoader(RandomDataset(32, 64)) + + +class SeedTrainLoaderManualModel(SeedTrainLoaderModel): + def training_step(self, batch, batch_idx, optimizer_idx): + # manual + (opt_a, opt_b) = self.optimizers() + loss_1 = self.step(batch) + + self.manual_backward(loss_1, opt_a) + self.manual_optimizer_step(opt_a) + + # fake discriminator + loss_2 = self.step(batch[0]) + + # ensure we forward the correct params to the optimizer + # without retain_graph we can't do multiple backward passes + self.manual_backward(loss_2, opt_b, retain_graph=True) + self.manual_backward(loss_2, opt_a, retain_graph=True) + self.manual_optimizer_step(opt_b) + + assert self.layer.weight.grad is None or torch.all(self.layer.weight.grad == 0) + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + @property + def automatic_optimization(self) -> bool: + return False + + +class SeedTrainLoaderMultipleOptimizersModel(SeedTrainLoaderModel): + def training_step(self, batch, batch_idx, optimizer_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + +def record_ddp_fit_model_stats(trainer, model, gpus): + """ + Helper to calculate wall clock time for fit + max allocated memory. + + Args: + trainer: The trainer object. + model: The LightningModule. + gpus: Number of GPUs in test. + + Returns: + Max Memory if using GPUs, and total wall clock time. + + """ + max_memory = None + + time_start = time.perf_counter() + if gpus > 0: + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + trainer.fit(model) + + if gpus > 0: + torch.cuda.synchronize() + max_memory = torch.cuda.max_memory_allocated() / 2 ** 20 + + total_time = time.perf_counter() - time_start + + return max_memory, total_time + + +def run_sharded_correctness( + accelerator='ddp_spawn', + gpus=0, + precision=32, + max_percent_speed_diff=0.25, + model_cls=SeedTrainLoaderModel): + """ + Ensures that the trained model is identical to the standard DDP implementation. + Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate. + + Args: + accelerator: Accelerator type for test. + gpus: Number of GPUS to enable. + precision: Whether to use AMP or normal FP32 training. + max_percent_speed_diff: The maximum speed difference compared to normal DDP training. + This is more a safety net for variability in CI which can vary in speed, not for benchmarking. + model_cls: Model class to use for test. + + """ + + # Train normal DDP + seed_everything(42) + ddp_model = model_cls() + + trainer = Trainer( + fast_dev_run=True, + max_epochs=1, + gpus=gpus, + precision=precision, + accelerator=accelerator, + ) + + max_ddp_memory, ddp_time = record_ddp_fit_model_stats( + trainer=trainer, + model=ddp_model, + gpus=gpus + ) + + # Reset and train sharded DDP + seed_everything(42) + sharded_model = model_cls() + + trainer = Trainer( + fast_dev_run=True, + max_epochs=1, + gpus=gpus, + precision=precision, + accelerator=accelerator, + plugins=[DDPShardedPlugin()], + ) + + max_sharded_memory, sharded_time = record_ddp_fit_model_stats( + trainer=trainer, + model=sharded_model, + gpus=gpus + ) + + # Assert model parameters are identical after fit + for ddp_param, shard_param in zip(ddp_model.parameters(), sharded_model.parameters()): + assert torch.equal(ddp_param, shard_param), 'Model parameters are different between DDP and Sharded plugin' + + # Assert speed parity by ensuring percentage difference between sharded/ddp is below threshold + percent_diff = (sharded_time - ddp_time) / sharded_time + + assert percent_diff <= max_percent_speed_diff, \ + f'Sharded plugin was too slow compared to DDP, Sharded Time: {sharded_time}, DDP Time: {ddp_time}' + + if gpus > 0: + # Assert CUDA memory parity + assert max_sharded_memory <= max_ddp_memory, \ + f'Sharded plugin used too much memory compared to DDP,' \ + f'Sharded Mem: {max_sharded_memory}, DDP Mem: {max_ddp_memory}' diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index e6bc5b3e255ea..df49199e2a504 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import platform +from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE -from pytorch_lightning.utilities import _module_available - -if platform.system() != "Windows" and _module_available('fairscale.nn.data_parallel.sharded_ddp'): +if FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel class LightningShardedDataParallel(ShardedDataParallel): @@ -32,6 +30,5 @@ def forward(self, *inputs, **kwargs): outputs = self.module.validation_step(*inputs, **kwargs) return outputs - FAIRSCALE_SHARDED_AVAILABLE = True else: - FAIRSCALE_SHARDED_AVAILABLE = False + LightningShardedDataParallel = None diff --git a/pytorch_lightning/plugins/sharded_native_amp_plugin.py b/pytorch_lightning/plugins/sharded_native_amp_plugin.py index d244f68dcca4c..67aac2c3a2f82 100644 --- a/pytorch_lightning/plugins/sharded_native_amp_plugin.py +++ b/pytorch_lightning/plugins/sharded_native_amp_plugin.py @@ -13,16 +13,12 @@ # limitations under the License. from typing import cast -from pytorch_lightning.utilities import _module_available, NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE, FAIRSCALE_AVAILABLE -if NATIVE_AMP_AVALAIBLE and _module_available('fairscale.optim'): +if NATIVE_AMP_AVALAIBLE and FAIRSCALE_AVAILABLE: from fairscale.optim import OSS from fairscale.optim.grad_scaler import ShardedGradScaler - FAIRSCALE_AMP_AVAILABLE = True -else: - FAIRSCALE_AMP_AVAILABLE = False - from pytorch_lightning.plugins.native_amp import NativeAMPPlugin diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index 2eb58093dc6ee..ba7b751fdedc4 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -11,22 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import platform from typing import List, Optional, Union from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.utilities import rank_zero_only, _module_available +from pytorch_lightning.utilities import rank_zero_only, FAIRSCALE_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException -if platform.system() != "Windows" and _module_available('fairscale.optim'): # Distributed not supported on windows +if FAIRSCALE_AVAILABLE: from fairscale.optim import OSS from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel - FAIRSCALE_AVAILABLE = True -else: - FAIRSCALE_AVAILABLE = False - class DDPShardedPlugin(DDPPlugin): diff --git a/pytorch_lightning/trainer/connectors/precision_connector.py b/pytorch_lightning/trainer/connectors/precision_connector.py index d5c82a396de2a..5961ac03b4493 100644 --- a/pytorch_lightning/trainer/connectors/precision_connector.py +++ b/pytorch_lightning/trainer/connectors/precision_connector.py @@ -16,9 +16,10 @@ from pytorch_lightning import _logger as log from pytorch_lightning.plugins.apex import ApexPlugin from pytorch_lightning.plugins.native_amp import NativeAMPPlugin -from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin, FAIRSCALE_AMP_AVAILABLE +from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin -from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, AMPType, rank_zero_warn +from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, AMPType, rank_zero_warn, \ + FAIRSCALE_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -60,7 +61,7 @@ def _setup_amp_backend(self, amp_type: str, plugins: Optional[list]): else: self.trainer.amp_backend = AMPType.NATIVE if plugins and self._sharded_in_plugins(plugins): - if not FAIRSCALE_AMP_AVAILABLE: + if not FAIRSCALE_AVAILABLE: raise MisconfigurationException('Sharded DDP Plugin requires Fairscale to be installed.') log.info('Using Sharded 16bit plugin.') self.backend = ShardedNativeAMPPlugin(self.trainer) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 4d91ec8c4f5e2..83a9144efad84 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. """General utilities""" import importlib +import platform from enum import Enum import numpy @@ -43,6 +44,7 @@ def _module_available(module_path: str) -> bool: APEX_AVAILABLE = _module_available("apex.amp") NATIVE_AMP_AVALAIBLE = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") +FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel') FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 76db75a21f1c8..11ec6cbe73283 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -1,18 +1,16 @@ import os import platform -import time from unittest import mock import pytest import torch -from torch.utils.data.distributed import DistributedSampler -from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning import Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin, FAIRSCALE_AVAILABLE from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE -from tests.base.boring_model import BoringModel, RandomDataset +from tests.base.boring_model import BoringModel @mock.patch.dict( @@ -280,246 +278,3 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): trainer.fit(model) return 1 - - -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") -def test_ddp_sharded_plugin_correctness_one_device(): - # Allow slightly slower speed due to one CPU doing additional sequential memory saving calls - run_sharded_correctness(accelerator='ddp_cpu', max_percent_speed_diff=0.5) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") -def test_ddp_sharded_plugin_correctness_one_gpu(): - run_sharded_correctness(gpus=1, accelerator='ddp_spawn') - - -@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="Requires native AMP") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") -def test_ddp_sharded_plugin_correctness_amp_one_gpu(): - run_sharded_correctness(gpus=1, precision=16, accelerator='ddp_spawn') - - -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, - reason="Fairscale is not available") -def test_ddp_sharded_plugin_correctness_multi_gpu(): - run_sharded_correctness(gpus=2, accelerator='ddp_spawn') - - -@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="Requires native AMP") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") -def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): - run_sharded_correctness(gpus=2, precision=16, accelerator='ddp_spawn') - - -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, - reason="Fairscale is not available") -def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim(): - """ - Ensures same results using multiple optimizers across multiple GPUs - """ - run_sharded_correctness( - gpus=2, - accelerator='ddp_spawn', - model_cls=TestMultipleOptimizersModel, - max_percent_speed_diff=0.3 # Increase speed diff since only 2 GPUs sharding 2 optimizers - ) - - -@pytest.mark.skip(reason="Currently DDP manual optimization is broken due to no reduce within training step.") -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, - reason="Fairscale is not available") -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir): - """ - Ensures using multiple optimizers across multiple GPUs with manual optimization - """ - run_sharded_correctness( - gpus=2, - accelerator='ddp_spawn', - model_cls=TestManualModel, - ) - - -class TestModel(BoringModel): - """ - Overrides training loader to ensure we enforce the same seed for all DDP processes. - """ - - def train_dataloader(self): - seed_everything(42) - return torch.utils.data.DataLoader(RandomDataset(32, 64)) - - -class TestManualModel(TestModel): - def training_step(self, batch, batch_idx, optimizer_idx): - # manual - (opt_a, opt_b) = self.optimizers() - loss_1 = self.step(batch) - - self.manual_backward(loss_1, opt_a) - self.manual_optimizer_step(opt_a) - - # fake discriminator - loss_2 = self.step(batch[0]) - - # ensure we forward the correct params to the optimizer - # without retain_graph we can't do multiple backward passes - self.manual_backward(loss_2, opt_b, retain_graph=True) - self.manual_backward(loss_2, opt_a, retain_graph=True) - self.manual_optimizer_step(opt_b) - - assert self.layer.weight.grad is None or torch.all(self.layer.weight.grad == 0) - - def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 - - def configure_optimizers(self): - optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) - optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) - return optimizer, optimizer_2 - - @property - def automatic_optimization(self) -> bool: - return False - - -class TestMultipleOptimizersModel(TestModel): - def training_step(self, batch, batch_idx, optimizer_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - return {"loss": loss} - - def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 - - def configure_optimizers(self): - optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) - optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) - return optimizer, optimizer_2 - - -def record_ddp_fit_model_stats(trainer, model, gpus): - """ - Helper to calculate wall clock time for fit + max allocated memory. - - Args: - trainer: The trainer object. - model: The LightningModule. - gpus: Number of GPUs in test. - - Returns: - Max Memory if using GPUs, and total wall clock time. - - """ - max_memory = None - - time_start = time.perf_counter() - if gpus > 0: - torch.cuda.reset_peak_memory_stats() - torch.cuda.synchronize() - - trainer.fit(model) - - if gpus > 0: - torch.cuda.synchronize() - max_memory = torch.cuda.max_memory_allocated() / 2 ** 20 - - total_time = time.perf_counter() - time_start - - return max_memory, total_time - - -def run_sharded_correctness( - accelerator='ddp_spawn', - gpus=0, - precision=32, - max_percent_speed_diff=0.25, - model_cls=TestModel): - """ - Ensures that the trained model is identical to the standard DDP implementation. - Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate. - - Args: - accelerator: Accelerator type for test. - gpus: Number of GPUS to enable. - precision: Whether to use AMP or normal FP32 training. - max_percent_speed_diff: The maximum speed difference compared to normal DDP training. - This is more a safety net for variability in CI which can vary in speed, not for benchmarking. - model_cls: Model class to use for test. - - """ - - # Train normal DDP - seed_everything(42) - ddp_model = model_cls() - - trainer = Trainer( - fast_dev_run=True, - max_epochs=1, - gpus=gpus, - precision=precision, - accelerator=accelerator, - ) - - max_ddp_memory, ddp_time = record_ddp_fit_model_stats( - trainer=trainer, - model=ddp_model, - gpus=gpus - ) - - # Reset and train sharded DDP - seed_everything(42) - sharded_model = model_cls() - - trainer = Trainer( - fast_dev_run=True, - max_epochs=1, - gpus=gpus, - precision=precision, - accelerator=accelerator, - plugins=[DDPShardedPlugin()], - ) - - max_sharded_memory, sharded_time = record_ddp_fit_model_stats( - trainer=trainer, - model=sharded_model, - gpus=gpus - ) - - # Assert model parameters are identical after fit - for ddp_param, shard_param in zip(ddp_model.parameters(), sharded_model.parameters()): - assert torch.equal(ddp_param, shard_param), 'Model parameters are different between DDP and Sharded plugin' - - # Assert speed parity by ensuring percentage difference between sharded/ddp is below threshold - percent_diff = (sharded_time - ddp_time) / sharded_time - - assert percent_diff <= max_percent_speed_diff, \ - f'Sharded plugin was too slow compared to DDP, Sharded Time: {sharded_time}, DDP Time: {ddp_time}' - - if gpus > 0: - # Assert CUDA memory parity - assert max_sharded_memory <= max_ddp_memory, \ - f'Sharded plugin used too much memory compared to DDP,' \ - f'Sharded Mem: {max_sharded_memory}, DDP Mem: {max_ddp_memory}' From 8a0c8fe0bda62290f24db79285e17eecca0491a0 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 26 Nov 2020 16:48:21 +0000 Subject: [PATCH 51/71] Fixed imports, swap to relying on function for entire batch --- benchmarks/test_sharded_correctness.py | 4 ++-- pytorch_lightning/plugins/sharded_plugin.py | 6 +----- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/benchmarks/test_sharded_correctness.py b/benchmarks/test_sharded_correctness.py index 0b986f4199d17..23027c430d490 100644 --- a/benchmarks/test_sharded_correctness.py +++ b/benchmarks/test_sharded_correctness.py @@ -8,8 +8,8 @@ from torch.utils.data.distributed import DistributedSampler from pytorch_lightning import Trainer, seed_everything -from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin, FAIRSCALE_AVAILABLE -from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE +from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin +from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE, FAIRSCALE_AVAILABLE from tests.base.boring_model import BoringModel, RandomDataset diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index ba7b751fdedc4..2654b329d9551 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -45,11 +45,7 @@ def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]: return self._optim_state_dict(optimizer) def on_before_forward(self, model: LightningModule, *args): - args = list(args) - batch = args[0] - batch = model.transfer_batch_to_device(batch, model.trainer.root_gpu) - args[0] = batch - return tuple(args) + return model.transfer_batch_to_device(args, model.trainer.root_gpu) def _check_fairscale(self): if not FAIRSCALE_AVAILABLE: From 29e310807cd589d90aabeb8ecade8ea5d2d67d8f Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 26 Nov 2020 17:07:48 +0000 Subject: [PATCH 52/71] Fix import order --- benchmarks/test_sharded_correctness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/test_sharded_correctness.py b/benchmarks/test_sharded_correctness.py index 23027c430d490..0f6d3658801d6 100644 --- a/benchmarks/test_sharded_correctness.py +++ b/benchmarks/test_sharded_correctness.py @@ -9,7 +9,7 @@ from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin -from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE, FAIRSCALE_AVAILABLE +from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVALAIBLE from tests.base.boring_model import BoringModel, RandomDataset From c0e148bc27a51fa030f4e2380977285823d2e2b2 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 26 Nov 2020 17:37:37 +0000 Subject: [PATCH 53/71] Fix formatting --- benchmarks/test_sharded_correctness.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/benchmarks/test_sharded_correctness.py b/benchmarks/test_sharded_correctness.py index 0f6d3658801d6..c6472f0b01d8a 100644 --- a/benchmarks/test_sharded_correctness.py +++ b/benchmarks/test_sharded_correctness.py @@ -41,8 +41,7 @@ def test_ddp_sharded_plugin_correctness_amp_one_gpu(): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, - reason="Fairscale is not available") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_multi_gpu(): run_sharded_correctness(gpus=2, accelerator='ddp_spawn') @@ -59,8 +58,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, - reason="Fairscale is not available") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim(): """ Ensures same results using multiple optimizers across multiple GPUs From 2e50c2e653ba417a80dd3517895be5dd6d06ba82 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 26 Nov 2020 18:42:06 +0000 Subject: [PATCH 54/71] Remove else check --- pytorch_lightning/overrides/fairscale.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index df49199e2a504..c1d5769df6446 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -13,6 +13,7 @@ # limitations under the License. from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE +LightningShardedDataParallel = None if FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel @@ -29,6 +30,3 @@ def forward(self, *inputs, **kwargs): else: outputs = self.module.validation_step(*inputs, **kwargs) return outputs - -else: - LightningShardedDataParallel = None From ab655e511862297f5c668853a2456ce14a13f940 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 26 Nov 2020 18:49:06 +0000 Subject: [PATCH 55/71] Removed old eval logic, added eval tests --- pytorch_lightning/plugins/sharded_plugin.py | 5 --- tests/plugins/test_sharded_plugin.py | 38 +++++++++++++++++++++ 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index 2654b329d9551..84b23c3230059 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -33,11 +33,6 @@ def configure_ddp( self, model: LightningModule, device_ids: List[int] ): self._wrap_optimizers(model) - if model.trainer.testing: # Revert to standard DDP if testing - return super().configure_ddp( - model=model, - device_ids=device_ids - ) return LightningShardedDataParallel(model, sharded_optimizer=model.trainer.optimizers) def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]: diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 11ec6cbe73283..87d568a3cdad5 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -278,3 +278,41 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): trainer.fit(model) return 1 + + +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_test(tmpdir): + """ + Test to ensure we can use test without fit + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_cpu', + plugins=[DDPShardedPlugin()], + fast_dev_run=True, + ) + + trainer.test(model) + return 1 + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_test_multigpu(tmpdir): + """ + Test to ensure we can use test without fit + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_spawn', + gpus=2, + plugins=[DDPShardedPlugin()], + fast_dev_run=True, + ) + + trainer.test(model) + return 1 From a9c316b6693955d354a3229593af46ba40835f5c Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 26 Nov 2020 19:00:55 +0000 Subject: [PATCH 56/71] Add additional check to ensure apex is not used with sharded --- .../trainer/connectors/precision_connector.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/precision_connector.py b/pytorch_lightning/trainer/connectors/precision_connector.py index 5961ac03b4493..bbb8f8c69307b 100644 --- a/pytorch_lightning/trainer/connectors/precision_connector.py +++ b/pytorch_lightning/trainer/connectors/precision_connector.py @@ -50,6 +50,7 @@ def _setup_amp_backend(self, amp_type: str, plugins: Optional[list]): # no AMP requested, so we can leave now return + using_sharded_plugin = self._check_sharded_plugin(plugins) amp_type = amp_type.lower() assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}' if amp_type == 'native': @@ -60,10 +61,8 @@ def _setup_amp_backend(self, amp_type: str, plugins: Optional[list]): amp_type = 'apex' else: self.trainer.amp_backend = AMPType.NATIVE - if plugins and self._sharded_in_plugins(plugins): - if not FAIRSCALE_AVAILABLE: - raise MisconfigurationException('Sharded DDP Plugin requires Fairscale to be installed.') - log.info('Using Sharded 16bit plugin.') + if using_sharded_plugin: + log.info('Using sharded 16bit precision.') self.backend = ShardedNativeAMPPlugin(self.trainer) else: log.info('Using native 16bit precision.') @@ -73,6 +72,9 @@ def _setup_amp_backend(self, amp_type: str, plugins: Optional[list]): if not APEX_AVAILABLE: rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.' ' Install apex first using this guide: https://github.com/NVIDIA/apex#linux') + if using_sharded_plugin: + rank_zero_warn( + 'Sharded Plugin is not supported with Apex AMP, please using native AMP for 16 bit precision.') else: log.info('Using APEX 16bit precision.') self.trainer.amp_backend = AMPType.APEX @@ -91,6 +93,13 @@ def connect(self, model): return model + def _check_sharded_plugin(self, plugins): + if plugins and self._sharded_in_plugins(plugins): + if not FAIRSCALE_AVAILABLE: + raise MisconfigurationException('Sharded DDP Plugin requires Fairscale to be installed.') + return True + return False + def _sharded_in_plugins(self, plugins): for plugin in plugins: if isinstance(plugin, DDPShardedPlugin): From 8dc857c38daebf284b29051570e245b1a5284ecc Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 26 Nov 2020 22:10:57 +0000 Subject: [PATCH 57/71] Ensure we add the condition to the case statement --- pytorch_lightning/trainer/connectors/precision_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/precision_connector.py b/pytorch_lightning/trainer/connectors/precision_connector.py index bbb8f8c69307b..e4ea99bdf209f 100644 --- a/pytorch_lightning/trainer/connectors/precision_connector.py +++ b/pytorch_lightning/trainer/connectors/precision_connector.py @@ -72,7 +72,7 @@ def _setup_amp_backend(self, amp_type: str, plugins: Optional[list]): if not APEX_AVAILABLE: rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.' ' Install apex first using this guide: https://github.com/NVIDIA/apex#linux') - if using_sharded_plugin: + elif using_sharded_plugin: rank_zero_warn( 'Sharded Plugin is not supported with Apex AMP, please using native AMP for 16 bit precision.') else: From fc9b2bf0154f139c23b3a6e5aea6b57753ef323b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 26 Nov 2020 22:45:21 +0000 Subject: [PATCH 58/71] Fix logic and add test for apex check, rename file, add DDP launcher tests --- ..._correctness.py => test_sharded_parity.py} | 21 +++++++++++++++++ .../trainer/connectors/precision_connector.py | 4 ++-- tests/plugins/test_sharded_plugin.py | 23 ++++++++++++++++++- 3 files changed, 45 insertions(+), 3 deletions(-) rename benchmarks/{test_sharded_correctness.py => test_sharded_parity.py} (87%) diff --git a/benchmarks/test_sharded_correctness.py b/benchmarks/test_sharded_parity.py similarity index 87% rename from benchmarks/test_sharded_correctness.py rename to benchmarks/test_sharded_parity.py index c6472f0b01d8a..78eb98e1482f0 100644 --- a/benchmarks/test_sharded_correctness.py +++ b/benchmarks/test_sharded_parity.py @@ -10,6 +10,7 @@ from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVALAIBLE +from tests.backends.launcher import DDPLauncher from tests.base.boring_model import BoringModel, RandomDataset @@ -55,6 +56,26 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): run_sharded_correctness(gpus=2, precision=16, accelerator='ddp_spawn') +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', + reason="test should be run outside of pytest") +@DDPLauncher.run("--distributed_backend ddp --gpus 2 --precision 32") +def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None): + run_sharded_correctness(gpus=args.gpus, precision=args.precision, accelerator=args.distributed_backend) + + +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', + reason="test should be run outside of pytest") +@DDPLauncher.run("--distributed_backend ddp --gpus 2 --precision 16") +def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None): + run_sharded_correctness(gpus=args.gpus, precision=args.precision, accelerator=args.distributed_backend) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") diff --git a/pytorch_lightning/trainer/connectors/precision_connector.py b/pytorch_lightning/trainer/connectors/precision_connector.py index e4ea99bdf209f..71d884bdfe342 100644 --- a/pytorch_lightning/trainer/connectors/precision_connector.py +++ b/pytorch_lightning/trainer/connectors/precision_connector.py @@ -73,8 +73,8 @@ def _setup_amp_backend(self, amp_type: str, plugins: Optional[list]): rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.' ' Install apex first using this guide: https://github.com/NVIDIA/apex#linux') elif using_sharded_plugin: - rank_zero_warn( - 'Sharded Plugin is not supported with Apex AMP, please using native AMP for 16 bit precision.') + raise MisconfigurationException('Sharded Plugin is not supported with Apex AMP, ' + 'please using native AMP for 16 bit precision.') else: log.info('Using APEX 16bit precision.') self.trainer.amp_backend = AMPType.APEX diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 87d568a3cdad5..6d93929761920 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -9,7 +9,8 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin, FAIRSCALE_AVAILABLE -from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE, APEX_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import BoringModel @@ -54,6 +55,26 @@ def on_fit_start(self, trainer, pl_module): trainer.fit(model) +@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_invalid_apex_sharded(tmpdir): + """ + Test to ensure that we raise an error when we try to use apex and sharded + """ + + model = BoringModel() + with pytest.raises(MisconfigurationException, match='Sharded Plugin is not supported with Apex AMP'): + trainer = Trainer( + fast_dev_run=True, + distributed_backend='ddp_spawn', + plugins=[DDPShardedPlugin()], + precision=16, + amp_backend='apex' + ) + + trainer.fit(model) + + @mock.patch.dict( os.environ, { From 508eaff541fcb86987cea4b8fa83e245f438c223 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 26 Nov 2020 23:01:04 +0000 Subject: [PATCH 59/71] Fix name --- benchmarks/test_sharded_parity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index 78eb98e1482f0..7065284a5b4da 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -62,7 +62,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): @pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest") @DDPLauncher.run("--distributed_backend ddp --gpus 2 --precision 32") -def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None): +def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None): run_sharded_correctness(gpus=args.gpus, precision=args.precision, accelerator=args.distributed_backend) From bde2a129902e4775b060078ff8be03255e635d92 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 27 Nov 2020 10:37:49 +0000 Subject: [PATCH 60/71] Fix var name --- benchmarks/test_sharded_parity.py | 6 +++--- pytorch_lightning/plugins/sharded_native_amp_plugin.py | 4 ++-- tests/plugins/test_sharded_plugin.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index 7065284a5b4da..449c9d070f7f7 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -9,7 +9,7 @@ from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin -from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVAILABLE from tests.backends.launcher import DDPLauncher from tests.base.boring_model import BoringModel, RandomDataset @@ -30,7 +30,7 @@ def test_ddp_sharded_plugin_correctness_one_gpu(): run_sharded_correctness(gpus=1, accelerator='ddp_spawn') -@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="Requires native AMP") +@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP") @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @@ -47,7 +47,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu(): run_sharded_correctness(gpus=2, accelerator='ddp_spawn') -@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="Requires native AMP") +@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") diff --git a/pytorch_lightning/plugins/sharded_native_amp_plugin.py b/pytorch_lightning/plugins/sharded_native_amp_plugin.py index 67aac2c3a2f82..581a84e5cbeef 100644 --- a/pytorch_lightning/plugins/sharded_native_amp_plugin.py +++ b/pytorch_lightning/plugins/sharded_native_amp_plugin.py @@ -13,9 +13,9 @@ # limitations under the License. from typing import cast -from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE, FAIRSCALE_AVAILABLE +from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE, FAIRSCALE_AVAILABLE -if NATIVE_AMP_AVALAIBLE and FAIRSCALE_AVAILABLE: +if NATIVE_AMP_AVAILABLE and FAIRSCALE_AVAILABLE: from fairscale.optim import OSS from fairscale.optim.grad_scaler import ShardedGradScaler diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 6d93929761920..a36e4575002be 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -9,7 +9,7 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin, FAIRSCALE_AVAILABLE -from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE, APEX_AVAILABLE +from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE, APEX_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import BoringModel @@ -92,7 +92,7 @@ def test_invalid_apex_sharded(tmpdir): [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], ) @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") -@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="Requires native AMP") +@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP") def test_ddp_choice_sharded_amp(tmpdir, ddp_backend, gpus, num_processes): """ Test to ensure that plugin native amp plugin is correctly chosen when using sharded From 10d41fb4ea7dd5e0e2bfea94edfeba08556ade9c Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 27 Nov 2020 12:25:44 +0000 Subject: [PATCH 61/71] Moved common functions into utilities --- benchmarks/test_sharded_parity.py | 168 +++++++++--------------------- benchmarks/utilities.py | 117 +++++++++++++++++++++ 2 files changed, 169 insertions(+), 116 deletions(-) create mode 100644 benchmarks/utilities.py diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index 449c9d070f7f7..b96af667e15a7 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -1,13 +1,13 @@ import os import platform -import time from unittest import mock import pytest import torch from torch.utils.data.distributed import DistributedSampler -from pytorch_lightning import Trainer, seed_everything +from benchmarks.utilities import plugin_parity_test +from pytorch_lightning import seed_everything from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVAILABLE from tests.backends.launcher import DDPLauncher @@ -19,7 +19,12 @@ @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_one_device(): # Allow slightly slower speed due to one CPU doing additional sequential memory saving calls - run_sharded_correctness(accelerator='ddp_cpu', max_percent_speed_diff=0.5) + plugin_parity_test( + accelerator='ddp_cpu', + max_percent_speed_diff=0.5, + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel + ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @@ -27,7 +32,12 @@ def test_ddp_sharded_plugin_correctness_one_device(): reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_one_gpu(): - run_sharded_correctness(gpus=1, accelerator='ddp_spawn') + plugin_parity_test( + gpus=1, + accelerator='ddp_spawn', + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel + ) @pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP") @@ -36,7 +46,13 @@ def test_ddp_sharded_plugin_correctness_one_gpu(): reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_amp_one_gpu(): - run_sharded_correctness(gpus=1, precision=16, accelerator='ddp_spawn') + plugin_parity_test( + gpus=1, + precision=16, + accelerator='ddp_spawn', + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel + ) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @@ -44,7 +60,12 @@ def test_ddp_sharded_plugin_correctness_amp_one_gpu(): reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_multi_gpu(): - run_sharded_correctness(gpus=2, accelerator='ddp_spawn') + plugin_parity_test( + gpus=2, + accelerator='ddp_spawn', + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel + ) @pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP") @@ -53,7 +74,13 @@ def test_ddp_sharded_plugin_correctness_multi_gpu(): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): - run_sharded_correctness(gpus=2, precision=16, accelerator='ddp_spawn') + plugin_parity_test( + gpus=2, + precision=16, + accelerator='ddp_spawn', + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel + ) @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") @@ -63,7 +90,13 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): reason="test should be run outside of pytest") @DDPLauncher.run("--distributed_backend ddp --gpus 2 --precision 32") def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None): - run_sharded_correctness(gpus=args.gpus, precision=args.precision, accelerator=args.distributed_backend) + plugin_parity_test( + gpus=args.gpus, + precision=args.precision, + accelerator=args.distributed_backend, + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel + ) @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") @@ -73,7 +106,13 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None): reason="test should be run outside of pytest") @DDPLauncher.run("--distributed_backend ddp --gpus 2 --precision 16") def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None): - run_sharded_correctness(gpus=args.gpus, precision=args.precision, accelerator=args.distributed_backend) + plugin_parity_test( + gpus=args.gpus, + precision=args.precision, + accelerator=args.distributed_backend, + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel + ) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @@ -84,7 +123,8 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim(): """ Ensures same results using multiple optimizers across multiple GPUs """ - run_sharded_correctness( + plugin_parity_test( + plugin=DDPShardedPlugin(), gpus=2, accelerator='ddp_spawn', model_cls=SeedTrainLoaderMultipleOptimizersModel, @@ -102,7 +142,8 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir): """ Ensures using multiple optimizers across multiple GPUs with manual optimization """ - run_sharded_correctness( + plugin_parity_test( + plugin=DDPShardedPlugin(), gpus=2, accelerator='ddp_spawn', model_cls=SeedTrainLoaderManualModel, @@ -167,108 +208,3 @@ def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer, optimizer_2 - - -def record_ddp_fit_model_stats(trainer, model, gpus): - """ - Helper to calculate wall clock time for fit + max allocated memory. - - Args: - trainer: The trainer object. - model: The LightningModule. - gpus: Number of GPUs in test. - - Returns: - Max Memory if using GPUs, and total wall clock time. - - """ - max_memory = None - - time_start = time.perf_counter() - if gpus > 0: - torch.cuda.reset_peak_memory_stats() - torch.cuda.synchronize() - - trainer.fit(model) - - if gpus > 0: - torch.cuda.synchronize() - max_memory = torch.cuda.max_memory_allocated() / 2 ** 20 - - total_time = time.perf_counter() - time_start - - return max_memory, total_time - - -def run_sharded_correctness( - accelerator='ddp_spawn', - gpus=0, - precision=32, - max_percent_speed_diff=0.25, - model_cls=SeedTrainLoaderModel): - """ - Ensures that the trained model is identical to the standard DDP implementation. - Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate. - - Args: - accelerator: Accelerator type for test. - gpus: Number of GPUS to enable. - precision: Whether to use AMP or normal FP32 training. - max_percent_speed_diff: The maximum speed difference compared to normal DDP training. - This is more a safety net for variability in CI which can vary in speed, not for benchmarking. - model_cls: Model class to use for test. - - """ - - # Train normal DDP - seed_everything(42) - ddp_model = model_cls() - - trainer = Trainer( - fast_dev_run=True, - max_epochs=1, - gpus=gpus, - precision=precision, - accelerator=accelerator, - ) - - max_ddp_memory, ddp_time = record_ddp_fit_model_stats( - trainer=trainer, - model=ddp_model, - gpus=gpus - ) - - # Reset and train sharded DDP - seed_everything(42) - sharded_model = model_cls() - - trainer = Trainer( - fast_dev_run=True, - max_epochs=1, - gpus=gpus, - precision=precision, - accelerator=accelerator, - plugins=[DDPShardedPlugin()], - ) - - max_sharded_memory, sharded_time = record_ddp_fit_model_stats( - trainer=trainer, - model=sharded_model, - gpus=gpus - ) - - # Assert model parameters are identical after fit - for ddp_param, shard_param in zip(ddp_model.parameters(), sharded_model.parameters()): - assert torch.equal(ddp_param, shard_param), 'Model parameters are different between DDP and Sharded plugin' - - # Assert speed parity by ensuring percentage difference between sharded/ddp is below threshold - percent_diff = (sharded_time - ddp_time) / sharded_time - - assert percent_diff <= max_percent_speed_diff, \ - f'Sharded plugin was too slow compared to DDP, Sharded Time: {sharded_time}, DDP Time: {ddp_time}' - - if gpus > 0: - # Assert CUDA memory parity - assert max_sharded_memory <= max_ddp_memory, \ - f'Sharded plugin used too much memory compared to DDP,' \ - f'Sharded Mem: {max_sharded_memory}, DDP Mem: {max_ddp_memory}' diff --git a/benchmarks/utilities.py b/benchmarks/utilities.py new file mode 100644 index 0000000000000..1c07d2274c111 --- /dev/null +++ b/benchmarks/utilities.py @@ -0,0 +1,117 @@ +import time +from typing import Callable + +import torch +from pytorch_lightning.plugins.ddp_plugin import DDPPlugin + +from pytorch_lightning import Trainer +from pytorch_lightning.utilities.seed import seed_everything + + +def record_ddp_fit_model_stats(trainer, model, use_cuda): + """ + Helper to calculate wall clock time for fit + max allocated memory. + + Args: + trainer: The trainer object. + model: The model to fit. + use_cuda: Whether to sync CUDA kernels. + + Returns: + Max Memory if using GPUs, and total wall clock time. + """ + max_memory = None + + time_start = time.perf_counter() + if use_cuda: + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + trainer.fit(model) + + if use_cuda: + torch.cuda.synchronize() + max_memory = torch.cuda.max_memory_allocated() / 2 ** 20 + + total_time = time.perf_counter() - time_start + + return max_memory, total_time + + +def plugin_parity_test( + model_cls: Callable, + plugin: DDPPlugin, + seed: int = 42, + accelerator: str = 'ddp_spawn', + gpus: int = 0, + precision: int = 32, + max_percent_speed_diff: float = 0.25): + """ + Ensures that the trained model is identical to the standard DDP implementation. + Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate. + + Args: + model_cls: Model class to use for test. + plugin: Plugin to parity test. + seed: Seed for generators. Note that this does not handle the seed for data-loading on multi-process. + accelerator: Accelerator type for test. + gpus: Number of GPUS to enable. + precision: Whether to use AMP or normal FP32 training. + max_percent_speed_diff: The maximum speed difference compared to normal DDP training. + This is more a safety net for variability in CI which can vary in speed, not for benchmarking. + + """ + + # Train normal DDP + seed_everything(seed) + ddp_model = model_cls() + use_cuda = gpus > 0 + + trainer = Trainer( + fast_dev_run=True, + max_epochs=1, + gpus=gpus, + precision=precision, + accelerator=accelerator, + ) + + max_memory_ddp, ddp_time = record_ddp_fit_model_stats( + trainer=trainer, + model=ddp_model, + use_cuda=use_cuda + ) + + # Reset and train Custom DDP + seed_everything(seed) + custom_plugin_model = model_cls() + + trainer = Trainer( + fast_dev_run=True, + max_epochs=1, + gpus=gpus, + precision=precision, + accelerator=accelerator, + plugins=[plugin], + ) + + max_memory_custom, custom_model_time = record_ddp_fit_model_stats( + trainer=trainer, + model=custom_plugin_model, + use_cuda=use_cuda + ) + + # Assert model parameters are identical after fit + for ddp_param, custom_param in zip(ddp_model.parameters(), custom_plugin_model.parameters()): + assert torch.equal(ddp_param, custom_param), 'Model parameters are different between DDP and Custom plugin' + + # Assert speed parity by ensuring percentage difference between custom/ddp is below threshold + percent_diff = (custom_model_time - ddp_time) / custom_model_time + + assert percent_diff <= max_percent_speed_diff, \ + f'Custom DDP plugin was too slow compared to DDP, Custom Plugin Time: {custom_model_time}, DDP Time: {ddp_time}' + + if use_cuda: + # Assert CUDA memory parity + assert max_memory_custom <= max_memory_ddp, \ + f'Custom plugin used too much memory compared to DDP,' \ + f'Custom Mem: {max_memory_custom}, DDP Mem: {max_memory_ddp}' From e52386b0038cca9702ae861ea53b41ee341f8a42 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 27 Nov 2020 12:38:38 +0000 Subject: [PATCH 62/71] Combine utilities --- benchmarks/test_sharded_parity.py | 114 ++++++++++++++++++++++++++++- benchmarks/utilities.py | 117 ------------------------------ 2 files changed, 113 insertions(+), 118 deletions(-) delete mode 100644 benchmarks/utilities.py diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index b96af667e15a7..99e2cfc207b08 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -1,13 +1,16 @@ import os import platform +import time +from typing import Callable from unittest import mock import pytest import torch from torch.utils.data.distributed import DistributedSampler -from benchmarks.utilities import plugin_parity_test +from pytorch_lightning import Trainer from pytorch_lightning import seed_everything +from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVAILABLE from tests.backends.launcher import DDPLauncher @@ -208,3 +211,112 @@ def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer, optimizer_2 + + +def record_ddp_fit_model_stats(trainer, model, use_cuda): + """ + Helper to calculate wall clock time for fit + max allocated memory. + + Args: + trainer: The trainer object. + model: The model to fit. + use_cuda: Whether to sync CUDA kernels. + + Returns: + Max Memory if using GPUs, and total wall clock time. + """ + max_memory = None + + time_start = time.perf_counter() + if use_cuda: + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + trainer.fit(model) + + if use_cuda: + torch.cuda.synchronize() + max_memory = torch.cuda.max_memory_allocated() / 2 ** 20 + + total_time = time.perf_counter() - time_start + + return max_memory, total_time + + +def plugin_parity_test( + model_cls: Callable, + plugin: DDPPlugin, + seed: int = 42, + accelerator: str = 'ddp_spawn', + gpus: int = 0, + precision: int = 32, + max_percent_speed_diff: float = 0.25): + """ + Ensures that the trained model is identical to the standard DDP implementation. + Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate. + + Args: + model_cls: Model class to use for test. + plugin: Plugin to parity test. + seed: Seed for generators. Note that this does not handle the seed for data-loading on multi-process. + accelerator: Accelerator type for test. + gpus: Number of GPUS to enable. + precision: Whether to use AMP or normal FP32 training. + max_percent_speed_diff: The maximum speed difference compared to normal DDP training. + This is more a safety net for variability in CI which can vary in speed, not for benchmarking. + + """ + + # Train normal DDP + seed_everything(seed) + ddp_model = model_cls() + use_cuda = gpus > 0 + + trainer = Trainer( + fast_dev_run=True, + max_epochs=1, + gpus=gpus, + precision=precision, + accelerator=accelerator, + ) + + max_memory_ddp, ddp_time = record_ddp_fit_model_stats( + trainer=trainer, + model=ddp_model, + use_cuda=use_cuda + ) + + # Reset and train Custom DDP + seed_everything(seed) + custom_plugin_model = model_cls() + + trainer = Trainer( + fast_dev_run=True, + max_epochs=1, + gpus=gpus, + precision=precision, + accelerator=accelerator, + plugins=[plugin], + ) + + max_memory_custom, custom_model_time = record_ddp_fit_model_stats( + trainer=trainer, + model=custom_plugin_model, + use_cuda=use_cuda + ) + + # Assert model parameters are identical after fit + for ddp_param, custom_param in zip(ddp_model.parameters(), custom_plugin_model.parameters()): + assert torch.equal(ddp_param, custom_param), 'Model parameters are different between DDP and Custom plugin' + + # Assert speed parity by ensuring percentage difference between custom/ddp is below threshold + percent_diff = (custom_model_time - ddp_time) / custom_model_time + + assert percent_diff <= max_percent_speed_diff, \ + f'Custom DDP plugin was too slow compared to DDP, Custom Plugin Time: {custom_model_time}, DDP Time: {ddp_time}' + + if use_cuda: + # Assert CUDA memory parity + assert max_memory_custom <= max_memory_ddp, \ + f'Custom plugin used too much memory compared to DDP,' \ + f'Custom Mem: {max_memory_custom}, DDP Mem: {max_memory_ddp}' diff --git a/benchmarks/utilities.py b/benchmarks/utilities.py deleted file mode 100644 index 1c07d2274c111..0000000000000 --- a/benchmarks/utilities.py +++ /dev/null @@ -1,117 +0,0 @@ -import time -from typing import Callable - -import torch -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin - -from pytorch_lightning import Trainer -from pytorch_lightning.utilities.seed import seed_everything - - -def record_ddp_fit_model_stats(trainer, model, use_cuda): - """ - Helper to calculate wall clock time for fit + max allocated memory. - - Args: - trainer: The trainer object. - model: The model to fit. - use_cuda: Whether to sync CUDA kernels. - - Returns: - Max Memory if using GPUs, and total wall clock time. - """ - max_memory = None - - time_start = time.perf_counter() - if use_cuda: - torch.cuda.reset_peak_memory_stats() - torch.cuda.synchronize() - - trainer.fit(model) - - if use_cuda: - torch.cuda.synchronize() - max_memory = torch.cuda.max_memory_allocated() / 2 ** 20 - - total_time = time.perf_counter() - time_start - - return max_memory, total_time - - -def plugin_parity_test( - model_cls: Callable, - plugin: DDPPlugin, - seed: int = 42, - accelerator: str = 'ddp_spawn', - gpus: int = 0, - precision: int = 32, - max_percent_speed_diff: float = 0.25): - """ - Ensures that the trained model is identical to the standard DDP implementation. - Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate. - - Args: - model_cls: Model class to use for test. - plugin: Plugin to parity test. - seed: Seed for generators. Note that this does not handle the seed for data-loading on multi-process. - accelerator: Accelerator type for test. - gpus: Number of GPUS to enable. - precision: Whether to use AMP or normal FP32 training. - max_percent_speed_diff: The maximum speed difference compared to normal DDP training. - This is more a safety net for variability in CI which can vary in speed, not for benchmarking. - - """ - - # Train normal DDP - seed_everything(seed) - ddp_model = model_cls() - use_cuda = gpus > 0 - - trainer = Trainer( - fast_dev_run=True, - max_epochs=1, - gpus=gpus, - precision=precision, - accelerator=accelerator, - ) - - max_memory_ddp, ddp_time = record_ddp_fit_model_stats( - trainer=trainer, - model=ddp_model, - use_cuda=use_cuda - ) - - # Reset and train Custom DDP - seed_everything(seed) - custom_plugin_model = model_cls() - - trainer = Trainer( - fast_dev_run=True, - max_epochs=1, - gpus=gpus, - precision=precision, - accelerator=accelerator, - plugins=[plugin], - ) - - max_memory_custom, custom_model_time = record_ddp_fit_model_stats( - trainer=trainer, - model=custom_plugin_model, - use_cuda=use_cuda - ) - - # Assert model parameters are identical after fit - for ddp_param, custom_param in zip(ddp_model.parameters(), custom_plugin_model.parameters()): - assert torch.equal(ddp_param, custom_param), 'Model parameters are different between DDP and Custom plugin' - - # Assert speed parity by ensuring percentage difference between custom/ddp is below threshold - percent_diff = (custom_model_time - ddp_time) / custom_model_time - - assert percent_diff <= max_percent_speed_diff, \ - f'Custom DDP plugin was too slow compared to DDP, Custom Plugin Time: {custom_model_time}, DDP Time: {ddp_time}' - - if use_cuda: - # Assert CUDA memory parity - assert max_memory_custom <= max_memory_ddp, \ - f'Custom plugin used too much memory compared to DDP,' \ - f'Custom Mem: {max_memory_custom}, DDP Mem: {max_memory_ddp}' From bd4223e951fea64e85f79be098ca9c59966053fa Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 27 Nov 2020 13:22:51 +0000 Subject: [PATCH 63/71] Fix imports --- benchmarks/test_sharded_parity.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index 99e2cfc207b08..f901a767ea308 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -6,10 +6,8 @@ import pytest import torch -from torch.utils.data.distributed import DistributedSampler -from pytorch_lightning import Trainer -from pytorch_lightning import seed_everything +from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVAILABLE From 5598dce1a98f348732dace7bdc47805168882b1f Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 27 Nov 2020 14:22:17 +0000 Subject: [PATCH 64/71] Remove unneeded check --- .../trainer/connectors/precision_connector.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/precision_connector.py b/pytorch_lightning/trainer/connectors/precision_connector.py index 11e70b45af0c6..449bc95f791a1 100644 --- a/pytorch_lightning/trainer/connectors/precision_connector.py +++ b/pytorch_lightning/trainer/connectors/precision_connector.py @@ -51,7 +51,7 @@ def _setup_amp_backend(self, amp_type: str, plugins: Optional[list]): # no AMP requested, so we can leave now return - using_sharded_plugin = self._check_sharded_plugin(plugins) + using_sharded_plugin = self._check_using_sharded_plugin(plugins) amp_type = amp_type.lower() assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}' if amp_type == 'native': @@ -94,14 +94,7 @@ def connect(self, model): return model - def _check_sharded_plugin(self, plugins): - if plugins and self._sharded_in_plugins(plugins): - if not FAIRSCALE_AVAILABLE: - raise MisconfigurationException('Sharded DDP Plugin requires Fairscale to be installed.') - return True - return False - - def _sharded_in_plugins(self, plugins): + def _check_using_sharded_plugin(self, plugins): for plugin in plugins: if isinstance(plugin, DDPShardedPlugin): return True From cdd2e122fcb3d37539bd4efbc8650948ef0eb291 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 27 Nov 2020 14:30:57 +0000 Subject: [PATCH 65/71] Add none check for func --- .../trainer/connectors/precision_connector.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/precision_connector.py b/pytorch_lightning/trainer/connectors/precision_connector.py index 449bc95f791a1..0fcce5c869032 100644 --- a/pytorch_lightning/trainer/connectors/precision_connector.py +++ b/pytorch_lightning/trainer/connectors/precision_connector.py @@ -94,8 +94,9 @@ def connect(self, model): return model - def _check_using_sharded_plugin(self, plugins): - for plugin in plugins: - if isinstance(plugin, DDPShardedPlugin): - return True + def _check_using_sharded_plugin(self, plugins: Optional[list]): + if plugins is not None: + for plugin in plugins: + if isinstance(plugin, DDPShardedPlugin): + return True return False From 4f693762ea265e854497bc0066d352a13265c62a Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Fri, 27 Nov 2020 14:45:15 +0000 Subject: [PATCH 66/71] Update pytorch_lightning/trainer/connectors/precision_connector.py Co-authored-by: Jirka Borovec --- pytorch_lightning/trainer/connectors/precision_connector.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/precision_connector.py b/pytorch_lightning/trainer/connectors/precision_connector.py index 0fcce5c869032..3c006a06ca5b3 100644 --- a/pytorch_lightning/trainer/connectors/precision_connector.py +++ b/pytorch_lightning/trainer/connectors/precision_connector.py @@ -74,8 +74,9 @@ def _setup_amp_backend(self, amp_type: str, plugins: Optional[list]): rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.' ' Install apex first using this guide: https://github.com/NVIDIA/apex#linux') elif using_sharded_plugin: - raise MisconfigurationException('Sharded Plugin is not supported with Apex AMP, ' - 'please using native AMP for 16 bit precision.') + raise MisconfigurationException( + 'Sharded Plugin is not supported with Apex AMP, please using native AMP for 16-bit precision.' + ) else: log.info('Using APEX 16bit precision.') self.trainer.amp_backend = AMPType.APEX From 17047737127734e6a99638f54df8362daf8b26cc Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 27 Nov 2020 14:50:02 +0000 Subject: [PATCH 67/71] Address code review --- benchmarks/test_sharded_parity.py | 8 ++------ pytorch_lightning/plugins/sharded_native_amp_plugin.py | 3 +-- .../trainer/connectors/precision_connector.py | 4 +--- tests/plugins/test_sharded_plugin.py | 6 ------ 4 files changed, 4 insertions(+), 17 deletions(-) diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index f901a767ea308..ff1a59fa15fe8 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -1,13 +1,12 @@ import os import platform import time -from typing import Callable -from unittest import mock import pytest import torch from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVAILABLE @@ -85,7 +84,6 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest") @@ -101,7 +99,6 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None): @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest") @@ -138,7 +135,6 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim(): @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir): """ Ensures using multiple optimizers across multiple GPUs with manual optimization @@ -242,7 +238,7 @@ def record_ddp_fit_model_stats(trainer, model, use_cuda): def plugin_parity_test( - model_cls: Callable, + model_cls: LightningModule, plugin: DDPPlugin, seed: int = 42, accelerator: str = 'ddp_spawn', diff --git a/pytorch_lightning/plugins/sharded_native_amp_plugin.py b/pytorch_lightning/plugins/sharded_native_amp_plugin.py index 581a84e5cbeef..a66b118da4b2b 100644 --- a/pytorch_lightning/plugins/sharded_native_amp_plugin.py +++ b/pytorch_lightning/plugins/sharded_native_amp_plugin.py @@ -14,13 +14,12 @@ from typing import cast from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE, FAIRSCALE_AVAILABLE +from pytorch_lightning.plugins.native_amp import NativeAMPPlugin if NATIVE_AMP_AVAILABLE and FAIRSCALE_AVAILABLE: from fairscale.optim import OSS from fairscale.optim.grad_scaler import ShardedGradScaler -from pytorch_lightning.plugins.native_amp import NativeAMPPlugin - class ShardedNativeAMPPlugin(NativeAMPPlugin): @property diff --git a/pytorch_lightning/trainer/connectors/precision_connector.py b/pytorch_lightning/trainer/connectors/precision_connector.py index 3c006a06ca5b3..8866607dc678c 100644 --- a/pytorch_lightning/trainer/connectors/precision_connector.py +++ b/pytorch_lightning/trainer/connectors/precision_connector.py @@ -16,11 +16,9 @@ from pytorch_lightning import _logger as log from pytorch_lightning.plugins.apex import ApexPlugin from pytorch_lightning.plugins.native_amp import NativeAMPPlugin -from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVAILABLE, AMPType, rank_zero_warn, \ - FAIRSCALE_AVAILABLE from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin - +from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVAILABLE, AMPType, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index a36e4575002be..527b85c0d0632 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -196,7 +196,6 @@ def test_ddp_sharded_plugin_finetune(tmpdir): fast_dev_run=True, ) trainer.fit(saved_model) - return 1 @pytest.mark.skipif(platform.system() == "Windows", @@ -228,7 +227,6 @@ def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir): ) trainer.fit(model) - return 1 @pytest.mark.skip(reason="Currently unsupported restarting training on different number of devices.") @@ -264,7 +262,6 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_downsize_gpus(tmpdir): ) trainer.fit(model) - return 1 @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @@ -298,7 +295,6 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): ) trainer.fit(model) - return 1 @pytest.mark.skipif(platform.system() == "Windows", @@ -316,7 +312,6 @@ def test_ddp_sharded_plugin_test(tmpdir): ) trainer.test(model) - return 1 @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @@ -336,4 +331,3 @@ def test_ddp_sharded_plugin_test_multigpu(tmpdir): ) trainer.test(model) - return 1 From bf9cf3dd016970cd7c3f0e95939357f53434a17b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 27 Nov 2020 15:26:06 +0000 Subject: [PATCH 68/71] Tighten up regression testing --- benchmarks/test_sharded_parity.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index ff1a59fa15fe8..d746ce2ce47e7 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -6,7 +6,6 @@ import torch from pytorch_lightning import Trainer, seed_everything -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVAILABLE @@ -18,10 +17,9 @@ reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_one_device(): - # Allow slightly slower speed due to one CPU doing additional sequential memory saving calls plugin_parity_test( accelerator='ddp_cpu', - max_percent_speed_diff=0.5, + max_percent_speed_diff=0.15, # slower speed due to one CPU doing additional sequential memory saving calls plugin=DDPShardedPlugin(), model_cls=SeedTrainLoaderModel ) @@ -64,7 +62,8 @@ def test_ddp_sharded_plugin_correctness_multi_gpu(): gpus=2, accelerator='ddp_spawn', plugin=DDPShardedPlugin(), - model_cls=SeedTrainLoaderModel + model_cls=SeedTrainLoaderModel, + max_percent_speed_diff=0.15 ) @@ -126,7 +125,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim(): gpus=2, accelerator='ddp_spawn', model_cls=SeedTrainLoaderMultipleOptimizersModel, - max_percent_speed_diff=0.3 # Increase speed diff since only 2 GPUs sharding 2 optimizers + max_percent_speed_diff=0.2 # Increase speed diff since only 2 GPUs sharding 2 optimizers ) @@ -238,13 +237,13 @@ def record_ddp_fit_model_stats(trainer, model, use_cuda): def plugin_parity_test( - model_cls: LightningModule, + model_cls: SeedTrainLoaderModel, plugin: DDPPlugin, seed: int = 42, accelerator: str = 'ddp_spawn', gpus: int = 0, precision: int = 32, - max_percent_speed_diff: float = 0.25): + max_percent_speed_diff: float = 0.1): """ Ensures that the trained model is identical to the standard DDP implementation. Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate. From b4e8071de2f980365bec34552500167dcefc5923 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 27 Nov 2020 15:49:02 +0000 Subject: [PATCH 69/71] Increase speed diff for drone --- benchmarks/test_sharded_parity.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index d746ce2ce47e7..584dcd20de782 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -63,7 +63,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu(): accelerator='ddp_spawn', plugin=DDPShardedPlugin(), model_cls=SeedTrainLoaderModel, - max_percent_speed_diff=0.15 + max_percent_speed_diff=0.2 ) @@ -78,7 +78,8 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): precision=16, accelerator='ddp_spawn', plugin=DDPShardedPlugin(), - model_cls=SeedTrainLoaderModel + model_cls=SeedTrainLoaderModel, + max_percent_speed_diff=0.2 ) From d12577d3481c5dfd83dcf9666ccba5e21619f644 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 27 Nov 2020 16:28:24 +0000 Subject: [PATCH 70/71] Reduce speed diff further, lack of GPU saturation is causing regressive times on drone CI --- benchmarks/test_sharded_parity.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index 584dcd20de782..0d8a704fa0ac4 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -63,7 +63,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu(): accelerator='ddp_spawn', plugin=DDPShardedPlugin(), model_cls=SeedTrainLoaderModel, - max_percent_speed_diff=0.2 + max_percent_speed_diff=0.25 ) @@ -79,7 +79,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): accelerator='ddp_spawn', plugin=DDPShardedPlugin(), model_cls=SeedTrainLoaderModel, - max_percent_speed_diff=0.2 + max_percent_speed_diff=0.25 ) From 1719b2dca4c899c1933437c0663ffed7559569db Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 27 Nov 2020 20:21:50 +0000 Subject: [PATCH 71/71] Skip a few tests to reduce drone CI wait times --- benchmarks/test_sharded_parity.py | 1 + tests/plugins/test_sharded_plugin.py | 1 + 2 files changed, 2 insertions(+) diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index 0d8a704fa0ac4..2ecdc6c50e709 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -53,6 +53,7 @@ def test_ddp_sharded_plugin_correctness_amp_one_gpu(): ) +@pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 527b85c0d0632..5010c39de7a80 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -229,6 +229,7 @@ def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir): trainer.fit(model) +@pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.") @pytest.mark.skip(reason="Currently unsupported restarting training on different number of devices.") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(platform.system() == "Windows",