Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix import and typo in AMP #4871

Merged
merged 6 commits into from Nov 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/amp.rst
Expand Up @@ -31,7 +31,7 @@ Native torch
When using PyTorch 1.6+ Lightning uses the native amp implementation to support 16-bit.

.. testcode::
:skipif: not APEX_AVAILABLE and not NATIVE_AMP_AVALAIBLE
:skipif: not APEX_AVAILABLE and not NATIVE_AMP_AVAILABLE

# turn on 16-bit
trainer = Trainer(precision=16)
Expand Down Expand Up @@ -73,7 +73,7 @@ Enable 16-bit
^^^^^^^^^^^^^

.. testcode::
:skipif: not APEX_AVAILABLE and not NATIVE_AMP_AVALAIBLE
:skipif: not APEX_AVAILABLE and not NATIVE_AMP_AVAILABLE

# turn on 16-bit
trainer = Trainer(amp_level='O2', precision=16)
Expand Down
6 changes: 4 additions & 2 deletions docs/source/conf.py
Expand Up @@ -357,8 +357,10 @@ def package_list_from_file(file):
import os
import torch

from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
APEX_AVAILABLE = importlib.util.find_spec("apex") is not None
from pytorch_lightning.utilities import (
NATIVE_AMP_AVAILABLE,
APEX_AVAILABLE,
)
XLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None
TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None

Expand Down
2 changes: 1 addition & 1 deletion docs/source/trainer.rst
Expand Up @@ -1177,7 +1177,7 @@ If used on TPU will use torch.bfloat16 but tensor printing
will still show torch.float32.

.. testcode::
:skipif: not APEX_AVAILABLE and not NATIVE_AMP_AVALAIBLE
:skipif: not APEX_AVAILABLE and not NATIVE_AMP_AVAILABLE

# default used by the Trainer
trainer = Trainer(precision=32)
Expand Down
9 changes: 1 addition & 8 deletions pytorch_lightning/core/hooks.py
Expand Up @@ -17,18 +17,11 @@
from typing import Any, Dict, List, Union

import torch
from pytorch_lightning.utilities import AMPType, move_data_to_device, rank_zero_warn
from torch import Tensor
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader


try:
from apex import amp
except ImportError:
amp = None


class ModelHooks:
"""Hooks to be used in LightningModule."""
def setup(self, stage: str):
Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/plugins/apex.py
Expand Up @@ -20,12 +20,10 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities import AMPType, APEX_AVAILABLE

try:
if APEX_AVAILABLE:
from apex import amp
except ImportError:
amp = None
Borda marked this conversation as resolved.
Show resolved Hide resolved


class ApexPlugin(PrecisionPlugin):
Expand Down
Expand Up @@ -25,16 +25,14 @@
import pytorch_lightning
from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities import AMPType, rank_zero_warn, APEX_AVAILABLE
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
from pytorch_lightning.utilities.exceptions import MisconfigurationException

try:
if APEX_AVAILABLE:
from apex import amp
except ImportError:
amp = None
Borda marked this conversation as resolved.
Show resolved Hide resolved

try:
from omegaconf import Container
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/connectors/precision_connector.py
Expand Up @@ -15,7 +15,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.plugins.apex import ApexPlugin
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, AMPType, rank_zero_warn
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVAILABLE, AMPType, rank_zero_warn


class PrecisionConnector:
Expand Down Expand Up @@ -48,7 +48,7 @@ def _setup_amp_backend(self, amp_type: str):
amp_type = amp_type.lower()
assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}'
if amp_type == 'native':
if not NATIVE_AMP_AVALAIBLE:
if not NATIVE_AMP_AVAILABLE:
rank_zero_warn('You have asked for native AMP but your PyTorch version does not support it.'
' Consider upgrading with `pip install torch>=1.6`.'
' We will attempt to use NVIDIA Apex for this session.')
Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/trainer/data_loading.py
Expand Up @@ -32,10 +32,6 @@
from typing import Iterable

TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
try:
from apex import amp
except ImportError:
amp = None

if TPU_AVAILABLE:
import torch_xla
Expand Down
5 changes: 0 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Expand Up @@ -69,11 +69,6 @@
'ignore', message='torch.distributed.reduce_op is deprecated, ' 'please use torch.distributed.ReduceOp instead'
)

try:
from apex import amp
except ImportError:
amp = None


