In [1]:
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [2]:
from monai.networks.nets import BasicUnetPlusPlus
import monai
import math
import torch
import torch.nn as nn

monai.config.print_config()

MONAI version: 1.2.0+63.g5feb3530
Numpy version: 1.25.1
Pytorch version: 2.0.1
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 5feb353030e0bb204e21e1de338cd81b5972bb8a
MONAI __file__: /Users/hung.nh/codespace/yauangon/MONAI/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
scipy version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: NOT INSTALLED or UNKNOWN VERSION.
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: NOT INSTALLED or UNKNOWN VERSION.
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.5
pandas version: NOT INSTALLED or UNKNOWN VERSION.
einops version: NOT INSTALLED or UNKNOWN VERSION.
transf

In [71]:
# def __init__(
#         self,
#         spatial_dims: int = 3,
#         in_channels: int = 1,
#         out_channels: int = 2,
#         features: Sequence[int] = (32, 32, 64, 128, 256, 32),
#         deep_supervision: bool = False,
#         act: str | tuple = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}),
#         norm: str | tuple = ("instance", {"affine": True}),
#         bias: bool = True,
#         dropout: float | tuple = 0.0,
#         upsample: str = "deconv",
#     ):
print(monai.networks.layers.factories.Norm.factories.keys())
for norm_layer in monai.networks.layers.factories.Norm.factories.keys():
    try:
        model = BasicUnetPlusPlus(  
            spatial_dims=3,
            in_channels=3,
            out_channels=3,
            features=(32, 32, 64, 128, 256, 32),
            # norm='localresponse',
            norm=norm_layer,
        )
    except Exception as msg:
        print(f"Exception layer: {msg}")
    

dict_keys(['INSTANCE', 'BATCH', 'GROUP', 'LAYER', 'LOCALRESPONSE', 'SYNCBATCH', 'INSTANCE_NVFUSER'])
BasicUNetPlusPlus features: (32, 32, 64, 128, 256, 32).
BasicUNetPlusPlus features: (32, 32, 64, 128, 256, 32).
BasicUNetPlusPlus features: (32, 32, 64, 128, 256, 32).
Exception layer: __init__() missing 1 required positional argument: 'num_groups'
BasicUNetPlusPlus features: (32, 32, 64, 128, 256, 32).
Exception layer: __init__() missing 1 required positional argument: 'normalized_shape'
BasicUNetPlusPlus features: (32, 32, 64, 128, 256, 32).
Exception layer: __init__() missing 1 required positional argument: 'size'
BasicUNetPlusPlus features: (32, 32, 64, 128, 256, 32).
BasicUNetPlusPlus features: (32, 32, 64, 128, 256, 32).




# Normalization
UNET++ use the same `TwoConv`, `Down`, and `UpCat` as UNet. Therefore, you can referred to the `modules/UNet_input_size_constrains.ipynb` for break down analysis. For summary, the constraints for these types of normalization are:

- Instance Norm: the product of spatial dimension must > 1 (not include channel and batch)
- Batch Norm: the product of spatial dimension must > 1 (not include channel and batch). For training best interested, `batch_size` should be larger than 1
- Local Response Norm: No constraint.
- Other Normalization: please referred to `modules/UNet_input_size_constrains.ipynb`

As for UNET++ have 4 down-sampling blocks with 2x kernel size, with no argument to change this behavior, the smallest edge we can have is `2**4 = 16`, and after the last down-sampling block, the `vector.shape  = [..., ..., 1, 1]` or (`[..., ..., 1, 1, 1]` for 3D), which will cause error for the Normalization layer.

See the test code below for examples of batch norm and instance norm

In [80]:
from typing import Dict


def make_model_with_layer(layer_norm):
    return BasicUnetPlusPlus(
        spatial_dims=2,
        in_channels=3,
        out_channels=1,
        features=(32, 32, 64, 128, 256, 32),
        norm=layer_norm
    )

