Skip to content
Merged

Ae docs #3067

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
8c35fb0
Adding documentation for AutoEncoder and VarAutoEncoder
ericspod Sep 24, 2021
d303886
Docs Update
ericspod Sep 30, 2021
4170f3b
cuCIM Transform (#2932)
bhashemian Sep 16, 2021
6ff564a
torch `SpatialCrop`, `SpatialCropd` (#2963)
rijobro Sep 16, 2021
32589cc
Update with cupy.ndarray (#2965)
bhashemian Sep 16, 2021
3aa24ea
Fix for Jupyter plotting (#2964)
ericspod Sep 16, 2021
c240f82
fixes tutorial issue 352 (#2968)
wyli Sep 17, 2021
efc52e9
2231 Enhance tensor transforms (#2966)
Nic-Ma Sep 17, 2021
e47f3b6
Torch `RandCropByPosNegLabel`, `RandCropByPosNegLabeld`, `RandCropByL…
rijobro Sep 19, 2021
061300a
enhance affinegrid to use torch backend (#2969)
wyli Sep 20, 2021
692be31
Add make the name of wsi reader lowercase (#2973)
bhashemian Sep 20, 2021
0d7c487
update multimodal doc and model names (Transchex) (#2979)
ahatamiz Sep 20, 2021
08d8b7a
2914 remove the deprecated API for v0.7 (#2981)
wyli Sep 20, 2021
9f0fa08
2985 - enhance nightly test (#2986)
wyli Sep 21, 2021
c662a7f
dimension check for pretrained model weights (#2984)
neuronflow Sep 21, 2021
fa4034c
Refactor unnecessary `else` / `elif` when `if` block has a `return` s…
deepsource-autofix[bot] Sep 22, 2021
5601868
3000 Support not copy in CacheDataset (#3001)
Nic-Ma Sep 22, 2021
89c5b16
3002 - compatibility with torch 1.9.1 (#3003)
wyli Sep 22, 2021
ced1cbe
Fix cucim dep compatibility (#3006)
wyli Sep 22, 2021
b9eac8a
update cucim dep (#3007)
wyli Sep 23, 2021
5b2a195
[DLMED] fix CI test (#3008)
Nic-Ma Sep 23, 2021
2938412
3009 Update highlights web page (#3013)
Nic-Ma Sep 23, 2021
579c61a
2914 release note, and what's new for v0.7 (#2992)
wyli Sep 23, 2021
5036654
delaying the removal (#3015)
wyli Sep 24, 2021
4b06838
update preview version tag (#3016)
wyli Sep 24, 2021
167cbd2
Enhance 0.7 README doc (#3017)
Nic-Ma Sep 24, 2021
624a3d6
3020 Enhance what's new for transfomer networks (#3019)
Nic-Ma Sep 24, 2021
defb2bc
Torch `GaussianSmooth`, `RandGaussianSmooth`, `GaussianSharpen`, `Ran…
rijobro Sep 24, 2021
655a16b
apply pyupgrade (#3026)
Borda Sep 25, 2021
087ed84
3018 enhance version string check (#3022)
wyli Sep 27, 2021
2748e07
3028 update ignite CI tests to 0.4.6 (#3029)
Nic-Ma Sep 27, 2021
613bed6
enhance DataStats to include dtype (#3043)
wyli Sep 28, 2021
f27cc4e
enhance error handling (#3042)
wyli Sep 28, 2021
27bb163
2792 - Torch GibbsNoise, RandGibbsNoise, KSpaceSpikeNoise, RandKSpace…
rijobro Sep 29, 2021
a4ca5fc
AdjustContrast, AdjustContrastd, RandAdjustContrast, RandAdjustContra…
rijobro Sep 29, 2021
6fb63ac
remove false positive tests (#3046)
wyli Sep 29, 2021
ef79653
Torch `Spacing`, `Spacingd` (#3045)
rijobro Sep 29, 2021
c048809
2975 torch fgbgtoindices (#3038)
wyli Sep 30, 2021
5b2f8b1
3051 Fix dtype issue in Spacing transform (#3052)
Nic-Ma Sep 30, 2021
af6ffa6
Documentation update
ericspod Oct 1, 2021
dc8de2a
Adding documentation
ericspod Oct 4, 2021
99b185a
Create transform images (#3039)
rijobro Sep 30, 2021
8279055
2975 Fix the perf issue of RandCropByPosNegLabel (#3050)
Nic-Ma Oct 1, 2021
8938881
[DLMED] fix broken link (#3059)
Nic-Ma Oct 2, 2021
0d0d6f8
remove redundant noqa (#3027)
Borda Oct 3, 2021
7bbfde7
[DLMED] enhance label classes (#3061)
Nic-Ma Oct 4, 2021
e88b170
[DLMED] enhance ScaleIntensity (#3062)
Nic-Ma Oct 4, 2021
c3d72ec
Extra transform examples (#3056)
rijobro Oct 4, 2021
e66cea0
update backend (#3065)
wyli Oct 4, 2021
61999c0
3063 Fix the complex Tensor issue in type conversion (#3064)
Nic-Ma Oct 4, 2021
56e93bb
Formatting
ericspod Oct 4, 2021
fb92e92
Merge branch 'dev' into ae_docs
ericspod Oct 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 79 additions & 10 deletions monai/networks/nets/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,69 @@

class AutoEncoder(nn.Module):
"""
Base class for the architecture implementing :py:class:`monai.networks.nets.VarAutoEncoder`.
Simple definition of an autoencoder and base class for the architecture implementing
:py:class:`monai.networks.nets.VarAutoEncoder`. The network is composed of an encode sequence of blocks, followed
by an intermediary sequence of blocks, and finally a decode sequence of blocks. The encode and decode blocks are
default :py:class:`monai.networks.blocks.Convolution` instances with the encode blocks having the given stride
and the decode blocks having transpose convolutions with the same stride. If `num_res_units` is given residual
blocks are used instead.

By default the intermediary sequence is empty but if `inter_channels` is given to specify the output channels of
blocks then this will be become a sequence of Convolution blocks or of residual blocks if `num_inter_units` is
given. The optional parameter `inter_dilations` can be used to specify the dilation values of the convolutions in
these blocks, this allows a network to use dilated kernels in this middle section. Since the intermediary section
isn't meant to change the size of the output the strides for all these kernels is 1.

Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
channels: sequence of channels. Top block first. The length of `channels` should be no less than 2.
strides: sequence of convolution strides. The length of `stride` should equal to `len(channels) - 1`.
kernel_size: convolution kernel size, the value(s) should be odd. If sequence,
its length should equal to dimensions. Defaults to 3.
up_kernel_size: upsampling convolution kernel size, the value(s) should be odd. If sequence,
its length should equal to dimensions. Defaults to 3.
num_res_units: number of residual units. Defaults to 0.
inter_channels: sequence of channels defining the blocks in the intermediate layer between encode and decode.
inter_dilations: defines the dilation value for each block of the intermediate layer. Defaults to 1.
num_inter_units: number of residual units for each block of the intermediate layer. Defaults to 0.
act: activation type and arguments. Defaults to PReLU.
norm: feature normalization type and arguments. Defaults to instance norm.
dropout: dropout ratio. Defaults to no dropout.
bias: whether to have a bias term in convolution blocks. Defaults to True.
According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
if a conv layer is directly followed by a batch norm layer, bias should be False.

.. deprecated:: 0.6.0
``dimensions`` is deprecated, use ``spatial_dims`` instead.

Examples::

from monai.networks.nets import AutoEncoder

# 3 layers each down/up sampling their inputs by a factor 2 with no intermediate layer
net = AutoEncoder(
spatial_dims=2,
in_channels=1,
out_channels=1,
channels=(2, 4, 8),
strides=(2, 2, 2)
)

# 1 layer downsampling by 2, followed by a sequence of residual units with 2 convolutions defined by
# progressively increasing dilations, then final upsample layer
net = AutoEncoder(
spatial_dims=2,
in_channels=1,
out_channels=1,
channels=(4,),
strides=(2,),
inter_channels=(8, 8, 8),
inter_dilations=(1, 2, 4),
num_inter_units=2
)

"""

@deprecated_arg(
Expand All @@ -48,13 +110,6 @@ def __init__(
bias: bool = True,
dimensions: Optional[int] = None,
) -> None:
"""
Initialize the AutoEncoder.

.. deprecated:: 0.6.0
``dimensions`` is deprecated, use ``spatial_dims`` instead.

"""

super().__init__()
self.dimensions = spatial_dims if dimensions is None else dimensions
Expand Down Expand Up @@ -87,6 +142,9 @@ def __init__(
def _get_encode_module(
self, in_channels: int, channels: Sequence[int], strides: Sequence[int]
) -> Tuple[nn.Sequential, int]:
"""
Returns the encode part of the network by building up a sequence of layers returned by `_get_encode_layer`.
"""
encode = nn.Sequential()
layer_channels = in_channels

Expand All @@ -98,6 +156,10 @@ def _get_encode_module(
return encode, layer_channels

def _get_intermediate_module(self, in_channels: int, num_inter_units: int) -> Tuple[nn.Module, int]:
"""
Returns the intermediate block of the network which accepts input from the encoder and whose output goes
to the decoder.
"""
# Define some types
intermediate: nn.Module
unit: nn.Module
Expand Down Expand Up @@ -145,6 +207,9 @@ def _get_intermediate_module(self, in_channels: int, num_inter_units: int) -> Tu
def _get_decode_module(
self, in_channels: int, channels: Sequence[int], strides: Sequence[int]
) -> Tuple[nn.Sequential, int]:
"""
Returns the decode part of the network by building up a sequence of layers returned by `_get_decode_layer`.
"""
decode = nn.Sequential()
layer_channels = in_channels

Expand All @@ -156,7 +221,9 @@ def _get_decode_module(
return decode, layer_channels

def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, is_last: bool) -> nn.Module:

"""
Returns a single layer of the encoder part of the network.
"""
mod: nn.Module
if self.num_res_units > 0:
mod = ResidualUnit(
Expand Down Expand Up @@ -187,7 +254,9 @@ def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, i
return mod

def _get_decode_layer(self, in_channels: int, out_channels: int, strides: int, is_last: bool) -> nn.Sequential:

"""
Returns a single layer of the decoder part of the network.
"""
decode = nn.Sequential()

conv = Convolution(
Expand Down
75 changes: 36 additions & 39 deletions monai/networks/nets/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,19 @@ class Classifier(Regressor):
Defines a classification network from Regressor by specifying the output shape as a single dimensional tensor
with size equal to the number of classes to predict. The final activation function can also be specified, eg.
softmax or sigmoid.

Args:
in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension)
classes: integer stating the dimension of the final output tensor
channels: tuple of integers stating the output channels of each convolutional layer
strides: tuple of integers stating the stride (downscale factor) of each convolutional layer
kernel_size: integer or tuple of integers stating size of convolutional kernels
num_res_units: integer stating number of convolutions in residual units, 0 means no residual units
act: name or type defining activation layers
norm: name or type defining normalization layers
dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout
bias: boolean stating if convolution layers should have a bias component
last_act: name defining the last activation layer
"""

def __init__(
Expand All @@ -41,20 +54,6 @@ def __init__(
bias: bool = True,
last_act: Optional[str] = None,
) -> None:
"""
Args:
in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension)
classes: integer stating the dimension of the final output tensor
channels: tuple of integers stating the output channels of each convolutional layer
strides: tuple of integers stating the stride (downscale factor) of each convolutional layer
kernel_size: integer or tuple of integers stating size of convolutional kernels
num_res_units: integer stating number of convolutions in residual units, 0 means no residual units
act: name or type defining activation layers
norm: name or type defining normalization layers
dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout
bias: boolean stating if convolution layers should have a bias component
last_act: name defining the last activation layer
"""
super().__init__(in_shape, (classes,), channels, strides, kernel_size, num_res_units, act, norm, dropout, bias)

if last_act is not None:
Expand All @@ -68,6 +67,18 @@ class Discriminator(Classifier):
"""
Defines a discriminator network from Classifier with a single output value and sigmoid activation by default. This
is meant for use with GANs or other applications requiring a generic discriminator network.

Args:
in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension)
channels: tuple of integers stating the output channels of each convolutional layer
strides: tuple of integers stating the stride (downscale factor) of each convolutional layer
kernel_size: integer or tuple of integers stating size of convolutional kernels
num_res_units: integer stating number of convolutions in residual units, 0 means no residual units
act: name or type defining activation layers
norm: name or type defining normalization layers
dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout
bias: boolean stating if convolution layers should have a bias component
last_act: name defining the last activation layer
"""

def __init__(
Expand All @@ -83,19 +94,6 @@ def __init__(
bias: bool = True,
last_act=Act.SIGMOID,
) -> None:
"""
Args:
in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension)
channels: tuple of integers stating the output channels of each convolutional layer
strides: tuple of integers stating the stride (downscale factor) of each convolutional layer
kernel_size: integer or tuple of integers stating size of convolutional kernels
num_res_units: integer stating number of convolutions in residual units, 0 means no residual units
act: name or type defining activation layers
norm: name or type defining normalization layers
dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout
bias: boolean stating if convolution layers should have a bias component
last_act: name defining the last activation layer
"""
super().__init__(in_shape, 1, channels, strides, kernel_size, num_res_units, act, norm, dropout, bias, last_act)


Expand All @@ -104,6 +102,17 @@ class Critic(Classifier):
Defines a critic network from Classifier with a single output value and no final activation. The final layer is
`nn.Flatten` instead of `nn.Linear`, the final result is computed as the mean over the first dimension. This is
meant to be used with Wasserstein GANs.

Args:
in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension)
channels: tuple of integers stating the output channels of each convolutional layer
strides: tuple of integers stating the stride (downscale factor) of each convolutional layer
kernel_size: integer or tuple of integers stating size of convolutional kernels
num_res_units: integer stating number of convolutions in residual units, 0 means no residual units
act: name or type defining activation layers
norm: name or type defining normalization layers
dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout
bias: boolean stating if convolution layers should have a bias component
"""

def __init__(
Expand All @@ -118,18 +127,6 @@ def __init__(
dropout: Optional[float] = 0.25,
bias: bool = True,
) -> None:
"""
Args:
in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension)
channels: tuple of integers stating the output channels of each convolutional layer
strides: tuple of integers stating the stride (downscale factor) of each convolutional layer
kernel_size: integer or tuple of integers stating size of convolutional kernels
num_res_units: integer stating number of convolutions in residual units, 0 means no residual units
act: name or type defining activation layers
norm: name or type defining normalization layers
dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout
bias: boolean stating if convolution layers should have a bias component
"""
super().__init__(in_shape, 1, channels, strides, kernel_size, num_res_units, act, norm, dropout, bias, None)

def _get_final_layer(self, in_shape: Sequence[int]):
Expand Down
51 changes: 45 additions & 6 deletions monai/networks/nets/fullyconnectednet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,24 @@ def _get_adn_layer(

class FullyConnectedNet(nn.Sequential):
"""
Plain full-connected layer neural network
Simple full-connected layer neural network composed of a sequence of linear layers with PReLU activation and
dropout. The network accepts input with `in_channels` channels, has output with `out_channels` channels, and
hidden layer output channels given in `hidden_channels`. If `bias` is True then linear units have a bias term.

Args:
in_channels: number of input channels.
out_channels: number of output channels.
hidden_channels: number of output channels for each hidden layer.
dropout: dropout ratio. Defaults to no dropout.
act: activation type and arguments. Defaults to PReLU.
bias: whether to have a bias term in linear units. Defaults to True.
adn_ordering: order of operations in :py:class:`monai.networks.blocks.ADN`.

Examples::

# accepts 4 values and infers 3 values as output, has 3 hidden layers with 10, 20, 10 values as output
net = FullyConnectedNet(4, 3, [10, 20, 10], dropout=0.2)

The network uses dropout and, by default, PReLU activation
"""

def __init__(
Expand All @@ -53,8 +68,11 @@ def __init__(
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = list(hidden_channels)
self.act = act
self.dropout = dropout
self.adn_ordering = adn_ordering

self.add_module("flatten", nn.Flatten())
self.adn_layer = _get_adn_layer(act, dropout, adn_ordering)

prev_channels = self.in_channels
for i, c in enumerate(hidden_channels):
Expand All @@ -64,13 +82,34 @@ def __init__(
self.add_module("output", nn.Linear(prev_channels, out_channels, bias))

def _get_layer(self, in_channels: int, out_channels: int, bias: bool) -> nn.Sequential:
seq = nn.Sequential(nn.Linear(in_channels, out_channels, bias))
seq.add_module("ADN", self.adn_layer)
seq = nn.Sequential(
nn.Linear(in_channels, out_channels, bias), _get_adn_layer(self.act, self.dropout, self.adn_ordering)
)
return seq


class VarFullyConnectedNet(nn.Module):
"""Variational fully-connected network."""
"""
Variational fully-connected network. This is composed of an encode layer, reparameterization layer, and then a
decode layer.

Args:
in_channels: number of input channels.
out_channels: number of output channels.
latent_size: number of latent variables to use.
encode_channels: number of output channels for each hidden layer of the encode half.
decode_channels: number of output channels for each hidden layer of the decode half.
dropout: dropout ratio. Defaults to no dropout.
act: activation type and arguments. Defaults to PReLU.
bias: whether to have a bias term in linear units. Defaults to True.
adn_ordering: order of operations in :py:class:`monai.networks.blocks.ADN`.

Examples::

# accepts inputs with 4 values, uses a latent space of 2 variables, and produces outputs of 3 values
net = VarFullyConnectedNet(4, 3, 2, [5, 10], [10, 5])

"""

def __init__(
self,
Expand Down
Loading