Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/cron.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
if: github.repository == 'Project-MONAI/MONAI'
strategy:
matrix:
container: ["pytorch:21.02", "pytorch:21.10", "pytorch:22.03"] # 21.02, 21.10 for backward comp.
container: ["pytorch:21.02", "pytorch:21.10", "pytorch:22.04"] # 21.02, 21.10 for backward comp.
container:
image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image
options: "--gpus all"
Expand Down Expand Up @@ -109,7 +109,7 @@ jobs:
if: github.repository == 'Project-MONAI/MONAI'
strategy:
matrix:
container: ["pytorch:21.02", "pytorch:21.10", "pytorch:22.03"] # 21.02, 21.10 for backward comp.
container: ["pytorch:21.02", "pytorch:21.10", "pytorch:22.04"] # 21.02, 21.10 for backward comp.
container:
image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image
options: "--gpus all"
Expand Down Expand Up @@ -207,7 +207,7 @@ jobs:
if: github.repository == 'Project-MONAI/MONAI'
needs: cron-gpu # so that monai itself is verified first
container:
image: nvcr.io/nvidia/pytorch:22.03-py3 # testing with the latest pytorch base image
image: nvcr.io/nvidia/pytorch:22.04-py3 # testing with the latest pytorch base image
options: "--gpus all --ipc=host"
runs-on: [self-hosted, linux, x64, common]
steps:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/pythonapp-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ jobs:
base: "nvcr.io/nvidia/pytorch:21.10-py3"
- environment: PT111+CUDA116
# we explicitly set pytorch to -h to avoid pip install error
# 22.03: 1.12.0a0+2c916ef
# 22.04: 1.12.0a0+bd13bc6
pytorch: "-h"
base: "nvcr.io/nvidia/pytorch:22.03-py3"
base: "nvcr.io/nvidia/pytorch:22.04-py3"
- environment: PT110+CUDA102
pytorch: "torch==1.10.2 torchvision==0.11.3"
base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04"
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# To build with a different base image
# please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag.
ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:22.03-py3
ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:22.04-py3
FROM ${PYTORCH_IMAGE}

LABEL maintainer="monai.contact@gmail.com"
Expand Down
11 changes: 11 additions & 0 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,17 @@ def affine(self, d: torch.Tensor) -> None:
"""Set the affine."""
self.meta["affine"] = d

def new_empty(self, size, dtype=None, device=None, requires_grad=False):
"""
must be defined for deepcopy to work

See:
- https://pytorch.org/docs/stable/generated/torch.Tensor.new_empty.html#torch-tensor-new-empty
"""
return type(self)(
self.as_tensor().new_empty(size=size, dtype=dtype, device=device, requires_grad=requires_grad)
)

@staticmethod
def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict):
"""
Expand Down
2 changes: 1 addition & 1 deletion monai/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .fcn import FCN, GCN, MCFCN, Refine
from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock
from .mlp import MLPBlock
from .patchembedding import PatchEmbeddingBlock
from .patchembedding import PatchEmbed, PatchEmbeddingBlock
from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock
from .segresnet_block import ResBlock
from .selfattention import SABlock
Expand Down
109 changes: 87 additions & 22 deletions monai/networks/blocks/patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import math
from typing import Sequence, Union
from typing import Sequence, Type, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm

from monai.networks.layers import Conv
from monai.networks.layers import Conv, trunc_normal_
from monai.utils import ensure_tuple_rep, optional_import
from monai.utils.module import look_up_option

Expand Down Expand Up @@ -98,38 +98,103 @@ def __init__(
)
self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
self.dropout = nn.Dropout(dropout_rate)
self.trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)
trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)
self.apply(self._init_weights)

def _init_weights(self, m):
if isinstance(m, nn.Linear):
self.trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0)
trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

def trunc_normal_(self, tensor, mean, std, a, b):
# From PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

with torch.no_grad():
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
tensor.uniform_(2 * l - 1, 2 * u - 1)
tensor.erfinv_()
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
tensor.clamp_(min=a, max=b)
return tensor

def forward(self, x):
x = self.patch_embeddings(x)
if self.pos_embed == "conv":
x = x.flatten(2).transpose(-1, -2)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings


class PatchEmbed(nn.Module):
"""
Patch embedding block based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>"
https://github.com/microsoft/Swin-Transformer

Unlike ViT patch embedding block: (1) input is padded to satisfy window size requirements (2) normalized if
specified (3) position embedding is not used.

Example::

>>> from monai.networks.blocks import PatchEmbed
>>> PatchEmbed(patch_size=2, in_chans=1, embed_dim=48, norm_layer=nn.LayerNorm, spatial_dims=3)
"""

def __init__(
self,
patch_size: Union[Sequence[int], int] = 2,
in_chans: int = 1,
embed_dim: int = 48,
norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
spatial_dims: int = 3,
) -> None:
"""
Args:
patch_size: dimension of patch size.
in_chans: dimension of input channels.
embed_dim: number of linear projection output channels.
norm_layer: normalization layer.
spatial_dims: spatial dimension.
"""

super().__init__()

if not (spatial_dims == 2 or spatial_dims == 3):
raise ValueError("spatial dimension should be 2 or 3.")

patch_size = ensure_tuple_rep(patch_size, spatial_dims)
self.patch_size = patch_size
self.embed_dim = embed_dim
self.proj = Conv[Conv.CONV, spatial_dims](
in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size
)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None

def forward(self, x):
x_shape = x.size()
if len(x_shape) == 5:
_, _, d, h, w = x_shape
if w % self.patch_size[2] != 0:
x = F.pad(x, (0, self.patch_size[2] - w % self.patch_size[2]))
if h % self.patch_size[1] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[1] - h % self.patch_size[1]))
if d % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - d % self.patch_size[0]))

elif len(x_shape) == 4:
_, _, h, w = x.size()
if w % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - w % self.patch_size[1]))
if h % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - h % self.patch_size[0]))

x = self.proj(x)
if self.norm is not None:
x_shape = x.size()
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
if len(x_shape) == 5:
d, wh, ww = x_shape[2], x_shape[3], x_shape[4]
x = x.transpose(1, 2).view(-1, self.embed_dim, d, wh, ww)
elif len(x_shape) == 4:
wh, ww = x_shape[2], x_shape[3]
x = x.transpose(1, 2).view(-1, self.embed_dim, wh, ww)
return x
2 changes: 2 additions & 0 deletions monai/networks/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding
from .drop_path import DropPath
from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args
from .filtering import BilateralFilter, PHLFilter
from .gmm import GaussianMixtureModel
Expand All @@ -27,3 +28,4 @@
)
from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push
from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer
from .weight_init import _no_grad_trunc_normal_, trunc_normal_
45 changes: 45 additions & 0 deletions monai/networks/layers/drop_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch.nn as nn


class DropPath(nn.Module):
"""Stochastic drop paths per sample for residual blocks.
Based on:
https://github.com/rwightman/pytorch-image-models
"""

def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True) -> None:
"""
Args:
drop_prob: drop path probability.
scale_by_keep: scaling by non-dropped probability.
"""
super().__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep

if not (0 <= drop_prob <= 1):
raise ValueError("Drop path prob should be between 0 and 1.")

def drop_path(self, x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor

def forward(self, x):
return self.drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
29 changes: 28 additions & 1 deletion monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,14 @@ def use_factory(fact_args):
layer = use_factory( (fact.TEST, kwargs) )
"""

import warnings
from typing import Any, Callable, Dict, Tuple, Type, Union

import torch.nn as nn

from monai.utils import look_up_option
from monai.utils import look_up_option, optional_import

InstanceNorm3dNVFuser, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser")

__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"]

Expand Down Expand Up @@ -242,6 +245,30 @@ def sync_batch_factory(_dim) -> Type[nn.SyncBatchNorm]:
return nn.SyncBatchNorm


@Norm.factory_function("instance_nvfuser")
def instance_nvfuser_factory(dim):
"""
`InstanceNorm3dNVFuser` is a faster verison of InstanceNorm layer and implemented in `apex`.
It only supports 3d tensors as the input. It also requires to use with CUDA and non-Windows OS.
In this function, if the required library `apex.normalization.InstanceNorm3dNVFuser` does not exist,
`nn.InstanceNorm3d` will be returned instead.
This layer is based on a customized autograd function, which is not supported in TorchScript currently.
Please switch to use `nn.InstanceNorm3d` if TorchScript is necessary.

Please check the following link for more details about how to install `apex`:
https://github.com/NVIDIA/apex#installation

"""
types = (nn.InstanceNorm1d, nn.InstanceNorm2d)
if dim != 3:
warnings.warn(f"`InstanceNorm3dNVFuser` only supports 3d cases, use {types[dim - 1]} instead.")
return types[dim - 1]
if not has_nvfuser:
warnings.warn("`apex.normalization.InstanceNorm3dNVFuser` is not found, use nn.InstanceNorm3d instead.")
return nn.InstanceNorm3d
return InstanceNorm3dNVFuser


Act.add_factory_callable("elu", lambda: nn.modules.ELU)
Act.add_factory_callable("relu", lambda: nn.modules.ReLU)
Act.add_factory_callable("leakyrelu", lambda: nn.modules.LeakyReLU)
Expand Down
64 changes: 64 additions & 0 deletions monai/networks/layers/weight_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import torch


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
"""Tensor initialization with truncated normal distribution.
Based on:
https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
https://github.com/rwightman/pytorch-image-models

Args:
tensor: an n-dimensional `torch.Tensor`.
mean: the mean of the normal distribution.
std: the standard deviation of the normal distribution.
a: the minimum cutoff value.
b: the maximum cutoff value.
"""

def norm_cdf(x):
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

with torch.no_grad():
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
tensor.uniform_(2 * l - 1, 2 * u - 1)
tensor.erfinv_()
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
tensor.clamp_(min=a, max=b)
return tensor


def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
"""Tensor initialization with truncated normal distribution.
Based on:
https://github.com/rwightman/pytorch-image-models

Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
"""

if not std > 0:
raise ValueError("the standard deviation should be greater than zero.")

if a >= b:
raise ValueError("minimum cutoff value (a) should be smaller than maximum cutoff value (b).")

return _no_grad_trunc_normal_(tensor, mean, std, a, b)
1 change: 1 addition & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
seresnext50,
seresnext101,
)
from .swin_unetr import SwinUNETR
from .torchvision_fc import TorchVisionFCModel, TorchVisionFullyConvModel
from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex
from .unet import UNet, Unet
Expand Down
Loading