Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .fcn import FCN, GCN, MCFCN, Refine
from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock
from .mlp import MLPBlock
from .patchembedding import PatchEmbeddingBlock
from .patchembedding import PatchEmbed, PatchEmbeddingBlock
from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock
from .segresnet_block import ResBlock
from .selfattention import SABlock
Expand Down
109 changes: 87 additions & 22 deletions monai/networks/blocks/patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import math
from typing import Sequence, Union
from typing import Sequence, Type, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm

from monai.networks.layers import Conv
from monai.networks.layers import Conv, trunc_normal_
from monai.utils import ensure_tuple_rep, optional_import
from monai.utils.module import look_up_option

Expand Down Expand Up @@ -98,38 +98,103 @@ def __init__(
)
self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
self.dropout = nn.Dropout(dropout_rate)
self.trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)
trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)
self.apply(self._init_weights)

def _init_weights(self, m):
if isinstance(m, nn.Linear):
self.trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0)
trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

def trunc_normal_(self, tensor, mean, std, a, b):
# From PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

with torch.no_grad():
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
tensor.uniform_(2 * l - 1, 2 * u - 1)
tensor.erfinv_()
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
tensor.clamp_(min=a, max=b)
return tensor

def forward(self, x):
x = self.patch_embeddings(x)
if self.pos_embed == "conv":
x = x.flatten(2).transpose(-1, -2)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings


class PatchEmbed(nn.Module):
"""
Patch embedding block based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>"
https://github.com/microsoft/Swin-Transformer

Unlike ViT patch embedding block: (1) input is padded to satisfy window size requirements (2) normalized if
specified (3) position embedding is not used.

Example::

>>> from monai.networks.blocks import PatchEmbed
>>> PatchEmbed(patch_size=2, in_chans=1, embed_dim=48, norm_layer=nn.LayerNorm, spatial_dims=3)
"""

def __init__(
self,
patch_size: Union[Sequence[int], int] = 2,
in_chans: int = 1,
embed_dim: int = 48,
norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
spatial_dims: int = 3,
) -> None:
"""
Args:
patch_size: dimension of patch size.
in_chans: dimension of input channels.
embed_dim: number of linear projection output channels.
norm_layer: normalization layer.
spatial_dims: spatial dimension.
"""

super().__init__()

if not (spatial_dims == 2 or spatial_dims == 3):
raise ValueError("spatial dimension should be 2 or 3.")

patch_size = ensure_tuple_rep(patch_size, spatial_dims)
self.patch_size = patch_size
self.embed_dim = embed_dim
self.proj = Conv[Conv.CONV, spatial_dims](
in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size
)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None

def forward(self, x):
x_shape = x.size()
if len(x_shape) == 5:
_, _, d, h, w = x_shape
if w % self.patch_size[2] != 0:
x = F.pad(x, (0, self.patch_size[2] - w % self.patch_size[2]))
if h % self.patch_size[1] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[1] - h % self.patch_size[1]))
if d % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - d % self.patch_size[0]))

elif len(x_shape) == 4:
_, _, h, w = x.size()
if w % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - w % self.patch_size[1]))
if h % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - h % self.patch_size[0]))

x = self.proj(x)
if self.norm is not None:
x_shape = x.size()
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
if len(x_shape) == 5:
d, wh, ww = x_shape[2], x_shape[3], x_shape[4]
x = x.transpose(1, 2).view(-1, self.embed_dim, d, wh, ww)
elif len(x_shape) == 4:
wh, ww = x_shape[2], x_shape[3]
x = x.transpose(1, 2).view(-1, self.embed_dim, wh, ww)
return x
2 changes: 2 additions & 0 deletions monai/networks/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding
from .drop_path import DropPath
from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args
from .filtering import BilateralFilter, PHLFilter
from .gmm import GaussianMixtureModel
Expand All @@ -27,3 +28,4 @@
)
from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push
from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer
from .weight_init import _no_grad_trunc_normal_, trunc_normal_
45 changes: 45 additions & 0 deletions monai/networks/layers/drop_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch.nn as nn


class DropPath(nn.Module):
"""Stochastic drop paths per sample for residual blocks.
Based on:
https://github.com/rwightman/pytorch-image-models
"""

def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True) -> None:
"""
Args:
drop_prob: drop path probability.
scale_by_keep: scaling by non-dropped probability.
"""
super().__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep

if not (0 <= drop_prob <= 1):
raise ValueError("Drop path prob should be between 0 and 1.")

def drop_path(self, x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor

def forward(self, x):
return self.drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
64 changes: 64 additions & 0 deletions monai/networks/layers/weight_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import torch


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
"""Tensor initialization with truncated normal distribution.
Based on:
https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
https://github.com/rwightman/pytorch-image-models

Args:
tensor: an n-dimensional `torch.Tensor`.
mean: the mean of the normal distribution.
std: the standard deviation of the normal distribution.
a: the minimum cutoff value.
b: the maximum cutoff value.
"""

def norm_cdf(x):
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

with torch.no_grad():
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
tensor.uniform_(2 * l - 1, 2 * u - 1)
tensor.erfinv_()
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
tensor.clamp_(min=a, max=b)
return tensor


def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
"""Tensor initialization with truncated normal distribution.
Based on:
https://github.com/rwightman/pytorch-image-models

Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
"""

if not std > 0:
raise ValueError("the standard deviation should be greater than zero.")

if a >= b:
raise ValueError("minimum cutoff value (a) should be smaller than maximum cutoff value (b).")

return _no_grad_trunc_normal_(tensor, mean, std, a, b)
1 change: 1 addition & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
seresnext50,
seresnext101,
)
from .swin_unetr import SwinUNETR
from .torchvision_fc import TorchVisionFCModel, TorchVisionFullyConvModel
from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex
from .unet import UNet, Unet
Expand Down
Loading