Skip to content

Commit

Permalink
add xpu support
Browse files Browse the repository at this point in the history
  • Loading branch information
jingxu10 committed May 26, 2023
1 parent 0cc458e commit d2ccff8
Show file tree
Hide file tree
Showing 14 changed files with 163 additions and 33 deletions.
6 changes: 5 additions & 1 deletion src/lightning/fabric/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from lightning.fabric.accelerators.mps import MPSAccelerator # noqa: F401
from lightning.fabric.accelerators.registry import _AcceleratorRegistry, call_register_accelerators
from lightning.fabric.accelerators.xla import XLAAccelerator # noqa: F401

from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE
_ACCELERATORS_BASE_MODULE = "lightning.fabric.accelerators"
ACCELERATOR_REGISTRY = _AcceleratorRegistry()
call_register_accelerators(ACCELERATOR_REGISTRY, _ACCELERATORS_BASE_MODULE)
if _LIGHTNING_XPU_AVAILABLE:
if "xpu" not in ACCELERATOR_REGISTRY:
from lightning_xpu.fabric import XPUAccelerator
XPUAccelerator.register_accelerators(ACCELERATOR_REGISTRY)
10 changes: 8 additions & 2 deletions src/lightning/fabric/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS
from lightning.fabric.strategies import STRATEGY_REGISTRY
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE

_log = logging.getLogger(__name__)

_CLICK_AVAILABLE = RequirementCache("click")

_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu")
_SUPPORTED_ACCELERATORS = ["cpu", "gpu", "cuda", "mps", "tpu"]
if _LIGHTNING_XPU_AVAILABLE:
_SUPPORTED_ACCELERATORS.append("xpu")


def _get_supported_strategies() -> List[str]:
Expand Down Expand Up @@ -146,13 +149,16 @@ def _set_env_variables(args: Namespace) -> None:
def _get_num_processes(accelerator: str, devices: str) -> int:
"""Parse the `devices` argument to determine how many processes need to be launched on the current machine."""
if accelerator == "gpu":
parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True)
parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True, include_xpu=True)
elif accelerator == "cuda":
parsed_devices = CUDAAccelerator.parse_devices(devices)
elif accelerator == "mps":
parsed_devices = MPSAccelerator.parse_devices(devices)
elif accelerator == "tpu":
raise ValueError("Launching processes for TPU through the CLI is not supported.")
elif accelerator == "xpu":
from lightning_xpu.fabric import XPUAccelerator
parsed_devices = XPUAccelerator.parse_devices(devices)
else:
return CPUAccelerator.parse_devices(devices)
return len(parsed_devices) if parsed_devices is not None else 0
Expand Down
33 changes: 30 additions & 3 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from lightning.fabric.accelerators.cuda import CUDAAccelerator
from lightning.fabric.accelerators.mps import MPSAccelerator
from lightning.fabric.accelerators.xla import XLAAccelerator
from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE
from lightning.fabric.plugins import (
CheckpointIO,
DeepSpeedPrecision,
Expand Down Expand Up @@ -288,6 +289,13 @@ def _check_config_and_set_final_flags(
f" but accelerator set to {self._accelerator_flag}, please choose one device type"
)
self._accelerator_flag = "cuda"
if self._strategy_flag.parallel_devices[0].type == "xpu":
if self._accelerator_flag and self._accelerator_flag not in ("auto", "xpu", "gpu"):
raise ValueError(
f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class,"
f" but accelerator set to {self._accelerator_flag}, please choose one device type"
)
self._accelerator_flag = "xpu"
self._parallel_devices = self._strategy_flag.parallel_devices

def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None:
Expand All @@ -313,6 +321,11 @@ def _choose_auto_accelerator(self) -> str:
return "mps"
if CUDAAccelerator.is_available():
return "cuda"
if _LIGHTNING_XPU_AVAILABLE:
from lightning_xpu.fabric import XPUAccelerator
if XPUAccelerator.is_available():
return "xpu"

return "cpu"

@staticmethod
Expand All @@ -321,6 +334,10 @@ def _choose_gpu_accelerator_backend() -> str:
return "mps"
if CUDAAccelerator.is_available():
return "cuda"
if _LIGHTNING_XPU_AVAILABLE:
from lightning_xpu.fabric import XPUAccelerator
if XPUAccelerator.is_available():
return "xpu"
raise RuntimeError("No supported gpu backend found!")

def _set_parallel_devices_and_init_accelerator(self) -> None:
Expand Down Expand Up @@ -378,8 +395,14 @@ def _choose_strategy(self) -> Union[Strategy, str]:
return "ddp"
if len(self._parallel_devices) <= 1:
# TODO: Change this once gpu accelerator was renamed to cuda accelerator
if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or (
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps")
supported_accelerators = [CUDAAccelerator, MPSAccelerator]
supported_accelerators_str = ["cuda", "gpu", "mps"]
if _LIGHTNING_XPU_AVAILABLE:
from lightning_xpu.fabric import XPUAccelerator
supported_accelerators.append(XPUAccelerator)
supported_accelerators_str.append("xpu")
if isinstance(self._accelerator_flag, tuple(supported_accelerators)) or (
isinstance(self._accelerator_flag, str) and self._accelerator_flag in tuple(supported_accelerators_str)
):
device = _determine_root_gpu_device(self._parallel_devices)
else:
Expand Down Expand Up @@ -462,7 +485,11 @@ def _check_and_init_precision(self) -> Precision:
if self._precision_input == "16-mixed"
else "Using bfloat16 Automatic Mixed Precision (AMP)"
)
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
device = "cuda"
if self._accelerator_flag == "cpu":
device = "cpu"
elif self._accelerator_flag == "xpu":
device = "xpu"

if isinstance(self.strategy, FSDPStrategy):
return FSDPPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type]
Expand Down
9 changes: 6 additions & 3 deletions src/lightning/fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,12 @@ def setup_environment(self) -> None:
def setup_module(self, module: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
device_ids = self._determine_ddp_device_ids()
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
with ctx:
if self.root_device.type == "cuda":
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
with ctx:
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)
else:
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)