def test_min_dim():
    MIN_EDGE = 16
    batch_size, spatial_dim, H, W = 1, 3, MIN_EDGE, MIN_EDGE
    MODEL_BY_NORM_LAYER: Dict[str, BasicUnetPlusPlus] = {}
    print("Prepare model")
    for norm_layer in ['instance', 'batch']:
        MODEL_BY_NORM_LAYER[norm_layer] = make_model_with_layer(norm_layer)
        
    # print(f"Input dimension {(batch_size, spatial_dim, H, W)} that will cause error")
    for norm_layer in ['instance', 'batch']:
        print("="*10 + f" USING NORM LAYER: {norm_layer.upper()} " + "="*10)
        model = MODEL_BY_NORM_LAYER[norm_layer]
        print("_" * 10 + " Changing the H dimension of 2D input " + "_" * 10)
        for _H_temp in [H, H*2]:
            try:
                x = torch.ones(batch_size, spatial_dim, _H_temp, W)
                print(f">> Using Input.shape={x.shape}")
                model(x)
            except Exception as msg:
                print(f">>>> Exception: {msg}\n")

        # print("Changing the batch size")
        print("_" * 10 + " Changing the batch size " + "_" * 10)
        for batch_size_tmp in [1, 2]:
            try:
                x = torch.ones(batch_size_tmp, spatial_dim, H, W)
                print(f">> Input.shape={x.shape}")
                model(x)
            except Exception as msg:
                print(f">> Exception: {msg}\n")
    pass

with torch.no_grad():
    test_min_dim()

Prepare model
BasicUNetPlusPlus features: (32, 32, 64, 128, 256, 32).
BasicUNetPlusPlus features: (32, 32, 64, 128, 256, 32).
__________ Changing the H dimension of 2D input __________
>> Using Input.shape=torch.Size([1, 3, 16, 16])
>>>> Exception: Expected more than 1 spatial element when training, got input size torch.Size([1, 256, 1, 1])

>> Using Input.shape=torch.Size([1, 3, 32, 16])
__________ Changing the batch size __________
>> Input.shape=torch.Size([1, 3, 16, 16])
>> Exception: Expected more than 1 spatial element when training, got input size torch.Size([1, 256, 1, 1])

>> Input.shape=torch.Size([2, 3, 16, 16])
>> Exception: Expected more than 1 spatial element when training, got input size torch.Size([2, 256, 1, 1])

__________ Changing the H dimension of 2D input __________
>> Using Input.shape=torch.Size([1, 3, 16, 16])
>>>> Exception: Expected more than 1 value per channel when training, got input size torch.Size([1, 256, 1, 1])

