Skip to content

Commit

Permalink
enhance upsampling layer
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli committed Oct 19, 2020
1 parent d94c4d3 commit 4eef561
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 85 deletions.
4 changes: 2 additions & 2 deletions docs/source/networks.rst
Expand Up @@ -40,8 +40,8 @@ Blocks
.. autoclass:: MCFCN
:members:

`No New Unet Block`
~~~~~~~~~~~~~~~~~~~
`Dynamic-Unet Block`
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: UnetResBlock
:members:
.. autoclass:: UnetBasicBlock
Expand Down
6 changes: 3 additions & 3 deletions monai/networks/blocks/fcn.py
Expand Up @@ -108,7 +108,8 @@ class FCN(nn.Module):
Using the second mode cannot guarantee the model's reproducibility. Defaults to ``bilinear``.
- ``transpose``, uses transposed convolution layers.
- ``bilinear``, uses bilinear interpolate.
- ``bilinear``, uses bilinear interpolation.
pretrained: If True, returns a model pre-trained on ImageNet
progress: If True, displays a progress bar of the download to stderr.
"""
Expand Down Expand Up @@ -157,9 +158,8 @@ def __init__(
self.up_conv = UpSample(
dimensions=2,
in_channels=self.out_channels,
out_channels=self.out_channels,
scale_factor=2,
with_conv=True,
mode="deconv",
)

def forward(self, x: torch.Tensor):
Expand Down
28 changes: 15 additions & 13 deletions monai/networks/blocks/segresnet_block.py
Expand Up @@ -9,11 +9,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import torch.nn as nn

from monai.networks.blocks.convolutions import Convolution
from monai.networks.blocks.upsample import UpSample
from monai.networks.layers.factories import Act, Norm
from monai.utils import InterpolateMode, UpsampleMode


def get_norm_layer(spatial_dims: int, in_channels: int, norm_name: str, num_groups: int = 8):
Expand Down Expand Up @@ -46,19 +49,18 @@ def get_conv_layer(
)


def get_upsample_layer(spatial_dims: int, in_channels: int, upsample_mode: str = "trilinear", scale_factor: int = 2):
up_module: nn.Module
if upsample_mode == "transpose":
up_module = UpSample(
spatial_dims,
in_channels,
scale_factor=scale_factor,
with_conv=True,
)
else:
upsample_mode = "bilinear" if spatial_dims == 2 else "trilinear"
up_module = nn.Upsample(scale_factor=scale_factor, mode=upsample_mode, align_corners=False)
return up_module
def get_upsample_layer(
spatial_dims: int, in_channels: int, upsample_mode: Union[UpsampleMode, str] = "nontrainable", scale_factor: int = 2
):
return UpSample(
dimensions=spatial_dims,
in_channels=in_channels,
out_channels=in_channels,
scale_factor=scale_factor,
mode=upsample_mode,
interp_mode=InterpolateMode.LINEAR,
align_corners=False,
)


class ResBlock(nn.Module):
Expand Down
139 changes: 101 additions & 38 deletions monai/networks/blocks/upsample.py
Expand Up @@ -16,62 +16,110 @@

from monai.networks.layers.factories import Conv, Pad, Pool
from monai.networks.utils import icnr_init, pixelshuffle
from monai.utils import UpsampleMode, ensure_tuple_rep
from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep


class UpSample(nn.Module):
class UpSample(nn.Sequential):
"""
Upsample with either kernel 1 conv + interpolation or transposed conv.
Upsamples data by `scale_factor`.
Supported modes are:
- "deconv": uses a transposed convolution.
- "nontrainable": uses :py:class:`torch.nn.Upsample`.
- "pixelshuffle": uses :py:class:`monai.networks.blocks.SubpixelUpsample`.
This module can optionally take a pre-convolution
(often used to map the number of features from `in_channels` to `out_channels`).
"""

def __init__(
self,
dimensions: int,
in_channels: int,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
scale_factor: Union[Sequence[float], float] = 2,
with_conv: bool = False,
mode: Union[UpsampleMode, str] = UpsampleMode.LINEAR,
mode: Union[UpsampleMode, str] = UpsampleMode.DECONV,
pre_conv: Optional[Union[nn.Module, str]] = "default",
interp_mode: Union[InterpolateMode, str] = InterpolateMode.LINEAR,
align_corners: Optional[bool] = True,
bias: bool = True,
apply_pad_pool: bool = True,
) -> None:
"""
Args:
dimensions: number of spatial dimensions of the input image.
in_channels: number of channels of the input image.
out_channels: number of channels of the output image. Defaults to `in_channels`.
scale_factor: multiplier for spatial size. Has to match input size if it is a tuple. Defaults to 2.
with_conv: whether to use a transposed convolution for upsampling. Defaults to False.
mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
mode: {``"deconv"``, ``"nontrainable"``, ``"pixelshuffle"``}. Defaults to ``"deconv"``.
pre_conv: a conv block applied before upsampling. Defaults to None.
When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized when
Only used in the "nontrainable" or "pixelshuffle" mode.
interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
Only used when ``mode`` is ``UpsampleMode.NONTRAINABLE``.
If ends with ``"linear"`` will use ``spatial dims`` to determine the correct interpolation.
This corresponds to linear, bilinear, trilinear for 1D, 2D, and 3D respectively.
The interpolation mode. Defaults to ``"linear"``.
See also: https://pytorch.org/docs/stable/nn.html#upsample
align_corners: set the align_corners parameter of `torch.nn.Upsample`. Defaults to True.
Only used in the nontrainable mode.
bias: whether to have a bias term in the default preconv and deconv layers. Defaults to True.
apply_pad_pool: if True the upsampled tensor is padded then average pooling is applied with a kernel the
size of `scale_factor` with a stride of 1. See also: :py:class:`monai.networks.blocks.SubpixelUpsample`.
Only used in the pixelshuffle mode.
"""
super().__init__()
scale_factor_ = ensure_tuple_rep(scale_factor, dimensions)
if not out_channels:
out_channels = in_channels
if not with_conv:
mode = UpsampleMode(mode)
linear_mode = [UpsampleMode.LINEAR, UpsampleMode.BILINEAR, UpsampleMode.TRILINEAR]
if mode in linear_mode: # choose mode based on dimensions
mode = linear_mode[dimensions - 1]
self.upsample = nn.Sequential(
Conv[Conv.CONV, dimensions](in_channels=in_channels, out_channels=out_channels, kernel_size=1),
nn.Upsample(scale_factor=scale_factor_, mode=mode.value, align_corners=align_corners),
up_mode = UpsampleMode(mode)
if up_mode == UpsampleMode.DECONV:
if not in_channels:
raise ValueError(f"in_channels needs to be specified in the '{mode}' mode.")
self.add_module(
"deconv",
Conv[Conv.CONVTRANS, dimensions](
in_channels=in_channels,
out_channels=out_channels or in_channels,
kernel_size=scale_factor_,
stride=scale_factor_,
bias=bias,
),
)
else:
self.upsample = Conv[Conv.CONVTRANS, dimensions](
in_channels=in_channels, out_channels=out_channels, kernel_size=scale_factor_, stride=scale_factor_
elif up_mode == UpsampleMode.NONTRAINABLE:
if pre_conv == "default" and (out_channels != in_channels): # defaults to no conv if out_chns==in_chns
if not in_channels:
raise ValueError(f"in_channels needs to be specified in the '{mode}' mode.")
self.add_module(
"preconv",
Conv[Conv.CONV, dimensions](
in_channels=in_channels, out_channels=out_channels or in_channels, kernel_size=1, bias=bias
),
)
elif pre_conv is not None and pre_conv != "default":
self.add_module("preconv", pre_conv) # type: ignore

interp_mode = InterpolateMode(interp_mode)
linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR]
if interp_mode in linear_mode: # choose mode based on dimensions
interp_mode = linear_mode[dimensions - 1]
self.add_module(
"upsample_non_trainable",
nn.Upsample(scale_factor=scale_factor_, mode=interp_mode.value, align_corners=align_corners),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor in shape (batch, channel, spatial_1[, spatial_2, ...).
"""
return torch.as_tensor(self.upsample(x))
elif up_mode == UpsampleMode.PIXELSHUFFLE:
self.add_module(
"pixelshuffle",
SubpixelUpsample(
dimensions=dimensions,
in_channels=in_channels,
out_channels=out_channels,
scale_factor=scale_factor_[0], # isotropic
conv_block=pre_conv,
apply_pad_pool=apply_pad_pool,
bias=bias,
),
)
else:
raise NotImplementedError(f"Unsupported upsampling mode {mode}.")


class SubpixelUpsample(nn.Module):
Expand Down Expand Up @@ -102,21 +150,29 @@ class SubpixelUpsample(nn.Module):
def __init__(
self,
dimensions: int,
in_channels: int,
in_channels: Optional[int],
out_channels: Optional[int] = None,
scale_factor: int = 2,
conv_block: Optional[nn.Module] = None,
conv_block: Optional[Union[nn.Module, str]] = "default",
apply_pad_pool: bool = True,
bias: bool = True,
) -> None:
"""
Args:
dimensions: number of spatial dimensions of the input image.
in_channels: number of channels of the input image.
out_channels: optional number of channels of the output image.
scale_factor: multiplier for spatial size. Defaults to 2.
conv_block: a conv block to extract feature maps before upsampling. Defaults to None.
When ``conv_block is None``, one reserved conv layer will be utilized.
- When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized.
- When ``conv_block`` is an ``nn.module``,
please ensure the output number of channels is divisible ``(scale_factor ** dimensions)``.
apply_pad_pool: if True the upsampled tensor is padded then average pooling is applied with a kernel the
size of `scale_factor` with a stride of 1. This implements the nearest neighbour resize convolution
component of subpixel convolutions described in Aitken et al.
bias: whether to have a bias term in the default conv_block. Defaults to True.
"""
super().__init__()

