Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,8 @@ def package_list_from_file(file):
from pytorch_lightning.utilities import (
NATIVE_AMP_AVAILABLE,
APEX_AVAILABLE,
XLA_AVAILABLE,
)
XLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None
TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None


Expand Down
9 changes: 1 addition & 8 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import torch

from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities import device_parser, XLA_AVAILABLE
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -24,13 +24,6 @@
from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment
from pytorch_lightning.accelerators.accelerator import Accelerator

try:
import torch_xla
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True

try:
import horovod.torch as hvd
except (ModuleNotFoundError, ImportError):
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import math
import os
import re
from typing import Optional, Union, Any
Expand All @@ -24,12 +23,9 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn, TPU_AVAILABLE
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils

TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

if TPU_AVAILABLE:
import torch_xla
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils

TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

from pytorch_lightning.utilities import rank_zero_warn, TPU_AVAILABLE

torch_inf = torch.tensor(np.Inf)

Expand Down
13 changes: 2 additions & 11 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,14 @@
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import rank_zero_warn, AMPType
from pytorch_lightning.utilities import rank_zero_warn, AMPType, TPU_AVAILABLE
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import (
AttributeDict,
collect_init_args,
get_init_args,
)
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args
from torch import ScriptModule, Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer


TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

if TPU_AVAILABLE:
import torch_xla.core.xla_model as xm

Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,15 @@

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn, TPU_AVAILABLE
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
from copy import deepcopy
from typing import Iterable

TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

if TPU_AVAILABLE:
import torch_xla
import torch_xla.core.xla_model as xm

try:
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable
from pytorch_lightning.utilities.xla_device_utils import XLA_AVAILABLE, XLADeviceUtils


def _module_available(module_path: str) -> bool:
Expand All @@ -44,6 +45,8 @@ def _module_available(module_path: str) -> bool:
APEX_AVAILABLE = _module_available("apex.amp")
NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")

TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps
Expand Down
13 changes: 6 additions & 7 deletions pytorch_lightning/utilities/xla_device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@

import torch

TORCHXLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None
if TORCHXLA_AVAILABLE:
XLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None

if XLA_AVAILABLE:
import torch_xla.core.xla_model as xm
else:
xm = None


def inner_f(queue, func, *args, **kwargs): # pragma: no cover
Expand Down Expand Up @@ -66,7 +65,7 @@ def _fetch_xla_device_type(device: torch.device) -> str:
Return:
Returns a str of the device hardware type. i.e TPU
"""
if xm is not None:
if XLA_AVAILABLE:
return xm.xla_device_hw(device)

@staticmethod
Expand All @@ -77,7 +76,7 @@ def _is_device_tpu() -> bool:
Return:
A boolean value indicating if the xla device is a TPU device or not
"""
if xm is not None:
if XLA_AVAILABLE:
device = xm.xla_device()
device_type = XLADeviceUtils._fetch_xla_device_type(device)
return device_type == "TPU"
Expand All @@ -90,6 +89,6 @@ def tpu_device_exists() -> bool:
Return:
A boolean value indicating if a TPU device exists on the system
"""
if XLADeviceUtils.TPU_AVAILABLE is None and TORCHXLA_AVAILABLE:
if XLADeviceUtils.TPU_AVAILABLE is None and XLA_AVAILABLE:
XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)()
return XLADeviceUtils.TPU_AVAILABLE
5 changes: 1 addition & 4 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,17 @@

import tests.base.develop_pipelines as tpipes
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.accelerators.accelerator import BackendType
from pytorch_lightning.accelerators import TPUAccelerator
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities import TPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
from tests.base import EvalModelTemplate
from tests.base.datasets import TrialMNIST
from tests.base.develop_utils import pl_multi_process_test

TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

if TPU_AVAILABLE:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
SERIAL_EXEC = xmp.MpSerialExecutor()

Expand Down
7 changes: 2 additions & 5 deletions tests/utilities/test_xla_device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,12 @@
import pytest

import pytorch_lightning.utilities.xla_device_utils as xla_utils
from pytorch_lightning.utilities import XLA_AVAILABLE
from tests.base.develop_utils import pl_multi_process_test

try:
if XLA_AVAILABLE:
import torch_xla.core.xla_model as xm

XLA_AVAILABLE = True
except ImportError as e:
XLA_AVAILABLE = False


@pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent")
def test_tpu_device_absence():
Expand Down