From 9e3587bd61e797dacbf0611abb01892f76f6df56 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 23 Jun 2021 11:09:50 +0100 Subject: [PATCH 1/4] Add torchelastic check --- pytorch_lightning/utilities/device_parser.py | 7 ++++++ tests/models/test_gpu.py | 24 ++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/pytorch_lightning/utilities/device_parser.py b/pytorch_lightning/utilities/device_parser.py index ecb5d6ac00a03..51483fc568ce9 100644 --- a/pytorch_lightning/utilities/device_parser.py +++ b/pytorch_lightning/utilities/device_parser.py @@ -16,6 +16,7 @@ import torch +from pytorch_lightning.plugins.environments import TorchElasticEnvironment from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _compare_version @@ -78,6 +79,12 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i gpus = _normalize_parse_gpu_input_to_list(gpus) if not gpus: raise MisconfigurationException("GPUs requested but none are available.") + + if TorchElasticEnvironment.is_using_torchelastic() and len(gpus) != 1 and len(_get_all_available_gpus()) == 1: + # omit sanity check on torchelastic + # as by default shows one visible GPU per process + return gpus + gpus = _sanitize_gpu_ids(gpus) return gpus diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 65a1e093a9e96..5636887eb1d66 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import operator +import os from collections import namedtuple from unittest.mock import patch +import mock import pytest import torch @@ -329,3 +331,25 @@ def to(self, *args, **kwargs): with patch.object(batch, 'to', wraps=batch.to) as mocked: batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0')) mocked.assert_called_with(torch.device('cuda', 0)) + + +@mock.patch.dict( + os.environ, { + "CUDA_VISIBLE_DEVICES": "0", + "LOCAL_RANK": "1", + "GROUP_RANK": "1", + "RANK": "3", + "WORLD_SIZE": "4", + "LOCAL_WORLD_SIZE": "2", + } +) +@mock.patch('torch.cuda.device_count', return_value=1) +@pytest.mark.parametrize("gpus", [[0, 1, 2], 2, '0']) +def test_torchelastic_gpu_parsing(mocked_device_count, gpus): + """ + Ensure when using torchelastic and nproc_per_node is set to the default of 1 + That we omit sanitizing the gpus as only one of the GPUs is visible. + """ + trainer = Trainer(gpus=gpus) + assert trainer.accelerator_connector.parallel_device_ids == device_parser.parse_gpu_ids(gpus) + assert trainer.gpus == gpus From b5a800137906470655153c3024c012c6c9d8f8c9 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 23 Jun 2021 11:13:07 +0100 Subject: [PATCH 2/4] Add changelog --- CHANGELOG.md | 3 +++ tests/models/test_gpu.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c1ac8a689ce15..f562e7fad0a49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -111,6 +111,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Add support for calling scripts using the module syntax (`python -m package.script`) ([#8073](https://github.com/PyTorchLightning/pytorch-lightning/pull/8073)) +- Add torchelastic check when sanitizing GPUs ([#8095](https://github.com/PyTorchLightning/pytorch-lightning/pull/8095)) + + ### Changed diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 5636887eb1d66..6da3225435e18 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -347,7 +347,7 @@ def to(self, *args, **kwargs): @pytest.mark.parametrize("gpus", [[0, 1, 2], 2, '0']) def test_torchelastic_gpu_parsing(mocked_device_count, gpus): """ - Ensure when using torchelastic and nproc_per_node is set to the default of 1 + Ensure when using torchelastic and nproc_per_node is set to the default of 1 per GPU device That we omit sanitizing the gpus as only one of the GPUs is visible. """ trainer = Trainer(gpus=gpus) From da2aa1ada7b498bd5f48df86cf5e30e344044997 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 23 Jun 2021 11:35:56 +0100 Subject: [PATCH 3/4] Address review --- pytorch_lightning/utilities/device_parser.py | 3 +- tests/models/test_gpu.py | 46 ++++++++++---------- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/pytorch_lightning/utilities/device_parser.py b/pytorch_lightning/utilities/device_parser.py index 51483fc568ce9..bfbd1847a3092 100644 --- a/pytorch_lightning/utilities/device_parser.py +++ b/pytorch_lightning/utilities/device_parser.py @@ -81,8 +81,7 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i raise MisconfigurationException("GPUs requested but none are available.") if TorchElasticEnvironment.is_using_torchelastic() and len(gpus) != 1 and len(_get_all_available_gpus()) == 1: - # omit sanity check on torchelastic - # as by default shows one visible GPU per process + # omit sanity check on torchelastic as by default shows one visible GPU per process return gpus gpus = _sanitize_gpu_ids(gpus) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 6da3225435e18..107c05e8e17f3 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -23,6 +23,7 @@ import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils from pytorch_lightning import Trainer +from pytorch_lightning.plugins.environments import TorchElasticEnvironment from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _compare_version @@ -221,6 +222,29 @@ def test_parse_gpu_returns_none_when_no_devices_are_available(mocked_device_coun device_parser.parse_gpu_ids(gpus) +@mock.patch.dict( + os.environ, { + "CUDA_VISIBLE_DEVICES": "0", + "LOCAL_RANK": "1", + "GROUP_RANK": "1", + "RANK": "3", + "WORLD_SIZE": "4", + "LOCAL_WORLD_SIZE": "2", + } +) +@mock.patch('torch.cuda.device_count', return_value=1) +@pytest.mark.parametrize("gpus", [[0, 1, 2], 2, '0']) +def test_torchelastic_gpu_parsing(mocked_device_count, gpus): + """ + Ensure when using torchelastic and nproc_per_node is set to the default of 1 per GPU device + That we omit sanitizing the gpus as only one of the GPUs is visible. + """ + trainer = Trainer(gpus=gpus) + assert isinstance(trainer.accelerator_connector.cluster_environment, TorchElasticEnvironment) + assert trainer.accelerator_connector.parallel_device_ids == device_parser.parse_gpu_ids(gpus) + assert trainer.gpus == gpus + + @RunIf(min_gpus=1) def test_single_gpu_batch_parse(): trainer = Trainer(gpus=1) @@ -331,25 +355,3 @@ def to(self, *args, **kwargs): with patch.object(batch, 'to', wraps=batch.to) as mocked: batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0')) mocked.assert_called_with(torch.device('cuda', 0)) - - -@mock.patch.dict( - os.environ, { - "CUDA_VISIBLE_DEVICES": "0", - "LOCAL_RANK": "1", - "GROUP_RANK": "1", - "RANK": "3", - "WORLD_SIZE": "4", - "LOCAL_WORLD_SIZE": "2", - } -) -@mock.patch('torch.cuda.device_count', return_value=1) -@pytest.mark.parametrize("gpus", [[0, 1, 2], 2, '0']) -def test_torchelastic_gpu_parsing(mocked_device_count, gpus): - """ - Ensure when using torchelastic and nproc_per_node is set to the default of 1 per GPU device - That we omit sanitizing the gpus as only one of the GPUs is visible. - """ - trainer = Trainer(gpus=gpus) - assert trainer.accelerator_connector.parallel_device_ids == device_parser.parse_gpu_ids(gpus) - assert trainer.gpus == gpus From a09e1e9296ed6a0f83013e569e5cf252a2d7d51b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 23 Jun 2021 12:35:52 +0100 Subject: [PATCH 4/4] fix --- tests/models/test_gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 107c05e8e17f3..cd7c90552ab2e 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -14,9 +14,9 @@ import operator import os from collections import namedtuple +from unittest import mock from unittest.mock import patch -import mock import pytest import torch