From bf44560860c9cce924657d52efe1242b199123a3 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 26 Nov 2020 18:04:21 +0100 Subject: [PATCH 1/5] xla --- docs/source/conf.py | 2 +- pytorch_lightning/accelerators/accelerator_connector.py | 9 +-------- pytorch_lightning/utilities/__init__.py | 1 + pytorch_lightning/utilities/xla_device_utils.py | 9 ++++----- tests/utilities/test_xla_device_utils.py | 7 ++----- 5 files changed, 9 insertions(+), 19 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index b05234fd16628..4a22f012c99cf 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -360,8 +360,8 @@ def package_list_from_file(file): from pytorch_lightning.utilities import ( NATIVE_AMP_AVAILABLE, APEX_AVAILABLE, + XLA_AVAILABLE, ) -XLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index f8d90945e9e77..c40600785a558 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -15,7 +15,7 @@ import os import torch -from pytorch_lightning.utilities import device_parser +from pytorch_lightning.utilities import device_parser, XLA_AVAILABLE from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -24,13 +24,6 @@ from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment from pytorch_lightning.accelerators.accelerator import Accelerator -try: - import torch_xla -except ImportError: - XLA_AVAILABLE = False -else: - XLA_AVAILABLE = True - try: import horovod.torch as hvd except (ModuleNotFoundError, ImportError): diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index da6e14a10cc03..12a65503dc1b4 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -43,6 +43,7 @@ def _module_available(module_path: str) -> bool: APEX_AVAILABLE = _module_available("apex.amp") NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") +XLA_AVAILABLE = _module_available("torch_xla") FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps diff --git a/pytorch_lightning/utilities/xla_device_utils.py b/pytorch_lightning/utilities/xla_device_utils.py index 14a59fd105c5a..dbce791c1771a 100644 --- a/pytorch_lightning/utilities/xla_device_utils.py +++ b/pytorch_lightning/utilities/xla_device_utils.py @@ -18,11 +18,10 @@ import torch -TORCHXLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None -if TORCHXLA_AVAILABLE: +from pytorch_lightning.utilities import XLA_AVAILABLE + +if XLA_AVAILABLE: import torch_xla.core.xla_model as xm -else: - xm = None def inner_f(queue, func, *args, **kwargs): # pragma: no cover @@ -90,6 +89,6 @@ def tpu_device_exists() -> bool: Return: A boolean value indicating if a TPU device exists on the system """ - if XLADeviceUtils.TPU_AVAILABLE is None and TORCHXLA_AVAILABLE: + if XLADeviceUtils.TPU_AVAILABLE is None and XLA_AVAILABLE: XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)() return XLADeviceUtils.TPU_AVAILABLE diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index 10de63db049e7..1b3911d4152c0 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -16,15 +16,12 @@ import pytest import pytorch_lightning.utilities.xla_device_utils as xla_utils +from pytorch_lightning.utilities import XLA_AVAILABLE from tests.base.develop_utils import pl_multi_process_test -try: +if XLA_AVAILABLE: import torch_xla.core.xla_model as xm - XLA_AVAILABLE = True -except ImportError as e: - XLA_AVAILABLE = False - @pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent") def test_tpu_device_absence(): From e569472c97492b6bcec0ecad79fc1b02b2fe18e3 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 26 Nov 2020 18:08:11 +0100 Subject: [PATCH 2/5] tpu --- pytorch_lightning/accelerators/tpu_accelerator.py | 6 +----- pytorch_lightning/callbacks/early_stopping.py | 6 +----- pytorch_lightning/core/lightning.py | 13 ++----------- pytorch_lightning/trainer/data_loading.py | 6 +----- pytorch_lightning/utilities/__init__.py | 2 ++ tests/models/test_tpu.py | 5 +---- 6 files changed, 8 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index cc7da4dc10781..6e8314106112e 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import io -import math import os import re from typing import Optional, Union, Any @@ -24,12 +23,9 @@ from pytorch_lightning import _logger as log from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core import LightningModule -from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn, TPU_AVAILABLE from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils - -TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() if TPU_AVAILABLE: import torch_xla diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 972d16fd705a8..005a3f8cde4ad 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -26,11 +26,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils - -TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() - +from pytorch_lightning.utilities import rank_zero_warn, TPU_AVAILABLE torch_inf = torch.tensor(np.Inf) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index bf9c30190cc8c..e888c6f4a87df 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -32,23 +32,14 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO from pytorch_lightning.core.step_result import Result -from pytorch_lightning.utilities import rank_zero_warn, AMPType +from pytorch_lightning.utilities import rank_zero_warn, AMPType, TPU_AVAILABLE from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin -from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.parsing import ( - AttributeDict, - collect_init_args, - get_init_args, -) -from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args from torch import ScriptModule, Tensor from torch.nn import Module from torch.optim.optimizer import Optimizer - -TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() - if TPU_AVAILABLE: import torch_xla.core.xla_model as xm diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 270ee0d4db30b..9fe18d14a7f12 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -22,19 +22,15 @@ from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.core import LightningModule -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn, TPU_AVAILABLE from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.model_utils import is_overridden -from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils from copy import deepcopy from typing import Iterable -TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() - if TPU_AVAILABLE: - import torch_xla import torch_xla.core.xla_model as xm try: diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 12a65503dc1b4..e52700c76b81e 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -21,6 +21,7 @@ from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable +from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils def _module_available(module_path: str) -> bool: @@ -44,6 +45,7 @@ def _module_available(module_path: str) -> bool: APEX_AVAILABLE = _module_available("apex.amp") NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") XLA_AVAILABLE = _module_available("torch_xla") +TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 0077df283ee61..b69f1b60fcbf7 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -19,20 +19,17 @@ import tests.base.develop_pipelines as tpipes from pytorch_lightning import Trainer, seed_everything -from pytorch_lightning.accelerators.accelerator import BackendType from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.utilities import TPU_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils from tests.base import EvalModelTemplate from tests.base.datasets import TrialMNIST from tests.base.develop_utils import pl_multi_process_test -TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() if TPU_AVAILABLE: import torch_xla - import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp SERIAL_EXEC = xmp.MpSerialExecutor() From 02e055b4eaa32b8087f2037ce76743a4417495a9 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 26 Nov 2020 18:36:32 +0100 Subject: [PATCH 3/5] fix --- pytorch_lightning/utilities/__init__.py | 2 +- pytorch_lightning/utilities/xla_device_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index e52700c76b81e..51e209958d815 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -21,7 +21,7 @@ from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable -from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils +from pytorch_lightning.utilities.xla_device_utils import XLA_AVAILABLE, XLADeviceUtils def _module_available(module_path: str) -> bool: diff --git a/pytorch_lightning/utilities/xla_device_utils.py b/pytorch_lightning/utilities/xla_device_utils.py index dbce791c1771a..788bf26c890ea 100644 --- a/pytorch_lightning/utilities/xla_device_utils.py +++ b/pytorch_lightning/utilities/xla_device_utils.py @@ -18,7 +18,7 @@ import torch -from pytorch_lightning.utilities import XLA_AVAILABLE +XLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None if XLA_AVAILABLE: import torch_xla.core.xla_model as xm From 11c06e0497c229cf9d422a561e6c94e8febff7b5 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 26 Nov 2020 20:00:27 +0100 Subject: [PATCH 4/5] fix --- pytorch_lightning/utilities/xla_device_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/xla_device_utils.py b/pytorch_lightning/utilities/xla_device_utils.py index 788bf26c890ea..c6dd63237e121 100644 --- a/pytorch_lightning/utilities/xla_device_utils.py +++ b/pytorch_lightning/utilities/xla_device_utils.py @@ -65,7 +65,7 @@ def _fetch_xla_device_type(device: torch.device) -> str: Return: Returns a str of the device hardware type. i.e TPU """ - if xm is not None: + if XLA_AVAILABLE: return xm.xla_device_hw(device) @staticmethod @@ -76,7 +76,7 @@ def _is_device_tpu() -> bool: Return: A boolean value indicating if the xla device is a TPU device or not """ - if xm is not None: + if XLA_AVAILABLE: device = xm.xla_device() device_type = XLADeviceUtils._fetch_xla_device_type(device) return device_type == "TPU" From d3fb0ec8dea4fda2101f1d7c42b2463c080af782 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 27 Nov 2020 00:04:17 +0100 Subject: [PATCH 5/5] flake8 --- pytorch_lightning/utilities/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 51e209958d815..3b4dcfc7061ff 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -44,7 +44,7 @@ def _module_available(module_path: str) -> bool: APEX_AVAILABLE = _module_available("apex.amp") NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") -XLA_AVAILABLE = _module_available("torch_xla") + TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps