Skip to content
8 changes: 7 additions & 1 deletion monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
import torch.nn as nn

from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.deprecate_utils import deprecated, deprecated_arg
from monai.utils.misc import ensure_tuple, set_determinism
from monai.utils.module import pytorch_after

Expand Down Expand Up @@ -93,7 +93,13 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f
return labels


@deprecated(since="0.8.0", msg_suffix="use `monai.utils.misc.sample_slices` instead.")
def slice_channels(tensor: torch.Tensor, *slicevals: Optional[int]) -> torch.Tensor:
"""
.. deprecated:: 0.8.0
Comment thread
Nic-Ma marked this conversation as resolved.
Use `monai.utils.misc.sample_slices` instead.

"""
slices = [slice(None)] * len(tensor.shape)
slices[1] = slice(*slicevals)

Expand Down
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
issequenceiterable,
list_to_dict,
progress_bar,
sample_slices,
set_determinism,
star_zip_with,
zip_with,
Expand Down
20 changes: 20 additions & 0 deletions monai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as np
import torch

from monai.config.type_definitions import NdarrayOrTensor
from monai.utils.module import version_leq

__all__ = [
Expand All @@ -44,6 +45,7 @@
"ImageMetaKey",
"is_module_ver_at_least",
"has_option",
"sample_slices",
]

_seed = None
Expand Down Expand Up @@ -366,3 +368,21 @@ def is_module_ver_at_least(module, version):
"""
test_ver = ".".join(map(str, version))
return module.__version__ != test_ver and version_leq(test_ver, module.__version__)


def sample_slices(data: NdarrayOrTensor, dim: int = 1, as_indices: bool = True, *slicevals: int) -> NdarrayOrTensor:
"""sample several slices of input numpy array or Tensor on specified `dim`.

Args:
data: input data to sample slices, can be numpy array or PyTorch Tensor.
dim: expected dimension index to sample slices, default to `1`.
as_indices: if `True`, `slicevals` arg will be treated as the expected indices of slice, like: `1, 3, 5`
means `data[..., [1, 3, 5], ...]`, if `False`, `slicevals` arg will be treated as args for `slice` func,
like: `1, None` means `data[..., [1:], ...]`, `1, 5` means `data[..., [1: 5], ...]`.
slicevals: indices of slices or start and end indices of expected slices, depends on `as_indices` flag.

"""
slices = [slice(None)] * len(data.shape)
slices[dim] = slicevals if as_indices else slice(*slicevals) # type: ignore

return data[tuple(slices)]
41 changes: 41 additions & 0 deletions tests/test_sample_slices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.utils import sample_slices
from tests.utils import TEST_NDARRAYS, assert_allclose

# test data[:, [1, ], ...]
TEST_CASE_1 = [torch.tensor([[[0, 2], [1, 0]]]), 1, True, (1,), torch.tensor([[[1, 0]]])]
# test data[:, [0, 2], ...]
TEST_CASE_2 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, True, (0, 2), torch.tensor([[[0, 2], [4, 5]]])]
# test data[:, [0: 2], ...]
TEST_CASE_3 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, False, (0, 2), torch.tensor([[[0, 2], [1, 0]]])]
# test data[:, [1: ], ...]
TEST_CASE_4 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, False, (1, None), torch.tensor([[[1, 0], [4, 5]]])]
# test data[:, [0: 3: 2], ...]
TEST_CASE_5 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, False, (0, 3, 2), torch.tensor([[[0, 2], [4, 5]]])]


class TestSampleSlices(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
def test_shape(self, input_data, dim, as_indices, vals, expected_result):
for p in TEST_NDARRAYS:
result = sample_slices(p(input_data), dim, as_indices, *vals)
assert_allclose(p(expected_result), result)


if __name__ == "__main__":
unittest.main()