From 574fe60f721b0918dd740dd19ac93dbff1967d72 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 5 Nov 2021 16:39:57 +0800 Subject: [PATCH 1/6] enhance dynunet based on brats21 1st solution Signed-off-by: Yiheng Wang --- monai/networks/blocks/dynunet_block.py | 9 ++-- monai/networks/nets/dynunet.py | 68 +++++++++++++++++++------- tests/test_dynunet.py | 3 +- tests/test_dynunet_block.py | 34 +++++++------ 4 files changed, 76 insertions(+), 38 deletions(-) diff --git a/monai/networks/blocks/dynunet_block.py b/monai/networks/blocks/dynunet_block.py index 248bc5341c..aaf01cfdf3 100644 --- a/monai/networks/blocks/dynunet_block.py +++ b/monai/networks/blocks/dynunet_block.py @@ -33,7 +33,7 @@ class UnetResBlock(nn.Module): kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. - dropout: dropout probability + dropout: dropout probability. """ @@ -100,7 +100,7 @@ class UnetBasicBlock(nn.Module): kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. - dropout: dropout probability + dropout: dropout probability. """ @@ -155,7 +155,8 @@ class UnetUpBlock(nn.Module): stride: convolution stride. upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: feature normalization type and arguments. - dropout: dropout probability + dropout: dropout probability. + trans_bias: transposed convolution bias. """ @@ -169,6 +170,7 @@ def __init__( upsample_kernel_size: Union[Sequence[int], int], norm_name: Union[Tuple, str], dropout: Optional[Union[Tuple, str, float]] = None, + trans_bias: bool = False, ): super().__init__() upsample_stride = upsample_kernel_size @@ -179,6 +181,7 @@ def __init__( kernel_size=upsample_kernel_size, stride=upsample_stride, dropout=dropout, + bias=trans_bias, conv_only=True, is_transposed=True, ) diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index 3e846b9b7b..8627d2680a 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -33,11 +33,11 @@ class DynUNetSkipLayer(nn.Module): heads: List[torch.Tensor] - def __init__(self, index, heads, downsample, upsample, super_head, next_layer): + def __init__(self, index, heads, downsample, upsample, next_layer, super_head=None): super().__init__() self.downsample = downsample - self.upsample = upsample self.next_layer = next_layer + self.upsample = upsample self.super_head = super_head self.heads = heads self.index = index @@ -46,8 +46,8 @@ def forward(self, x): downout = self.downsample(x) nextout = self.next_layer(downout) upout = self.upsample(nextout, downout) - - self.heads[self.index] = self.super_head(upout) + if self.super_head is not None and self.index > 0: + self.heads[self.index - 1] = self.super_head(upout) return upout @@ -57,6 +57,7 @@ class DynUNet(nn.Module): This reimplementation of a dynamic UNet (DynUNet) is based on: `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. + `Optimized U-Net for Brain Tumor Segmentation `_. This model is more flexible compared with ``monai.networks.nets.UNet`` in three places: @@ -89,6 +90,12 @@ class DynUNet(nn.Module): strides: convolution strides for each blocks. upsample_kernel_size: convolution kernel size for transposed convolution layers. The values should equal to strides[1:]. + filters: number of output channels for each blocks. Different from nnU-Net, in this implementation we add + this argument to make the network more flexible. As shown in the third reference, one way to determine + this argument is like: + ``[64, 96, 128, 192, 256, 384, 512, 768, 1024][: len(strides)]``. + The above way is used in the network that wins task 1 in the BraTS21 Challenge. + If not specified, the way which nnUNet used will be employed. Defaults to ``None``. dropout: dropout ratio. Defaults to no dropout. norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``. deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. @@ -109,6 +116,7 @@ class DynUNet(nn.Module): Defaults to 1. res_block: whether to use residual connection based convolution blocks during the network. Defaults to ``False``. + trans_bias: whether to set the bias parameter in transposed convolution layers. Defaults to ``False``. """ def __init__( @@ -119,11 +127,13 @@ def __init__( kernel_size: Sequence[Union[Sequence[int], int]], strides: Sequence[Union[Sequence[int], int]], upsample_kernel_size: Sequence[Union[Sequence[int], int]], + filters: Optional[Sequence[int]] = None, dropout: Optional[Union[Tuple, str, float]] = None, norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), deep_supervision: bool = False, deep_supr_num: int = 1, res_block: bool = False, + trans_bias: bool = False, ): super().__init__() self.spatial_dims = spatial_dims @@ -135,21 +145,26 @@ def __init__( self.norm_name = norm_name self.dropout = dropout self.conv_block = UnetResBlock if res_block else UnetBasicBlock - self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] + self.trans_bias = trans_bias + if filters is not None: + self.filters = filters + self.check_filters() + else: + self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] self.input_block = self.get_input_block() self.downsamples = self.get_downsamples() self.bottleneck = self.get_bottleneck() self.upsamples = self.get_upsamples() self.output_block = self.get_output_block(0) self.deep_supervision = deep_supervision - self.deep_supervision_heads = self.get_deep_supervision_heads() self.deep_supr_num = deep_supr_num + self.deep_supervision_heads = self.get_deep_supervision_heads() self.apply(self.initialize_weights) self.check_kernel_stride() self.check_deep_supr_num() # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on - self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1) + self.heads: List[torch.Tensor] = [torch.rand(1)] * self.deep_supr_num def create_skips(index, downsamples, upsamples, superheads, bottleneck): """ @@ -162,22 +177,27 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck): if len(downsamples) != len(upsamples): raise AssertionError(f"{len(downsamples)} != {len(upsamples)}") - if (len(downsamples) - len(superheads)) not in (1, 0): - raise AssertionError(f"{len(downsamples)}-(0,1) != {len(superheads)}") if len(downsamples) == 0: # bottom of the network, pass the bottleneck block return bottleneck + super_head_flag = False if index == 0: # don't associate a supervision head with self.input_block - current_head, rest_heads = nn.Identity(), superheads + rest_heads = superheads elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one - current_head, rest_heads = nn.Identity(), superheads[1:] + rest_heads = nn.ModuleList() else: - current_head, rest_heads = superheads[0], superheads[1:] + if len(superheads) > 0: + super_head_flag = True + rest_heads = superheads[1:] + else: + rest_heads = nn.ModuleList() # create the next layer down, this will stop at the bottleneck layer next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck) - - return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer) + if super_head_flag: + return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], next_layer, superheads[0]) + else: + return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], next_layer) self.skip_layers = create_skips( 0, @@ -212,13 +232,19 @@ def check_deep_supr_num(self): if deep_supr_num < 1: raise AssertionError("deep_supr_num should be larger than 0.") + def check_filters(self): + filters = self.filters + if len(filters) < len(self.strides): + raise AssertionError("length of filters should be no less than the length of strides.") + else: + self.filters = filters[: len(self.strides)] + def forward(self, x): out = self.skip_layers(x) out = self.output_block(out) if self.training and self.deep_supervision: out_all = [out] - feature_maps = self.heads[1 : self.deep_supr_num + 1] - for feature_map in feature_maps: + for feature_map in self.heads: out_all.append(interpolate(feature_map, out.shape[2:])) return torch.stack(out_all, dim=1) return out @@ -257,7 +283,9 @@ def get_upsamples(self): inp, out = self.filters[1:][::-1], self.filters[:-1][::-1] strides, kernel_size = self.strides[1:][::-1], self.kernel_size[1:][::-1] upsample_kernel_size = self.upsample_kernel_size[::-1] - return self.get_module_list(inp, out, kernel_size, strides, UnetUpBlock, upsample_kernel_size) + return self.get_module_list( + inp, out, kernel_size, strides, UnetUpBlock, upsample_kernel_size, trans_bias=self.trans_bias + ) def get_module_list( self, @@ -267,6 +295,7 @@ def get_module_list( strides: Sequence[Union[Sequence[int], int]], conv_block: nn.Module, upsample_kernel_size: Optional[Sequence[Union[Sequence[int], int]]] = None, + trans_bias: bool = False, ): layers = [] if upsample_kernel_size is not None: @@ -282,6 +311,7 @@ def get_module_list( "norm_name": self.norm_name, "dropout": self.dropout, "upsample_kernel_size": up_kernel, + "trans_bias": trans_bias, } layer = conv_block(**params) layers.append(layer) @@ -301,7 +331,9 @@ def get_module_list( return nn.ModuleList(layers) def get_deep_supervision_heads(self): - return nn.ModuleList([self.get_output_block(i + 1) for i in range(len(self.upsamples) - 1)]) + if not self.deep_supervision: + return nn.ModuleList() + return nn.ModuleList([self.get_output_block(i + 1) for i in range(self.deep_supr_num)]) @staticmethod def initialize_weights(module): diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index 18fe146a40..a3cc192d1b 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -66,8 +66,9 @@ "kernel_size": (3, (1, 1, 3), 3, 3), "strides": ((1, 2, 1), 2, 2, 1), "upsample_kernel_size": (2, 2, 1), + "filters": (64, 96, 128, 192), "norm_name": ("INSTANCE", {"affine": True}), - "deep_supervision": False, + "deep_supervision": True, "res_block": res_block, "dropout": ("alphadropout", {"p": 0.25}), }, diff --git a/tests/test_dynunet_block.py b/tests/test_dynunet_block.py index 7e832f6d81..ec638b30a0 100644 --- a/tests/test_dynunet_block.py +++ b/tests/test_dynunet_block.py @@ -49,22 +49,24 @@ for stride in [1, 2]: for norm_name in ["batch", "instance"]: for in_size in [15, 16]: - out_size = in_size * stride - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": in_channels, - "out_channels": out_channels, - "kernel_size": kernel_size, - "norm_name": norm_name, - "stride": stride, - "upsample_kernel_size": stride, - }, - (1, in_channels, *([in_size] * spatial_dims)), - (1, out_channels, *([out_size] * spatial_dims)), - (1, out_channels, *([in_size * stride] * spatial_dims)), - ] - TEST_UP_BLOCK.append(test_case) + for trans_bias in [True, False]: + out_size = in_size * stride + test_case = [ + { + "spatial_dims": spatial_dims, + "in_channels": in_channels, + "out_channels": out_channels, + "kernel_size": kernel_size, + "norm_name": norm_name, + "stride": stride, + "upsample_kernel_size": stride, + "trans_bias": trans_bias, + }, + (1, in_channels, *([in_size] * spatial_dims)), + (1, out_channels, *([out_size] * spatial_dims)), + (1, out_channels, *([in_size * stride] * spatial_dims)), + ] + TEST_UP_BLOCK.append(test_case) class TestResBasicBlock(unittest.TestCase): From 5d20dba06a3993efcae2780a4a8f15c7a2042e59 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 5 Nov 2021 17:29:49 +0800 Subject: [PATCH 2/6] add act_name argument Signed-off-by: Yiheng Wang --- monai/networks/blocks/dynunet_block.py | 11 +++++++++-- monai/networks/nets/dynunet.py | 7 +++++++ tests/test_dynunet.py | 1 + tests/test_dynunet_block.py | 1 + 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/monai/networks/blocks/dynunet_block.py b/monai/networks/blocks/dynunet_block.py index aaf01cfdf3..43d3c46cc9 100644 --- a/monai/networks/blocks/dynunet_block.py +++ b/monai/networks/blocks/dynunet_block.py @@ -33,6 +33,7 @@ class UnetResBlock(nn.Module): kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. + act_name: activation layer type and arguments. dropout: dropout probability. """ @@ -45,6 +46,7 @@ def __init__( kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), dropout: Optional[Union[Tuple, str, float]] = None, ): super().__init__() @@ -63,7 +65,7 @@ def __init__( self.conv3 = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=1, stride=stride, dropout=dropout, conv_only=True ) - self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01})) + self.lrelu = get_act_layer(name=act_name) self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) self.norm3 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) @@ -100,6 +102,7 @@ class UnetBasicBlock(nn.Module): kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. + act_name: activation layer type and arguments. dropout: dropout probability. """ @@ -112,6 +115,7 @@ def __init__( kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), dropout: Optional[Union[Tuple, str, float]] = None, ): super().__init__() @@ -127,7 +131,7 @@ def __init__( self.conv2 = get_conv_layer( spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True ) - self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01})) + self.lrelu = get_act_layer(name=act_name) self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) @@ -155,6 +159,7 @@ class UnetUpBlock(nn.Module): stride: convolution stride. upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: feature normalization type and arguments. + act_name: activation layer type and arguments. dropout: dropout probability. trans_bias: transposed convolution bias. @@ -169,6 +174,7 @@ def __init__( stride: Union[Sequence[int], int], upsample_kernel_size: Union[Sequence[int], int], norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), dropout: Optional[Union[Tuple, str, float]] = None, trans_bias: bool = False, ): @@ -193,6 +199,7 @@ def __init__( stride=1, dropout=dropout, norm_name=norm_name, + act_name=act_name, ) def forward(self, inp, skip): diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index 8627d2680a..a0e9c739cc 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -98,6 +98,7 @@ class DynUNet(nn.Module): If not specified, the way which nnUNet used will be employed. Defaults to ``None``. dropout: dropout ratio. Defaults to no dropout. norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``. + act_name: activation layer type and arguments. Defaults to ``leakyrelu``. deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. If ``True``, in training mode, the forward function will output not only the last feature map, but also the previous feature maps that come from the intermediate up sample layers. @@ -130,6 +131,7 @@ def __init__( filters: Optional[Sequence[int]] = None, dropout: Optional[Union[Tuple, str, float]] = None, norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), deep_supervision: bool = False, deep_supr_num: int = 1, res_block: bool = False, @@ -143,6 +145,7 @@ def __init__( self.strides = strides self.upsample_kernel_size = upsample_kernel_size self.norm_name = norm_name + self.act_name = act_name self.dropout = dropout self.conv_block = UnetResBlock if res_block else UnetBasicBlock self.trans_bias = trans_bias @@ -257,6 +260,7 @@ def get_input_block(self): self.kernel_size[0], self.strides[0], self.norm_name, + self.act_name, dropout=self.dropout, ) @@ -268,6 +272,7 @@ def get_bottleneck(self): self.kernel_size[-1], self.strides[-1], self.norm_name, + self.act_name, dropout=self.dropout, ) @@ -309,6 +314,7 @@ def get_module_list( "kernel_size": kernel, "stride": stride, "norm_name": self.norm_name, + "act_name": self.act_name, "dropout": self.dropout, "upsample_kernel_size": up_kernel, "trans_bias": trans_bias, @@ -324,6 +330,7 @@ def get_module_list( "kernel_size": kernel, "stride": stride, "norm_name": self.norm_name, + "act_name": self.act_name, "dropout": self.dropout, } layer = conv_block(**params) diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index a3cc192d1b..ca19ea2b47 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -43,6 +43,7 @@ "strides": strides, "upsample_kernel_size": strides[1:], "norm_name": "batch", + "act_name": ("leakyrelu", {"inplace": True, "negative_slope": 0.2}), "deep_supervision": False, "res_block": res_block, "dropout": None, diff --git a/tests/test_dynunet_block.py b/tests/test_dynunet_block.py index ec638b30a0..de3c018d78 100644 --- a/tests/test_dynunet_block.py +++ b/tests/test_dynunet_block.py @@ -35,6 +35,7 @@ "out_channels": 16, "kernel_size": kernel_size, "norm_name": norm_name, + "act_name": ("leakyrelu", {"inplace": True, "negative_slope": 0.1}), "stride": stride, }, (1, 16, *([in_size] * spatial_dims)), From 57d0c45ee99cd94caae8195f696e5aaa779a0188 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 5 Nov 2021 22:07:20 +0800 Subject: [PATCH 3/6] remove v1 and change error type Signed-off-by: Yiheng Wang --- monai/networks/nets/dynunet.py | 14 +-- monai/networks/nets/dynunet_v1.py | 147 ------------------------------ tests/test_dynunet_v1.py | 128 -------------------------- 3 files changed, 7 insertions(+), 282 deletions(-) delete mode 100644 monai/networks/nets/dynunet_v1.py delete mode 100644 tests/test_dynunet_v1.py diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index a0e9c739cc..cf34ae8044 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -179,7 +179,7 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck): """ if len(downsamples) != len(upsamples): - raise AssertionError(f"{len(downsamples)} != {len(upsamples)}") + raise ValueError(f"{len(downsamples)} != {len(upsamples)}") if len(downsamples) == 0: # bottom of the network, pass the bottleneck block return bottleneck @@ -214,31 +214,31 @@ def check_kernel_stride(self): kernels, strides = self.kernel_size, self.strides error_msg = "length of kernel_size and strides should be the same, and no less than 3." if len(kernels) != len(strides) or len(kernels) < 3: - raise AssertionError(error_msg) + raise ValueError(error_msg) for idx, k_i in enumerate(kernels): kernel, stride = k_i, strides[idx] if not isinstance(kernel, int): error_msg = f"length of kernel_size in block {idx} should be the same as spatial_dims." if len(kernel) != self.spatial_dims: - raise AssertionError(error_msg) + raise ValueError(error_msg) if not isinstance(stride, int): error_msg = f"length of stride in block {idx} should be the same as spatial_dims." if len(stride) != self.spatial_dims: - raise AssertionError(error_msg) + raise ValueError(error_msg) def check_deep_supr_num(self): deep_supr_num, strides = self.deep_supr_num, self.strides num_up_layers = len(strides) - 1 if deep_supr_num >= num_up_layers: - raise AssertionError("deep_supr_num should be less than the number of up sample layers.") + raise ValueError("deep_supr_num should be less than the number of up sample layers.") if deep_supr_num < 1: - raise AssertionError("deep_supr_num should be larger than 0.") + raise ValueError("deep_supr_num should be larger than 0.") def check_filters(self): filters = self.filters if len(filters) < len(self.strides): - raise AssertionError("length of filters should be no less than the length of strides.") + raise ValueError("length of filters should be no less than the length of strides.") else: self.filters = filters[: len(self.strides)] diff --git a/monai/networks/nets/dynunet_v1.py b/monai/networks/nets/dynunet_v1.py deleted file mode 100644 index 4c910157c9..0000000000 --- a/monai/networks/nets/dynunet_v1.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import List, Sequence, Union - -import torch -import torch.nn as nn - -from monai.networks.blocks.dynunet_block_v1 import _UnetBasicBlockV1, _UnetResBlockV1, _UnetUpBlockV1 -from monai.networks.nets.dynunet import DynUNet, DynUNetSkipLayer -from monai.utils import deprecated - -__all__ = ["DynUNetV1", "DynUnetV1", "DynunetV1"] - - -@deprecated( - since="0.6.0", - removed="0.8.0", - msg_suffix="This module is for backward compatibility purpose only. Please use `DynUNet` instead.", -) -class DynUNetV1(DynUNet): - """ - This a deprecated reimplementation of a dynamic UNet (DynUNet), please use `monai.networks.nets.DynUNet` instead. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - kernel_size: convolution kernel size. - strides: convolution strides for each blocks. - upsample_kernel_size: convolution kernel size for transposed convolution layers. - dropout: dropout ratio. Defaults to no dropout. - norm_name: [``"batch"``, ``"instance"``, ``"group"``]. Defaults to "instance". - deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. - deep_supr_num: number of feature maps that will output during deep supervision head. Defaults to 1. - res_block: whether to use residual connection based convolution blocks during the network. - Defaults to ``False``. - - .. deprecated:: 0.6.0 - Use :class:`monai.networks.nets.DynUNet` instead. - - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Sequence[Union[Sequence[int], int]], - strides: Sequence[Union[Sequence[int], int]], - upsample_kernel_size: Sequence[Union[Sequence[int], int]], - dropout: float = 0.0, - norm_name: str = "instance", - deep_supervision: bool = False, - deep_supr_num: int = 1, - res_block: bool = False, - ): - nn.Module.__init__(self) - self.spatial_dims = spatial_dims - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.strides = strides - self.upsample_kernel_size = upsample_kernel_size - self.norm_name = norm_name - self.dropout = dropout - self.conv_block = _UnetResBlockV1 if res_block else _UnetBasicBlockV1 # type: ignore - self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] - self.input_block = self.get_input_block() - self.downsamples = self.get_downsamples() - self.bottleneck = self.get_bottleneck() - self.upsamples = self.get_upsamples() - self.output_block = self.get_output_block(0) - self.deep_supervision = deep_supervision - self.deep_supervision_heads = self.get_deep_supervision_heads() - self.deep_supr_num = deep_supr_num - self.apply(self.initialize_weights) - self.check_kernel_stride() - self.check_deep_supr_num() - - # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on - self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1) - - def create_skips(index, downsamples, upsamples, superheads, bottleneck): - """ - Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is - done recursively from the top down since a recursive nn.Module subclass is being used to be compatible - with Torchscript. Initially the length of `downsamples` will be one more than that of `superheads` - since the `input_block` is passed to this function as the first item in `downsamples`, however this - shouldn't be associated with a supervision head. - """ - - if len(downsamples) != len(upsamples): - raise AssertionError(f"{len(downsamples)} != {len(upsamples)}") - if (len(downsamples) - len(superheads)) not in (1, 0): - raise AssertionError(f"{len(downsamples)}-(0,1) != {len(superheads)}") - - if len(downsamples) == 0: # bottom of the network, pass the bottleneck block - return bottleneck - if index == 0: # don't associate a supervision head with self.input_block - current_head, rest_heads = nn.Identity(), superheads - elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one - current_head, rest_heads = nn.Identity(), superheads[1:] - else: - current_head, rest_heads = superheads[0], superheads[1:] - - # create the next layer down, this will stop at the bottleneck layer - next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck) - - return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer) - - self.skip_layers = create_skips( - 0, - [self.input_block] + list(self.downsamples), - self.upsamples[::-1], - self.deep_supervision_heads, - self.bottleneck, - ) - - def get_upsamples(self): - inp, out = self.filters[1:][::-1], self.filters[:-1][::-1] - strides, kernel_size = self.strides[1:][::-1], self.kernel_size[1:][::-1] - upsample_kernel_size = self.upsample_kernel_size[::-1] - return self.get_module_list(inp, out, kernel_size, strides, _UnetUpBlockV1, upsample_kernel_size) - - @staticmethod - def initialize_weights(module): - name = module.__class__.__name__.lower() - if "conv3d" in name or "conv2d" in name: - nn.init.kaiming_normal_(module.weight, a=0.01) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - elif "norm" in name: - nn.init.normal_(module.weight, 1.0, 0.02) - nn.init.zeros_(module.bias) - - -DynUnetV1 = DynunetV1 = DynUNetV1 diff --git a/tests/test_dynunet_v1.py b/tests/test_dynunet_v1.py deleted file mode 100644 index fc216c145b..0000000000 --- a/tests/test_dynunet_v1.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from typing import Any, Sequence, Union - -import torch -from parameterized import parameterized - -from monai.networks import eval_mode -from monai.networks.nets.dynunet_v1 import DynUNetV1 -from tests.utils import skip_if_quick, test_script_save - -device = "cuda" if torch.cuda.is_available() else "cpu" - -strides: Sequence[Union[Sequence[int], int]] -kernel_size: Sequence[Any] -expected_shape: Sequence[Any] - -TEST_CASE_DYNUNET_2D = [] -for kernel_size in [(3, 3, 3, 1), ((3, 1), 1, (3, 3), (1, 1))]: - for strides in [(1, 1, 1, 1), (2, 2, 2, 1)]: - for in_channels in [2, 3]: - for res_block in [True, False]: - out_channels = 2 - in_size = 64 - spatial_dims = 2 - expected_shape = (1, out_channels, *[in_size // strides[0]] * spatial_dims) - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": in_channels, - "out_channels": out_channels, - "kernel_size": kernel_size, - "strides": strides, - "upsample_kernel_size": strides[1:], - "norm_name": "batch", - "deep_supervision": False, - "res_block": res_block, - }, - (1, in_channels, in_size, in_size), - expected_shape, - ] - TEST_CASE_DYNUNET_2D.append(test_case) - -TEST_CASE_DYNUNET_3D = [] # in 3d cases, also test anisotropic kernel/strides -for out_channels in [2, 3]: - for res_block in [True, False]: - in_channels = 1 - in_size = 64 - expected_shape = (1, out_channels, 64, 32, 64) - test_case = [ - { - "spatial_dims": 3, - "in_channels": in_channels, - "out_channels": out_channels, - "kernel_size": (3, (1, 1, 3), 3, 3), - "strides": ((1, 2, 1), 2, 2, 1), - "upsample_kernel_size": (2, 2, 1), - "norm_name": "instance", - "deep_supervision": False, - "res_block": res_block, - }, - (1, in_channels, in_size, in_size, in_size), - expected_shape, - ] - TEST_CASE_DYNUNET_3D.append(test_case) - -TEST_CASE_DEEP_SUPERVISION = [] -for spatial_dims in [2, 3]: - for res_block in [True, False]: - for deep_supr_num in [1, 2]: - for strides in [(1, 2, 1, 2, 1), (2, 2, 2, 1), (2, 1, 1, 2, 2)]: - scale = strides[0] - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": 1, - "out_channels": 2, - "kernel_size": [3] * len(strides), - "strides": strides, - "upsample_kernel_size": strides[1:], - "norm_name": "group", - "deep_supervision": True, - "deep_supr_num": deep_supr_num, - "res_block": res_block, - }, - (1, 1, *[in_size] * spatial_dims), - (1, 1 + deep_supr_num, 2, *[in_size // scale] * spatial_dims), - ] - TEST_CASE_DEEP_SUPERVISION.append(test_case) - - -@skip_if_quick -class TestDynUNet(unittest.TestCase): - @parameterized.expand(TEST_CASE_DYNUNET_2D + TEST_CASE_DYNUNET_3D) - def test_shape(self, input_param, input_shape, expected_shape): - net = DynUNetV1(**input_param).to(device) - with eval_mode(net): - result = net(torch.randn(input_shape).to(device)) - self.assertEqual(result.shape, expected_shape) - - def test_script(self): - input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] - net = DynUNetV1(**input_param) - test_data = torch.randn(input_shape) - test_script_save(net, test_data) - - -class TestDynUNetDeepSupervision(unittest.TestCase): - @parameterized.expand(TEST_CASE_DEEP_SUPERVISION) - def test_shape(self, input_param, input_shape, expected_shape): - net = DynUNetV1(**input_param).to(device) - with torch.no_grad(): - results = net(torch.randn(input_shape).to(device)) - self.assertEqual(results.shape, expected_shape) - - -if __name__ == "__main__": - unittest.main() From cd38e9e2ce15b38473c376290a5d63fc538036c6 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 8 Nov 2021 23:43:06 +0800 Subject: [PATCH 4/6] skip torchscript check before pt 1.6 for dynunet Signed-off-by: Yiheng Wang --- tests/test_dynunet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index ca19ea2b47..b288238cf6 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -17,7 +17,7 @@ from monai.networks import eval_mode from monai.networks.nets import DynUNet -from tests.utils import test_script_save +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -111,6 +111,7 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) + @SkipIfBeforePyTorchVersion((1, 6)) def test_script(self): input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] net = DynUNet(**input_param) From 9a1cc09716cd0f1eafaad926b43f23f12df8d1fe Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 8 Nov 2021 19:31:08 +0000 Subject: [PATCH 5/6] set seed Signed-off-by: Wenqi Li --- tests/test_network_consistency.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_network_consistency.py b/tests/test_network_consistency.py index 9698a40116..ccccd9e7f0 100644 --- a/tests/test_network_consistency.py +++ b/tests/test_network_consistency.py @@ -20,6 +20,7 @@ from parameterized.parameterized import parameterized import monai.networks.nets as nets +from monai.utils import set_determinism extra_test_data_dir = os.environ.get("MONAI_EXTRA_TEST_DATA", None) @@ -33,6 +34,12 @@ class TestNetworkConsistency(unittest.TestCase): + def setUp(self): + set_determinism(0) + + def tearDown(self): + set_determinism(None) + @skipIf( len(TESTS) == 0, "To run these tests, clone https://github.com/Project-MONAI/MONAI-extra-test-data and set MONAI_EXTRA_TEST_DATA", From 91e74e55cd952e235a9ae0af580023433ab2b63f Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 9 Nov 2021 16:11:25 +0800 Subject: [PATCH 6/6] modify to keep deep supervision heads Signed-off-by: Yiheng Wang --- monai/networks/nets/dynunet.py | 38 ++++++++++++++-------------------- tests/test_dynunet.py | 3 +-- 2 files changed, 17 insertions(+), 24 deletions(-) diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index cf34ae8044..696c9d25dc 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -33,11 +33,11 @@ class DynUNetSkipLayer(nn.Module): heads: List[torch.Tensor] - def __init__(self, index, heads, downsample, upsample, next_layer, super_head=None): + def __init__(self, index, heads, downsample, upsample, super_head, next_layer): super().__init__() self.downsample = downsample - self.next_layer = next_layer self.upsample = upsample + self.next_layer = next_layer self.super_head = super_head self.heads = heads self.index = index @@ -46,8 +46,8 @@ def forward(self, x): downout = self.downsample(x) nextout = self.next_layer(downout) upout = self.upsample(nextout, downout) - if self.super_head is not None and self.index > 0: - self.heads[self.index - 1] = self.super_head(upout) + + self.heads[self.index] = self.super_head(upout) return upout @@ -160,14 +160,14 @@ def __init__( self.upsamples = self.get_upsamples() self.output_block = self.get_output_block(0) self.deep_supervision = deep_supervision - self.deep_supr_num = deep_supr_num self.deep_supervision_heads = self.get_deep_supervision_heads() + self.deep_supr_num = deep_supr_num self.apply(self.initialize_weights) self.check_kernel_stride() self.check_deep_supr_num() # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on - self.heads: List[torch.Tensor] = [torch.rand(1)] * self.deep_supr_num + self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1) def create_skips(index, downsamples, upsamples, superheads, bottleneck): """ @@ -180,27 +180,22 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck): if len(downsamples) != len(upsamples): raise ValueError(f"{len(downsamples)} != {len(upsamples)}") + if (len(downsamples) - len(superheads)) not in (1, 0): + raise ValueError(f"{len(downsamples)}-(0,1) != {len(superheads)}") if len(downsamples) == 0: # bottom of the network, pass the bottleneck block return bottleneck - super_head_flag = False if index == 0: # don't associate a supervision head with self.input_block - rest_heads = superheads + current_head, rest_heads = nn.Identity(), superheads elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one - rest_heads = nn.ModuleList() + current_head, rest_heads = nn.Identity(), superheads[1:] else: - if len(superheads) > 0: - super_head_flag = True - rest_heads = superheads[1:] - else: - rest_heads = nn.ModuleList() + current_head, rest_heads = superheads[0], superheads[1:] # create the next layer down, this will stop at the bottleneck layer next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck) - if super_head_flag: - return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], next_layer, superheads[0]) - else: - return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], next_layer) + + return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer) self.skip_layers = create_skips( 0, @@ -247,7 +242,8 @@ def forward(self, x): out = self.output_block(out) if self.training and self.deep_supervision: out_all = [out] - for feature_map in self.heads: + feature_maps = self.heads[1 : self.deep_supr_num + 1] + for feature_map in feature_maps: out_all.append(interpolate(feature_map, out.shape[2:])) return torch.stack(out_all, dim=1) return out @@ -338,9 +334,7 @@ def get_module_list( return nn.ModuleList(layers) def get_deep_supervision_heads(self): - if not self.deep_supervision: - return nn.ModuleList() - return nn.ModuleList([self.get_output_block(i + 1) for i in range(self.deep_supr_num)]) + return nn.ModuleList([self.get_output_block(i + 1) for i in range(len(self.upsamples) - 1)]) @staticmethod def initialize_weights(module): diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index b288238cf6..ca19ea2b47 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -17,7 +17,7 @@ from monai.networks import eval_mode from monai.networks.nets import DynUNet -from tests.utils import SkipIfBeforePyTorchVersion, test_script_save +from tests.utils import test_script_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -111,7 +111,6 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) - @SkipIfBeforePyTorchVersion((1, 6)) def test_script(self): input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] net = DynUNet(**input_param)