diff --git a/docs/source/conf.py b/docs/source/conf.py index b05234fd16628..4a22f012c99cf 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -360,8 +360,8 @@ def package_list_from_file(file): from pytorch_lightning.utilities import ( NATIVE_AMP_AVAILABLE, APEX_AVAILABLE, + XLA_AVAILABLE, ) -XLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index f8d90945e9e77..c40600785a558 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -15,7 +15,7 @@ import os import torch -from pytorch_lightning.utilities import device_parser +from pytorch_lightning.utilities import device_parser, XLA_AVAILABLE from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -24,13 +24,6 @@ from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment from pytorch_lightning.accelerators.accelerator import Accelerator -try: - import torch_xla -except ImportError: - XLA_AVAILABLE = False -else: - XLA_AVAILABLE = True - try: import horovod.torch as hvd except (ModuleNotFoundError, ImportError): diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index cc7da4dc10781..6e8314106112e 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import io -import math import os import re from typing import Optional, Union, Any @@ -24,12 +23,9 @@ from pytorch_lightning import _logger as log from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core import LightningModule -from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn, TPU_AVAILABLE from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils - -TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() if TPU_AVAILABLE: import torch_xla diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 972d16fd705a8..005a3f8cde4ad 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -26,11 +26,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils - -TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() - +from pytorch_lightning.utilities import rank_zero_warn, TPU_AVAILABLE torch_inf = torch.tensor(np.Inf) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index bf9c30190cc8c..e888c6f4a87df 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -32,23 +32,14 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO from pytorch_lightning.core.step_result import Result -from pytorch_lightning.utilities import rank_zero_warn, AMPType +from pytorch_lightning.utilities import rank_zero_warn, AMPType, TPU_AVAILABLE from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin -from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.parsing import ( - AttributeDict, - collect_init_args, - get_init_args, -) -from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args from torch import ScriptModule, Tensor from torch.nn import Module from torch.optim.optimizer import Optimizer - -TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() - if TPU_AVAILABLE: import torch_xla.core.xla_model as xm diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 270ee0d4db30b..9fe18d14a7f12 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -22,19 +22,15 @@ from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.core import LightningModule -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn, TPU_AVAILABLE from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.model_utils import is_overridden -from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils from copy import deepcopy from typing import Iterable -TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() - if TPU_AVAILABLE: - import torch_xla import torch_xla.core.xla_model as xm try: diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index da6e14a10cc03..3b4dcfc7061ff 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -21,6 +21,7 @@ from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable +from pytorch_lightning.utilities.xla_device_utils import XLA_AVAILABLE, XLADeviceUtils def _module_available(module_path: str) -> bool: @@ -44,6 +45,8 @@ def _module_available(module_path: str) -> bool: APEX_AVAILABLE = _module_available("apex.amp") NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") +TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() + FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps diff --git a/pytorch_lightning/utilities/xla_device_utils.py b/pytorch_lightning/utilities/xla_device_utils.py index 14a59fd105c5a..c6dd63237e121 100644 --- a/pytorch_lightning/utilities/xla_device_utils.py +++ b/pytorch_lightning/utilities/xla_device_utils.py @@ -18,11 +18,10 @@ import torch -TORCHXLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None -if TORCHXLA_AVAILABLE: +XLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None + +if XLA_AVAILABLE: import torch_xla.core.xla_model as xm -else: - xm = None def inner_f(queue, func, *args, **kwargs): # pragma: no cover @@ -66,7 +65,7 @@ def _fetch_xla_device_type(device: torch.device) -> str: Return: Returns a str of the device hardware type. i.e TPU """ - if xm is not None: + if XLA_AVAILABLE: return xm.xla_device_hw(device) @staticmethod @@ -77,7 +76,7 @@ def _is_device_tpu() -> bool: Return: A boolean value indicating if the xla device is a TPU device or not """ - if xm is not None: + if XLA_AVAILABLE: device = xm.xla_device() device_type = XLADeviceUtils._fetch_xla_device_type(device) return device_type == "TPU" @@ -90,6 +89,6 @@ def tpu_device_exists() -> bool: Return: A boolean value indicating if a TPU device exists on the system """ - if XLADeviceUtils.TPU_AVAILABLE is None and TORCHXLA_AVAILABLE: + if XLADeviceUtils.TPU_AVAILABLE is None and XLA_AVAILABLE: XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)() return XLADeviceUtils.TPU_AVAILABLE diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 0077df283ee61..b69f1b60fcbf7 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -19,20 +19,17 @@ import tests.base.develop_pipelines as tpipes from pytorch_lightning import Trainer, seed_everything -from pytorch_lightning.accelerators.accelerator import BackendType from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.utilities import TPU_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils from tests.base import EvalModelTemplate from tests.base.datasets import TrialMNIST from tests.base.develop_utils import pl_multi_process_test -TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() if TPU_AVAILABLE: import torch_xla - import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp SERIAL_EXEC = xmp.MpSerialExecutor() diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index 10de63db049e7..1b3911d4152c0 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -16,15 +16,12 @@ import pytest import pytorch_lightning.utilities.xla_device_utils as xla_utils +from pytorch_lightning.utilities import XLA_AVAILABLE from tests.base.develop_utils import pl_multi_process_test -try: +if XLA_AVAILABLE: import torch_xla.core.xla_model as xm - XLA_AVAILABLE = True -except ImportError as e: - XLA_AVAILABLE = False - @pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent") def test_tpu_device_absence():