diff --git a/src/lightning/fabric/accelerators/__init__.py b/src/lightning/fabric/accelerators/__init__.py index f4b258f60b655..54b6ac16dc992 100644 --- a/src/lightning/fabric/accelerators/__init__.py +++ b/src/lightning/fabric/accelerators/__init__.py @@ -17,10 +17,11 @@ from lightning.fabric.accelerators.registry import _AcceleratorRegistry, call_register_accelerators from lightning.fabric.accelerators.xla import XLAAccelerator # noqa: F401 from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE + _ACCELERATORS_BASE_MODULE = "lightning.fabric.accelerators" ACCELERATOR_REGISTRY = _AcceleratorRegistry() call_register_accelerators(ACCELERATOR_REGISTRY, _ACCELERATORS_BASE_MODULE) -if _LIGHTNING_XPU_AVAILABLE: - if "xpu" not in ACCELERATOR_REGISTRY: - from lightning_xpu.fabric import XPUAccelerator - XPUAccelerator.register_accelerators(ACCELERATOR_REGISTRY) +if _LIGHTNING_XPU_AVAILABLE and "xpu" not in ACCELERATOR_REGISTRY: + from lightning_xpu.fabric import XPUAccelerator + + XPUAccelerator.register_accelerators(ACCELERATOR_REGISTRY) diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index 323299f7c4c8f..bebf197581ba8 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -158,6 +158,7 @@ def _get_num_processes(accelerator: str, devices: str) -> int: raise ValueError("Launching processes for TPU through the CLI is not supported.") elif accelerator == "xpu": from lightning_xpu.fabric import XPUAccelerator + parsed_devices = XPUAccelerator.parse_devices(devices) else: return CPUAccelerator.parse_devices(devices) diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index c31470d374b12..931e996344e20 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -23,7 +23,6 @@ from lightning.fabric.accelerators.cuda import CUDAAccelerator from lightning.fabric.accelerators.mps import MPSAccelerator from lightning.fabric.accelerators.xla import XLAAccelerator -from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE from lightning.fabric.plugins import ( CheckpointIO, DeepSpeedPrecision, @@ -64,7 +63,7 @@ from lightning.fabric.strategies.fsdp import _FSDP_ALIASES, FSDPStrategy from lightning.fabric.utilities import rank_zero_info, rank_zero_warn from lightning.fabric.utilities.device_parser import _determine_root_gpu_device -from lightning.fabric.utilities.imports import _IS_INTERACTIVE +from lightning.fabric.utilities.imports import _IS_INTERACTIVE, _LIGHTNING_XPU_AVAILABLE _PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO] _PLUGIN_INPUT = Union[_PLUGIN, str] @@ -323,6 +322,7 @@ def _choose_auto_accelerator(self) -> str: return "cuda" if _LIGHTNING_XPU_AVAILABLE: from lightning_xpu.fabric import XPUAccelerator + if XPUAccelerator.is_available(): return "xpu" @@ -336,6 +336,7 @@ def _choose_gpu_accelerator_backend() -> str: return "cuda" if _LIGHTNING_XPU_AVAILABLE: from lightning_xpu.fabric import XPUAccelerator + if XPUAccelerator.is_available(): return "xpu" raise RuntimeError("No supported gpu backend found!") @@ -399,6 +400,7 @@ def _choose_strategy(self) -> Union[Strategy, str]: supported_accelerators_str = ["cuda", "gpu", "mps"] if _LIGHTNING_XPU_AVAILABLE: from lightning_xpu.fabric import XPUAccelerator + supported_accelerators.append(XPUAccelerator) supported_accelerators_str.append("xpu") if isinstance(self._accelerator_flag, tuple(supported_accelerators)) or ( diff --git a/src/lightning/fabric/utilities/device_parser.py b/src/lightning/fabric/utilities/device_parser.py index 7cd7c53296928..eca550d344e21 100644 --- a/src/lightning/fabric/utilities/device_parser.py +++ b/src/lightning/fabric/utilities/device_parser.py @@ -16,8 +16,8 @@ import lightning.fabric.accelerators as accelerators # avoid circular dependency from lightning.fabric.plugins.environments.torchelastic import TorchElasticEnvironment from lightning.fabric.utilities.exceptions import MisconfigurationException -from lightning.fabric.utilities.types import _DEVICE from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE +from lightning.fabric.utilities.types import _DEVICE def _determine_root_gpu_device(gpus: List[_DEVICE]) -> Optional[_DEVICE]: @@ -87,14 +87,17 @@ def _parse_gpu_ids( # We know the user requested GPUs therefore if some of the # requested GPUs are not available an exception is thrown. gpus = _normalize_parse_gpu_string_input(gpus) - gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu) + gpus = _normalize_parse_gpu_input_to_list( + gpus, include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu + ) if not gpus: raise MisconfigurationException("GPUs requested but none are available.") if ( TorchElasticEnvironment.detect() and len(gpus) != 1 - and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu)) == 1 + and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu)) + == 1 ): # Omit sanity check on torchelastic because by default it shows one visible GPU per process return gpus @@ -115,7 +118,9 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[in return int(s.strip()) -def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False) -> List[int]: +def _sanitize_gpu_ids( + gpus: List[int], include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False +) -> List[int]: """Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of the GPUs is not available. @@ -131,7 +136,9 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: """ if sum((include_cuda, include_mps, include_xpu)) == 0: raise ValueError("At least one gpu type should be specified!") - all_available_gpus = _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu) + all_available_gpus = _get_all_available_gpus( + include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu + ) for gpu in gpus: if gpu not in all_available_gpus: raise MisconfigurationException( @@ -141,7 +148,10 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: def _normalize_parse_gpu_input_to_list( - gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool, include_xpu: bool, + gpus: Union[int, List[int], Tuple[int, ...]], + include_cuda: bool, + include_mps: bool, + include_xpu: bool, ) -> Optional[List[int]]: assert gpus is not None if isinstance(gpus, (MutableSequence, tuple)): @@ -156,7 +166,9 @@ def _normalize_parse_gpu_input_to_list( return list(range(gpus)) -def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False) -> List[int]: +def _get_all_available_gpus( + include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False +) -> List[int]: """ Returns: A list of all available GPUs @@ -166,6 +178,7 @@ def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = Fals xpu_gpus = [] if _LIGHTNING_XPU_AVAILABLE: import lightning_xpu.fabric as accelerator_xpu + xpu_gpus += accelerator_xpu._get_all_visible_xpu_devices() if include_xpu else [] return cuda_gpus + mps_gpus + xpu_gpus diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 0e510fbd95a39..382f916f01c96 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -447,13 +447,12 @@ def _choose_strategy(self) -> Union[Strategy, str]: from lightning_habana import SingleHPUStrategy return SingleHPUStrategy(device=torch.device("hpu")) - if self._accelerator_flag == "xpu": - if not _LIGHTNING_XPU_AVAILABLE: - raise ImportError( - "You have asked for XPU but you miss install related integration." - " Please run `pip install lightning-xpu` or see for further instructions" - " in https://github.com/Lightning-AI/lightning-XPU/." - ) + if self._accelerator_flag == "xpu" and not _LIGHTNING_XPU_AVAILABLE: + raise ImportError( + "You have asked for XPU but you miss install related integration." + " Please run `pip install lightning-xpu` or see for further instructions" + " in https://github.com/Lightning-AI/lightning-XPU/." + ) if self._accelerator_flag == "tpu" or isinstance(self._accelerator_flag, XLAAccelerator): if self._parallel_devices and len(self._parallel_devices) > 1: return XLAStrategy.strategy_name diff --git a/src/lightning/pytorch/trainer/setup.py b/src/lightning/pytorch/trainer/setup.py index 58ad7a84ee04f..65c3789b30a0f 100644 --- a/src/lightning/pytorch/trainer/setup.py +++ b/src/lightning/pytorch/trainer/setup.py @@ -28,7 +28,11 @@ XLAProfiler, ) from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.imports import _LIGHTNING_GRAPHCORE_AVAILABLE, _LIGHTNING_HABANA_AVAILABLE, _LIGHTNING_XPU_AVAILABLE +from lightning.pytorch.utilities.imports import ( + _LIGHTNING_GRAPHCORE_AVAILABLE, + _LIGHTNING_HABANA_AVAILABLE, + _LIGHTNING_XPU_AVAILABLE, +) from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn