Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

# UNet++ input size constraints

MONAI provides an enhanced version of UNet (``monai.networks.nets.UNet ``), which not only supports residual units, but also can use more hyperparameters (like ``strides``, ``kernel_size`` and ``up_kernel_size``) than ``monai.networks.nets.BasicUNet``. However, ``UNet`` has some constraints for both network hyperparameters and sizes of input.

MONAI provides a version of UNET++ (`` monai.networks.nets.BasicUnetPlusPlus ``), with fixed num. of down-scale layer, strides of 2. The configurations you can change are: the number input and output channels, number of hidden channels (6 different layers), norm and activation, bias of convolution, dropout rate, and up-sampling model. As `UNET`, different model configurations can affect the input shape.

The constraints of hyper-parameters can be found in the docstring of the network, and this tutorial is focused on how to determine a reasonable input size.

## Setup environment

In [None]:
!python -c "import monai" || pip install -q monai-weekly

## Setup imports

In [2]:
from monai.networks.nets import BasicUnetPlusPlus
import monai
import torch
import torch.nn as nn
from typing import Dict

monai.config.print_config()

MONAI version: 1.2.0+63.g5feb3530
Numpy version: 1.25.1
Pytorch version: 2.0.1
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 5feb353030e0bb204e21e1de338cd81b5972bb8a
MONAI __file__: /Users/hung.nh/codespace/yauangon/MONAI/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
scipy version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: NOT INSTALLED or UNKNOWN VERSION.
Tensorboard version: 2.13.0
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: 4.65.0
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.5
pandas version: NOT INSTALLED or UNKNOWN VERSION.
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlfl

## Check UNet++ structure

![](../../figures/unet++.png)

In [None]:
model = BasicUnetPlusPlus(
    spatial_dims=3,
    in_channels=3,
    out_channels=3,
    features=(32, 32, 64, 128, 256, 32),
    # norm='localresponse',
    norm="batch",
)
print(model)

## Normalization

UNET++ use the same `TwoConv`, `Down`, and `UpCat` as UNet. Therefore, you can referred to the `modules/UNet_input_size_constraints.ipynb` for break down analysis. For summary, the constraints for these types of normalization are:

- Instance Norm: the product of spatial dimension must > 1 (not include channel and batch)
- Batch Norm: the product of spatial dimension and batch must > 1 (not include channels). For training best interested, `batch_size` should be larger than 1
- Local Response Norm: No constraint.
- Other Normalization: please referred to `modules/UNet_input_size_constraints.ipynb`

As for UNET++ have 4 down-sampling blocks with 2x kernel size, with no argument to change this behavior, the smallest edge we can have is `2**4 = 16`, and after the last down-sampling block, the `vector.shape  = [..., ..., 1, 1]` or (`[..., ..., 1, 1, 1]` for 3D), which will cause error for the Normalization layer.

See the test code below for examples of batch norm and instance norm


In [4]:
print(monai.networks.layers.factories.Norm.factories.keys())

dict_keys(['INSTANCE', 'BATCH', 'GROUP', 'LAYER', 'LOCALRESPONSE', 'SYNCBATCH', 'INSTANCE_NVFUSER'])


In [5]:
def make_model_with_layer(layer_norm):
    return BasicUnetPlusPlus(
        spatial_dims=2, in_channels=3, out_channels=1, features=(32, 32, 64, 128, 256, 32), norm=layer_norm
    )


def test_min_dim():
    min_edge = 16
    batch_size, spatial_dim, height, width = 1, 3, min_edge, min_edge
    model_dict: Dict[str, BasicUnetPlusPlus] = {}
    print("Prepare model")
    for norm_layer in ["instance", "batch"]:
        model_dict[norm_layer] = make_model_with_layer(norm_layer)

    # print(f"Input dimension {(batch_size, spatial_dim, H, W)} that will cause error")
    for norm_layer in ["instance", "batch"]:
        print("=" * 10 + f" USING NORM LAYER: {norm_layer.upper()} " + "=" * 10)
        model = model_dict[norm_layer]
        print("_" * 10 + " Changing the H dimension of 2D input " + "_" * 10)
        for temp_height in [height, height * 2]:
            try:
                x = torch.ones(batch_size, spatial_dim, temp_height, width)
                print(f">> Using Input.shape={x.shape}")
                model(x)
            except Exception as msg:
                print(f">>>> Exception: {msg}\n")

        # print("Changing the batch size")
        print("_" * 10 + " Changing the batch size " + "_" * 10)
        for batch_size_tmp in [1, 2]:
            try:
                x = torch.ones(batch_size_tmp, spatial_dim, height, width)
                print(f">> Input.shape={x.shape}")
                model(x)
            except Exception as msg:
                print(f">> Exception: {msg}\n")
    pass


with torch.no_grad():
    test_min_dim()

Prepare model
BasicUNetPlusPlus features: (32, 32, 64, 128, 256, 32).
BasicUNetPlusPlus features: (32, 32, 64, 128, 256, 32).
__________ Changing the H dimension of 2D input __________
>> Using Input.shape=torch.Size([1, 3, 16, 16])
>>>> Exception: Expected more than 1 spatial element when training, got input size torch.Size([1, 256, 1, 1])

>> Using Input.shape=torch.Size([1, 3, 32, 16])
__________ Changing the batch size __________
>> Input.shape=torch.Size([1, 3, 16, 16])
>> Exception: Expected more than 1 spatial element when training, got input size torch.Size([1, 256, 1, 1])