Expand All @@ -126,17 +182,18 @@ def __init__(
self.dimensions = dimensions
self.scale_factor = scale_factor

if conv_block is None:
conv_out_channels = in_channels * (scale_factor ** dimensions)
if conv_block == "default":
out_channels = out_channels or in_channels
if not out_channels:
raise ValueError("in_channels need to be specified.")
conv_out_channels = out_channels * (scale_factor ** dimensions)
self.conv_block = Conv[Conv.CONV, dimensions](
in_channels=in_channels,
out_channels=conv_out_channels,
kernel_size=3,
stride=1,
padding=1,
in_channels=in_channels, out_channels=conv_out_channels, kernel_size=3, stride=1, padding=1, bias=bias
)

icnr_init(self.conv_block, self.scale_factor)
elif conv_block is None:
self.conv_block = nn.Identity()
else:
self.conv_block = conv_block

Expand All @@ -157,6 +214,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x: Tensor in shape (batch, channel, spatial_1[, spatial_2, ...).
"""
x = self.conv_block(x)
if x.shape[1] % (self.scale_factor ** self.dimensions) != 0:
raise ValueError(
f"Number of channels after `conv_block` ({x.shape[1]}) must be evenly "
"divisible by scale_factor ** dimensions "
f"({self.scale_factor}^{self.dimensions}={self.scale_factor**self.dimensions})."
)
x = pixelshuffle(x, self.dimensions, self.scale_factor)
x = self.pad_pool(x)
return x
28 changes: 15 additions & 13 deletions monai/networks/nets/segresnet.py
Expand Up @@ -9,15 +9,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Sequence
from typing import Optional, Sequence, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.blocks.segresnet_block import *
from monai.networks.layers.factories import Act, Dropout
from monai.utils import UpsampleMode


class SegResNet(nn.Module):
Expand All @@ -39,13 +39,14 @@ class SegResNet(nn.Module):
use_conv_final: if add a final convolution block to output. Defaults to ``True``.
blocks_down: number of down sample blocks in each layer. Defaults to ``[1,2,2,4]``.
blocks_up: number of up sample blocks in each layer. Defaults to ``[1,1,1]``.
upsample_mode: [``"transpose"``, ``"bilinear"``, ``"trilinear"``]
upsample_mode: [``"transpose"``, ``"nontrainable"``, ``"pixelshuffle"``]
The mode of upsampling manipulations.
Using the last two modes cannot guarantee the model's reproducibility. Defaults to``trilinear``.
Using the ``nontrainable`` modes cannot guarantee the model's reproducibility. Defaults to``nontrainable``.
- ``transpose``, uses transposed convolution layers.
- ``bilinear``, uses bilinear interpolate.
- ``trilinear``, uses trilinear interpolate.
- ``nontrainable``, uses non-trainable `linear` interpolation.
- ``pixelshuffle``, uses :py:class:`monai.networks.blocks.SubpixelUpsample`.
"""

def __init__(
Expand All @@ -60,7 +61,7 @@ def __init__(
use_conv_final: bool = True,
blocks_down: tuple = (1, 2, 2, 4),
blocks_up: tuple = (1, 1, 1),
upsample_mode: str = "trilinear",
upsample_mode: Union[UpsampleMode, str] = UpsampleMode.NONTRAINABLE,
):
super().__init__()

Expand All @@ -73,7 +74,7 @@ def __init__(
self.dropout_prob = dropout_prob
self.norm_name = norm_name
self.num_groups = num_groups
self.upsample_mode = upsample_mode
self.upsample_mode = UpsampleMode(upsample_mode)
self.use_conv_final = use_conv_final
self.convInit = get_conv_layer(spatial_dims, in_channels, init_filters)
self.down_layers = self._make_down_layers()
Expand Down Expand Up @@ -187,13 +188,14 @@ class SegResNetVAE(SegResNet):
use_conv_final: if add a final convolution block to output. Defaults to ``True``.
blocks_down: number of down sample blocks in each layer. Defaults to ``[1,2,2,4]``.
blocks_up: number of up sample blocks in each layer. Defaults to ``[1,1,1]``.
upsample_mode: [``"transpose"``, ``"bilinear"``, ``"trilinear"``]
upsample_mode: [``"transpose"``, ``"nontrainable"``, ``"pixelshuffle"``]
The mode of upsampling manipulations.
Using the last two modes cannot guarantee the model's reproducibility. Defaults to``trilinear``.
Using the ``nontrainable`` modes cannot guarantee the model's reproducibility. Defaults to `nontrainable`.
- ``transpose``, uses transposed convolution layers.
- ``bilinear``, uses bilinear interpolate.
- ``trilinear``, uses trilinear interpolate.
- ``nontrainable``, uses non-trainable `linear` interpolation.
- ``pixelshuffle``, uses :py:class:`monai.networks.blocks.SubpixelUpsample`.
use_vae: if use the variational autoencoder (VAE) during training. Defaults to ``False``.
input_image_size: the size of images to input into the network. It is used to
determine the in_features of the fc layer in VAE. When ``use_vae == True``, please
Expand All @@ -220,7 +222,7 @@ def __init__(
use_conv_final: bool = True,
blocks_down: tuple = (1, 2, 2, 4),
blocks_up: tuple = (1, 1, 1),
upsample_mode: str = "trilinear",
upsample_mode: Union[UpsampleMode, str] = "nontrainable",
):
super(SegResNetVAE, self).__init__(
spatial_dims=spatial_dims,
Expand Down
4 changes: 2 additions & 2 deletions monai/networks/utils.py
Expand Up @@ -218,8 +218,8 @@ def pixelshuffle(x: torch.Tensor, dimensions: int, scale_factor: int) -> torch.T

if channels % scale_divisor != 0:
raise ValueError(
f"Number of input channels ({channels}) must be evenly \
divisible by scale_factor ** dimensions ({scale_divisor})."
f"Number of input channels ({channels}) must be evenly "
f"divisible by scale_factor ** dimensions ({factor}**{dim}={scale_divisor})."
)

org_channels = channels // scale_divisor
Expand Down

0 comments on commit 4eef561

Please sign in to comment.