# 2D Model Inference on a 3D Volume  

Usecase: A 2D Model, such as, a 2D segmentation U-Net operates on 2D input which can be slices from a 3D volume (for example, a CT scan). 

After editing sliding window inferer as described in this tutorial, it can handle the entire flow as shown:
![image](../figures/2d_inference_3d_input.png)

The input is a *3D Volume*, a *2D model* and the output is a *3D volume* with 2D slice predictions aggregated. 



[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/modules/2d_inference_3d_volume.ipynb)


In [4]:
# Install monai
!python -c "import monai" || pip install -q "monai-weekly[tqdm]"

In [5]:
# Import libs
from monai.inferers import SliceInferer
import torch
from monai.networks.nets import UNet

## SliceInferer
The simplest way to achieve this functionality is to extend the `SlidingWindowInferer` in `monai.inferers`. This is made available as `SliceInferer` in MONAI (https://docs.monai.io/en/latest/inferers.html#sliceinferer).

## Usage

In [6]:
# Create a 2D UNet with randomly initialized weights for testing purposes

# 3 layer network with down/upsampling by a factor of 2 at each layer with 2-convolution residual units
net = UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    channels=(4, 8, 16),
    strides=(2, 2),
    num_res_units=2,
)

# Initialize a dummy 3D tensor volume with shape (N,C,D,H,W)
input_volume = torch.ones(1, 1, 64, 256, 256)

# Create an instance of SliceInferer with roi_size as the 256x256 (HxW) and sliding over D axis
axial_inferer = SliceInferer(roi_size=(256, 256), sw_batch_size=1, cval=-1, progress=True)

output = axial_inferer(input_volume, net)

# Output is a 3D volume with 2D slices aggregated
print("Axial Inferer Output Shape: ", output.shape)
# Create an instance of SliceInferer with roi_size as the 64x256 (DxW) and sliding over H axis
coronal_inferer = SliceInferer(
    roi_size=(64, 256),
    sw_batch_size=1,
    spatial_dim=1,  # Spatial dim to slice along is added here
    cval=-1,
    progress=True,
)

output = coronal_inferer(input_volume, net)

# Output is a 3D volume with 2D slices aggregated
print("Coronal Inferer Output Shape: ", output.shape)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:00<00:00, 107.33it/s]


Axial Inferer Output Shape:  torch.Size([1, 1, 64, 256, 256])


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [00:01<00:00, 177.69it/s]


Coronal Inferer Output Shape:  torch.Size([1, 1, 64, 256, 256])


Note that with `axial_inferer` and `coronal_inferer`, the number of inference iterations is 64 and 256 respectively.