From 92ea1a0fc966cfc4186ab3f8bb518e4aa2dbc128 Mon Sep 17 00:00:00 2001 From: Jordi Smit Date: Tue, 23 Aug 2022 20:53:54 +0200 Subject: [PATCH 01/14] disable non-blocking for mps due to race condition bug --- src/pytorch_lightning/utilities/apply_func.py | 5 ++++- tests/tests_pytorch/accelerators/test_mps.py | 11 +++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/apply_func.py b/src/pytorch_lightning/utilities/apply_func.py index 8729520ee9d96..efab0a9f8f0c5 100644 --- a/src/pytorch_lightning/utilities/apply_func.py +++ b/src/pytorch_lightning/utilities/apply_func.py @@ -40,6 +40,7 @@ _CPU_DEVICES = ("cpu", torch.device("cpu")) +_MPS_DEVICES = ("mps", torch.device("mps")) def to_dtype_tensor( @@ -343,7 +344,9 @@ def batch_to(data: Any) -> Any: kwargs = {} # Don't issue non-blocking transfers to CPU - if isinstance(data, Tensor) and device not in _CPU_DEVICES: + # Don't issue non-blocking transfers to MPS due to race condition bug: + # https://github.com/pytorch/pytorch/issues/83015 + if isinstance(data, Tensor) and device not in _CPU_DEVICES and device not in _MPS_DEVICES: kwargs["non_blocking"] = True data_output = data.to(device, **kwargs) if data_output is not None: diff --git a/tests/tests_pytorch/accelerators/test_mps.py b/tests/tests_pytorch/accelerators/test_mps.py index 01e13e937b4d0..30389f20f59e6 100644 --- a/tests/tests_pytorch/accelerators/test_mps.py +++ b/tests/tests_pytorch/accelerators/test_mps.py @@ -162,3 +162,14 @@ def to(self, *args, **kwargs): assert batch.text.type() == "torch.mps.LongTensor" assert batch.label.type() == "torch.mps.LongTensor" + + +@RunIf(mps=True) +def test_data_is_no_changed_after_move_to_mps_device(): + trainer = Trainer(accelerator="mps", devices=1) + x = torch.zeros([10, 10]) + device = torch.device("mps") + + for _ in range(1000): + x_mps = trainer.strategy.batch_to_device(x.clone(), device) + torch.testing.assert_close(x_mps, x) From 437a3668778ae3c5dc1f77f2a26395f25bef7740 Mon Sep 17 00:00:00 2001 From: Jordi Smit Date: Tue, 23 Aug 2022 21:04:39 +0200 Subject: [PATCH 02/14] fixed typo --- tests/tests_pytorch/accelerators/test_mps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/accelerators/test_mps.py b/tests/tests_pytorch/accelerators/test_mps.py index 30389f20f59e6..83155fd723a44 100644 --- a/tests/tests_pytorch/accelerators/test_mps.py +++ b/tests/tests_pytorch/accelerators/test_mps.py @@ -165,7 +165,7 @@ def to(self, *args, **kwargs): @RunIf(mps=True) -def test_data_is_no_changed_after_move_to_mps_device(): +def test_data_is_not_changed_after_move_to_mps_device(): trainer = Trainer(accelerator="mps", devices=1) x = torch.zeros([10, 10]) device = torch.device("mps") From 401f97a31dca27d82f2c06f7c12c5ad795137906 Mon Sep 17 00:00:00 2001 From: Jordi Smit Date: Tue, 23 Aug 2022 21:27:45 +0200 Subject: [PATCH 03/14] fixed: unknown mps device for non arm systems --- src/pytorch_lightning/utilities/apply_func.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/utilities/apply_func.py b/src/pytorch_lightning/utilities/apply_func.py index efab0a9f8f0c5..03b7a8402c9ba 100644 --- a/src/pytorch_lightning/utilities/apply_func.py +++ b/src/pytorch_lightning/utilities/apply_func.py @@ -15,6 +15,7 @@ import dataclasses import operator +import platform from abc import ABC from collections import defaultdict, OrderedDict from collections.abc import Mapping, Sequence @@ -27,7 +28,7 @@ from torch import Tensor from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_LEGACY +from pytorch_lightning.utilities.imports import _compare_version, _TORCH_GREATER_EQUAL_1_12, _TORCHTEXT_LEGACY from pytorch_lightning.utilities.warnings import rank_zero_deprecation if _TORCHTEXT_LEGACY: @@ -40,7 +41,10 @@ _CPU_DEVICES = ("cpu", torch.device("cpu")) -_MPS_DEVICES = ("mps", torch.device("mps")) +if _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64"): + _MPS_DEVICES = ("mps", torch.device("mps")) +else: + _MPS_DEVICES = ("mps",) def to_dtype_tensor( From a112f9587f90ae6801b859130058cc6ffa62483b Mon Sep 17 00:00:00 2001 From: Jordi Smit Date: Tue, 23 Aug 2022 22:23:37 +0200 Subject: [PATCH 04/14] Removed unrobust test case --- tests/tests_pytorch/accelerators/test_mps.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/tests_pytorch/accelerators/test_mps.py b/tests/tests_pytorch/accelerators/test_mps.py index 83155fd723a44..01e13e937b4d0 100644 --- a/tests/tests_pytorch/accelerators/test_mps.py +++ b/tests/tests_pytorch/accelerators/test_mps.py @@ -162,14 +162,3 @@ def to(self, *args, **kwargs): assert batch.text.type() == "torch.mps.LongTensor" assert batch.label.type() == "torch.mps.LongTensor" - - -@RunIf(mps=True) -def test_data_is_not_changed_after_move_to_mps_device(): - trainer = Trainer(accelerator="mps", devices=1) - x = torch.zeros([10, 10]) - device = torch.device("mps") - - for _ in range(1000): - x_mps = trainer.strategy.batch_to_device(x.clone(), device) - torch.testing.assert_close(x_mps, x) From 1e63b365ae7d9a7ced55f2fd4dfcfd4ce55b9302 Mon Sep 17 00:00:00 2001 From: Jordi Smit Date: Tue, 23 Aug 2022 22:26:26 +0200 Subject: [PATCH 05/14] moved _MPS_DEVICES such that we used in apply_func --- src/pytorch_lightning/accelerators/mps.py | 9 +-------- src/pytorch_lightning/utilities/apply_func.py | 5 ++--- src/pytorch_lightning/utilities/imports.py | 6 ++++++ 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/pytorch_lightning/accelerators/mps.py b/src/pytorch_lightning/accelerators/mps.py index 5ebcb37cd0ed7..ca8eb5394518d 100644 --- a/src/pytorch_lightning/accelerators/mps.py +++ b/src/pytorch_lightning/accelerators/mps.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import platform from typing import Any, Dict, List, Optional, Union import torch @@ -19,15 +18,9 @@ from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE, _TORCH_GREATER_EQUAL_1_12 +from pytorch_lightning.utilities.imports import _MPS_AVAILABLE, _PSUTIL_AVAILABLE from pytorch_lightning.utilities.types import _DEVICE -# For using the `MPSAccelerator`, user's machine should have `torch>=1.12`, Metal programming framework and -# the ARM-based Apple Silicon processors. -_MPS_AVAILABLE = ( - _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64") -) - class MPSAccelerator(Accelerator): """Accelerator for Metal Apple Silicon GPU devices.""" diff --git a/src/pytorch_lightning/utilities/apply_func.py b/src/pytorch_lightning/utilities/apply_func.py index 03b7a8402c9ba..5acd3a9a735ad 100644 --- a/src/pytorch_lightning/utilities/apply_func.py +++ b/src/pytorch_lightning/utilities/apply_func.py @@ -15,7 +15,6 @@ import dataclasses import operator -import platform from abc import ABC from collections import defaultdict, OrderedDict from collections.abc import Mapping, Sequence @@ -28,7 +27,7 @@ from torch import Tensor from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _compare_version, _TORCH_GREATER_EQUAL_1_12, _TORCHTEXT_LEGACY +from pytorch_lightning.utilities.imports import _compare_version, _MPS_AVAILABLE, _TORCHTEXT_LEGACY from pytorch_lightning.utilities.warnings import rank_zero_deprecation if _TORCHTEXT_LEGACY: @@ -41,7 +40,7 @@ _CPU_DEVICES = ("cpu", torch.device("cpu")) -if _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64"): +if _MPS_AVAILABLE: _MPS_DEVICES = ("mps", torch.device("mps")) else: _MPS_DEVICES = ("mps",) diff --git a/src/pytorch_lightning/utilities/imports.py b/src/pytorch_lightning/utilities/imports.py index ba437ad332dfa..7cd060c4a1d37 100644 --- a/src/pytorch_lightning/utilities/imports.py +++ b/src/pytorch_lightning/utilities/imports.py @@ -157,6 +157,12 @@ def __repr__(self) -> str: _TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() +# For using the `MPSAccelerator`, user's machine should have `torch>=1.12`, Metal programming framework and +# the ARM-based Apple Silicon processors. +_MPS_AVAILABLE = ( + _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64") +) + if _POPTORCH_AVAILABLE: import poptorch From 02220a711bbfb693a5b295a2304d9cd3033b8875 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 24 Aug 2022 00:04:37 +0200 Subject: [PATCH 06/14] Resolve circular dependencies --- src/pytorch_lightning/accelerators/cpu.py | 6 +++--- src/pytorch_lightning/accelerators/hpu.py | 5 +++-- src/pytorch_lightning/accelerators/ipu.py | 2 +- src/pytorch_lightning/accelerators/mps.py | 9 ++++++++- .../plugins/environments/xla_environment.py | 2 +- src/pytorch_lightning/utilities/apply_func.py | 3 ++- src/pytorch_lightning/utilities/device_parser.py | 3 ++- src/pytorch_lightning/utilities/imports.py | 6 ------ 8 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/pytorch_lightning/accelerators/cpu.py b/src/pytorch_lightning/accelerators/cpu.py index fea8ee70d17df..d0981e7269305 100644 --- a/src/pytorch_lightning/accelerators/cpu.py +++ b/src/pytorch_lightning/accelerators/cpu.py @@ -16,7 +16,7 @@ import torch from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.utilities import device_parser +from pytorch_lightning.utilities.device_parser import parse_cpu_cores from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE from pytorch_lightning.utilities.types import _DEVICE @@ -42,13 +42,13 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: @staticmethod def parse_devices(devices: Union[int, str, List[int]]) -> int: """Accelerator device parsing logic.""" - devices = device_parser.parse_cpu_cores(devices) + devices = parse_cpu_cores(devices) return devices @staticmethod def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" - devices = device_parser.parse_cpu_cores(devices) + devices = parse_cpu_cores(devices) return [torch.device("cpu")] * devices @staticmethod diff --git a/src/pytorch_lightning/accelerators/hpu.py b/src/pytorch_lightning/accelerators/hpu.py index 8fc242fa55f20..c85e81756c2a9 100644 --- a/src/pytorch_lightning/accelerators/hpu.py +++ b/src/pytorch_lightning/accelerators/hpu.py @@ -17,8 +17,9 @@ import torch from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.utilities import _HPU_AVAILABLE, device_parser +from pytorch_lightning.utilities.device_parser import parse_hpus from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _HPU_AVAILABLE from pytorch_lightning.utilities.rank_zero import rank_zero_debug if _HPU_AVAILABLE: @@ -61,7 +62,7 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: @staticmethod def parse_devices(devices: Union[int, str, List[int]]) -> Optional[int]: """Accelerator device parsing logic.""" - return device_parser.parse_hpus(devices) + return parse_hpus(devices) @staticmethod def get_parallel_devices(devices: int) -> List[torch.device]: diff --git a/src/pytorch_lightning/accelerators/ipu.py b/src/pytorch_lightning/accelerators/ipu.py index b5110e58028a5..b09fd33c29227 100644 --- a/src/pytorch_lightning/accelerators/ipu.py +++ b/src/pytorch_lightning/accelerators/ipu.py @@ -16,7 +16,7 @@ import torch from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.utilities import _IPU_AVAILABLE +from pytorch_lightning.utilities.imports import _IPU_AVAILABLE class IPUAccelerator(Accelerator): diff --git a/src/pytorch_lightning/accelerators/mps.py b/src/pytorch_lightning/accelerators/mps.py index ca8eb5394518d..5ebcb37cd0ed7 100644 --- a/src/pytorch_lightning/accelerators/mps.py +++ b/src/pytorch_lightning/accelerators/mps.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import platform from typing import Any, Dict, List, Optional, Union import torch @@ -18,9 +19,15 @@ from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _MPS_AVAILABLE, _PSUTIL_AVAILABLE +from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE, _TORCH_GREATER_EQUAL_1_12 from pytorch_lightning.utilities.types import _DEVICE +# For using the `MPSAccelerator`, user's machine should have `torch>=1.12`, Metal programming framework and +# the ARM-based Apple Silicon processors. +_MPS_AVAILABLE = ( + _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64") +) + class MPSAccelerator(Accelerator): """Accelerator for Metal Apple Silicon GPU devices.""" diff --git a/src/pytorch_lightning/plugins/environments/xla_environment.py b/src/pytorch_lightning/plugins/environments/xla_environment.py index a78ebeb36a6a4..4072f6f8715f5 100644 --- a/src/pytorch_lightning/plugins/environments/xla_environment.py +++ b/src/pytorch_lightning/plugins/environments/xla_environment.py @@ -15,7 +15,7 @@ import os from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment -from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities.imports import _TPU_AVAILABLE if _TPU_AVAILABLE: import torch_xla.core.xla_env_vars as xenv diff --git a/src/pytorch_lightning/utilities/apply_func.py b/src/pytorch_lightning/utilities/apply_func.py index 7a07dbc2c7780..5091a51b040be 100644 --- a/src/pytorch_lightning/utilities/apply_func.py +++ b/src/pytorch_lightning/utilities/apply_func.py @@ -25,8 +25,9 @@ import torch from torch import Tensor +from pytorch_lightning.accelerators.mps import _MPS_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _compare_version, _MPS_AVAILABLE, _TORCHTEXT_LEGACY +from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_LEGACY from pytorch_lightning.utilities.warnings import rank_zero_deprecation if _TORCHTEXT_LEGACY: diff --git a/src/pytorch_lightning/utilities/device_parser.py b/src/pytorch_lightning/utilities/device_parser.py index c76933e489db7..d6c37bb626443 100644 --- a/src/pytorch_lightning/utilities/device_parser.py +++ b/src/pytorch_lightning/utilities/device_parser.py @@ -17,7 +17,6 @@ import torch import torch.cuda -from pytorch_lightning.plugins.environments import TorchElasticEnvironment from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import _DEVICE @@ -109,6 +108,8 @@ def parse_gpu_ids( gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps) if not gpus: raise MisconfigurationException("GPUs requested but none are available.") + from pytorch_lightning.plugins.environments import TorchElasticEnvironment + if ( TorchElasticEnvironment.detect() and len(gpus) != 1 diff --git a/src/pytorch_lightning/utilities/imports.py b/src/pytorch_lightning/utilities/imports.py index 7cd060c4a1d37..ba437ad332dfa 100644 --- a/src/pytorch_lightning/utilities/imports.py +++ b/src/pytorch_lightning/utilities/imports.py @@ -157,12 +157,6 @@ def __repr__(self) -> str: _TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() -# For using the `MPSAccelerator`, user's machine should have `torch>=1.12`, Metal programming framework and -# the ARM-based Apple Silicon processors. -_MPS_AVAILABLE = ( - _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64") -) - if _POPTORCH_AVAILABLE: import poptorch From fa5ea1b3f8aa9b00954adbc4ff1444f8cde511a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 24 Aug 2022 00:06:12 +0200 Subject: [PATCH 07/14] Comment rewording --- src/pytorch_lightning/utilities/apply_func.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/pytorch_lightning/utilities/apply_func.py b/src/pytorch_lightning/utilities/apply_func.py index 5091a51b040be..db5a707024a31 100644 --- a/src/pytorch_lightning/utilities/apply_func.py +++ b/src/pytorch_lightning/utilities/apply_func.py @@ -347,8 +347,7 @@ def batch_to(data: Any) -> Any: kwargs = {} # Don't issue non-blocking transfers to CPU - # Don't issue non-blocking transfers to MPS due to race condition bug: - # https://github.com/pytorch/pytorch/issues/83015 + # Same with MPS due to a race condition bug: https://github.com/pytorch/pytorch/issues/83015 if isinstance(data, Tensor) and device not in _CPU_DEVICES and device not in _MPS_DEVICES: kwargs["non_blocking"] = True data_output = data.to(device, **kwargs) From a81ee6f4b0217d5cccd244a5788aa42ed0a63e7a Mon Sep 17 00:00:00 2001 From: Jordi Smit Date: Wed, 24 Aug 2022 10:42:42 +0200 Subject: [PATCH 08/14] changed torchElasticEnvironment to a global import --- src/pytorch_lightning/utilities/device_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/device_parser.py b/src/pytorch_lightning/utilities/device_parser.py index d6c37bb626443..9f036481687b7 100644 --- a/src/pytorch_lightning/utilities/device_parser.py +++ b/src/pytorch_lightning/utilities/device_parser.py @@ -17,6 +17,7 @@ import torch import torch.cuda +from pytorch_lightning.plugins.environments import TorchElasticEnvironment from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import _DEVICE @@ -108,7 +109,6 @@ def parse_gpu_ids( gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps) if not gpus: raise MisconfigurationException("GPUs requested but none are available.") - from pytorch_lightning.plugins.environments import TorchElasticEnvironment if ( TorchElasticEnvironment.detect() From 8762196015f57b84f2efceb4f7332ab7d1b778e9 Mon Sep 17 00:00:00 2001 From: Jordi Smit Date: Wed, 24 Aug 2022 10:46:46 +0200 Subject: [PATCH 09/14] simplified if statement to blocking device type --- src/pytorch_lightning/utilities/apply_func.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/pytorch_lightning/utilities/apply_func.py b/src/pytorch_lightning/utilities/apply_func.py index db5a707024a31..0aa8a6f0dc37e 100644 --- a/src/pytorch_lightning/utilities/apply_func.py +++ b/src/pytorch_lightning/utilities/apply_func.py @@ -25,7 +25,6 @@ import torch from torch import Tensor -from pytorch_lightning.accelerators.mps import _MPS_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_LEGACY from pytorch_lightning.utilities.warnings import rank_zero_deprecation @@ -39,11 +38,7 @@ Batch = type(None) -_CPU_DEVICES = ("cpu", torch.device("cpu")) -if _MPS_AVAILABLE: - _MPS_DEVICES = ("mps", torch.device("mps")) -else: - _MPS_DEVICES = ("mps",) +_BLOCKING_DEVICE_TYPES = ("cpu", "mps") def to_dtype_tensor( @@ -327,6 +322,9 @@ def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any: - :class:`torch.device` """ + if isinstance(device, str): + device = torch.device(device) + def batch_to(data: Any) -> Any: # try to move torchtext data first if _TORCHTEXT_LEGACY and isinstance(data, Batch): @@ -348,7 +346,7 @@ def batch_to(data: Any) -> Any: kwargs = {} # Don't issue non-blocking transfers to CPU # Same with MPS due to a race condition bug: https://github.com/pytorch/pytorch/issues/83015 - if isinstance(data, Tensor) and device not in _CPU_DEVICES and device not in _MPS_DEVICES: + if isinstance(data, Tensor) and device.type not in _BLOCKING_DEVICE_TYPES: kwargs["non_blocking"] = True data_output = data.to(device, **kwargs) if data_output is not None: From c81b8c2814ab968bfab1189c53375f9f93b482be Mon Sep 17 00:00:00 2001 From: Jordi Smit Date: Wed, 24 Aug 2022 10:53:17 +0200 Subject: [PATCH 10/14] Added change to CHANGELOG --- src/pytorch_lightning/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 07c34bbc0e579..9687f8da9c5a2 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -82,6 +82,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed wrong num padding for `RichProgressBar` ([#14296](https://github.com/Lightning-AI/lightning/pull/14296)) +- Fixed incorrect values after transferring data to a MPS device ([#13285](https://github.com/Lightning-AI/lightning/issues/13285)) + + ## [1.7.2] - 2022-08-17 ### Added From 9ff968adca79b86d993aa8411008653e4765c19c Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Wed, 24 Aug 2022 11:56:51 +0200 Subject: [PATCH 11/14] Update src/pytorch_lightning/utilities/apply_func.py --- src/pytorch_lightning/utilities/apply_func.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pytorch_lightning/utilities/apply_func.py b/src/pytorch_lightning/utilities/apply_func.py index 0aa8a6f0dc37e..e17b989e3e121 100644 --- a/src/pytorch_lightning/utilities/apply_func.py +++ b/src/pytorch_lightning/utilities/apply_func.py @@ -324,6 +324,8 @@ def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any: if isinstance(device, str): device = torch.device(device) + + assert isinstance(device, torch.device) def batch_to(data: Any) -> Any: # try to move torchtext data first From 3add8c939633fc2c0ece727b1670b441d904bdaf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Aug 2022 09:58:26 +0000 Subject: [PATCH 12/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/utilities/apply_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/apply_func.py b/src/pytorch_lightning/utilities/apply_func.py index e17b989e3e121..e863f2e091f34 100644 --- a/src/pytorch_lightning/utilities/apply_func.py +++ b/src/pytorch_lightning/utilities/apply_func.py @@ -324,7 +324,7 @@ def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any: if isinstance(device, str): device = torch.device(device) - + assert isinstance(device, torch.device) def batch_to(data: Any) -> Any: From b89148db8241ea665c39168d695df603a81bbe78 Mon Sep 17 00:00:00 2001 From: Jordi Smit Date: Wed, 24 Aug 2022 12:55:07 +0200 Subject: [PATCH 13/14] fixed mypy not detecting casting of device --- src/pytorch_lightning/utilities/apply_func.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/utilities/apply_func.py b/src/pytorch_lightning/utilities/apply_func.py index e863f2e091f34..9cc5c3047a3fe 100644 --- a/src/pytorch_lightning/utilities/apply_func.py +++ b/src/pytorch_lightning/utilities/apply_func.py @@ -325,9 +325,10 @@ def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any: if isinstance(device, str): device = torch.device(device) - assert isinstance(device, torch.device) - def batch_to(data: Any) -> Any: + # Check must happen inside the inner function else mypy won't register the casting in the outer function. + assert isinstance(device, torch.device) + # try to move torchtext data first if _TORCHTEXT_LEGACY and isinstance(data, Batch): # TODO: also remove the torchtext dependency with Lightning 1.8 From a89eaeb146e472356a3abd51b38e6b190bf1bc76 Mon Sep 17 00:00:00 2001 From: Jordi Smit Date: Thu, 25 Aug 2022 09:28:12 +0200 Subject: [PATCH 14/14] Moved check into if statement to mainain original behavior --- src/pytorch_lightning/utilities/apply_func.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/pytorch_lightning/utilities/apply_func.py b/src/pytorch_lightning/utilities/apply_func.py index 9cc5c3047a3fe..757640b965092 100644 --- a/src/pytorch_lightning/utilities/apply_func.py +++ b/src/pytorch_lightning/utilities/apply_func.py @@ -326,9 +326,6 @@ def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any: device = torch.device(device) def batch_to(data: Any) -> Any: - # Check must happen inside the inner function else mypy won't register the casting in the outer function. - assert isinstance(device, torch.device) - # try to move torchtext data first if _TORCHTEXT_LEGACY and isinstance(data, Batch): # TODO: also remove the torchtext dependency with Lightning 1.8 @@ -349,7 +346,7 @@ def batch_to(data: Any) -> Any: kwargs = {} # Don't issue non-blocking transfers to CPU # Same with MPS due to a race condition bug: https://github.com/pytorch/pytorch/issues/83015 - if isinstance(data, Tensor) and device.type not in _BLOCKING_DEVICE_TYPES: + if isinstance(data, Tensor) and isinstance(device, torch.device) and device.type not in _BLOCKING_DEVICE_TYPES: kwargs["non_blocking"] = True data_output = data.to(device, **kwargs) if data_output is not None: