diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 3f8065f7a3..c7a8b49e30 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -26,7 +26,7 @@ from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform -from monai.utils import PT_BEFORE_1_7, min_version, optional_import +from monai.utils import min_version, optional_import, pytorch_after from monai.utils.enums import CommonKeys as Keys if TYPE_CHECKING: @@ -190,7 +190,7 @@ def _compute_pred_loss(): self.network.train() # `set_to_none` only work from PyTorch 1.7.0 - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): self.optimizer.zero_grad() else: self.optimizer.zero_grad(set_to_none=self.optim_set_to_none) @@ -359,7 +359,7 @@ def _iteration( d_total_loss = torch.zeros(1) for _ in range(self.d_train_steps): # `set_to_none` only work from PyTorch 1.7.0 - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): self.d_optimizer.zero_grad() else: self.d_optimizer.zero_grad(set_to_none=self.optim_set_to_none) @@ -377,7 +377,7 @@ def _iteration( non_blocking=engine.non_blocking, # type: ignore ) g_output = self.g_inferer(g_input, self.g_network) - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): self.g_optimizer.zero_grad() else: self.g_optimizer.zero_grad(set_to_none=self.optim_set_to_none) diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index 24fe91687b..4f365d169e 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -21,18 +21,17 @@ from monai.networks.layers.convutils import gaussian_1d from monai.networks.layers.factories import Conv from monai.utils import ( - PT_BEFORE_1_7, ChannelMatching, InvalidPyTorchVersionError, SkipMode, look_up_option, optional_import, - version_leq, + pytorch_after, ) from monai.utils.misc import issequenceiterable _C, _ = optional_import("monai._C") -if not PT_BEFORE_1_7: +if pytorch_after(1, 7): fft, _ = optional_import("torch.fft") __all__ = [ @@ -295,11 +294,12 @@ def apply_filter(x: torch.Tensor, kernel: torch.Tensor, **kwargs) -> torch.Tenso x = x.view(1, kernel.shape[0], *spatials) conv = [F.conv1d, F.conv2d, F.conv3d][n_spatial - 1] if "padding" not in kwargs: - if version_leq(torch.__version__, "1.10.0b"): + if pytorch_after(1, 10): + kwargs["padding"] = "same" + else: # even-sized kernels are not supported kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]] - else: - kwargs["padding"] = "same" + if "stride" not in kwargs: kwargs["stride"] = 1 output = conv(x, kernel, groups=kernel.shape[0], bias=None, **kwargs) @@ -387,7 +387,7 @@ class HilbertTransform(nn.Module): def __init__(self, axis: int = 2, n: Union[int, None] = None) -> None: - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__) super().__init__() diff --git a/monai/networks/utils.py b/monai/networks/utils.py index c60867d8e6..ef0cff0eed 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -22,7 +22,7 @@ from monai.utils.deprecate_utils import deprecated_arg from monai.utils.misc import ensure_tuple, set_determinism -from monai.utils.module import PT_BEFORE_1_7 +from monai.utils.module import pytorch_after __all__ = [ "one_hot", @@ -464,7 +464,7 @@ def convert_to_torchscript( with torch.no_grad(): script_module = torch.jit.script(model) if filename_or_obj is not None: - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): torch.jit.save(m=script_module, f=filename_or_obj) else: torch.jit.save(m=script_module, f=filename_or_obj, _extra_files=extra_files) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 48ebd1788d..cffbc27d19 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -30,7 +30,6 @@ from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where from monai.utils import ( - PT_BEFORE_1_7, InvalidPyTorchVersionError, convert_data_type, convert_to_dst_type, @@ -38,6 +37,7 @@ ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, + pytorch_after, ) from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import TransformBackends @@ -1072,7 +1072,7 @@ class DetectEnvelope(Transform): def __init__(self, axis: int = 1, n: Union[int, None] = None) -> None: - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__) if axis < 0: diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 3fa3b2d5f6..d58fcca32d 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -58,7 +58,6 @@ zip_with, ) from .module import ( - PT_BEFORE_1_7, InvalidPyTorchVersionError, OptionalImportError, damerau_levenshtein_distance, @@ -71,6 +70,7 @@ look_up_option, min_version, optional_import, + pytorch_after, require_pkg, version_leq, ) diff --git a/monai/utils/module.py b/monai/utils/module.py index 12ffb27b82..e475ae1957 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -8,7 +8,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import enum +import os +import re import sys import warnings from functools import wraps @@ -36,8 +39,8 @@ "get_full_type_name", "get_package_version", "get_torch_version_tuple", - "PT_BEFORE_1_7", "version_leq", + "pytorch_after", ] @@ -450,7 +453,51 @@ def _try_cast(val: str): return True -try: - PT_BEFORE_1_7 = torch.__version__ != "1.7.0" and version_leq(torch.__version__, "1.7.0") -except (AttributeError, TypeError): - PT_BEFORE_1_7 = True +def pytorch_after(major, minor, patch=0, current_ver_string=None) -> bool: + """ + Compute whether the current pytorch version is after or equal to the specified version. + The current system pytorch version is determined by `torch.__version__` or + via system environment variable `PYTORCH_VER`. + + Args: + major: major version number to be compared with + minor: minor version number to be compared with + patch: patch version number to be compared with + current_ver_string: if None, `torch.__version__` will be used. + + Returns: + True if the current pytorch version is greater than or equal to the specified version. + """ + + try: + if current_ver_string is None: + _env_var = os.environ.get("PYTORCH_VER", "") + current_ver_string = _env_var if _env_var else torch.__version__ + ver, has_ver = optional_import("pkg_resources", name="parse_version") + if has_ver: + return ver(".".join((f"{major}", f"{minor}", f"{patch}"))) <= ver(f"{current_ver_string}") # type: ignore + parts = f"{current_ver_string}".split("+", 1)[0].split(".", 3) + while len(parts) < 3: + parts += ["0"] + c_major, c_minor, c_patch = parts[:3] + except (AttributeError, ValueError, TypeError): + c_major, c_minor = get_torch_version_tuple() + c_patch = "0" + c_mn = int(c_major), int(c_minor) + mn = int(major), int(minor) + if c_mn != mn: + return c_mn > mn + is_prerelease = ("a" in f"{c_patch}".lower()) or ("rc" in f"{c_patch}".lower()) + c_p = 0 + try: + p_reg = re.search(r"\d+", f"{c_patch}") + if p_reg: + c_p = int(p_reg.group()) + except (AttributeError, TypeError, ValueError): + is_prerelease = True + patch = int(patch) + if c_p != patch: + return c_p > patch # type: ignore + if is_prerelease: + return False + return True diff --git a/tests/test_map_label_value.py b/tests/test_map_label_value.py index 2de549ad23..009dbd4281 100644 --- a/tests/test_map_label_value.py +++ b/tests/test_map_label_value.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import MapLabelValue -from monai.utils import PT_BEFORE_1_7 +from monai.utils import pytorch_after from tests.utils import TEST_NDARRAYS TESTS = [] @@ -34,7 +34,7 @@ ] ) # PyTorch 1.5.1 doesn't support rich dtypes - if not PT_BEFORE_1_7: + if pytorch_after(1, 7): TESTS.append( [ {"orig_labels": [1.5, 2.5, 3.5], "target_labels": [0, 1, 2], "dtype": np.int8}, diff --git a/tests/test_pytorch_version_after.py b/tests/test_pytorch_version_after.py new file mode 100644 index 0000000000..6b5ca6fdb0 --- /dev/null +++ b/tests/test_pytorch_version_after.py @@ -0,0 +1,47 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from parameterized import parameterized + +from monai.utils import pytorch_after + +TEST_CASES = ( + (1, 5, 9, "1.6.0"), + (1, 6, 0, "1.6.0"), + (1, 6, 1, "1.6.0", False), + (1, 7, 0, "1.6.0", False), + (2, 6, 0, "1.6.0", False), + (0, 6, 0, "1.6.0a0+3fd9dcf"), + (1, 5, 9, "1.6.0a0+3fd9dcf"), + (1, 6, 0, "1.6.0a0+3fd9dcf", False), + (1, 6, 1, "1.6.0a0+3fd9dcf", False), + (2, 6, 0, "1.6.0a0+3fd9dcf", False), + (1, 6, 0, "1.6.0-rc0+3fd9dcf", False), # defaults to prerelease + (1, 6, 0, "1.6.0rc0", False), + (1, 6, 0, "1.6", True), + (1, 6, 0, "1", False), + (1, 6, 0, "1.6.0+cpu", True), + (1, 6, 1, "1.6.0+cpu", False), +) + + +class TestPytorchVersionCompare(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_compare(self, a, b, p, current, expected=True): + """Test pytorch_after with a and b""" + self.assertEqual(pytorch_after(a, b, p, current), expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils.py b/tests/utils.py index 97e57c6be6..7ed4ba1745 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -36,8 +36,7 @@ from monai.data import create_test_image_2d, create_test_image_3d from monai.networks import convert_to_torchscript from monai.utils import optional_import -from monai.utils.misc import is_module_ver_at_least -from monai.utils.module import version_leq +from monai.utils.module import pytorch_after, version_leq from monai.utils.type_conversion import convert_data_type nib, _ = optional_import("nibabel") @@ -193,7 +192,7 @@ class SkipIfBeforePyTorchVersion: def __init__(self, pytorch_version_tuple): self.min_version = pytorch_version_tuple - self.version_too_old = not is_module_ver_at_least(torch, pytorch_version_tuple) + self.version_too_old = not pytorch_after(*pytorch_version_tuple) def __call__(self, obj): return unittest.skipIf( @@ -207,8 +206,7 @@ class SkipIfAtLeastPyTorchVersion: def __init__(self, pytorch_version_tuple): self.max_version = pytorch_version_tuple - test_ver = ".".join(map(str, self.max_version)) - self.version_too_new = version_leq(test_ver, torch.__version__) + self.version_too_new = pytorch_after(*pytorch_version_tuple) def __call__(self, obj): return unittest.skipIf(