def module_to_device(self, module: Module) -> None:
Expand Down
29 changes: 18 additions & 11 deletions src/lightning/fabric/utilities/device_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from lightning.fabric.plugins.environments.torchelastic import TorchElasticEnvironment
from lightning.fabric.utilities.exceptions import MisconfigurationException
from lightning.fabric.utilities.types import _DEVICE
from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE


def _determine_root_gpu_device(gpus: List[_DEVICE]) -> Optional[_DEVICE]:
Expand Down Expand Up @@ -49,6 +50,7 @@ def _parse_gpu_ids(
gpus: Optional[Union[int, str, List[int]]],
include_cuda: bool = False,
include_mps: bool = False,
include_xpu: bool = False,
) -> Optional[List[int]]:
"""
Parses the GPU IDs given in the format as accepted by the
Expand All @@ -62,6 +64,7 @@ def _parse_gpu_ids(
Any int N > 0 indicates that GPUs [0..N) should be used.
include_cuda: A boolean value indicating whether to include CUDA devices for GPU parsing.
include_mps: A boolean value indicating whether to include MPS devices for GPU parsing.
include_xpu: A boolean value indicating whether to include Intel GPU devices for GPU parsing.
Returns:
A list of GPUs to be used or ``None`` if no GPUs were requested
Expand All @@ -71,7 +74,7 @@ def _parse_gpu_ids(
If no GPUs are available but the value of gpus variable indicates request for GPUs
.. note::
``include_cuda`` and ``include_mps`` default to ``False`` so that you only
``include_cuda``, ``include_mps`` and ``include_xpu`` default to ``False`` so that you only
have to specify which device type to use and all other devices are not disabled.
"""
# Check that gpus param is None, Int, String or Sequence of Ints
Expand All @@ -84,22 +87,22 @@ def _parse_gpu_ids(
# We know the user requested GPUs therefore if some of the
# requested GPUs are not available an exception is thrown.
gpus = _normalize_parse_gpu_string_input(gpus)
gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps)
gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu)
if not gpus:
raise MisconfigurationException("GPUs requested but none are available.")

if (
TorchElasticEnvironment.detect()
and len(gpus) != 1
and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)) == 1
and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu)) == 1
):
# Omit sanity check on torchelastic because by default it shows one visible GPU per process
return gpus

# Check that GPUs are unique. Duplicate GPUs are not supported by the backend.
_check_unique(gpus)

return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps)
return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu)