>> Using Input.shape=torch.Size([1, 3, 3

# Pooling (down-sampling) and Up-sampling

While all the internal padding is handled to ensure the input spatial shape is the same as the output spatial shape, you still need to aware about these implicit processing:

- For down-scale path: MaxPooling with `kernel_size` and `strides` = 2. Any odd size will have the last element ignored. See below code block for the pooling with odd shape
- For up-scale path, depend on the up-samping mode, using `UpCat` module with auto padding (`replicate` mode) the up-sampled path to the same input shape as input path. See this code of the `UpCat`

```[python]
def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]):
        """

        Args:
            x: features to be upsampled.
            x_e: optional features from the encoder, if None, this branch is not in use.
        """
        x_0 = self.upsample(x)

        if x_e is not None and torch.jit.isinstance(x_e, torch.Tensor):
            if self.is_pad: # alway True
                # handling spatial shapes due to the 2x max-pooling with odd edge lengths.
                dimensions = len(x.shape) - 2
                sp = [0] * (dimensions * 2)
                for i in range(dimensions):
                    if x_e.shape[-i - 1] != x_0.shape[-i - 1]:
                        sp[i * 2 + 1] = 1
                x_0 = torch.nn.functional.pad(x_0, sp, "replicate")
            x = self.convs(torch.cat([x_e, x_0], dim=1))  # input channels: (cat_chns + up_chns)
        else:
            x = self.convs(x_0)

        return x

```

In [86]:
# Example about the pooling with odd shape
un_pool = torch.Tensor(
    [[1, 2, 3, 4, 5],
    [1, 2, 3, 4, 5],
    [1, 2, 3, 4, 5],
    [1, 2, 3, 4, 5],
    [1, 2, 3, 4, 5]]
) * torch.Tensor([1, 2, 3, 4, 5])[..., None]
un_pool = un_pool[None, None, ...]

pooled = nn.MaxPool2d(kernel_size=2)(un_pool)
print(f"{un_pool.shape=}, {pooled.shape=}")
print(un_pool)
print(pooled)

un_pool.shape=torch.Size([1, 1, 5, 5]), pooled.shape=torch.Size([1, 1, 2, 2])
tensor([[[[ 1.,  2.,  3.,  4.,  5.],
          [ 2.,  4.,  6.,  8., 10.],
          [ 3.,  6.,  9., 12., 15.],
          [ 4.,  8., 12., 16., 20.],
          [ 5., 10., 15., 20., 25.]]]])
tensor([[[[ 4.,  8.],
          [ 8., 16.]]]])


In [85]:
# # @torch.no_grad()
# def test():
#     model = BasicUnetPlusPlus(
#         spatial_dims=3,
#         in_channels=3,
#         out_channels=3,
#         features=(32, 32, 64, 128, 256, 32),
#         deep_supervision=True,
#     )
#     batch_size, spatial_dim, D, H, W = 2, 3, 16 * 2 + 1, 16 + 1, 16 + 1
#     x = torch.ones(
#         [batch_size, spatial_dim, D, H, W]
#     )
#     # y = model.forward(x)
#     # print(f"{x.shape=}")  
#     # print(f"{[o.shape for o in y]=}")
    
#     return model, x

# # Loss edge info. if input dim is not divisible by 2 when pooling
# # Ensure the lowest image dimension
# # As the lowest dim before the norm is [1x1x1] 
# # -> instance and batch norm don't allow that, so make it at least
# # [2x1x1]
# from monai.networks.blocks import ADN
# with torch.no_grad():
#     model, x = test()
#     print(f"Input {x.shape=}")
#     model.eval()
#     # 2 conv
#     x_0_0 = model.conv_0_0(x)     
#     # down conv
#     x_1_0 = model.conv_1_0(x_0_0) 
#     print(f"{x_1_0.shape=}")
#     x_0_1 = model.upcat_0_1(x_1_0, x_0_0)
#     print(f"{x_0_1.shape=}")
#     # x_2_0 = model.conv_2_0(x_1_0) 
#     # print(f"{x_2_0.shape=}")
#     # x_3_0 = model.conv_3_0(x_2_0) 
#     # print(f"{x_3_0.shape=}") # 2 x 2 -> 1 x 1
#     # pooled = nn.MaxPool3d(kernel_size=2)(x_3_0)
#     # print(f"{pooled.shape=}") # 2 x 2 -> 1 x 1
#     # conved1 = nn.Conv3d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)(pooled)
#     # conved2 = nn.Conv3d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)(conved1)
#     # print(f"{conved1.shape=}") # 2 x 2 -> 1 x 1
#     # print(f"{conved2.shape=}") # 2 x 2 -> 1 x 1
    
#     # normed = ADN(
#     #     ordering='NDA',
#     #     in_channels=256,
#     #     norm='batch',
#     #     norm_dim=3,
#     #     # dropout=dropout,
#     #     # dropout_dim=dropout_dim,
#     # )(conved2)
#     # print(f"{normed.shape=}") # 2 x 2 -> 1 x 1

#     # x_4_0 = model.conv_4_0(x_3_0) 
#     # print(f"{x_4_0.shape=}")

#     # Up path
#     # x_3_0 = model.conv_3_0(x_2_0)
#     # print(f"{x_1_0.shape=}")

BasicUNetPlusPlus features: (32, 32, 64, 128, 256, 32).
Input x.shape=torch.Size([2, 3, 33, 17, 17])
x_1_0.shape=torch.Size([2, 32, 16, 8, 8])
x_0_1.shape=torch.Size([2, 32, 33, 17, 17])


37

un_pool.shape=torch.Size([1, 1, 5, 5]), pooled.shape=torch.Size([1, 1, 2, 2])
tensor([[[[ 1.,  2.,  3.,  4.,  5.],
          [ 2.,  4.,  6.,  8., 10.],
          [ 3.,  6.,  9., 12., 15.],
          [ 4.,  8., 12., 16., 20.],
          [ 5., 10., 15., 20., 25.]]]])
tensor([[[[ 4.,  8.],
          [ 8., 16.]]]])
