From 89d57f616123be0036b944abda83ae9e09af8c81 Mon Sep 17 00:00:00 2001 From: Andrew Tritt Date: Wed, 2 Jun 2021 20:50:16 -0400 Subject: [PATCH 01/13] check distributed backend when selecting training type plugin --- .../trainer/connectors/accelerator_connector.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 4d692ec517d19..530e1c65b9df8 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -168,6 +168,7 @@ def handle_given_plugins(self) -> None: precision = None cluster_environment = None + for plug in self.plugins: if isinstance(plug, str) and plug in TrainingTypePluginsRegistry: if training_type is None: @@ -403,7 +404,9 @@ def select_precision_plugin(self) -> PrecisionPlugin: raise NotImplementedError("We only support precisions 64, 32 and 16!") def select_training_type_plugin(self) -> TrainingTypePlugin: - if self.use_ddp2: + if isinstance(self.distributed_backend, Accelerator) and self.distributed_backend.training_type_plugin is not None: + plugin = self.distributed_backend.training_type_plugin + elif self.use_ddp2: plugin = DDP2Plugin( parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment, From cf5ba43653881ff010caf645076794f5cef43ac7 Mon Sep 17 00:00:00 2001 From: Andrew Tritt Date: Wed, 2 Jun 2021 20:50:38 -0400 Subject: [PATCH 02/13] add is_distributed attribute to DDPPlugin --- pytorch_lightning/plugins/training_type/ddp.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index e65a6512d3846..726fe2822aecf 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -95,6 +95,10 @@ def __init__( self._ddp_comm_wrapper = ddp_comm_wrapper self.set_world_ranks() + @property + def is_distributed(self) -> bool: + return True + @property def root_device(self) -> torch.device: return self.parallel_devices[self.local_rank] From c81f3290efd5d2e4f6853c6cf5ff8aab433d4cb7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Jun 2021 04:04:19 +0000 Subject: [PATCH 03/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../trainer/connectors/accelerator_connector.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 530e1c65b9df8..c01a6b30c561d 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -168,7 +168,6 @@ def handle_given_plugins(self) -> None: precision = None cluster_environment = None - for plug in self.plugins: if isinstance(plug, str) and plug in TrainingTypePluginsRegistry: if training_type is None: @@ -404,7 +403,9 @@ def select_precision_plugin(self) -> PrecisionPlugin: raise NotImplementedError("We only support precisions 64, 32 and 16!") def select_training_type_plugin(self) -> TrainingTypePlugin: - if isinstance(self.distributed_backend, Accelerator) and self.distributed_backend.training_type_plugin is not None: + if isinstance( + self.distributed_backend, Accelerator + ) and self.distributed_backend.training_type_plugin is not None: plugin = self.distributed_backend.training_type_plugin elif self.use_ddp2: plugin = DDP2Plugin( From 5cdaeff93a62721a8aa91f236575c4b6f8b08b31 Mon Sep 17 00:00:00 2001 From: Andrew Tritt Date: Thu, 3 Jun 2021 00:04:41 -0400 Subject: [PATCH 04/13] remove unnecessary newline --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 530e1c65b9df8..b4f90fbc91dd3 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -168,7 +168,6 @@ def handle_given_plugins(self) -> None: precision = None cluster_environment = None - for plug in self.plugins: if isinstance(plug, str) and plug in TrainingTypePluginsRegistry: if training_type is None: From 59ebfbd92676adf6a3c8b658b68cdf3a42e2aed8 Mon Sep 17 00:00:00 2001 From: Andrew Tritt Date: Thu, 3 Jun 2021 13:51:17 -0400 Subject: [PATCH 05/13] add test for setting training_type_plugin --- .../connectors/test_accelerator_connector.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 tests/trainer/connectors/test_accelerator_connector.py diff --git a/tests/trainer/connectors/test_accelerator_connector.py b/tests/trainer/connectors/test_accelerator_connector.py new file mode 100644 index 0000000000000..bcfeda16b818e --- /dev/null +++ b/tests/trainer/connectors/test_accelerator_connector.py @@ -0,0 +1,29 @@ +import torch +from pytorch_lightning import Trainer +from pytorch_lightning.accelerators import Accelerator +from pytorch_lightning.plugins import PrecisionPlugin, DDPPlugin, SingleDevicePlugin +from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector + +def test_accelerator_training_type_plugin(): + """Test that the training_type_plugin pulled from accelearator""" + + # check that this works for different types of plugins to ensure + # there are no dependencies on TrainingTypePlugin class refinements + + precision_plugin = PrecisionPlugin() + singledev_plugin = SingleDevicePlugin(torch.device('cpu')) + accelerator = Accelerator( + precision_plugin=precision_plugin, + training_type_plugin=singledev_plugin, + ) + trainer = Trainer(accelerator=accelerator) + assert trainer.training_type_plugin is singledev_plugin + + precision_plugin = PrecisionPlugin() + ddp_plugin = DDPPlugin() + accelerator = Accelerator( + precision_plugin=precision_plugin, + training_type_plugin=ddp_plugin, + ) + trainer = Trainer(accelerator=accelerator) + assert trainer.training_type_plugin is ddp_plugin From c667471589593f0d5fa94fd6a54a0854a9e9ad60 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Jun 2021 17:52:26 +0000 Subject: [PATCH 06/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/trainer/connectors/test_accelerator_connector.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/trainer/connectors/test_accelerator_connector.py b/tests/trainer/connectors/test_accelerator_connector.py index bcfeda16b818e..e08b3d50b75e3 100644 --- a/tests/trainer/connectors/test_accelerator_connector.py +++ b/tests/trainer/connectors/test_accelerator_connector.py @@ -1,9 +1,11 @@ import torch + from pytorch_lightning import Trainer from pytorch_lightning.accelerators import Accelerator -from pytorch_lightning.plugins import PrecisionPlugin, DDPPlugin, SingleDevicePlugin +from pytorch_lightning.plugins import DDPPlugin, PrecisionPlugin, SingleDevicePlugin from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector + def test_accelerator_training_type_plugin(): """Test that the training_type_plugin pulled from accelearator""" From 767b678faf440c07042071f07a3a24333fb64d23 Mon Sep 17 00:00:00 2001 From: Andrew Tritt Date: Thu, 3 Jun 2021 13:52:48 -0400 Subject: [PATCH 07/13] fix typo --- tests/trainer/connectors/test_accelerator_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/connectors/test_accelerator_connector.py b/tests/trainer/connectors/test_accelerator_connector.py index bcfeda16b818e..15b65b49daf09 100644 --- a/tests/trainer/connectors/test_accelerator_connector.py +++ b/tests/trainer/connectors/test_accelerator_connector.py @@ -5,7 +5,7 @@ from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector def test_accelerator_training_type_plugin(): - """Test that the training_type_plugin pulled from accelearator""" + """Test that training_type_plugin is pulled from accelearator""" # check that this works for different types of plugins to ensure # there are no dependencies on TrainingTypePlugin class refinements From eb276d110e69ab2125dddbae8646950105bcec19 Mon Sep 17 00:00:00 2001 From: Andrew Tritt Date: Thu, 3 Jun 2021 14:01:02 -0400 Subject: [PATCH 08/13] read training_type_plugin from connector --- tests/trainer/connectors/test_accelerator_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/connectors/test_accelerator_connector.py b/tests/trainer/connectors/test_accelerator_connector.py index 484b6b563abca..12bef90e1a301 100644 --- a/tests/trainer/connectors/test_accelerator_connector.py +++ b/tests/trainer/connectors/test_accelerator_connector.py @@ -19,7 +19,7 @@ def test_accelerator_training_type_plugin(): training_type_plugin=singledev_plugin, ) trainer = Trainer(accelerator=accelerator) - assert trainer.training_type_plugin is singledev_plugin + assert trainer.accelerator_connector.training_type_plugin is singledev_plugin precision_plugin = PrecisionPlugin() ddp_plugin = DDPPlugin() @@ -28,4 +28,4 @@ def test_accelerator_training_type_plugin(): training_type_plugin=ddp_plugin, ) trainer = Trainer(accelerator=accelerator) - assert trainer.training_type_plugin is ddp_plugin + assert trainer.accelerator_connector.training_type_plugin is ddp_plugin From c0a90678dd765464b7399994ceee7a01f3fecdbf Mon Sep 17 00:00:00 2001 From: Andrew Tritt Date: Thu, 3 Jun 2021 14:03:37 -0400 Subject: [PATCH 09/13] remove unused import --- tests/trainer/connectors/test_accelerator_connector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/connectors/test_accelerator_connector.py b/tests/trainer/connectors/test_accelerator_connector.py index 12bef90e1a301..67c691d4ef27b 100644 --- a/tests/trainer/connectors/test_accelerator_connector.py +++ b/tests/trainer/connectors/test_accelerator_connector.py @@ -3,7 +3,6 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.plugins import DDPPlugin, PrecisionPlugin, SingleDevicePlugin -from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector def test_accelerator_training_type_plugin(): From bd9916ce2ba04cfe4427df0fd5c0eec5426fff15 Mon Sep 17 00:00:00 2001 From: Andrew Tritt Date: Fri, 4 Jun 2021 14:15:05 -0400 Subject: [PATCH 10/13] move test to test_custom_accelerator --- .../test_accelerator_connector.py | 22 +++++++++++++- .../connectors/test_accelerator_connector.py | 30 ------------------- 2 files changed, 21 insertions(+), 31 deletions(-) delete mode 100644 tests/trainer/connectors/test_accelerator_connector.py diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index e60b86513e5ff..9c3bccb4d3283 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -453,8 +453,9 @@ class Prec(PrecisionPlugin): class TrainTypePlugin(SingleDevicePlugin): pass + ttp = TrainTypePlugin(device=torch.device("cpu")) accelerator = Accel( - training_type_plugin=TrainTypePlugin(device=torch.device("cpu")), + training_type_plugin=ttp, precision_plugin=Prec(), ) trainer = Trainer( @@ -465,6 +466,25 @@ class TrainTypePlugin(SingleDevicePlugin): assert isinstance(trainer.accelerator, Accel) assert isinstance(trainer.training_type_plugin, TrainTypePlugin) assert isinstance(trainer.precision_plugin, Prec) + assert trainer.accelerator_connector.training_type_plugin is ttp + + class DistributedPlugin(DDPPlugin): + pass + + ttp = DistributedPlugin() + accelerator = Accel( + training_type_plugin=ttp, + precision_plugin=Prec(), + ) + trainer = Trainer( + accelerator=accelerator, + fast_dev_run=True, + num_processes=2, + ) + assert isinstance(trainer.accelerator, Accel) + assert isinstance(trainer.training_type_plugin, DistributedPlugin) + assert isinstance(trainer.precision_plugin, Prec) + assert trainer.accelerator_connector.training_type_plugin is ttp @mock.patch.dict( diff --git a/tests/trainer/connectors/test_accelerator_connector.py b/tests/trainer/connectors/test_accelerator_connector.py deleted file mode 100644 index 67c691d4ef27b..0000000000000 --- a/tests/trainer/connectors/test_accelerator_connector.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch - -from pytorch_lightning import Trainer -from pytorch_lightning.accelerators import Accelerator -from pytorch_lightning.plugins import DDPPlugin, PrecisionPlugin, SingleDevicePlugin - - -def test_accelerator_training_type_plugin(): - """Test that training_type_plugin is pulled from accelearator""" - - # check that this works for different types of plugins to ensure - # there are no dependencies on TrainingTypePlugin class refinements - - precision_plugin = PrecisionPlugin() - singledev_plugin = SingleDevicePlugin(torch.device('cpu')) - accelerator = Accelerator( - precision_plugin=precision_plugin, - training_type_plugin=singledev_plugin, - ) - trainer = Trainer(accelerator=accelerator) - assert trainer.accelerator_connector.training_type_plugin is singledev_plugin - - precision_plugin = PrecisionPlugin() - ddp_plugin = DDPPlugin() - accelerator = Accelerator( - precision_plugin=precision_plugin, - training_type_plugin=ddp_plugin, - ) - trainer = Trainer(accelerator=accelerator) - assert trainer.accelerator_connector.training_type_plugin is ddp_plugin From fed73b93d13afdbab8758aa0d9f764727f358eae Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sat, 12 Jun 2021 10:11:01 +0200 Subject: [PATCH 11/13] Apply suggestions from code review --- .../trainer/connectors/accelerator_connector.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index c01a6b30c561d..8658dd0ef70c6 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -403,9 +403,8 @@ def select_precision_plugin(self) -> PrecisionPlugin: raise NotImplementedError("We only support precisions 64, 32 and 16!") def select_training_type_plugin(self) -> TrainingTypePlugin: - if isinstance( - self.distributed_backend, Accelerator - ) and self.distributed_backend.training_type_plugin is not None: + is_plugin = self.distributed_backend.training_type_plugin is not None + if isinstance(self.distributed_backend, Accelerator) and is_plugin: plugin = self.distributed_backend.training_type_plugin elif self.use_ddp2: plugin = DDP2Plugin( From 6aa4e7d81fd0c3084f16d8d51ef814b2894ba8c3 Mon Sep 17 00:00:00 2001 From: Andrew Tritt Date: Mon, 14 Jun 2021 10:11:17 -0700 Subject: [PATCH 12/13] Update pytorch_lightning/trainer/connectors/accelerator_connector.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 8658dd0ef70c6..b4f90fbc91dd3 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -403,8 +403,7 @@ def select_precision_plugin(self) -> PrecisionPlugin: raise NotImplementedError("We only support precisions 64, 32 and 16!") def select_training_type_plugin(self) -> TrainingTypePlugin: - is_plugin = self.distributed_backend.training_type_plugin is not None - if isinstance(self.distributed_backend, Accelerator) and is_plugin: + if isinstance(self.distributed_backend, Accelerator) and self.distributed_backend.training_type_plugin is not None: plugin = self.distributed_backend.training_type_plugin elif self.use_ddp2: plugin = DDP2Plugin( From 2816921657d93b4743031eab908c82831d00bcea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 14 Jun 2021 17:12:20 +0000 Subject: [PATCH 13/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index b4f90fbc91dd3..c01a6b30c561d 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -403,7 +403,9 @@ def select_precision_plugin(self) -> PrecisionPlugin: raise NotImplementedError("We only support precisions 64, 32 and 16!") def select_training_type_plugin(self) -> TrainingTypePlugin: - if isinstance(self.distributed_backend, Accelerator) and self.distributed_backend.training_type_plugin is not None: + if isinstance( + self.distributed_backend, Accelerator + ) and self.distributed_backend.training_type_plugin is not None: plugin = self.distributed_backend.training_type_plugin elif self.use_ddp2: plugin = DDP2Plugin(