class Trainer(
TrainerProperties,
Expand Down
5 changes: 0 additions & 5 deletions pytorch_lightning/trainer/training_tricks.py
Expand Up @@ -20,11 +20,6 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule

try:
from apex import amp
except ImportError:
amp = None

EPSILON = 1e-6
EPSILON_FP16 = 1e-5

Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/utilities/__init__.py
Expand Up @@ -26,23 +26,23 @@
def _module_available(module_path: str) -> bool:
"""Testing if given module is avalaible in your env

>>> _module_available('system')
>>> _module_available('os')
True
>>> _module_available('bla.bla')
False
"""
mods = module_path.split('.')
assert mods, 'nothing given to test'
# it has to be tested as per partets
for i in range(1, len(mods)):
module_path = '.'.join(mods[:i])
for i in range(len(mods)):
module_path = '.'.join(mods[:i + 1])
if importlib.util.find_spec(module_path) is None:
return False
return True


APEX_AVAILABLE = _module_available("apex.amp")
NATIVE_AMP_AVALAIBLE = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")

FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
Expand Down
2 changes: 1 addition & 1 deletion tests/backends/test_ddp.py
Expand Up @@ -19,7 +19,7 @@

from tests.backends import ddp_model
from tests.backends.launcher import DDPLauncher
from tests.utilities.dist import call_training_script
from tests.utilities.distributed import call_training_script


@pytest.mark.parametrize('cli_args', [
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_horovod.py
Expand Up @@ -30,7 +30,7 @@
from pytorch_lightning.accelerators.horovod_accelerator import HorovodAccelerator
from pytorch_lightning.core.step_result import Result, TrainResult, EvalResult
from pytorch_lightning.metrics.classification.accuracy import Accuracy
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVAILABLE
from tests.base import EvalModelTemplate
from tests.base.models import BasicGAN

Expand Down Expand Up @@ -157,7 +157,7 @@ def test_horovod_apex(tmpdir):
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="test requires torch.cuda.amp")
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="test requires torch.cuda.amp")
def test_horovod_amp(tmpdir):
"""Test Horovod with multi-GPU support using native amp."""
trainer_options = dict(
Expand Down
19 changes: 5 additions & 14 deletions tests/plugins/test_amp_plugin.py
@@ -1,18 +1,15 @@
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE
from tests.base.boring_model import BoringModel
from pytorch_lightning import Trainer
import pytest
import os
from unittest import mock
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin
from distutils.version import LooseVersion
import torch


@pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("1.6.0"),
reason="Minimal PT version is set to 1.6",
)
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
@mock.patch.dict(os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
Expand Down Expand Up @@ -46,10 +43,7 @@ def on_fit_start(self, trainer, pl_module):
trainer.fit(model)


@pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("1.6.0"),
reason="Minimal PT version is set to 1.6",
)
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
@mock.patch.dict(os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
Expand Down Expand Up @@ -93,9 +87,7 @@ def on_after_backward(self):
assert norm.item() < 15.


@pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("1.6.0"),
reason="Minimal PT version is set to 1.6")
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_amp_gradient_unscale(tmpdir):
model = GradientUnscaleBoringModel()
Expand Down Expand Up @@ -124,8 +116,7 @@ def on_after_backward(self):
assert norm.item() < 15.


@pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("1.6.0"), reason="Minimal PT version is set to 1.6")
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_amp_gradient_unscale_accumulate_grad_batches(tmpdir):
model = UnscaleAccumulateGradBatchesBoringModel()
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/test_trainer.py
Expand Up @@ -36,7 +36,7 @@
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE
from tests.base import EvalModelTemplate, BoringModel


Expand Down Expand Up @@ -988,7 +988,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="test requires native AMP.")
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="test requires native AMP.")
def test_gradient_clipping_fp16(tmpdir):
"""
Test gradient clipping with fp16
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/test_trainer_tricks.py
Expand Up @@ -19,7 +19,7 @@

import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import AMPType, NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities import AMPType, NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.base.datamodules import MNISTDataModule
Expand Down Expand Up @@ -328,7 +328,7 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="test requires native AMP.")
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="test requires native AMP.")
def test_auto_scale_batch_size_with_amp(tmpdir):
model = EvalModelTemplate()
batch_size_before = model.batch_size
Expand Down
File renamed without changes.
24 changes: 24 additions & 0 deletions tests/utilities/test_imports.py
@@ -0,0 +1,24 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.utilities import _module_available


def test_module_exists():
"""Test if the some 3rd party libs are available"""
assert _module_available("torch")
assert _module_available("torch.nn.parallel")
assert not _module_available("torch.nn.asdf")
assert not _module_available("asdf")
assert not _module_available("asdf.bla.asdf")
File renamed without changes.