diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml index adcba8e770..5b4e2f7285 100644 --- a/.github/workflows/cron.yml +++ b/.github/workflows/cron.yml @@ -65,7 +65,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:21.02", "pytorch:21.10", "pytorch:22.03"] # 21.02, 21.10 for backward comp. + container: ["pytorch:21.02", "pytorch:21.10", "pytorch:22.04"] # 21.02, 21.10 for backward comp. container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -109,7 +109,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:21.02", "pytorch:21.10", "pytorch:22.03"] # 21.02, 21.10 for backward comp. + container: ["pytorch:21.02", "pytorch:21.10", "pytorch:22.04"] # 21.02, 21.10 for backward comp. container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -207,7 +207,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' needs: cron-gpu # so that monai itself is verified first container: - image: nvcr.io/nvidia/pytorch:22.03-py3 # testing with the latest pytorch base image + image: nvcr.io/nvidia/pytorch:22.04-py3 # testing with the latest pytorch base image options: "--gpus all --ipc=host" runs-on: [self-hosted, linux, x64, common] steps: diff --git a/.github/workflows/pythonapp-gpu.yml b/.github/workflows/pythonapp-gpu.yml index 50bbe13062..2fdfa5a80f 100644 --- a/.github/workflows/pythonapp-gpu.yml +++ b/.github/workflows/pythonapp-gpu.yml @@ -46,9 +46,9 @@ jobs: base: "nvcr.io/nvidia/pytorch:21.10-py3" - environment: PT111+CUDA116 # we explicitly set pytorch to -h to avoid pip install error - # 22.03: 1.12.0a0+2c916ef + # 22.04: 1.12.0a0+bd13bc6 pytorch: "-h" - base: "nvcr.io/nvidia/pytorch:22.03-py3" + base: "nvcr.io/nvidia/pytorch:22.04-py3" - environment: PT110+CUDA102 pytorch: "torch==1.10.2 torchvision==0.11.3" base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04" diff --git a/Dockerfile b/Dockerfile index dc76584d5a..1b022fc92e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,7 @@ # To build with a different base image # please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag. -ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:22.03-py3 +ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:22.04-py3 FROM ${PYTORCH_IMAGE} LABEL maintainer="monai.contact@gmail.com" diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index aae012fec0..4f4fd44497 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -235,6 +235,17 @@ def affine(self, d: torch.Tensor) -> None: """Set the affine.""" self.meta["affine"] = d + def new_empty(self, size, dtype=None, device=None, requires_grad=False): + """ + must be defined for deepcopy to work + + See: + - https://pytorch.org/docs/stable/generated/torch.Tensor.new_empty.html#torch-tensor-new-empty + """ + return type(self)( + self.as_tensor().new_empty(size=size, dtype=dtype, device=device, requires_grad=requires_grad) + ) + @staticmethod def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict): """ diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index 0fdc944760..b6328734b0 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -20,7 +20,7 @@ from .fcn import FCN, GCN, MCFCN, Refine from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock from .mlp import MLPBlock -from .patchembedding import PatchEmbeddingBlock +from .patchembedding import PatchEmbed, PatchEmbeddingBlock from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock from .segresnet_block import ResBlock from .selfattention import SABlock diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 4c7263c6d5..f02f6342e8 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -9,15 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import math -from typing import Sequence, Union +from typing import Sequence, Type, Union import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm -from monai.networks.layers import Conv +from monai.networks.layers import Conv, trunc_normal_ from monai.utils import ensure_tuple_rep, optional_import from monai.utils.module import look_up_option @@ -98,34 +98,18 @@ def __init__( ) self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) self.dropout = nn.Dropout(dropout_rate) - self.trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) + trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): - self.trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0) + trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - def trunc_normal_(self, tensor, mean, std, a, b): - # From PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - with torch.no_grad(): - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - tensor.uniform_(2 * l - 1, 2 * u - 1) - tensor.erfinv_() - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - tensor.clamp_(min=a, max=b) - return tensor - def forward(self, x): x = self.patch_embeddings(x) if self.pos_embed == "conv": @@ -133,3 +117,84 @@ def forward(self, x): embeddings = x + self.position_embeddings embeddings = self.dropout(embeddings) return embeddings + + +class PatchEmbed(nn.Module): + """ + Patch embedding block based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + + Unlike ViT patch embedding block: (1) input is padded to satisfy window size requirements (2) normalized if + specified (3) position embedding is not used. + + Example:: + + >>> from monai.networks.blocks import PatchEmbed + >>> PatchEmbed(patch_size=2, in_chans=1, embed_dim=48, norm_layer=nn.LayerNorm, spatial_dims=3) + """ + + def __init__( + self, + patch_size: Union[Sequence[int], int] = 2, + in_chans: int = 1, + embed_dim: int = 48, + norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore + spatial_dims: int = 3, + ) -> None: + """ + Args: + patch_size: dimension of patch size. + in_chans: dimension of input channels. + embed_dim: number of linear projection output channels. + norm_layer: normalization layer. + spatial_dims: spatial dimension. + """ + + super().__init__() + + if not (spatial_dims == 2 or spatial_dims == 3): + raise ValueError("spatial dimension should be 2 or 3.") + + patch_size = ensure_tuple_rep(patch_size, spatial_dims) + self.patch_size = patch_size + self.embed_dim = embed_dim + self.proj = Conv[Conv.CONV, spatial_dims]( + in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size + ) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x_shape = x.size() + if len(x_shape) == 5: + _, _, d, h, w = x_shape + if w % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - w % self.patch_size[2])) + if h % self.patch_size[1] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[1] - h % self.patch_size[1])) + if d % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - d % self.patch_size[0])) + + elif len(x_shape) == 4: + _, _, h, w = x.size() + if w % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - w % self.patch_size[1])) + if h % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - h % self.patch_size[0])) + + x = self.proj(x) + if self.norm is not None: + x_shape = x.size() + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + if len(x_shape) == 5: + d, wh, ww = x_shape[2], x_shape[3], x_shape[4] + x = x.transpose(1, 2).view(-1, self.embed_dim, d, wh, ww) + elif len(x_shape) == 4: + wh, ww = x_shape[2], x_shape[3] + x = x.transpose(1, 2).view(-1, self.embed_dim, wh, ww) + return x diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index 5115c00af3..f122dccee6 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -10,6 +10,7 @@ # limitations under the License. from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding +from .drop_path import DropPath from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args from .filtering import BilateralFilter, PHLFilter from .gmm import GaussianMixtureModel @@ -27,3 +28,4 @@ ) from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer +from .weight_init import _no_grad_trunc_normal_, trunc_normal_ diff --git a/monai/networks/layers/drop_path.py b/monai/networks/layers/drop_path.py new file mode 100644 index 0000000000..7bb209ed25 --- /dev/null +++ b/monai/networks/layers/drop_path.py @@ -0,0 +1,45 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn + + +class DropPath(nn.Module): + """Stochastic drop paths per sample for residual blocks. + Based on: + https://github.com/rwightman/pytorch-image-models + """ + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True) -> None: + """ + Args: + drop_prob: drop path probability. + scale_by_keep: scaling by non-dropped probability. + """ + super().__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + if not (0 <= drop_prob <= 1): + raise ValueError("Drop path prob should be between 0 and 1.") + + def drop_path(self, x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + def forward(self, x): + return self.drop_path(x, self.drop_prob, self.training, self.scale_by_keep) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 6379f49449..b808c24de0 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -60,11 +60,14 @@ def use_factory(fact_args): layer = use_factory( (fact.TEST, kwargs) ) """ +import warnings from typing import Any, Callable, Dict, Tuple, Type, Union import torch.nn as nn -from monai.utils import look_up_option +from monai.utils import look_up_option, optional_import + +InstanceNorm3dNVFuser, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser") __all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"] @@ -242,6 +245,30 @@ def sync_batch_factory(_dim) -> Type[nn.SyncBatchNorm]: return nn.SyncBatchNorm +@Norm.factory_function("instance_nvfuser") +def instance_nvfuser_factory(dim): + """ + `InstanceNorm3dNVFuser` is a faster verison of InstanceNorm layer and implemented in `apex`. + It only supports 3d tensors as the input. It also requires to use with CUDA and non-Windows OS. + In this function, if the required library `apex.normalization.InstanceNorm3dNVFuser` does not exist, + `nn.InstanceNorm3d` will be returned instead. + This layer is based on a customized autograd function, which is not supported in TorchScript currently. + Please switch to use `nn.InstanceNorm3d` if TorchScript is necessary. + + Please check the following link for more details about how to install `apex`: + https://github.com/NVIDIA/apex#installation + + """ + types = (nn.InstanceNorm1d, nn.InstanceNorm2d) + if dim != 3: + warnings.warn(f"`InstanceNorm3dNVFuser` only supports 3d cases, use {types[dim - 1]} instead.") + return types[dim - 1] + if not has_nvfuser: + warnings.warn("`apex.normalization.InstanceNorm3dNVFuser` is not found, use nn.InstanceNorm3d instead.") + return nn.InstanceNorm3d + return InstanceNorm3dNVFuser + + Act.add_factory_callable("elu", lambda: nn.modules.ELU) Act.add_factory_callable("relu", lambda: nn.modules.ReLU) Act.add_factory_callable("leakyrelu", lambda: nn.modules.LeakyReLU) diff --git a/monai/networks/layers/weight_init.py b/monai/networks/layers/weight_init.py new file mode 100644 index 0000000000..9b81ef17f8 --- /dev/null +++ b/monai/networks/layers/weight_init.py @@ -0,0 +1,64 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + """Tensor initialization with truncated normal distribution. + Based on: + https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + https://github.com/rwightman/pytorch-image-models + + Args: + tensor: an n-dimensional `torch.Tensor`. + mean: the mean of the normal distribution. + std: the standard deviation of the normal distribution. + a: the minimum cutoff value. + b: the maximum cutoff value. + """ + + def norm_cdf(x): + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + with torch.no_grad(): + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + tensor.uniform_(2 * l - 1, 2 * u - 1) + tensor.erfinv_() + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + """Tensor initialization with truncated normal distribution. + Based on: + https://github.com/rwightman/pytorch-image-models + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + + if not std > 0: + raise ValueError("the standard deviation should be greater than zero.") + + if a >= b: + raise ValueError("minimum cutoff value (a) should be smaller than maximum cutoff value (b).") + + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 16686fa25c..394ff51907 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -80,6 +80,7 @@ seresnext50, seresnext101, ) +from .swin_unetr import SwinUNETR from .torchvision_fc import TorchVisionFCModel, TorchVisionFullyConvModel from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex from .unet import UNet, Unet diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index e858dcbb9b..053ab255b8 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -104,6 +104,8 @@ class DynUNet(nn.Module): If not specified, the way which nnUNet used will be employed. Defaults to ``None``. dropout: dropout ratio. Defaults to no dropout. norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``. + `INSTANCE_NVFUSER` is a faster version of the instance norm layer, it can be used when: + 1) `spatial_dims=3`, 2) CUDA device is available, 3) `apex` is installed and 4) non-Windows OS is used. act_name: activation layer type and arguments. Defaults to ``leakyrelu``. deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. If ``True``, in training mode, the forward function will output not only the final feature map diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py new file mode 100644 index 0000000000..d898da9884 --- /dev/null +++ b/monai/networks/nets/swin_unetr.py @@ -0,0 +1,982 @@ +# Copyright 2020 - 2022 -> (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence, Tuple, Type, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from torch.nn import LayerNorm + +from monai.networks.blocks import MLPBlock as Mlp +from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock +from monai.networks.layers import DropPath, trunc_normal_ +from monai.utils import ensure_tuple_rep, optional_import + +rearrange, _ = optional_import("einops", name="rearrange") + + +class SwinUNETR(nn.Module): + """ + Swin UNETR based on: "Hatamizadeh et al., + Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images + " + """ + + def __init__( + self, + img_size: Union[Sequence[int], int], + in_channels: int, + out_channels: int, + depths: Sequence[int] = (2, 2, 2, 2), + num_heads: Sequence[int] = (3, 6, 12, 24), + feature_size: int = 48, + norm_name: Union[Tuple, str] = "instance", + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + normalize: bool = False, + use_checkpoint: bool = False, + spatial_dims: int = 3, + ) -> None: + """ + Args: + img_size: dimension of input image. + in_channels: dimension of input channels. + out_channels: dimension of output channels. + feature_size: dimension of network feature size. + depths: number of layers in each stage. + num_heads: number of attention heads. + norm_name: feature normalization type and arguments. + drop_rate: dropout rate. + attn_drop_rate: attention dropout rate. + dropout_path_rate: drop path rate. + normalize: normalize output intermediate features in each stage. + use_checkpoint: use gradient checkpointing for reduced memory usage. + spatial_dims: number of spatial dims. + + Examples:: + + # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48. + >>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48) + + # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage. + >>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2)) + + # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing. + >>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2) + + """ + + super().__init__() + + img_size = ensure_tuple_rep(img_size, spatial_dims) + patch_size = ensure_tuple_rep(2, spatial_dims) + window_size = ensure_tuple_rep(7, spatial_dims) + + if not (spatial_dims == 2 or spatial_dims == 3): + raise ValueError("spatial dimension should be 2 or 3.") + + for m, p in zip(img_size, patch_size): + for i in range(5): + if m % np.power(p, i + 1) != 0: + raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.") + + if not (0 <= drop_rate <= 1): + raise ValueError("dropout rate should be between 0 and 1.") + + if not (0 <= attn_drop_rate <= 1): + raise ValueError("attention dropout rate should be between 0 and 1.") + + if not (0 <= dropout_path_rate <= 1): + raise ValueError("drop path rate should be between 0 and 1.") + + if feature_size % 12 != 0: + raise ValueError("feature_size should be divisible by 12.") + + self.normalize = normalize + + self.swinViT = SwinTransformer( + in_chans=in_channels, + embed_dim=feature_size, + window_size=window_size, + patch_size=patch_size, + depths=depths, + num_heads=num_heads, + mlp_ratio=4.0, + qkv_bias=True, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dropout_path_rate, + norm_layer=nn.LayerNorm, + use_checkpoint=use_checkpoint, + spatial_dims=spatial_dims, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder2 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder3 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=2 * feature_size, + out_channels=2 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder4 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=4 * feature_size, + out_channels=4 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder10 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=16 * feature_size, + out_channels=16 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.decoder5 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=16 * feature_size, + out_channels=8 * feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder4 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 8, + out_channels=feature_size * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder3 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 4, + out_channels=feature_size * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + self.decoder2 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 2, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder1 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.out = UnetOutBlock( + spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels + ) # type: ignore + + def load_from(self, weights): + + with torch.no_grad(): + self.swinViT.patch_embed.proj.weight.copy_(weights["state_dict"]["module.patch_embed.proj.weight"]) + self.swinViT.patch_embed.proj.bias.copy_(weights["state_dict"]["module.patch_embed.proj.bias"]) + for bname, block in self.swinViT.layers1[0].blocks.named_children(): + block.load_from(weights, n_block=bname, layer="layers1") + self.swinViT.layers1[0].downsample.reduction.weight.copy_( + weights["state_dict"]["module.layers1.0.downsample.reduction.weight"] + ) + self.swinViT.layers1[0].downsample.norm.weight.copy_( + weights["state_dict"]["module.layers1.0.downsample.norm.weight"] + ) + self.swinViT.layers1[0].downsample.norm.bias.copy_( + weights["state_dict"]["module.layers1.0.downsample.norm.bias"] + ) + for bname, block in self.swinViT.layers2[0].blocks.named_children(): + block.load_from(weights, n_block=bname, layer="layers2") + self.swinViT.layers2[0].downsample.reduction.weight.copy_( + weights["state_dict"]["module.layers2.0.downsample.reduction.weight"] + ) + self.swinViT.layers2[0].downsample.norm.weight.copy_( + weights["state_dict"]["module.layers2.0.downsample.norm.weight"] + ) + self.swinViT.layers2[0].downsample.norm.bias.copy_( + weights["state_dict"]["module.layers2.0.downsample.norm.bias"] + ) + for bname, block in self.swinViT.layers3[0].blocks.named_children(): + block.load_from(weights, n_block=bname, layer="layers3") + self.swinViT.layers3[0].downsample.reduction.weight.copy_( + weights["state_dict"]["module.layers3.0.downsample.reduction.weight"] + ) + self.swinViT.layers3[0].downsample.norm.weight.copy_( + weights["state_dict"]["module.layers3.0.downsample.norm.weight"] + ) + self.swinViT.layers3[0].downsample.norm.bias.copy_( + weights["state_dict"]["module.layers3.0.downsample.norm.bias"] + ) + for bname, block in self.swinViT.layers4[0].blocks.named_children(): + block.load_from(weights, n_block=bname, layer="layers4") + self.swinViT.layers4[0].downsample.reduction.weight.copy_( + weights["state_dict"]["module.layers4.0.downsample.reduction.weight"] + ) + self.swinViT.layers4[0].downsample.norm.weight.copy_( + weights["state_dict"]["module.layers4.0.downsample.norm.weight"] + ) + self.swinViT.layers4[0].downsample.norm.bias.copy_( + weights["state_dict"]["module.layers4.0.downsample.norm.bias"] + ) + self.swinViT.norm.weight.copy_(weights["state_dict"]["module.norm.weight"]) + self.swinViT.norm.bias.copy_(weights["state_dict"]["module.norm.bias"]) + + def forward(self, x_in): + hidden_states_out = self.swinViT(x_in, self.normalize) + enc0 = self.encoder1(x_in) + enc1 = self.encoder2(hidden_states_out[0]) + enc2 = self.encoder3(hidden_states_out[1]) + enc3 = self.encoder4(hidden_states_out[2]) + dec4 = self.encoder10(hidden_states_out[4]) + dec3 = self.decoder5(dec4, hidden_states_out[3]) + dec2 = self.decoder4(dec3, enc3) + dec1 = self.decoder3(dec2, enc2) + dec0 = self.decoder2(dec1, enc1) + out = self.decoder1(dec0, enc0) + logits = self.out(out) + return logits + + +def window_partition(x, window_size): + """window partition operation based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + + Args: + x: input tensor. + window_size: local window size. + """ + x_shape = x.size() + if len(x_shape) == 5: + b, d, h, w, c = x_shape + x = x.view( + b, + d // window_size[0], + window_size[0], + h // window_size[1], + window_size[1], + w // window_size[2], + window_size[2], + c, + ) + windows = ( + x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c) + ) + elif len(x_shape) == 4: + b, h, w, c = x.shape + x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c) + return windows + + +def window_reverse(windows, window_size, dims): + """window reverse operation based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + + Args: + windows: windows tensor. + window_size: local window size. + dims: dimension values. + """ + if len(dims) == 4: + b, d, h, w = dims + x = windows.view( + b, + d // window_size[0], + h // window_size[1], + w // window_size[2], + window_size[0], + window_size[1], + window_size[2], + -1, + ) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1) + + elif len(dims) == 3: + b, h, w = dims + x = windows.view(b, h // window_size[0], w // window_size[0], window_size[0], window_size[1], -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1) + return x + + +def get_window_size(x_size, window_size, shift_size=None): + """Computing window size based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + + Args: + x_size: input size. + window_size: local window size. + shift_size: window shifting size. + """ + + use_window_size = list(window_size) + if shift_size is not None: + use_shift_size = list(shift_size) + for i in range(len(x_size)): + if x_size[i] <= window_size[i]: + use_window_size[i] = x_size[i] + if shift_size is not None: + use_shift_size[i] = 0 + + if shift_size is None: + return tuple(use_window_size) + else: + return tuple(use_window_size), tuple(use_shift_size) + + +class WindowAttention(nn.Module): + """ + Window based multi-head self attention module with relative position bias based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, + dim: int, + num_heads: int, + window_size: Sequence[int], + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + """ + Args: + dim: number of feature channels. + num_heads: number of attention heads. + window_size: local window size. + qkv_bias: add a learnable bias to query, key, value. + attn_drop: attention dropout rate. + proj_drop: dropout rate of output. + """ + + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + mesh_args = torch.meshgrid.__kwdefaults__ + + if len(self.window_size) == 3: + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1), + num_heads, + ) + ) + coords_d = torch.arange(self.window_size[0]) + coords_h = torch.arange(self.window_size[1]) + coords_w = torch.arange(self.window_size[2]) + if mesh_args is not None: + coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij")) + else: + coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 + elif len(self.window_size) == 2: + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + if mesh_args is not None: + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) + else: + coords = torch.stack(torch.meshgrid(coords_h, coords_w)) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask): + b, n, c = x.shape + qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * self.scale + attn = q @ k.transpose(-2, -1) + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index[:n, :n].reshape(-1) + ].reshape(n, n, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attn = attn + relative_position_bias.unsqueeze(0) + if mask is not None: + nw = mask.shape[0] + attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, n, n) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(b, n, c) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """ + Swin Transformer block based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, + dim: int, + num_heads: int, + window_size: Sequence[int], + shift_size: Sequence[int], + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + act_layer: str = "GELU", + norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore + use_checkpoint: bool = False, + ) -> None: + """ + Args: + dim: number of feature channels. + num_heads: number of attention heads. + window_size: local window size. + shift_size: window shift size. + mlp_ratio: ratio of mlp hidden dim to embedding dim. + qkv_bias: add a learnable bias to query, key, value. + drop: dropout rate. + attn_drop: attention dropout rate. + drop_path: stochastic depth rate. + act_layer: activation layer. + norm_layer: normalization layer. + use_checkpoint: use gradient checkpointing for reduced memory usage. + """ + + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.use_checkpoint = use_checkpoint + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=self.window_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin") + + def forward_part1(self, x, mask_matrix): + x_shape = x.size() + x = self.norm1(x) + if len(x_shape) == 5: + b, d, h, w, c = x.shape + window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size) + pad_l = pad_t = pad_d0 = 0 + pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0] + pad_b = (window_size[1] - h % window_size[1]) % window_size[1] + pad_r = (window_size[2] - w % window_size[2]) % window_size[2] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) + _, dp, hp, wp, _ = x.shape + dims = [b, dp, hp, wp] + + elif len(x_shape) == 4: + b, h, w, c = x.shape + window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size) + pad_l = pad_t = 0 + pad_r = (window_size[0] - h % window_size[0]) % window_size[0] + pad_b = (window_size[1] - w % window_size[1]) % window_size[1] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, hp, wp, _ = x.shape + dims = [b, hp, wp] + + if any(i > 0 for i in shift_size): + if len(x_shape) == 5: + shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) + elif len(x_shape) == 4: + shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + x_windows = window_partition(shifted_x, window_size) + attn_windows = self.attn(x_windows, mask=attn_mask) + attn_windows = attn_windows.view(-1, *(window_size + (c,))) + shifted_x = window_reverse(attn_windows, window_size, dims) + if any(i > 0 for i in shift_size): + if len(x_shape) == 5: + x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) + elif len(x_shape) == 4: + x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) + else: + x = shifted_x + + if len(x_shape) == 5: + if pad_d1 > 0 or pad_r > 0 or pad_b > 0: + x = x[:, :d, :h, :w, :].contiguous() + elif len(x_shape) == 4: + if pad_r > 0 or pad_b > 0: + x = x[:, :h, :w, :].contiguous() + + return x + + def forward_part2(self, x): + return self.drop_path(self.mlp(self.norm2(x))) + + def load_from(self, weights, n_block, layer): + root = f"module.{layer}.0.blocks.{n_block}." + block_names = [ + "norm1.weight", + "norm1.bias", + "attn.relative_position_bias_table", + "attn.relative_position_index", + "attn.qkv.weight", + "attn.qkv.bias", + "attn.proj.weight", + "attn.proj.bias", + "norm2.weight", + "norm2.bias", + "mlp.linear1.weight", + "mlp.linear1.bias", + "mlp.linear2.weight", + "mlp.linear2.bias", + ] + with torch.no_grad(): + self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]]) + self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]]) + self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]]) + self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]]) + self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]]) + self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]]) + self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]]) + self.attn.proj.bias.copy_(weights["state_dict"][root + block_names[7]]) + self.norm2.weight.copy_(weights["state_dict"][root + block_names[8]]) + self.norm2.bias.copy_(weights["state_dict"][root + block_names[9]]) + self.mlp.linear1.weight.copy_(weights["state_dict"][root + block_names[10]]) + self.mlp.linear1.bias.copy_(weights["state_dict"][root + block_names[11]]) + self.mlp.linear2.weight.copy_(weights["state_dict"][root + block_names[12]]) + self.mlp.linear2.bias.copy_(weights["state_dict"][root + block_names[13]]) + + def forward(self, x, mask_matrix): + shortcut = x + if self.use_checkpoint: + x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) + else: + x = self.forward_part1(x, mask_matrix) + x = shortcut + self.drop_path(x) + if self.use_checkpoint: + x = x + checkpoint.checkpoint(self.forward_part2, x) + else: + x = x + self.forward_part2(x) + return x + + +class PatchMerging(nn.Module): + """ + Patch merging layer based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3 + ) -> None: # type: ignore + """ + Args: + dim: number of feature channels. + norm_layer: normalization layer. + spatial_dims: number of spatial dims. + """ + + super().__init__() + self.dim = dim + if spatial_dims == 3: + self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False) + self.norm = norm_layer(8 * dim) + elif spatial_dims == 2: + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + + x_shape = x.size() + if len(x_shape) == 5: + b, d, h, w, c = x_shape + pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, d % 2, 0, w % 2, 0, h % 2)) + x0 = x[:, 0::2, 0::2, 0::2, :] + x1 = x[:, 1::2, 0::2, 0::2, :] + x2 = x[:, 0::2, 1::2, 0::2, :] + x3 = x[:, 0::2, 0::2, 1::2, :] + x4 = x[:, 1::2, 0::2, 1::2, :] + x5 = x[:, 0::2, 1::2, 0::2, :] + x6 = x[:, 0::2, 0::2, 1::2, :] + x7 = x[:, 1::2, 1::2, 1::2, :] + x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1) + + elif len(x_shape) == 4: + b, h, w, c = x_shape + pad_input = (h % 2 == 1) or (w % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2)) + x0 = x[:, 0::2, 0::2, :] + x1 = x[:, 1::2, 0::2, :] + x2 = x[:, 0::2, 1::2, :] + x3 = x[:, 1::2, 1::2, :] + x = torch.cat([x0, x1, x2, x3], -1) + + x = self.norm(x) + x = self.reduction(x) + return x + + +def compute_mask(dims, window_size, shift_size, device): + """Computing region masks based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + + Args: + dims: dimension values. + window_size: local window size. + shift_size: shift size. + device: device. + """ + + cnt = 0 + + if len(dims) == 3: + d, h, w = dims + img_mask = torch.zeros((1, d, h, w, 1), device=device) + for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None): + for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None): + for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None): + img_mask[:, d, h, w, :] = cnt + cnt += 1 + + elif len(dims) == 2: + h, w = dims + img_mask = torch.zeros((1, h, w, 1), device=device) + for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None): + for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None): + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, window_size) + mask_windows = mask_windows.squeeze(-1) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + +class BasicLayer(nn.Module): + """ + Basic Swin Transformer layer in one stage based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, + dim: int, + depth: int, + num_heads: int, + window_size: Sequence[int], + drop_path: list, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + drop: float = 0.0, + attn_drop: float = 0.0, + norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore + downsample: isinstance = None, # type: ignore + use_checkpoint: bool = False, + ) -> None: + """ + Args: + dim: number of feature channels. + depths: number of layers in each stage. + num_heads: number of attention heads. + window_size: local window size. + drop_path: stochastic depth rate. + mlp_ratio: ratio of mlp hidden dim to embedding dim. + qkv_bias: add a learnable bias to query, key, value. + drop: dropout rate. + attn_drop: attention dropout rate. + norm_layer: normalization layer. + downsample: downsample layer at the end of the layer. + use_checkpoint: use gradient checkpointing for reduced memory usage. + """ + + super().__init__() + self.window_size = window_size + self.shift_size = tuple(i // 2 for i in window_size) + self.no_shift = tuple(0 for i in window_size) + self.depth = depth + self.use_checkpoint = use_checkpoint + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=self.window_size, + shift_size=self.no_shift if (i % 2 == 0) else self.shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + use_checkpoint=use_checkpoint, + ) + for i in range(depth) + ] + ) + self.downsample = downsample + if self.downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size)) + + def forward(self, x): + x_shape = x.size() + if len(x_shape) == 5: + b, c, d, h, w = x_shape + window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size) + x = rearrange(x, "b c d h w -> b d h w c") + dp = int(np.ceil(d / window_size[0])) * window_size[0] + hp = int(np.ceil(h / window_size[1])) * window_size[1] + wp = int(np.ceil(w / window_size[2])) * window_size[2] + attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device) + for blk in self.blocks: + x = blk(x, attn_mask) + x = x.view(b, d, h, w, -1) + if self.downsample is not None: + x = self.downsample(x) + x = rearrange(x, "b d h w c -> b c d h w") + + elif len(x_shape) == 4: + b, c, h, w = x_shape + window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size) + x = rearrange(x, "b c h w -> b h w c") + hp = int(np.ceil(h / window_size[0])) * window_size[0] + wp = int(np.ceil(w / window_size[1])) * window_size[1] + attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device) + for blk in self.blocks: + x = blk(x, attn_mask) + x = x.view(b, h, w, -1) + if self.downsample is not None: + x = self.downsample(x) + x = rearrange(x, "b h w c -> b c h w") + return x + + +class SwinTransformer(nn.Module): + """ + Swin Transformer based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, + in_chans: int, + embed_dim: int, + window_size: Sequence[int], + patch_size: Sequence[int], + depths: Sequence[int], + num_heads: Sequence[int], + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore + patch_norm: bool = False, + use_checkpoint: bool = False, + spatial_dims: int = 3, + ) -> None: + """ + Args: + in_chans: dimension of input channels. + embed_dim: number of linear projection output channels. + window_size: local window size. + patch_size: patch size. + depths: number of layers in each stage. + num_heads: number of attention heads. + mlp_ratio: ratio of mlp hidden dim to embedding dim. + qkv_bias: add a learnable bias to query, key, value. + drop_rate: dropout rate. + attn_drop_rate: attention dropout rate. + drop_path_rate: stochastic depth rate. + norm_layer: normalization layer. + patch_norm: add normalization after patch embedding. + use_checkpoint: use gradient checkpointing for reduced memory usage. + spatial_dims: spatial dimension. + """ + + super().__init__() + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.window_size = window_size + self.patch_size = patch_size + self.patch_embed = PatchEmbed( + patch_size=self.patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, # type: ignore + spatial_dims=spatial_dims, + ) + self.pos_drop = nn.Dropout(p=drop_rate) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + self.layers1 = nn.ModuleList() + self.layers2 = nn.ModuleList() + self.layers3 = nn.ModuleList() + self.layers4 = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=self.window_size, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + norm_layer=norm_layer, + downsample=PatchMerging, + use_checkpoint=use_checkpoint, + ) + if i_layer == 0: + self.layers1.append(layer) + elif i_layer == 1: + self.layers2.append(layer) + elif i_layer == 2: + self.layers3.append(layer) + elif i_layer == 3: + self.layers4.append(layer) + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.norm = norm_layer(self.num_features) + + def proj_out(self, x, normalize=False): + if normalize: + x_shape = x.size() + if len(x_shape) == 5: + n, ch, d, h, w = x_shape + x = rearrange(x, "n c d h w -> n d h w c") + x = F.layer_norm(x, [ch]) + x = rearrange(x, "n d h w c -> n c d h w") + elif len(x_shape) == 4: + n, ch, h, w = x_shape + x = rearrange(x, "n c h w -> n h w c") + x = F.layer_norm(x, [ch]) + x = rearrange(x, "n h w c -> n c h w") + return x + + def forward(self, x, normalize=False): + x0 = self.patch_embed(x) + x0 = self.pos_drop(x0) + x0_out = self.proj_out(x0, normalize) + x1 = self.layers1[0](x0.contiguous()) + x1_out = self.proj_out(x1, normalize) + x2 = self.layers2[0](x1.contiguous()) + x2_out = self.proj_out(x2, normalize) + x3 = self.layers3[0](x2.contiguous()) + x3_out = self.proj_out(x3, normalize) + x4 = self.layers4[0](x3.contiguous()) + x4_out = self.proj_out(x4, normalize) + return [x0_out, x1_out, x2_out, x3_out, x4_out] diff --git a/tests/test_drop_path.py b/tests/test_drop_path.py new file mode 100644 index 0000000000..f8ea454228 --- /dev/null +++ b/tests/test_drop_path.py @@ -0,0 +1,43 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.layers import DropPath + +TEST_CASES = [ + [{"drop_prob": 0.0, "scale_by_keep": True}, (1, 8, 8)], + [{"drop_prob": 0.7, "scale_by_keep": False}, (2, 16, 16, 16)], + [{"drop_prob": 0.3, "scale_by_keep": True}, (6, 16, 12)], +] + +TEST_ERRORS = [[{"drop_prob": 2, "scale_by_keep": False}, (1, 24, 6)]] + + +class TestDropPath(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_shape): + im = torch.rand(input_shape) + dr_path = DropPath(**input_param) + out = dr_path(im) + self.assertEqual(out.shape, input_shape) + + @parameterized.expand(TEST_ERRORS) + def test_ill_arg(self, input_param, input_shape): + with self.assertRaises(ValueError): + DropPath(**input_param) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index 36ac9d0309..14006b96e6 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -17,7 +17,10 @@ from monai.networks import eval_mode from monai.networks.nets import DynUNet -from tests.utils import test_script_save +from monai.utils import optional_import +from tests.utils import skip_if_no_cuda, skip_if_windows, test_script_save + +_, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser") device = "cuda" if torch.cuda.is_available() else "cpu" @@ -118,6 +121,33 @@ def test_script(self): test_script_save(net, test_data) +@skip_if_no_cuda +@skip_if_windows +@unittest.skipUnless(has_nvfuser, "To use `instance_nvfuser`, `apex.normalization.InstanceNorm3dNVFuser` is needed.") +class TestDynUNetWithInstanceNorm3dNVFuser(unittest.TestCase): + @parameterized.expand([TEST_CASE_DYNUNET_3D[0]]) + def test_consistency(self, input_param, input_shape, _): + for eps in [1e-4, 1e-5]: + for momentum in [0.1, 0.01]: + for affine in [True, False]: + norm_param = {"eps": eps, "momentum": momentum, "affine": affine} + input_param["norm_name"] = ("instance", norm_param) + input_param_fuser = input_param.copy() + input_param_fuser["norm_name"] = ("instance_nvfuser", norm_param) + for memory_format in [torch.contiguous_format, torch.channels_last_3d]: + net = DynUNet(**input_param).to("cuda:0", memory_format=memory_format) + net_fuser = DynUNet(**input_param_fuser).to("cuda:0", memory_format=memory_format) + net_fuser.load_state_dict(net.state_dict()) + + input_tensor = torch.randn(input_shape).to("cuda:0", memory_format=memory_format) + with eval_mode(net): + result = net(input_tensor) + with eval_mode(net_fuser): + result_fuser = net_fuser(input_tensor) + + torch.testing.assert_close(result, result_fuser) + + class TestDynUNetDeepSupervision(unittest.TestCase): @parameterized.expand(TEST_CASE_DEEP_SUPERVISION) def test_shape(self, input_param, input_shape, expected_shape): diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py index 4af2b47ba5..6971eb0463 100644 --- a/tests/test_patchembedding.py +++ b/tests/test_patchembedding.py @@ -13,10 +13,11 @@ from unittest import skipUnless import torch +import torch.nn as nn from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.blocks.patchembedding import PatchEmbeddingBlock +from monai.networks.blocks.patchembedding import PatchEmbed, PatchEmbeddingBlock from monai.utils import optional_import einops, has_einops = optional_import("einops") @@ -48,6 +49,26 @@ test_case[0]["spatial_dims"] = 2 # type: ignore TEST_CASE_PATCHEMBEDDINGBLOCK.append(test_case) +TEST_CASE_PATCHEMBED = [] +for patch_size in [2]: + for in_chans in [1, 4]: + for img_size in [96]: + for embed_dim in [6, 12]: + for norm_layer in [nn.LayerNorm]: + for nd in [2, 3]: + test_case = [ + { + "patch_size": (patch_size,) * nd, + "in_chans": in_chans, + "embed_dim": embed_dim, + "norm_layer": norm_layer, + "spatial_dims": nd, + }, + (2, in_chans, *([img_size] * nd)), + (2, embed_dim, *([img_size // patch_size] * nd)), + ] + TEST_CASE_PATCHEMBED.append(test_case) + class TestPatchEmbeddingBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_PATCHEMBEDDINGBLOCK) @@ -115,5 +136,19 @@ def test_ill_arg(self): ) +class TestPatchEmbed(unittest.TestCase): + @parameterized.expand(TEST_CASE_PATCHEMBED) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, input_param, input_shape, expected_shape): + net = PatchEmbed(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + PatchEmbed(patch_size=(2, 2, 2), in_chans=1, embed_dim=24, norm_layer=nn.LayerNorm, spatial_dims=5) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py new file mode 100644 index 0000000000..0d48e99c44 --- /dev/null +++ b/tests/test_swin_unetr.py @@ -0,0 +1,89 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.swin_unetr import SwinUNETR +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") + +TEST_CASE_SWIN_UNETR = [] +for attn_drop_rate in [0.4]: + for in_channels in [1]: + for depth in [[2, 1, 1, 1], [1, 2, 1, 1]]: + for out_channels in [2]: + for img_size in [64]: + for feature_size in [12]: + for norm_name in ["instance"]: + for nd in (2, 3): + test_case = [ + { + "in_channels": in_channels, + "out_channels": out_channels, + "img_size": (img_size,) * nd, + "feature_size": feature_size, + "depths": depth, + "norm_name": norm_name, + "attn_drop_rate": attn_drop_rate, + }, + (2, in_channels, *([img_size] * nd)), + (2, out_channels, *([img_size] * nd)), + ] + if nd == 2: + test_case[0]["spatial_dims"] = 2 # type: ignore + TEST_CASE_SWIN_UNETR.append(test_case) + + +class TestSWINUNETR(unittest.TestCase): + @parameterized.expand(TEST_CASE_SWIN_UNETR) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, input_param, input_shape, expected_shape): + net = SwinUNETR(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + SwinUNETR( + in_channels=1, + out_channels=3, + img_size=(128, 128, 128), + feature_size=24, + norm_name="instance", + attn_drop_rate=4, + ) + + with self.assertRaises(ValueError): + SwinUNETR(in_channels=1, out_channels=2, img_size=(96, 96), feature_size=48, norm_name="instance") + + with self.assertRaises(ValueError): + SwinUNETR(in_channels=1, out_channels=4, img_size=(96, 96, 96), feature_size=50, norm_name="instance") + + with self.assertRaises(ValueError): + SwinUNETR( + in_channels=1, + out_channels=3, + img_size=(85, 85, 85), + feature_size=24, + norm_name="instance", + drop_rate=0.4, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_weight_init.py b/tests/test_weight_init.py new file mode 100644 index 0000000000..c850ff4ce6 --- /dev/null +++ b/tests/test_weight_init.py @@ -0,0 +1,47 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.layers import trunc_normal_ + +TEST_CASES = [ + [{"mean": 0.0, "std": 1.0, "a": 2, "b": 4}, (6, 12, 3, 1, 7)], + [{"mean": 0.3, "std": 0.6, "a": -1.0, "b": 1.3}, (1, 4, 4, 4)], + [{"mean": 0.1, "std": 0.4, "a": 1.3, "b": 1.8}, (5, 7, 7, 8, 9)], +] + +TEST_ERRORS = [ + [{"mean": 0.0, "std": 1.0, "a": 5, "b": 1.1}, (1, 1, 8, 8, 8)], + [{"mean": 0.3, "std": -0.1, "a": 1.0, "b": 2.0}, (8, 5, 2, 6, 9)], + [{"mean": 0.7, "std": 0.0, "a": 0.1, "b": 2.0}, (4, 12, 23, 17)], +] + + +class TestWeightInit(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_shape): + im = torch.rand(input_shape) + trunc_normal_(im, **input_param) + self.assertEqual(im.shape, input_shape) + + @parameterized.expand(TEST_ERRORS) + def test_ill_arg(self, input_param, input_shape): + with self.assertRaises(ValueError): + im = torch.rand(input_shape) + trunc_normal_(im, **input_param) + + +if __name__ == "__main__": + unittest.main()