diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 42be8009aab28..ed320f37d7006 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -95,6 +95,10 @@ def __init__( self._ddp_comm_wrapper = ddp_comm_wrapper self.set_world_ranks() + @property + def is_distributed(self) -> bool: + return True + @property def root_device(self) -> torch.device: return self.parallel_devices[self.local_rank] diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 8f5de9a6302aa..eb5f0a2fc1a7d 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -421,7 +421,11 @@ def select_precision_plugin(self) -> PrecisionPlugin: raise NotImplementedError("We only support precisions 64, 32 and 16!") def select_training_type_plugin(self) -> TrainingTypePlugin: - if self.use_ddp2: + if isinstance( + self.distributed_backend, Accelerator + ) and self.distributed_backend.training_type_plugin is not None: + plugin = self.distributed_backend.training_type_plugin + elif self.use_ddp2: plugin = DDP2Plugin( parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment, diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index e60b86513e5ff..9c3bccb4d3283 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -453,8 +453,9 @@ class Prec(PrecisionPlugin): class TrainTypePlugin(SingleDevicePlugin): pass + ttp = TrainTypePlugin(device=torch.device("cpu")) accelerator = Accel( - training_type_plugin=TrainTypePlugin(device=torch.device("cpu")), + training_type_plugin=ttp, precision_plugin=Prec(), ) trainer = Trainer( @@ -465,6 +466,25 @@ class TrainTypePlugin(SingleDevicePlugin): assert isinstance(trainer.accelerator, Accel) assert isinstance(trainer.training_type_plugin, TrainTypePlugin) assert isinstance(trainer.precision_plugin, Prec) + assert trainer.accelerator_connector.training_type_plugin is ttp + + class DistributedPlugin(DDPPlugin): + pass + + ttp = DistributedPlugin() + accelerator = Accel( + training_type_plugin=ttp, + precision_plugin=Prec(), + ) + trainer = Trainer( + accelerator=accelerator, + fast_dev_run=True, + num_processes=2, + ) + assert isinstance(trainer.accelerator, Accel) + assert isinstance(trainer.training_type_plugin, DistributedPlugin) + assert isinstance(trainer.precision_plugin, Prec) + assert trainer.accelerator_connector.training_type_plugin is ttp @mock.patch.dict(