>> Input.shape=torch.Size([2, 3, 16, 16])
>> Exception: Expected more than 1 spatial element when training, got input size torch.Size([2, 256, 1, 1])

__________ Changing the H dimension of 2D input __________
>> Using Input.shape=torch.Size([1, 3, 16, 16])
>>>> Exception: Expected more than 1 value per channel when training, got input size torch.Size([1, 256, 1, 1])

>> Using Input.shape=torch.Size([1, 3, 3

### Normalization conclusion

**Note:** These are lower constraint. A higher resolution input is recommended.

For convention, let consider the shape of 3D model is `(B, C, D, H, W)` and for 2D model is `(B, C, H, W)`. The minimum value of any `D, H, W` is `16`, as mentioned above

If you are using:
- Batch Norm: ensure that, there are **at least one value** of `(D, H, W)` for 3D, or `(H, W)`:
  - `>= 32` if the batch size `== 1`
  - `>= 16` if the batch size `> 1`

- Instance Norm: ensure that there are **at least one value** of `(D, H, W)` for 3D, or `(H, W)` that `>= 32`

- For Local response: No constraint.

- Others normalization: those norm required input shape agree with the norm's parameters, therefore you will have to research about those layer before any usage.


**Note**: also note that, you can pass argument to normalization layer (some will result in error if you don't), check below example.

In [9]:
# NOTE: this will result in error, as lack of argument
# model = BasicUnetPlusPlus(
#     spatial_dims=3,
#     in_channels=3,
#     out_channels=3,
#     features=(32, 32, 64, 128, 256, 32),
#     norm='localresponse',
# )

# NOTE: this will work fine
model = BasicUnetPlusPlus(
    spatial_dims=3,
    in_channels=3,
    out_channels=3,
    features=(32, 32, 64, 128, 256, 32),
    norm=("localresponse", {"size": 10}),
)

BasicUNetPlusPlus features: (32, 32, 64, 128, 256, 32).


# Pooling (down-sampling) and Up-sampling

While all the internal padding is handled to ensure the input spatial shape is the same as the output spatial shape, you still need to aware about these implicit processing:

- For down-scale path: MaxPooling with `kernel_size` and `strides` = 2. Any odd size will have the last element ignored. See below code block for the pooling with odd shape
- For up-scale path, depend on the up-samping mode, using `UpCat` module with auto padding (`replicate` mode) the up-sampled path to the same input shape as input path. See this code of the `UpCat`

```[python]
def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]):
        """

        Args:
            x: features to be upsampled.
            x_e: optional features from the encoder, if None, this branch is not in use.
        """
        x_0 = self.upsample(x)

        if x_e is not None and torch.jit.isinstance(x_e, torch.Tensor):
            if self.is_pad: # alway True
                # handling spatial shapes due to the 2x max-pooling with odd edge lengths.
                dimensions = len(x.shape) - 2
                sp = [0] * (dimensions * 2)
                for i in range(dimensions):
                    if x_e.shape[-i - 1] != x_0.shape[-i - 1]:
                        sp[i * 2 + 1] = 1
                x_0 = torch.nn.functional.pad(x_0, sp, "replicate")
            x = self.convs(torch.cat([x_e, x_0], dim=1))  # input channels: (cat_chns + up_chns)
        else:
            x = self.convs(x_0)

        return x

```


In [2]:
# Example about the pooling with odd shape
un_pool = (
    torch.Tensor(
        [
            [1, 2, 3, 4, 5],
            [1, 2, 3, 4, 5],
            [1, 2, 3, 4, 5],
            [1, 2, 3, 4, 5],
            [1, 2, 3, 4, 5],
        ]
    )
    * torch.Tensor([1, 2, 3, 4, 5])[..., None]
)
un_pool = un_pool[None, None, ...]

pooled = nn.MaxPool2d(kernel_size=2)(un_pool)
print(f"{un_pool.shape=}, {pooled.shape=}")
print(un_pool)
print(pooled)

un_pool.shape=torch.Size([1, 1, 5, 5]), pooled.shape=torch.Size([1, 1, 2, 2])
tensor([[[[ 1.,  2.,  3.,  4.,  5.],
          [ 2.,  4.,  6.,  8., 10.],
          [ 3.,  6.,  9., 12., 15.],
          [ 4.,  8., 12., 16., 20.],
          [ 5., 10., 15., 20., 25.]]]])
tensor([[[[ 4.,  8.],
          [ 8., 16.]]]])


### Up/Down sampling conclusion

It's best to keep your input spatial dimension (`W, D, H`) a multiple of 16 (`16 * x`). They will maximize your data-usage as no padding will be required. The input shape also should be greater than 32 for safety against normalization, as mentioned above.

## Features argument

Features argument decide the num. of output channels for each levels of convolution stacks. There are 6 values to fill, anymore or less will result in an error.

In the original paper, they use 2 settings:
- UNet's : `features=(32, 64, 128, 256, 512, 32)`
- Wide UNet's : `features=(32, 70, 140, 280, 560, 32)`

Note that, `features[5]` is used only for the last up-sampling + convolution block. Compare to the paper, MONAI's implementation implies that `features[5] = features[0]`. 

In [13]:
model = BasicUnetPlusPlus(
    spatial_dims=3,
    in_channels=3,
    out_channels=3,
    features=(32, 64, 128, 256, 512, 32),
)

BasicUNetPlusPlus features: (32, 64, 128, 256, 512, 32).
