# 2D Model Inference on a 3D Volume  

Usecase: A 2D Model, such as, a 2D segmentation U-Net operates on 2D input which can be slices from a 3D volume (for example, a CT scan). 

After editing sliding window inferer as described in this tutorial, it can handle the entire flow as shown:
![image](../figures/2d_inference_3d_input.png)

The input is a *3D Volume*, a *2D model* and the output is a *3D volume* with 2D slice predictions aggregated. 



[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/modules/2d_inference_3d_volume.ipynb)


In [1]:
# Install monai
!python -c "import monai" || pip install -q "monai-weekly"

In [2]:
# Import libs
from monai.inferers import SlidingWindowInferer
import torch
from typing import Callable, Any
from monai.networks.nets import UNet

## Overiding SlidingWindowInferer
The simplest way to achieve this functionality is to create a class `YourSlidingWindowInferer` that inherits from `SlidingWindowInferer` in `monai.inferers`

In [3]:
class YourSlidingWindowInferer(SlidingWindowInferer):
    def __init__(self, spatial_dim: int = 0, *args, **kwargs):
        # Set dim to slice the volume across, for example, `0` could slide over axial slices,
        # `1` over coronal slices
        # and `2` over sagittal slices.
        self.spatial_dim = spatial_dim

        super().__init__(*args, **kwargs)

    def __call__(
        self,
        inputs: torch.Tensor,
        network: Callable[..., torch.Tensor],
        slice_axis: int = 0,
        *args: Any,
        **kwargs: Any,
    ) -> torch.Tensor:

        assert (
            self.spatial_dim < 3
        ), "`spatial_dim` can only be `[D, H, W]` with `0, 1, 2` respectively"

        # Check if roi size (eg. 2D roi) and input volume sizes (3D input) mismatch
        if len(self.roi_size) != len(inputs.shape[2:]):

            # If they mismatch and roi_size is 2D add another dimension to roi size
            if len(self.roi_size) == 2:
                self.roi_size = list(self.roi_size)
                self.roi_size.insert(self.spatial_dim, 1)
            else:
                raise RuntimeError(
                    "Currently, only 2D `roi_size` is supported, cannot broadcast to volume. "
                )

        return super().__call__(inputs, lambda x: self.network_wrapper(network, x))

    def network_wrapper(self, network, x, *args, **kwargs):
        """
        Wrapper handles cases where inference needs to be done using
        2D models over 3D volume inputs.
        """
        # If depth dim is 1 in [D, H, W] roi size, then the input is 2D and needs
        # be handled accordingly

        if self.roi_size[self.spatial_dim] == 1:
            #  Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.
            x = x.squeeze(dim=self.spatial_dim + 2)
            out = network(x, *args, **kwargs)
            #  Unsqueeze the network output so it is [N, C, D, H, W] as expected by
            # the default SlidingWindowInferer class
            return out.unsqueeze(dim=self.spatial_dim + 2)

        else:
            return network(x, *args, **kwargs)

## Testing added functionality
Let's use the `YourSlidingWindowInferer` in a dummy example to execute the workflow described above.

In [4]:
# Create a 2D UNet with randomly initialized weights for testing purposes

# 3 layer network with down/upsampling by a factor of 2 at each layer with 2-convolution residual units
net = UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    channels=(4, 8, 16),
    strides=(2, 2),
    num_res_units=2,
)

# Initialize a dummy 3D tensor volume with shape (N,C,D,H,W)
input_volume = torch.ones(1, 1, 64, 256, 256)

# Create an instance of YourSlidingWindowInferer with roi_size as the 256x256 (HxW) and sliding over D axis
axial_inferer = YourSlidingWindowInferer(roi_size=(256, 256), sw_batch_size=1, cval=-1)

output = axial_inferer(input_volume, net)

# Output is a 3D volume with 2D slices aggregated
print("Axial Inferer Output Shape: ", output.shape)
# Create an instance of YourSlidingWindowInferer with roi_size as the 64x256 (DxW) and sliding over H axis
coronal_inferer = YourSlidingWindowInferer(
    roi_size=(64, 256),
    sw_batch_size=1,
    spatial_dim=1,  # Spatial dim to slice along is added here
    cval=-1,
)

output = coronal_inferer(input_volume, net)

# Output is a 3D volume with 2D slices aggregated
print("Coronal Inferer Output Shape: ", output.shape)

Axial Inferer Output Shape:  torch.Size([1, 1, 64, 256, 256])
Coronal Inferer Output Shape:  torch.Size([1, 1, 64, 256, 256])
