From 17c19bd0e5ffba49eab942214065afc9a08aad76 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 7 Nov 2021 13:28:34 +0000 Subject: [PATCH 1/8] torch version check Signed-off-by: Wenqi Li --- monai/engines/trainer.py | 8 ++--- monai/networks/layers/simplelayers.py | 9 +++--- monai/networks/utils.py | 4 +-- monai/transforms/intensity/array.py | 4 +-- monai/utils/__init__.py | 2 +- monai/utils/module.py | 41 +++++++++++++++++++++--- tests/test_map_label_value.py | 4 +-- tests/test_pytorch_version_after.py | 45 +++++++++++++++++++++++++++ tests/utils.py | 7 ++--- 9 files changed, 99 insertions(+), 25 deletions(-) create mode 100644 tests/test_pytorch_version_after.py diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 3f8065f7a3..c7a8b49e30 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -26,7 +26,7 @@ from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform -from monai.utils import PT_BEFORE_1_7, min_version, optional_import +from monai.utils import min_version, optional_import, pytorch_after from monai.utils.enums import CommonKeys as Keys if TYPE_CHECKING: @@ -190,7 +190,7 @@ def _compute_pred_loss(): self.network.train() # `set_to_none` only work from PyTorch 1.7.0 - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): self.optimizer.zero_grad() else: self.optimizer.zero_grad(set_to_none=self.optim_set_to_none) @@ -359,7 +359,7 @@ def _iteration( d_total_loss = torch.zeros(1) for _ in range(self.d_train_steps): # `set_to_none` only work from PyTorch 1.7.0 - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): self.d_optimizer.zero_grad() else: self.d_optimizer.zero_grad(set_to_none=self.optim_set_to_none) @@ -377,7 +377,7 @@ def _iteration( non_blocking=engine.non_blocking, # type: ignore ) g_output = self.g_inferer(g_input, self.g_network) - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): self.g_optimizer.zero_grad() else: self.g_optimizer.zero_grad(set_to_none=self.optim_set_to_none) diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index 24fe91687b..9d8d82a1c8 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -21,18 +21,17 @@ from monai.networks.layers.convutils import gaussian_1d from monai.networks.layers.factories import Conv from monai.utils import ( - PT_BEFORE_1_7, ChannelMatching, InvalidPyTorchVersionError, SkipMode, look_up_option, optional_import, - version_leq, + pytorch_after, ) from monai.utils.misc import issequenceiterable _C, _ = optional_import("monai._C") -if not PT_BEFORE_1_7: +if pytorch_after(1, 7): fft, _ = optional_import("torch.fft") __all__ = [ @@ -295,7 +294,7 @@ def apply_filter(x: torch.Tensor, kernel: torch.Tensor, **kwargs) -> torch.Tenso x = x.view(1, kernel.shape[0], *spatials) conv = [F.conv1d, F.conv2d, F.conv3d][n_spatial - 1] if "padding" not in kwargs: - if version_leq(torch.__version__, "1.10.0b"): + if pytorch_after(1, 10): # even-sized kernels are not supported kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]] else: @@ -387,7 +386,7 @@ class HilbertTransform(nn.Module): def __init__(self, axis: int = 2, n: Union[int, None] = None) -> None: - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__) super().__init__() diff --git a/monai/networks/utils.py b/monai/networks/utils.py index c60867d8e6..ef0cff0eed 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -22,7 +22,7 @@ from monai.utils.deprecate_utils import deprecated_arg from monai.utils.misc import ensure_tuple, set_determinism -from monai.utils.module import PT_BEFORE_1_7 +from monai.utils.module import pytorch_after __all__ = [ "one_hot", @@ -464,7 +464,7 @@ def convert_to_torchscript( with torch.no_grad(): script_module = torch.jit.script(model) if filename_or_obj is not None: - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): torch.jit.save(m=script_module, f=filename_or_obj) else: torch.jit.save(m=script_module, f=filename_or_obj, _extra_files=extra_files) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index f24b184ec8..a73cbada73 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -30,7 +30,6 @@ from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where from monai.utils import ( - PT_BEFORE_1_7, InvalidPyTorchVersionError, convert_data_type, convert_to_dst_type, @@ -38,6 +37,7 @@ ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, + pytorch_after, ) from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import TransformBackends @@ -1056,7 +1056,7 @@ class DetectEnvelope(Transform): def __init__(self, axis: int = 1, n: Union[int, None] = None) -> None: - if PT_BEFORE_1_7: + if not pytorch_after(1, 7): raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__) if axis < 0: diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 42eba2e67f..46af5e550d 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -57,7 +57,6 @@ zip_with, ) from .module import ( - PT_BEFORE_1_7, InvalidPyTorchVersionError, OptionalImportError, damerau_levenshtein_distance, @@ -70,6 +69,7 @@ look_up_option, min_version, optional_import, + pytorch_after, require_pkg, version_leq, ) diff --git a/monai/utils/module.py b/monai/utils/module.py index 12ffb27b82..c066d628cf 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -36,8 +36,8 @@ "get_full_type_name", "get_package_version", "get_torch_version_tuple", - "PT_BEFORE_1_7", "version_leq", + "pytorch_after", ] @@ -450,7 +450,38 @@ def _try_cast(val: str): return True -try: - PT_BEFORE_1_7 = torch.__version__ != "1.7.0" and version_leq(torch.__version__, "1.7.0") -except (AttributeError, TypeError): - PT_BEFORE_1_7 = True +def pytorch_after(major, minor, patch=0, current_ver_string=None) -> bool: + """ + Compute whether the current pytorch version is after or equal to the specified version. + + Args: + major: major version number to be compared with + minor: minor version number to be compared with + patch: patch version number to be compared with + current_ver_string: if None, `torch.__version__` will be used. + + Returns: + True if the current pytorch version is greater than or equal to the specified version. + """ + try: + if current_ver_string is None: + current_ver_string = torch.__version__ + c_major, c_minor, c_patch = current_ver_string.split("+", 1)[0].split(".", 3) + except (AttributeError, ValueError, TypeError): + c_major, c_minor = get_torch_version_tuple() + c_patch = 0 + c_mn = int(c_major), int(c_minor) + mn = int(major), int(minor) + if c_mn != mn: + return c_mn > mn + is_prerelease = "a" in c_patch + try: + c_patch = int(c_patch) if not is_prerelease else int(c_patch.split("a", 1)[0]) + except (AttributeError, TypeError, ValueError): + c_patch = 0 + is_prerelease = True + if c_patch != patch: + return c_patch > patch + if is_prerelease: + return False + return True diff --git a/tests/test_map_label_value.py b/tests/test_map_label_value.py index 2de549ad23..009dbd4281 100644 --- a/tests/test_map_label_value.py +++ b/tests/test_map_label_value.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import MapLabelValue -from monai.utils import PT_BEFORE_1_7 +from monai.utils import pytorch_after from tests.utils import TEST_NDARRAYS TESTS = [] @@ -34,7 +34,7 @@ ] ) # PyTorch 1.5.1 doesn't support rich dtypes - if not PT_BEFORE_1_7: + if pytorch_after(1, 7): TESTS.append( [ {"orig_labels": [1.5, 2.5, 3.5], "target_labels": [0, 1, 2], "dtype": np.int8}, diff --git a/tests/test_pytorch_version_after.py b/tests/test_pytorch_version_after.py new file mode 100644 index 0000000000..a596f2c97e --- /dev/null +++ b/tests/test_pytorch_version_after.py @@ -0,0 +1,45 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from parameterized import parameterized + +from monai.utils import pytorch_after + +TEST_CASES = ( + (1, 5, 9, "1.6.0"), + (1, 6, 0, "1.6.0"), + (1, 6, 1, "1.6.0", False), + (1, 7, 0, "1.6.0", False), + (2, 6, 0, "1.6.0", False), + (0, 6, 0, "1.6.0a0+3fd9dcf"), + (1, 5, 9, "1.6.0a0+3fd9dcf"), + (1, 6, 0, "1.6.0a0+3fd9dcf", False), + (1, 6, 1, "1.6.0a0+3fd9dcf", False), + (2, 6, 0, "1.6.0a0+3fd9dcf", False), + (1, 6, 0, "1.6.0-rc0+3fd9dcf", False), # defaults to prerelease + (1, 6, 0, "1.6.0rc0", False), + (1, 6, 0, "1.6", True), + (1, 6, 0, "1", True), +) + + +class TestPytorchVersionCompare(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_compare(self, a, b, p, current, expected=True): + """Test version_leq with `a` and `b`""" + self.assertEqual(pytorch_after(a, b, p, current), expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils.py b/tests/utils.py index 97e57c6be6..3e0da22559 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -37,7 +37,7 @@ from monai.networks import convert_to_torchscript from monai.utils import optional_import from monai.utils.misc import is_module_ver_at_least -from monai.utils.module import version_leq +from monai.utils.module import pytorch_after, version_leq from monai.utils.type_conversion import convert_data_type nib, _ = optional_import("nibabel") @@ -193,7 +193,7 @@ class SkipIfBeforePyTorchVersion: def __init__(self, pytorch_version_tuple): self.min_version = pytorch_version_tuple - self.version_too_old = not is_module_ver_at_least(torch, pytorch_version_tuple) + self.version_too_old = not pytorch_after(*pytorch_version_tuple) def __call__(self, obj): return unittest.skipIf( @@ -207,8 +207,7 @@ class SkipIfAtLeastPyTorchVersion: def __init__(self, pytorch_version_tuple): self.max_version = pytorch_version_tuple - test_ver = ".".join(map(str, self.max_version)) - self.version_too_new = version_leq(test_ver, torch.__version__) + self.version_too_new = pytorch_after(*pytorch_version_tuple) def __call__(self, obj): return unittest.skipIf( From 3171f6b7263dde5d8dfe21a34eaf4426eb7b5f58 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 7 Nov 2021 13:29:40 +0000 Subject: [PATCH 2/8] temp tests Signed-off-by: Wenqi Li --- .github/workflows/pythonapp-gpu.yml | 1 + .github/workflows/pythonapp-min.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/pythonapp-gpu.yml b/.github/workflows/pythonapp-gpu.yml index 2ca0492e5c..1a6840f7d8 100644 --- a/.github/workflows/pythonapp-gpu.yml +++ b/.github/workflows/pythonapp-gpu.yml @@ -6,6 +6,7 @@ on: branches: - main - releasing/* + - torch-version-check pull_request: concurrency: diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml index 002701c5ad..0cfd97e82c 100644 --- a/.github/workflows/pythonapp-min.yml +++ b/.github/workflows/pythonapp-min.yml @@ -7,6 +7,7 @@ on: - dev - main - releasing/* + - torch-version-check pull_request: concurrency: From 001badefaa6b58a1d7df8234532e125e144deaa9 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 7 Nov 2021 13:31:06 +0000 Subject: [PATCH 3/8] additional cases Signed-off-by: Wenqi Li --- tests/test_pytorch_version_after.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_pytorch_version_after.py b/tests/test_pytorch_version_after.py index a596f2c97e..fb396ec620 100644 --- a/tests/test_pytorch_version_after.py +++ b/tests/test_pytorch_version_after.py @@ -31,6 +31,8 @@ (1, 6, 0, "1.6.0rc0", False), (1, 6, 0, "1.6", True), (1, 6, 0, "1", True), + (1, 6, 0, "1.6.0+cpu", True), + (1, 6, 1, "1.6.0+cpu", False), ) From 67842d86e8069199b96cd52ce49903168943c6e7 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 7 Nov 2021 14:14:38 +0000 Subject: [PATCH 4/8] fixes tests Signed-off-by: Wenqi Li --- monai/networks/layers/simplelayers.py | 5 +++-- monai/utils/module.py | 28 +++++++++++++++++++-------- tests/test_pytorch_version_after.py | 2 +- tests/utils.py | 1 - 4 files changed, 24 insertions(+), 12 deletions(-) diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index 9d8d82a1c8..4f365d169e 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -295,10 +295,11 @@ def apply_filter(x: torch.Tensor, kernel: torch.Tensor, **kwargs) -> torch.Tenso conv = [F.conv1d, F.conv2d, F.conv3d][n_spatial - 1] if "padding" not in kwargs: if pytorch_after(1, 10): + kwargs["padding"] = "same" + else: # even-sized kernels are not supported kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]] - else: - kwargs["padding"] = "same" + if "stride" not in kwargs: kwargs["stride"] = 1 output = conv(x, kernel, groups=kernel.shape[0], bias=None, **kwargs) diff --git a/monai/utils/module.py b/monai/utils/module.py index c066d628cf..24d1fb2d96 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -8,7 +8,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import enum +import os +import re import sys import warnings from functools import wraps @@ -453,6 +456,8 @@ def _try_cast(val: str): def pytorch_after(major, minor, patch=0, current_ver_string=None) -> bool: """ Compute whether the current pytorch version is after or equal to the specified version. + The current system pytorch version is determined by `torch.__version__` or + via system environment variable `PYTORCH_VER`. Args: major: major version number to be compared with @@ -465,23 +470,30 @@ def pytorch_after(major, minor, patch=0, current_ver_string=None) -> bool: """ try: if current_ver_string is None: - current_ver_string = torch.__version__ - c_major, c_minor, c_patch = current_ver_string.split("+", 1)[0].split(".", 3) + _env_var = os.environ.get("PYTORCH_VER", "") + current_ver_string = _env_var if _env_var else torch.__version__ + parts = f"{current_ver_string}".split("+", 1)[0].split(".", 3) + while len(parts) < 3: + parts += ["0"] + c_major, c_minor, c_patch = parts[:3] except (AttributeError, ValueError, TypeError): c_major, c_minor = get_torch_version_tuple() - c_patch = 0 + c_patch = "0" c_mn = int(c_major), int(c_minor) mn = int(major), int(minor) if c_mn != mn: return c_mn > mn - is_prerelease = "a" in c_patch + is_prerelease = ("a" in f"{c_patch}".lower()) or ("rc" in f"{c_patch}".lower()) + c_p = 0 try: - c_patch = int(c_patch) if not is_prerelease else int(c_patch.split("a", 1)[0]) + p_reg = re.search(r"\d+", f"{c_patch}") + if p_reg: + c_p = int(p_reg.group()) except (AttributeError, TypeError, ValueError): - c_patch = 0 is_prerelease = True - if c_patch != patch: - return c_patch > patch + patch = int(patch) + if c_p != patch: + return c_p > patch # type: ignore if is_prerelease: return False return True diff --git a/tests/test_pytorch_version_after.py b/tests/test_pytorch_version_after.py index fb396ec620..17dc870602 100644 --- a/tests/test_pytorch_version_after.py +++ b/tests/test_pytorch_version_after.py @@ -30,7 +30,7 @@ (1, 6, 0, "1.6.0-rc0+3fd9dcf", False), # defaults to prerelease (1, 6, 0, "1.6.0rc0", False), (1, 6, 0, "1.6", True), - (1, 6, 0, "1", True), + (1, 6, 0, "1", False), (1, 6, 0, "1.6.0+cpu", True), (1, 6, 1, "1.6.0+cpu", False), ) diff --git a/tests/utils.py b/tests/utils.py index 3e0da22559..7ed4ba1745 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -36,7 +36,6 @@ from monai.data import create_test_image_2d, create_test_image_3d from monai.networks import convert_to_torchscript from monai.utils import optional_import -from monai.utils.misc import is_module_ver_at_least from monai.utils.module import pytorch_after, version_leq from monai.utils.type_conversion import convert_data_type From 5cc18ec5b1b642d285510d582fb339be638a2b47 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 8 Nov 2021 11:25:32 +0000 Subject: [PATCH 5/8] update unit test names Signed-off-by: Wenqi Li --- tests/test_pytorch_version_after.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pytorch_version_after.py b/tests/test_pytorch_version_after.py index 17dc870602..6b5ca6fdb0 100644 --- a/tests/test_pytorch_version_after.py +++ b/tests/test_pytorch_version_after.py @@ -39,7 +39,7 @@ class TestPytorchVersionCompare(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_compare(self, a, b, p, current, expected=True): - """Test version_leq with `a` and `b`""" + """Test pytorch_after with a and b""" self.assertEqual(pytorch_after(a, b, p, current), expected) From 5fa7715e8c6c316fb5baf40ad535339d52c1c407 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 8 Nov 2021 14:34:11 +0000 Subject: [PATCH 6/8] remove temp tests Signed-off-by: Wenqi Li --- .github/workflows/pythonapp-gpu.yml | 1 - .github/workflows/pythonapp-min.yml | 1 - 2 files changed, 2 deletions(-) diff --git a/.github/workflows/pythonapp-gpu.yml b/.github/workflows/pythonapp-gpu.yml index 1a6840f7d8..2ca0492e5c 100644 --- a/.github/workflows/pythonapp-gpu.yml +++ b/.github/workflows/pythonapp-gpu.yml @@ -6,7 +6,6 @@ on: branches: - main - releasing/* - - torch-version-check pull_request: concurrency: diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml index 0cfd97e82c..002701c5ad 100644 --- a/.github/workflows/pythonapp-min.yml +++ b/.github/workflows/pythonapp-min.yml @@ -7,7 +7,6 @@ on: - dev - main - releasing/* - - torch-version-check pull_request: concurrency: From 9b191df30b3eaa4c11c9ae4bfd6d0c4a01811f63 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 10 Nov 2021 16:59:56 +0000 Subject: [PATCH 7/8] update based on comments Signed-off-by: Wenqi Li --- monai/utils/module.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/monai/utils/module.py b/monai/utils/module.py index 24d1fb2d96..89e2fb778a 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -468,10 +468,14 @@ def pytorch_after(major, minor, patch=0, current_ver_string=None) -> bool: Returns: True if the current pytorch version is greater than or equal to the specified version. """ + try: if current_ver_string is None: _env_var = os.environ.get("PYTORCH_VER", "") current_ver_string = _env_var if _env_var else torch.__version__ + ver, has_ver = optional_import("pkg_resources", name="parse_version") + if has_ver: + return ver(".".join((f"{major}", f"{minor}", f"{patch}"))) <= ver(f"{current_ver_string}") parts = f"{current_ver_string}".split("+", 1)[0].split(".", 3) while len(parts) < 3: parts += ["0"] From 0d053e3fcf7521214651f8b06521980d1d58729f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 10 Nov 2021 20:10:21 +0000 Subject: [PATCH 8/8] fixes codeformat Signed-off-by: Wenqi Li --- monai/utils/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/utils/module.py b/monai/utils/module.py index 89e2fb778a..e475ae1957 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -475,7 +475,7 @@ def pytorch_after(major, minor, patch=0, current_ver_string=None) -> bool: current_ver_string = _env_var if _env_var else torch.__version__ ver, has_ver = optional_import("pkg_resources", name="parse_version") if has_ver: - return ver(".".join((f"{major}", f"{minor}", f"{patch}"))) <= ver(f"{current_ver_string}") + return ver(".".join((f"{major}", f"{minor}", f"{patch}"))) <= ver(f"{current_ver_string}") # type: ignore parts = f"{current_ver_string}".split("+", 1)[0].split(".", 3) while len(parts) < 3: parts += ["0"]