diff --git a/monai/networks/utils.py b/monai/networks/utils.py index f7fd2e2956..5a2bb13fec 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -20,7 +20,7 @@ import torch import torch.nn as nn -from monai.utils.deprecate_utils import deprecated_arg +from monai.utils.deprecate_utils import deprecated, deprecated_arg from monai.utils.misc import ensure_tuple, set_determinism from monai.utils.module import pytorch_after @@ -93,7 +93,13 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f return labels +@deprecated(since="0.8.0", msg_suffix="use `monai.utils.misc.sample_slices` instead.") def slice_channels(tensor: torch.Tensor, *slicevals: Optional[int]) -> torch.Tensor: + """ + .. deprecated:: 0.8.0 + Use `monai.utils.misc.sample_slices` instead. + + """ slices = [slice(None)] * len(tensor.shape) slices[1] = slice(*slicevals) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 0c04680234..59b779e3b9 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -54,6 +54,7 @@ issequenceiterable, list_to_dict, progress_bar, + sample_slices, set_determinism, star_zip_with, zip_with, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index eae0580696..7aa6c5bbc3 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -22,6 +22,7 @@ import numpy as np import torch +from monai.config.type_definitions import NdarrayOrTensor from monai.utils.module import version_leq __all__ = [ @@ -44,6 +45,7 @@ "ImageMetaKey", "is_module_ver_at_least", "has_option", + "sample_slices", ] _seed = None @@ -366,3 +368,21 @@ def is_module_ver_at_least(module, version): """ test_ver = ".".join(map(str, version)) return module.__version__ != test_ver and version_leq(test_ver, module.__version__) + + +def sample_slices(data: NdarrayOrTensor, dim: int = 1, as_indices: bool = True, *slicevals: int) -> NdarrayOrTensor: + """sample several slices of input numpy array or Tensor on specified `dim`. + + Args: + data: input data to sample slices, can be numpy array or PyTorch Tensor. + dim: expected dimension index to sample slices, default to `1`. + as_indices: if `True`, `slicevals` arg will be treated as the expected indices of slice, like: `1, 3, 5` + means `data[..., [1, 3, 5], ...]`, if `False`, `slicevals` arg will be treated as args for `slice` func, + like: `1, None` means `data[..., [1:], ...]`, `1, 5` means `data[..., [1: 5], ...]`. + slicevals: indices of slices or start and end indices of expected slices, depends on `as_indices` flag. + + """ + slices = [slice(None)] * len(data.shape) + slices[dim] = slicevals if as_indices else slice(*slicevals) # type: ignore + + return data[tuple(slices)] diff --git a/tests/test_sample_slices.py b/tests/test_sample_slices.py new file mode 100644 index 0000000000..117d39b486 --- /dev/null +++ b/tests/test_sample_slices.py @@ -0,0 +1,41 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.utils import sample_slices +from tests.utils import TEST_NDARRAYS, assert_allclose + +# test data[:, [1, ], ...] +TEST_CASE_1 = [torch.tensor([[[0, 2], [1, 0]]]), 1, True, (1,), torch.tensor([[[1, 0]]])] +# test data[:, [0, 2], ...] +TEST_CASE_2 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, True, (0, 2), torch.tensor([[[0, 2], [4, 5]]])] +# test data[:, [0: 2], ...] +TEST_CASE_3 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, False, (0, 2), torch.tensor([[[0, 2], [1, 0]]])] +# test data[:, [1: ], ...] +TEST_CASE_4 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, False, (1, None), torch.tensor([[[1, 0], [4, 5]]])] +# test data[:, [0: 3: 2], ...] +TEST_CASE_5 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, False, (0, 3, 2), torch.tensor([[[0, 2], [4, 5]]])] + + +class TestSampleSlices(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + def test_shape(self, input_data, dim, as_indices, vals, expected_result): + for p in TEST_NDARRAYS: + result = sample_slices(p(input_data), dim, as_indices, *vals) + assert_allclose(p(expected_result), result) + + +if __name__ == "__main__": + unittest.main()