From 4be56fba81b9e1be99300970495b7add924d5c84 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Thu, 6 Jan 2022 20:42:40 -0600 Subject: [PATCH 1/3] Add SliceInferer for slice-by-slice infer Signed-off-by: Suraj Pai --- docs/source/inferers.rst | 6 ++++ monai/inferers/__init__.py | 2 +- monai/inferers/inferer.py | 67 +++++++++++++++++++++++++++++++++++++- 3 files changed, 73 insertions(+), 2 deletions(-) diff --git a/docs/source/inferers.rst b/docs/source/inferers.rst index e358e603bd..5f6eac6ef0 100644 --- a/docs/source/inferers.rst +++ b/docs/source/inferers.rst @@ -36,3 +36,9 @@ Inferers .. autoclass:: SaliencyInferer :members: :special-members: __call__ + +`SliceInferer` +~~~~~~~~~~~~~~~~~ +.. autoclass:: SliceInferer + :members: + :special-members: __call__ diff --git a/monai/inferers/__init__.py b/monai/inferers/__init__.py index 20d829297f..3447782be9 100644 --- a/monai/inferers/__init__.py +++ b/monai/inferers/__init__.py @@ -9,5 +9,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .inferer import Inferer, SaliencyInferer, SimpleInferer, SlidingWindowInferer +from .inferer import Inferer, SaliencyInferer, SimpleInferer, SliceInferer, SlidingWindowInferer from .utils import sliding_window_inference diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index c7b70e06ca..289db505d5 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -19,7 +19,7 @@ from monai.utils import BlendMode, PytorchPadMode from monai.visualize import CAM, GradCAM, GradCAMpp -__all__ = ["Inferer", "SimpleInferer", "SlidingWindowInferer", "SaliencyInferer"] +__all__ = ["Inferer", "SimpleInferer", "SlidingWindowInferer", "SaliencyInferer", "SliceInferer"] class Inferer(ABC): @@ -221,3 +221,68 @@ def __call__(self, inputs: torch.Tensor, network: nn.Module, *args: Any, **kwarg cam = GradCAMpp(network, self.target_layers, *self.args, **self.kwargs) return cam(inputs, self.class_idx, *args, **kwargs) + + +class SliceInferer(SlidingWindowInferer): + """ + SliceInferer extends SlidingWindowInferer to provide slice-by-slice (2D) inference + when provided a 3D volume. + + Args: + spatial_dim: Spatial dimension over which the slice-by-slice inference runs on the 3D volume. + For example ``0`` could slide over axial slices. ``1`` over coronal slices and ``2`` over sagittal slices. + args: other optional args to be passed to the `__init__` of SliceInferer. + kwargs: other optional keyword args to be passed to `__init__` of SliceInferer. + + + ``args``, ``kwargs`` follow :py:class:`monai.inferer.SlidingWindowInferer`. + + """ + + def __init__(self, spatial_dim: int = 0, *args, **kwargs) -> None: + self.spatial_dim = spatial_dim + super().__init__(*args, **kwargs) + + def __call__( + self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any + ) -> torch.Tensor: + """ + Args: + inputs: 3D input for inference + network: 2D model to execute inference on slices in the 3D input + args: optional args to be passed to ``network``. + kwargs: optional keyword args to be passed to ``network``. + """ + assert self.spatial_dim < 3, "`spatial_dim` can only be `[D, H, W]` with `0, 1, 2` respectively" + + # Check if roi size (eg. 2D roi) and input volume sizes (3D input) mismatch + if len(self.roi_size) != len(inputs.shape[2:]): + + # If they mismatch and roi_size is 2D add another dimension to roi size + if len(self.roi_size) == 2: + self.roi_size = list(self.roi_size) + self.roi_size.insert(self.spatial_dim, 1) + else: + raise RuntimeError("Currently, only 2D `roi_size` is supported, cannot broadcast to volume. ") + + return super().__call__(inputs, lambda x: self.network_wrapper(network, x)) + + def network_wrapper( + self, network: Callable[..., torch.Tensor], x: torch.Tensor, *args, **kwargs + ) -> Callable[..., torch.Tensor]: + """ + Wrapper handles cases where inference needs to be done using + 2D models over 3D volume inputs. + """ + # If depth dim is 1 in [D, H, W] roi size, then the input is 2D and needs + # be handled accordingly + + if self.roi_size[self.spatial_dim] == 1: + # Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D. + x = x.squeeze(dim=self.spatial_dim + 2) + out = network(x, *args, **kwargs) + # Unsqueeze the network output so it is [N, C, D, H, W] as expected by + # the default SlidingWindowInferer class + return out.unsqueeze(dim=self.spatial_dim + 2) + else: + return network(x, *args, **kwargs) From 462ff6b2235356c77e8491fe25f183bab9929d93 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Mon, 17 Jan 2022 21:14:12 -0500 Subject: [PATCH 2/3] Add tests + fix review comments Signed-off-by: Suraj Pai --- docs/source/inferers.rst | 2 +- monai/inferers/inferer.py | 50 +++++++++++++++--------------------- tests/test_slice_inferer.py | 51 +++++++++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 30 deletions(-) create mode 100644 tests/test_slice_inferer.py diff --git a/docs/source/inferers.rst b/docs/source/inferers.rst index 5f6eac6ef0..ac638eb38d 100644 --- a/docs/source/inferers.rst +++ b/docs/source/inferers.rst @@ -38,7 +38,7 @@ Inferers :special-members: __call__ `SliceInferer` -~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~ .. autoclass:: SliceInferer :members: :special-members: __call__ diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 289db505d5..b133cdd182 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -231,11 +231,12 @@ class SliceInferer(SlidingWindowInferer): Args: spatial_dim: Spatial dimension over which the slice-by-slice inference runs on the 3D volume. For example ``0`` could slide over axial slices. ``1`` over coronal slices and ``2`` over sagittal slices. - args: other optional args to be passed to the `__init__` of SliceInferer. - kwargs: other optional keyword args to be passed to `__init__` of SliceInferer. + args: other optional args to be passed to the `__init__` of base class SlidingWindowInferer. + kwargs: other optional keyword args to be passed to `__init__` of base class SlidingWindowInferer. - - ``args``, ``kwargs`` follow :py:class:`monai.inferer.SlidingWindowInferer`. + Note: + ``roi_size`` in SliceInferer is expected to be a 2D tuple when a 3D volume is provided. This allows + sliding across slices along the 3D volume using a selected ``spatial_dim``. """ @@ -253,36 +254,27 @@ def __call__( args: optional args to be passed to ``network``. kwargs: optional keyword args to be passed to ``network``. """ - assert self.spatial_dim < 3, "`spatial_dim` can only be `[D, H, W]` with `0, 1, 2` respectively" - - # Check if roi size (eg. 2D roi) and input volume sizes (3D input) mismatch - if len(self.roi_size) != len(inputs.shape[2:]): + if self.spatial_dim > 2: + raise ValueError("`spatial_dim` can only be `[H, W, D]` with `0, 1, 2` respectively.") - # If they mismatch and roi_size is 2D add another dimension to roi size - if len(self.roi_size) == 2: - self.roi_size = list(self.roi_size) - self.roi_size.insert(self.spatial_dim, 1) - else: - raise RuntimeError("Currently, only 2D `roi_size` is supported, cannot broadcast to volume. ") + # Check if ``roi_size`` tuple is 2D and ``inputs`` tensor is 3D + if len(self.roi_size) == 2 and len(inputs.shape[2:]) == 3: + self.roi_size = list(self.roi_size) + self.roi_size.insert(self.spatial_dim, 1) + else: + raise RuntimeError("Currently, only 2D `roi_size` with 3D `inputs` tensor is supported.") - return super().__call__(inputs, lambda x: self.network_wrapper(network, x)) + return super().__call__(inputs, lambda x: self.network_wrapper(network, x, *args, **kwargs)) def network_wrapper( self, network: Callable[..., torch.Tensor], x: torch.Tensor, *args, **kwargs ) -> Callable[..., torch.Tensor]: """ - Wrapper handles cases where inference needs to be done using - 2D models over 3D volume inputs. + Wrapper handles inference for 2D models over 3D volume inputs. """ - # If depth dim is 1 in [D, H, W] roi size, then the input is 2D and needs - # be handled accordingly - - if self.roi_size[self.spatial_dim] == 1: - # Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D. - x = x.squeeze(dim=self.spatial_dim + 2) - out = network(x, *args, **kwargs) - # Unsqueeze the network output so it is [N, C, D, H, W] as expected by - # the default SlidingWindowInferer class - return out.unsqueeze(dim=self.spatial_dim + 2) - else: - return network(x, *args, **kwargs) + # Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D. + x = x.squeeze(dim=self.spatial_dim + 2) + out = network(x, *args, **kwargs) + # Unsqueeze the network output so it is [N, C, D, H, W] as expected by + # the default SlidingWindowInferer class + return out.unsqueeze(dim=self.spatial_dim + 2) diff --git a/tests/test_slice_inferer.py b/tests/test_slice_inferer.py new file mode 100644 index 0000000000..0aa87b75cf --- /dev/null +++ b/tests/test_slice_inferer.py @@ -0,0 +1,51 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.inferers import SliceInferer +from monai.networks.nets import UNet + +TEST_CASES = ["0", "1", "2"] + + +class TestSliceInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, spatial_dim): + spatial_dim = int(spatial_dim) + + model = UNet( + spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16), strides=(2, 2), num_res_units=2 + ) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + + # Initialize a dummy 3D tensor volume with shape (N,C,D,H,W) + input_volume = torch.ones(1, 1, 64, 256, 256) + + # Remove spatial dim to slide across from the roi_size + roi_size = list(input_volume.shape[2:]) + roi_size.pop(spatial_dim) + + # Initialize and run inferer + inferer = SliceInferer(roi_size=roi_size, spatial_dim=spatial_dim, sw_batch_size=1, cval=-1) + result = inferer(input_volume, model) + + self.assertTupleEqual(result.shape, input_volume.shape) + + +if __name__ == "__main__": + unittest.main() From 81baf9dd712458c03e6d1fe798badc654410c992 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Mon, 24 Jan 2022 11:47:07 -0500 Subject: [PATCH 3/3] Update error message Signed-off-by: Suraj Pai --- monai/inferers/inferer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 326e41acf7..7185674d79 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -256,7 +256,7 @@ def __call__( kwargs: optional keyword args to be passed to ``network``. """ if self.spatial_dim > 2: - raise ValueError("`spatial_dim` can only be `[H, W, D]` with `0, 1, 2` respectively.") + raise ValueError("`spatial_dim` can only be `0, 1, 2` with `[H, W, D]` respectively.") # Check if ``roi_size`` tuple is 2D and ``inputs`` tensor is 3D if len(self.roi_size) == 2 and len(inputs.shape[2:]) == 3: