Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions monai/networks/blocks/dynunet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class UnetResBlock(nn.Module):
kernel_size: convolution kernel size.
stride: convolution stride.
norm_name: feature normalization type and arguments.
dropout: dropout probability
act_name: activation layer type and arguments.
dropout: dropout probability.

"""

Expand All @@ -45,6 +46,7 @@ def __init__(
kernel_size: Union[Sequence[int], int],
stride: Union[Sequence[int], int],
norm_name: Union[Tuple, str],
act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
dropout: Optional[Union[Tuple, str, float]] = None,
):
super().__init__()
Expand All @@ -63,7 +65,7 @@ def __init__(
self.conv3 = get_conv_layer(
spatial_dims, in_channels, out_channels, kernel_size=1, stride=stride, dropout=dropout, conv_only=True
)
self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01}))
self.lrelu = get_act_layer(name=act_name)
self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)
self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)
self.norm3 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)
Expand Down Expand Up @@ -100,7 +102,8 @@ class UnetBasicBlock(nn.Module):
kernel_size: convolution kernel size.
stride: convolution stride.
norm_name: feature normalization type and arguments.
dropout: dropout probability
act_name: activation layer type and arguments.
dropout: dropout probability.

"""

Expand All @@ -112,6 +115,7 @@ def __init__(
kernel_size: Union[Sequence[int], int],
stride: Union[Sequence[int], int],
norm_name: Union[Tuple, str],
act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
dropout: Optional[Union[Tuple, str, float]] = None,
):
super().__init__()
Expand All @@ -127,7 +131,7 @@ def __init__(
self.conv2 = get_conv_layer(
spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True
)
self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01}))
self.lrelu = get_act_layer(name=act_name)
self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)
self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)

Expand Down Expand Up @@ -155,7 +159,9 @@ class UnetUpBlock(nn.Module):
stride: convolution stride.
upsample_kernel_size: convolution kernel size for transposed convolution layers.
norm_name: feature normalization type and arguments.
dropout: dropout probability
act_name: activation layer type and arguments.
dropout: dropout probability.
trans_bias: transposed convolution bias.

"""

Expand All @@ -168,7 +174,9 @@ def __init__(
stride: Union[Sequence[int], int],
upsample_kernel_size: Union[Sequence[int], int],
norm_name: Union[Tuple, str],
act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
dropout: Optional[Union[Tuple, str, float]] = None,
trans_bias: bool = False,
):
super().__init__()
upsample_stride = upsample_kernel_size
Expand All @@ -179,6 +187,7 @@ def __init__(
kernel_size=upsample_kernel_size,
stride=upsample_stride,
dropout=dropout,
bias=trans_bias,
conv_only=True,
is_transposed=True,
)
Expand All @@ -190,6 +199,7 @@ def __init__(
stride=1,
dropout=dropout,
norm_name=norm_name,
act_name=act_name,
)

def forward(self, inp, skip):
Expand Down
51 changes: 42 additions & 9 deletions monai/networks/nets/dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class DynUNet(nn.Module):
This reimplementation of a dynamic UNet (DynUNet) is based on:
`Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_.
`nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation <https://arxiv.org/abs/1809.10486>`_.
`Optimized U-Net for Brain Tumor Segmentation <https://arxiv.org/pdf/2110.03352.pdf>`_.

This model is more flexible compared with ``monai.networks.nets.UNet`` in three
places:
Expand Down Expand Up @@ -89,8 +90,15 @@ class DynUNet(nn.Module):
strides: convolution strides for each blocks.
upsample_kernel_size: convolution kernel size for transposed convolution layers. The values should
equal to strides[1:].
filters: number of output channels for each blocks. Different from nnU-Net, in this implementation we add
this argument to make the network more flexible. As shown in the third reference, one way to determine
this argument is like:
``[64, 96, 128, 192, 256, 384, 512, 768, 1024][: len(strides)]``.
The above way is used in the network that wins task 1 in the BraTS21 Challenge.
If not specified, the way which nnUNet used will be employed. Defaults to ``None``.
dropout: dropout ratio. Defaults to no dropout.
norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``.
act_name: activation layer type and arguments. Defaults to ``leakyrelu``.
deep_supervision: whether to add deep supervision head before output. Defaults to ``False``.
If ``True``, in training mode, the forward function will output not only the last feature
map, but also the previous feature maps that come from the intermediate up sample layers.
Expand All @@ -109,6 +117,7 @@ class DynUNet(nn.Module):
Defaults to 1.
res_block: whether to use residual connection based convolution blocks during the network.
Defaults to ``False``.
trans_bias: whether to set the bias parameter in transposed convolution layers. Defaults to ``False``.
"""

