diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 093065394b337..724b5b6f244c1 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Any, Dict, Union import torch @@ -19,7 +19,7 @@ import pytorch_lightning as pl -class Accelerator: +class Accelerator(ABC): """The Accelerator Base Class. An Accelerator is meant to deal with one type of Hardware. Currently there are accelerators for: @@ -45,7 +45,7 @@ def setup(self, trainer: "pl.Trainer") -> None: """ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: - """Gets stats for a given device. + """Get stats for a given device. Args: device: device for which to get stats @@ -58,4 +58,4 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: @staticmethod @abstractmethod def auto_device_count() -> int: - """Get the devices when set to auto.""" + """Get the device count when set to auto.""" diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 10d1caec5db21..3e2ec15216841 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -336,7 +336,9 @@ def creates_processes_externally(self) -> bool: @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) def test_custom_accelerator(device_count_mock, setup_distributed_mock): class Accel(Accelerator): - pass + @staticmethod + def auto_device_count() -> int: + return 1 class Prec(PrecisionPlugin): pass