From a274e48b7b5d402a21a2e87a5965425f1d503764 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 6 Jan 2022 00:12:51 +0800 Subject: [PATCH 1/5] [DLMED] enhance slice_channels Signed-off-by: Nic Ma --- monai/networks/utils.py | 12 ++++++-- tests/test_slice_channels.py | 55 ++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) create mode 100644 tests/test_slice_channels.py diff --git a/monai/networks/utils.py b/monai/networks/utils.py index f7fd2e2956..4db5957b2a 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -93,9 +93,17 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f return labels -def slice_channels(tensor: torch.Tensor, *slicevals: Optional[int]) -> torch.Tensor: +def slice_channels(tensor: torch.Tensor, dim: int = 1, *slicevals: int) -> torch.Tensor: + """Slice several channels of input Tensor on specified `dim`. + + Args: + tensor: input Tensor data to slice. + dim: expected dimension index to slice channels, default to `1`. + slicevals: channel indices to slice. + + """ slices = [slice(None)] * len(tensor.shape) - slices[1] = slice(*slicevals) + slices[dim] = slicevals # type: ignore return tensor[slices] diff --git a/tests/test_slice_channels.py b/tests/test_slice_channels.py new file mode 100644 index 0000000000..f0999ba24c --- /dev/null +++ b/tests/test_slice_channels.py @@ -0,0 +1,55 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks import one_hot, slice_channels + +TEST_CASE_1 = [ # single channel 2D, batch 2, shape (2, 1, 2, 2) + {"labels": torch.tensor([[[[0, 1], [1, 2]]], [[[2, 1], [1, 0]]]]), "num_classes": 3}, + 0, + (1,), + (1, 3, 2, 2), +] + +TEST_CASE_2 = [ # single channel 1D, batch 2, shape (2, 1, 4) + {"labels": torch.tensor([[[1, 2, 2, 0]], [[2, 1, 0, 1]]]), "num_classes": 3}, + 1, + (1, 2), + (2, 2, 4), + np.array([[[1, 0, 0, 0], [0, 1, 1, 0]], [[0, 1, 0, 1], [1, 0, 0, 0]]]), +] + +TEST_CASE_3 = [ # single channel 2D, batch 2, shape (2, 1, 2, 2) + {"labels": torch.tensor([[[[0, 1], [1, 2]]], [[[2, 1], [1, 0]]]]), "num_classes": 3}, + 0, + (), # select no channels + (0, 3, 2, 2), +] + + +class TestSliceChannels(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, input_data, dim, vals, expected_shape, expected_result=None): + result = one_hot(**input_data) + result = slice_channels(result, dim, *vals) + + self.assertEqual(result.shape, expected_shape) + if expected_result is not None: + self.assertTrue(np.allclose(expected_result, result.numpy())) + + +if __name__ == "__main__": + unittest.main() From 43dbc2a11fb2c2681371bd6c6dab2e8c506b3ed7 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 6 Jan 2022 07:17:05 +0800 Subject: [PATCH 2/5] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/networks/utils.py | 16 +++++++--------- monai/utils/__init__.py | 1 + monai/utils/misc.py | 17 +++++++++++++++++ ..._slice_channels.py => test_sample_slices.py} | 17 ++++++++++------- 4 files changed, 35 insertions(+), 16 deletions(-) rename tests/{test_slice_channels.py => test_sample_slices.py} (76%) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 4db5957b2a..5a2bb13fec 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -20,7 +20,7 @@ import torch import torch.nn as nn -from monai.utils.deprecate_utils import deprecated_arg +from monai.utils.deprecate_utils import deprecated, deprecated_arg from monai.utils.misc import ensure_tuple, set_determinism from monai.utils.module import pytorch_after @@ -93,17 +93,15 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f return labels -def slice_channels(tensor: torch.Tensor, dim: int = 1, *slicevals: int) -> torch.Tensor: - """Slice several channels of input Tensor on specified `dim`. - - Args: - tensor: input Tensor data to slice. - dim: expected dimension index to slice channels, default to `1`. - slicevals: channel indices to slice. +@deprecated(since="0.8.0", msg_suffix="use `monai.utils.misc.sample_slices` instead.") +def slice_channels(tensor: torch.Tensor, *slicevals: Optional[int]) -> torch.Tensor: + """ + .. deprecated:: 0.8.0 + Use `monai.utils.misc.sample_slices` instead. """ slices = [slice(None)] * len(tensor.shape) - slices[dim] = slicevals # type: ignore + slices[1] = slice(*slicevals) return tensor[slices] diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 0c04680234..59b779e3b9 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -54,6 +54,7 @@ issequenceiterable, list_to_dict, progress_bar, + sample_slices, set_determinism, star_zip_with, zip_with, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index eae0580696..813c71c648 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -22,6 +22,7 @@ import numpy as np import torch +from monai.config.type_definitions import NdarrayOrTensor from monai.utils.module import version_leq __all__ = [ @@ -44,6 +45,7 @@ "ImageMetaKey", "is_module_ver_at_least", "has_option", + "sample_slices", ] _seed = None @@ -366,3 +368,18 @@ def is_module_ver_at_least(module, version): """ test_ver = ".".join(map(str, version)) return module.__version__ != test_ver and version_leq(test_ver, module.__version__) + + +def sample_slices(data: NdarrayOrTensor, dim: int = 1, *slicevals: int) -> NdarrayOrTensor: + """sample several slices of input numpy array or Tensor on specified `dim`. + + Args: + data: input data to sample slices, can be numpy array or PyTorch Tensor. + dim: expected dimension index to sample slices, default to `1`. + slicevals: indices of slice to sample. + + """ + slices = [slice(None)] * len(data.shape) + slices[dim] = slicevals # type: ignore + + return data[tuple(slices)] diff --git a/tests/test_slice_channels.py b/tests/test_sample_slices.py similarity index 76% rename from tests/test_slice_channels.py rename to tests/test_sample_slices.py index f0999ba24c..329b42fa0b 100644 --- a/tests/test_slice_channels.py +++ b/tests/test_sample_slices.py @@ -15,7 +15,9 @@ import torch from parameterized import parameterized -from monai.networks import one_hot, slice_channels +from monai.networks import one_hot +from monai.utils import sample_slices +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASE_1 = [ # single channel 2D, batch 2, shape (2, 1, 2, 2) {"labels": torch.tensor([[[[0, 1], [1, 2]]], [[[2, 1], [1, 0]]]]), "num_classes": 3}, @@ -40,15 +42,16 @@ ] -class TestSliceChannels(unittest.TestCase): +class TestSampleSlices(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, input_data, dim, vals, expected_shape, expected_result=None): - result = one_hot(**input_data) - result = slice_channels(result, dim, *vals) + onehot = one_hot(**input_data) + for p in TEST_NDARRAYS: + result = sample_slices(p(onehot), dim, *vals) - self.assertEqual(result.shape, expected_shape) - if expected_result is not None: - self.assertTrue(np.allclose(expected_result, result.numpy())) + self.assertEqual(result.shape, expected_shape) + if expected_result is not None: + assert_allclose(p(expected_result), result) if __name__ == "__main__": From 41a158cb0b9b2325fbef7fd248c34e80f7e47682 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 6 Jan 2022 12:01:06 +0800 Subject: [PATCH 3/5] [DLMED] update according to comments Signed-off-by: Nic Ma --- tests/test_sample_slices.py | 36 +++++++----------------------------- 1 file changed, 7 insertions(+), 29 deletions(-) diff --git a/tests/test_sample_slices.py b/tests/test_sample_slices.py index 329b42fa0b..599cb2ae78 100644 --- a/tests/test_sample_slices.py +++ b/tests/test_sample_slices.py @@ -11,47 +11,25 @@ import unittest -import numpy as np import torch from parameterized import parameterized -from monai.networks import one_hot from monai.utils import sample_slices from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ # single channel 2D, batch 2, shape (2, 1, 2, 2) - {"labels": torch.tensor([[[[0, 1], [1, 2]]], [[[2, 1], [1, 0]]]]), "num_classes": 3}, - 0, - (1,), - (1, 3, 2, 2), -] - -TEST_CASE_2 = [ # single channel 1D, batch 2, shape (2, 1, 4) - {"labels": torch.tensor([[[1, 2, 2, 0]], [[2, 1, 0, 1]]]), "num_classes": 3}, - 1, - (1, 2), - (2, 2, 4), - np.array([[[1, 0, 0, 0], [0, 1, 1, 0]], [[0, 1, 0, 1], [1, 0, 0, 0]]]), -] - -TEST_CASE_3 = [ # single channel 2D, batch 2, shape (2, 1, 2, 2) - {"labels": torch.tensor([[[[0, 1], [1, 2]]], [[[2, 1], [1, 0]]]]), "num_classes": 3}, - 0, - (), # select no channels - (0, 3, 2, 2), -] +TEST_CASE_1 = [torch.tensor([[[0, 2], [1, 0]]]), 1, (1,), (1, 1, 2), torch.tensor([[[1, 0]]])] + +TEST_CASE_2 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, (0, 2), (1, 2, 2), torch.tensor([[[0, 2], [4, 5]]])] class TestSampleSlices(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_data, dim, vals, expected_shape, expected_result=None): - onehot = one_hot(**input_data) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_shape(self, input_data, dim, vals, expected_shape, expected_result): for p in TEST_NDARRAYS: - result = sample_slices(p(onehot), dim, *vals) + result = sample_slices(p(input_data), dim, *vals) self.assertEqual(result.shape, expected_shape) - if expected_result is not None: - assert_allclose(p(expected_result), result) + assert_allclose(p(expected_result), result) if __name__ == "__main__": From 1fce20129e69966afa860b60a5860d5f23962146 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 6 Jan 2022 22:13:14 +0800 Subject: [PATCH 4/5] [DLMED] remove test Signed-off-by: Nic Ma --- tests/test_sample_slices.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/test_sample_slices.py b/tests/test_sample_slices.py index 599cb2ae78..532e2d63c0 100644 --- a/tests/test_sample_slices.py +++ b/tests/test_sample_slices.py @@ -17,18 +17,16 @@ from monai.utils import sample_slices from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [torch.tensor([[[0, 2], [1, 0]]]), 1, (1,), (1, 1, 2), torch.tensor([[[1, 0]]])] +TEST_CASE_1 = [torch.tensor([[[0, 2], [1, 0]]]), 1, (1,), torch.tensor([[[1, 0]]])] -TEST_CASE_2 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, (0, 2), (1, 2, 2), torch.tensor([[[0, 2], [4, 5]]])] +TEST_CASE_2 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, (0, 2), torch.tensor([[[0, 2], [4, 5]]])] class TestSampleSlices(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_shape(self, input_data, dim, vals, expected_shape, expected_result): + def test_shape(self, input_data, dim, vals, expected_result): for p in TEST_NDARRAYS: result = sample_slices(p(input_data), dim, *vals) - - self.assertEqual(result.shape, expected_shape) assert_allclose(p(expected_result), result) From 63d4dbdf69e60f711ac3b6e40fa87ba016944555 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 7 Jan 2022 07:17:51 +0800 Subject: [PATCH 5/5] [DLMED] add flag Signed-off-by: Nic Ma --- monai/utils/misc.py | 9 ++++++--- tests/test_sample_slices.py | 19 +++++++++++++------ 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 813c71c648..7aa6c5bbc3 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -370,16 +370,19 @@ def is_module_ver_at_least(module, version): return module.__version__ != test_ver and version_leq(test_ver, module.__version__) -def sample_slices(data: NdarrayOrTensor, dim: int = 1, *slicevals: int) -> NdarrayOrTensor: +def sample_slices(data: NdarrayOrTensor, dim: int = 1, as_indices: bool = True, *slicevals: int) -> NdarrayOrTensor: """sample several slices of input numpy array or Tensor on specified `dim`. Args: data: input data to sample slices, can be numpy array or PyTorch Tensor. dim: expected dimension index to sample slices, default to `1`. - slicevals: indices of slice to sample. + as_indices: if `True`, `slicevals` arg will be treated as the expected indices of slice, like: `1, 3, 5` + means `data[..., [1, 3, 5], ...]`, if `False`, `slicevals` arg will be treated as args for `slice` func, + like: `1, None` means `data[..., [1:], ...]`, `1, 5` means `data[..., [1: 5], ...]`. + slicevals: indices of slices or start and end indices of expected slices, depends on `as_indices` flag. """ slices = [slice(None)] * len(data.shape) - slices[dim] = slicevals # type: ignore + slices[dim] = slicevals if as_indices else slice(*slicevals) # type: ignore return data[tuple(slices)] diff --git a/tests/test_sample_slices.py b/tests/test_sample_slices.py index 532e2d63c0..117d39b486 100644 --- a/tests/test_sample_slices.py +++ b/tests/test_sample_slices.py @@ -17,16 +17,23 @@ from monai.utils import sample_slices from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [torch.tensor([[[0, 2], [1, 0]]]), 1, (1,), torch.tensor([[[1, 0]]])] - -TEST_CASE_2 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, (0, 2), torch.tensor([[[0, 2], [4, 5]]])] +# test data[:, [1, ], ...] +TEST_CASE_1 = [torch.tensor([[[0, 2], [1, 0]]]), 1, True, (1,), torch.tensor([[[1, 0]]])] +# test data[:, [0, 2], ...] +TEST_CASE_2 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, True, (0, 2), torch.tensor([[[0, 2], [4, 5]]])] +# test data[:, [0: 2], ...] +TEST_CASE_3 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, False, (0, 2), torch.tensor([[[0, 2], [1, 0]]])] +# test data[:, [1: ], ...] +TEST_CASE_4 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, False, (1, None), torch.tensor([[[1, 0], [4, 5]]])] +# test data[:, [0: 3: 2], ...] +TEST_CASE_5 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, False, (0, 3, 2), torch.tensor([[[0, 2], [4, 5]]])] class TestSampleSlices(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_shape(self, input_data, dim, vals, expected_result): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + def test_shape(self, input_data, dim, as_indices, vals, expected_result): for p in TEST_NDARRAYS: - result = sample_slices(p(input_data), dim, *vals) + result = sample_slices(p(input_data), dim, as_indices, *vals) assert_allclose(p(expected_result), result)