def __init__(
Expand All @@ -119,11 +128,14 @@ def __init__(
kernel_size: Sequence[Union[Sequence[int], int]],
strides: Sequence[Union[Sequence[int], int]],
upsample_kernel_size: Sequence[Union[Sequence[int], int]],
filters: Optional[Sequence[int]] = None,
dropout: Optional[Union[Tuple, str, float]] = None,
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
deep_supervision: bool = False,
deep_supr_num: int = 1,
res_block: bool = False,
trans_bias: bool = False,
):
super().__init__()
self.spatial_dims = spatial_dims
Expand All @@ -133,9 +145,15 @@ def __init__(
self.strides = strides
self.upsample_kernel_size = upsample_kernel_size
self.norm_name = norm_name
self.act_name = act_name
self.dropout = dropout
self.conv_block = UnetResBlock if res_block else UnetBasicBlock
self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))]
self.trans_bias = trans_bias
if filters is not None:
self.filters = filters
self.check_filters()
else:
self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))]
self.input_block = self.get_input_block()
self.downsamples = self.get_downsamples()
self.bottleneck = self.get_bottleneck()
Expand All @@ -161,9 +179,9 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck):
"""

if len(downsamples) != len(upsamples):
raise AssertionError(f"{len(downsamples)} != {len(upsamples)}")
raise ValueError(f"{len(downsamples)} != {len(upsamples)}")
if (len(downsamples) - len(superheads)) not in (1, 0):
raise AssertionError(f"{len(downsamples)}-(0,1) != {len(superheads)}")
raise ValueError(f"{len(downsamples)}-(0,1) != {len(superheads)}")

if len(downsamples) == 0: # bottom of the network, pass the bottleneck block
return bottleneck
Expand Down Expand Up @@ -191,26 +209,33 @@ def check_kernel_stride(self):
kernels, strides = self.kernel_size, self.strides
error_msg = "length of kernel_size and strides should be the same, and no less than 3."
if len(kernels) != len(strides) or len(kernels) < 3:
raise AssertionError(error_msg)
raise ValueError(error_msg)

for idx, k_i in enumerate(kernels):
kernel, stride = k_i, strides[idx]
if not isinstance(kernel, int):
error_msg = f"length of kernel_size in block {idx} should be the same as spatial_dims."
if len(kernel) != self.spatial_dims:
raise AssertionError(error_msg)
raise ValueError(error_msg)
if not isinstance(stride, int):
error_msg = f"length of stride in block {idx} should be the same as spatial_dims."
if len(stride) != self.spatial_dims:
raise AssertionError(error_msg)
raise ValueError(error_msg)

def check_deep_supr_num(self):
deep_supr_num, strides = self.deep_supr_num, self.strides
num_up_layers = len(strides) - 1
if deep_supr_num >= num_up_layers:
raise AssertionError("deep_supr_num should be less than the number of up sample layers.")
raise ValueError("deep_supr_num should be less than the number of up sample layers.")
if deep_supr_num < 1:
raise AssertionError("deep_supr_num should be larger than 0.")
raise ValueError("deep_supr_num should be larger than 0.")

def check_filters(self):
filters = self.filters
if len(filters) < len(self.strides):
raise ValueError("length of filters should be no less than the length of strides.")
else:
self.filters = filters[: len(self.strides)]

def forward(self, x):
out = self.skip_layers(x)
Expand All @@ -231,6 +256,7 @@ def get_input_block(self):
self.kernel_size[0],
self.strides[0],
self.norm_name,
self.act_name,
dropout=self.dropout,
)

Expand All @@ -242,6 +268,7 @@ def get_bottleneck(self):
self.kernel_size[-1],
self.strides[-1],
self.norm_name,
self.act_name,
dropout=self.dropout,
)

Expand All @@ -257,7 +284,9 @@ def get_upsamples(self):
inp, out = self.filters[1:][::-1], self.filters[:-1][::-1]
strides, kernel_size = self.strides[1:][::-1], self.kernel_size[1:][::-1]
upsample_kernel_size = self.upsample_kernel_size[::-1]
return self.get_module_list(inp, out, kernel_size, strides, UnetUpBlock, upsample_kernel_size)
return self.get_module_list(
inp, out, kernel_size, strides, UnetUpBlock, upsample_kernel_size, trans_bias=self.trans_bias
)

def get_module_list(
self,
Expand All @@ -267,6 +296,7 @@ def get_module_list(
strides: Sequence[Union[Sequence[int], int]],
conv_block: nn.Module,
upsample_kernel_size: Optional[Sequence[Union[Sequence[int], int]]] = None,
trans_bias: bool = False,
):
layers = []
if upsample_kernel_size is not None:
Expand All @@ -280,8 +310,10 @@ def get_module_list(
"kernel_size": kernel,
"stride": stride,
"norm_name": self.norm_name,
"act_name": self.act_name,
"dropout": self.dropout,
"upsample_kernel_size": up_kernel,
"trans_bias": trans_bias,
}
layer = conv_block(**params)
layers.append(layer)
Expand All @@ -294,6 +326,7 @@ def get_module_list(
"kernel_size": kernel,
"stride": stride,
"norm_name": self.norm_name,
"act_name": self.act_name,
"dropout": self.dropout,
}
layer = conv_block(**params)
Expand Down
Loading