def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]:
Expand All @@ -112,7 +115,7 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[in
return int(s.strip())


def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False) -> List[int]:
def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False) -> List[int]:
"""Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of
the GPUs is not available.
Expand All @@ -126,9 +129,9 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps:
MisconfigurationException:
If machine has fewer available GPUs than requested.
"""
if sum((include_cuda, include_mps)) == 0:
if sum((include_cuda, include_mps, include_xpu)) == 0:
raise ValueError("At least one gpu type should be specified!")
all_available_gpus = _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)
all_available_gpus = _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu)
for gpu in gpus:
if gpu not in all_available_gpus:
raise MisconfigurationException(
Expand All @@ -138,7 +141,7 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps:


def _normalize_parse_gpu_input_to_list(
gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool
gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool, include_xpu: bool,
) -> Optional[List[int]]:
assert gpus is not None
if isinstance(gpus, (MutableSequence, tuple)):
Expand All @@ -148,19 +151,23 @@ def _normalize_parse_gpu_input_to_list(
if not gpus: # gpus==0
return None
if gpus == -1:
return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)
return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu)

return list(range(gpus))


def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> List[int]:
def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False) -> List[int]:
"""
Returns:
A list of all available GPUs
"""
cuda_gpus = accelerators.cuda._get_all_visible_cuda_devices() if include_cuda else []
mps_gpus = accelerators.mps._get_all_available_mps_gpus() if include_mps else []
return cuda_gpus + mps_gpus
xpu_gpus = []
if _LIGHTNING_XPU_AVAILABLE:
import lightning_xpu.fabric as accelerator_xpu
xpu_gpus += accelerator_xpu._get_all_visible_xpu_devices() if include_xpu else []
return cuda_gpus + mps_gpus + xpu_gpus


def _check_unique(device_ids: List[int]) -> None:
Expand Down
9 changes: 7 additions & 2 deletions src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def _init_dist_connection(
Args:
cluster_environment: ``ClusterEnvironment`` instance
torch_distributed_backend: Backend to use (includes `nccl` and `gloo`)
torch_distributed_backend: Backend to use (includes `nccl`, `gloo` and `ccl`)
global_rank: Rank of the current process
world_size: Number of processes in the group
kwargs: Kwargs for ``init_process_group``
Expand Down Expand Up @@ -248,7 +248,12 @@ def _init_dist_connection(


def _get_default_process_group_backend_for_device(device: torch.device) -> str:
return "nccl" if device.type == "cuda" else "gloo"
if device.type == "cuda":
return "nccl"
elif device.type == "xpu":
return "ccl"
else:
return "gloo"


class _DatasetSamplerWrapper(Dataset):
Expand Down
10 changes: 6 additions & 4 deletions src/lightning/fabric/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import platform
import sys

from lightning_utilities.core.imports import compare_version
from lightning_utilities.core.imports import compare_version, RequirementCache

_IS_WINDOWS = platform.system() == "Windows"

Expand All @@ -25,8 +25,10 @@
# 2. The inspection mode via `python -i`: https://stackoverflow.com/a/6879085/1162383
_IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive)

_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0")
_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0")
_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0")
_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0", use_base_version=True)
_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0", use_base_version=True)
_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0", use_base_version=True)
_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True)
_TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1

_LIGHTNING_XPU_AVAILABLE = RequirementCache("lightning-xpu")
34 changes: 34 additions & 0 deletions src/lightning/pytorch/_graveyard/xpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import sys
from typing import Any

import lightning.pytorch as pl


def _patch_sys_modules() -> None:
self = sys.modules[__name__]
sys.modules["lightning.pytorch.accelerators.xpu"] = self


class XPUAccelerator:
auto_device_count = ...
get_parallel_devices = ...
is_available = ...
parse_devices = ...
setup_device = ...
teardown = ...

def __init__(self, *_: Any, **__: Any) -> None:
raise NotImplementedError(
"The `XPUAccelerator` class has been moved to an external package."
" Install the extension package as `pip install lightning-xpu`"
" and import with `from lightning_xpu import XPUAccelerator`."
" Please see: https://github.com/Lightning-AI/lightning-XPU for more details."
)


def _patch_classes() -> None:
setattr(pl.accelerators, "XPUAccelerator", XPUAccelerator)


_patch_sys_modules()
_patch_classes()
9 changes: 6 additions & 3 deletions src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,12 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
device_ids = self.determine_ddp_device_ids()
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
with ctx:
if self.root_device.type == "cuda":
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
with ctx:
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
else:
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)

def setup_distributed(self) -> None:
Expand Down
Loading

0 comments on commit d2ccff8

Please sign in to comment.