Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable non blocking to device with MPS #14368

Merged
merged 19 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
92ea1a0
disable non-blocking for mps due to race condition bug
j0rd1smit Aug 23, 2022
437a366
fixed typo
j0rd1smit Aug 23, 2022
401f97a
fixed: unknown mps device for non arm systems
j0rd1smit Aug 23, 2022
a112f95
Removed unrobust test case
j0rd1smit Aug 23, 2022
1e63b36
moved _MPS_DEVICES such that we used in apply_func
j0rd1smit Aug 23, 2022
e56026d
Merge branch 'master' into bugfix/13285_disable_non_blocking_to_devic…
j0rd1smit Aug 23, 2022
02220a7
Resolve circular dependencies
carmocca Aug 23, 2022
fa5ea1b
Comment rewording
carmocca Aug 23, 2022
a81ee6f
changed torchElasticEnvironment to a global import
j0rd1smit Aug 24, 2022
8762196
simplified if statement to blocking device type
j0rd1smit Aug 24, 2022
c81b8c2
Added change to CHANGELOG
j0rd1smit Aug 24, 2022
9ff968a
Update src/pytorch_lightning/utilities/apply_func.py
justusschock Aug 24, 2022
3add8c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2022
b89148d
fixed mypy not detecting casting of device
j0rd1smit Aug 24, 2022
a89eaeb
Moved check into if statement to mainain original behavior
j0rd1smit Aug 25, 2022
39af0b9
Merge branch 'master' into bugfix/13285_disable_non_blocking_to_devic…
justusschock Aug 25, 2022
523b3a8
Merge branch 'master' into bugfix/13285_disable_non_blocking_to_devic…
Borda Aug 26, 2022
f364736
Merge branch 'master' into bugfix/13285_disable_non_blocking_to_devic…
carmocca Aug 26, 2022
fa87e28
Merge branch 'master' into bugfix/13285_disable_non_blocking_to_devic…
carmocca Aug 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed wrong num padding for `RichProgressBar` ([#14296](https://github.com/Lightning-AI/lightning/pull/14296))


- Fixed incorrect values after transferring data to a MPS device ([#13285](https://github.com/Lightning-AI/lightning/issues/13285))


- Fixed an issue to avoid the impact of sanity check on `reload_dataloaders_every_n_epochs` for validation ([#13964](https://github.com/Lightning-AI/lightning/pull/13964))


Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_lightning/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities.device_parser import parse_cpu_cores
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE
from pytorch_lightning.utilities.types import _DEVICE
Expand All @@ -42,13 +42,13 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
@staticmethod
def parse_devices(devices: Union[int, str, List[int]]) -> int:
"""Accelerator device parsing logic."""
devices = device_parser.parse_cpu_cores(devices)
devices = parse_cpu_cores(devices)
return devices

@staticmethod
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
devices = device_parser.parse_cpu_cores(devices)
devices = parse_cpu_cores(devices)
return [torch.device("cpu")] * devices

@staticmethod
Expand Down
5 changes: 3 additions & 2 deletions src/pytorch_lightning/accelerators/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import torch

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities import _HPU_AVAILABLE, device_parser
from pytorch_lightning.utilities.device_parser import parse_hpus
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _HPU_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_debug

if _HPU_AVAILABLE:
Expand Down Expand Up @@ -61,7 +62,7 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
@staticmethod
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[int]:
"""Accelerator device parsing logic."""
return device_parser.parse_hpus(devices)
return parse_hpus(devices)

@staticmethod
def get_parallel_devices(devices: int) -> List[torch.device]:
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/accelerators/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities import _IPU_AVAILABLE
from pytorch_lightning.utilities.imports import _IPU_AVAILABLE


class IPUAccelerator(Accelerator):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os

from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE

if _TPU_AVAILABLE:
import torch_xla.core.xla_env_vars as xenv
Expand Down
8 changes: 6 additions & 2 deletions src/pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
Batch = type(None)


_CPU_DEVICES = ("cpu", torch.device("cpu"))
_BLOCKING_DEVICE_TYPES = ("cpu", "mps")


def to_dtype_tensor(
Expand Down Expand Up @@ -322,6 +322,9 @@ def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any:
- :class:`torch.device`
"""

if isinstance(device, str):
device = torch.device(device)
justusschock marked this conversation as resolved.
Show resolved Hide resolved

def batch_to(data: Any) -> Any:
# try to move torchtext data first
if _TORCHTEXT_LEGACY and isinstance(data, Batch):
Expand All @@ -342,7 +345,8 @@ def batch_to(data: Any) -> Any:

kwargs = {}
# Don't issue non-blocking transfers to CPU
if isinstance(data, Tensor) and device not in _CPU_DEVICES:
# Same with MPS due to a race condition bug: https://github.com/pytorch/pytorch/issues/83015
if isinstance(data, Tensor) and isinstance(device, torch.device) and device.type not in _BLOCKING_DEVICE_TYPES:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
kwargs["non_blocking"] = True
data_output = data.to(device, **kwargs)
if data_output is not None:
Expand Down
1 change: 1 addition & 0 deletions src/pytorch_lightning/utilities/device_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def parse_gpu_ids(
gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps)
if not gpus:
raise MisconfigurationException("GPUs requested but none are available.")

if (
TorchElasticEnvironment.detect()
and len(gpus) != 1
Expand Down