From f97cad8e27b9697ec6fa791a31e7293b88049d04 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 4 Apr 2022 15:33:27 -0700 Subject: [PATCH 01/26] add swin_unetr model Signed-off-by: ahatamizadeh --- monai/networks/nets/swin_unetr.py | 1048 +++++++++++++++++++++++++++++ tests/test_swin_unetr.py | 81 +++ 2 files changed, 1129 insertions(+) create mode 100644 monai/networks/nets/swin_unetr.py create mode 100644 tests/test_swin_unetr.py diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py new file mode 100644 index 0000000000..7bdb9ddb6b --- /dev/null +++ b/monai/networks/nets/swin_unetr.py @@ -0,0 +1,1048 @@ +# Copyright 2020 - 2022 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import reduce +from operator import mul +from typing import Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from monai.networks.blocks.dynunet_block import UnetOutBlock +from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock +from monai.networks.layers import Conv +from monai.utils import ensure_tuple_rep, optional_import + +rearrange, _ = optional_import("einops", name="rearrange") + + +class SwinUNETR(nn.Module): + """ + Swin UNETR based on: "Hatamizadeh et al., + Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images + " + """ + + def __init__( + self, + img_size: Union[Sequence[int], int], + in_channels: int, + out_channels: int, + feature_size: int = 48, + depths: Tuple[int, int, int, int] = [2, 2, 2, 2], + num_heads: Tuple[int, int, int, int] = [3, 6, 12, 24], + norm_name: Union[Tuple, str] = "instance", + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + use_checkpoint: bool = False, + ) -> None: + """ + Args: + img_size: dimension of input image. + in_channels: dimension of input channels. + out_channels: dimension of output channels. + feature_size: dimension of network feature size. + depths: number of layers in each stage. + num_heads: number of attention heads. + norm_name: feature normalization type and arguments. + drop_rate: dropout rate. + attn_drop_rate: attention dropout rate. + dropout_path_rate: drop path rate. + use_checkpoint: use gradient checkpointing for reduced memory usage. + + Examples:: + + # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48. + >>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48) + + # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage. + >>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2)) + + # for 3D single channel input with size (96,96,96), 2-channel output and gradient checkpointing. + >>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True) + + """ + + super().__init__() + + if isinstance(img_size, tuple) and len(img_size) == 2: + raise ValueError("3D Swin UNETR requires volumetric inputs.") + + spatial_dims = 3 + img_size = ensure_tuple_rep(img_size, spatial_dims) + patch_size = ensure_tuple_rep(2, spatial_dims) + window_size = ensure_tuple_rep(7, spatial_dims) + for m, p in zip(img_size, patch_size): + for i in range(5): + if m % np.power(p, i + 1) != 0: + raise ValueError("img_size should be divisible by stage-wise image resolution.") + + if not (0 <= drop_rate <= 1): + raise ValueError("dropout rate should be between 0 and 1.") + + if not (0 <= attn_drop_rate <= 1): + raise ValueError("attention dropout rate should be between 0 and 1.") + + if not (0 <= dropout_path_rate <= 1): + raise ValueError("drop path rate should be between 0 and 1.") + + if feature_size % 12 != 0: + raise ValueError("feature_size should be divisible by 12.") + + self.swinViT = SwinTransformer( + in_chans=in_channels, + embed_dim=feature_size, + window_size=window_size, + patch_size=patch_size, + depths=depths, + num_heads=num_heads, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dropout_path_rate, + norm_layer=nn.LayerNorm, + use_checkpoint=use_checkpoint, + spatial_dims=spatial_dims, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder2 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder3 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=2 * feature_size, + out_channels=2 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder4 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=4 * feature_size, + out_channels=4 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder10 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=16 * feature_size, + out_channels=16 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.decoder5 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=16 * feature_size, + out_channels=8 * feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder4 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 8, + out_channels=feature_size * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder3 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 4, + out_channels=feature_size * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + self.decoder2 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 2, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder1 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.out = UnetOutBlock( + spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels + ) # type: ignore + + def proj_feat(self, x, hidden_size, feat_size): + + x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) + x = x.permute(0, 4, 1, 2, 3).contiguous() + return x + + def load_from(self, weights): + + with torch.no_grad(): + self.swinViT.patch_embed.proj.weight.copy_(weights["state_dict"]["module.patch_embed.proj.weight"]) + self.swinViT.patch_embed.proj.bias.copy_(weights["state_dict"]["module.patch_embed.proj.bias"]) + for bname, block in self.swinViT.layers1[0].blocks.named_children(): + block.load_from(weights, n_block=bname, layer="layers1") + self.swinViT.layers1[0].downsample.reduction.weight.copy_( + weights["state_dict"]["module.layers1.0.downsample.reduction.weight"] + ) + self.swinViT.layers1[0].downsample.norm.weight.copy_( + weights["state_dict"]["module.layers1.0.downsample.norm.weight"] + ) + self.swinViT.layers1[0].downsample.norm.bias.copy_( + weights["state_dict"]["module.layers1.0.downsample.norm.bias"] + ) + for bname, block in self.swinViT.layers2[0].blocks.named_children(): + block.load_from(weights, n_block=bname, layer="layers2") + self.swinViT.layers2[0].downsample.reduction.weight.copy_( + weights["state_dict"]["module.layers2.0.downsample.reduction.weight"] + ) + self.swinViT.layers2[0].downsample.norm.weight.copy_( + weights["state_dict"]["module.layers2.0.downsample.norm.weight"] + ) + self.swinViT.layers2[0].downsample.norm.bias.copy_( + weights["state_dict"]["module.layers2.0.downsample.norm.bias"] + ) + for bname, block in self.swinViT.layers3[0].blocks.named_children(): + block.load_from(weights, n_block=bname, layer="layers3") + self.swinViT.layers3[0].downsample.reduction.weight.copy_( + weights["state_dict"]["module.layers3.0.downsample.reduction.weight"] + ) + self.swinViT.layers3[0].downsample.norm.weight.copy_( + weights["state_dict"]["module.layers3.0.downsample.norm.weight"] + ) + self.swinViT.layers3[0].downsample.norm.bias.copy_( + weights["state_dict"]["module.layers3.0.downsample.norm.bias"] + ) + for bname, block in self.swinViT.layers4[0].blocks.named_children(): + block.load_from(weights, n_block=bname, layer="layers4") + self.swinViT.layers4[0].downsample.reduction.weight.copy_( + weights["state_dict"]["module.layers4.0.downsample.reduction.weight"] + ) + self.swinViT.layers4[0].downsample.norm.weight.copy_( + weights["state_dict"]["module.layers4.0.downsample.norm.weight"] + ) + self.swinViT.layers4[0].downsample.norm.bias.copy_( + weights["state_dict"]["module.layers4.0.downsample.norm.bias"] + ) + self.swinViT.norm.weight.copy_(weights["state_dict"]["module.norm.weight"]) + self.swinViT.norm.bias.copy_(weights["state_dict"]["module.norm.bias"]) + + def forward(self, x_in): + hidden_states_out = self.swinViT(x_in) + enc0 = self.encoder1(x_in) + enc1 = self.encoder2(hidden_states_out[0]) + enc2 = self.encoder3(hidden_states_out[1]) + enc3 = self.encoder4(hidden_states_out[2]) + dec4 = self.encoder10(hidden_states_out[4]) + dec3 = self.decoder5(dec4, hidden_states_out[3]) + dec2 = self.decoder4(dec3, enc3) + dec1 = self.decoder3(dec2, enc2) + dec0 = self.decoder2(dec1, enc1) + out = self.decoder1(dec0, enc0) + logits = self.out(out) + return logits + + +class Mlp(nn.Module): + """ + multi-layer perceptron based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, + in_features: int, + hidden_features: type = None, + out_features: type = None, + act_layer: type = nn.GELU, + drop: float = 0.0, + ) -> None: + super().__init__() + """ + Args: + in_features: number of input feature channels. + hidden_features: number of hidden feature channels. + out_features: number of output feature channels. + act_layer: activation layer. + drop: dropout rate. + """ + + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """window partition operation based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + + Args: + x: input tensor. + window_size: local window size. + """ + + B, D, H, W, C = x.shape + x = x.view( + B, + D // window_size[0], + window_size[0], + H // window_size[1], + window_size[1], + W // window_size[2], + window_size[2], + C, + ) + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C) + return windows + + +def window_reverse(windows, window_size, B, D, H, W): + """window reverse operation based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + + Args: + windows: windows tensor. + window_size: local window size. + B: batch size. + D: depth. + H: height. + W: width. + """ + + x = windows.view( + B, + D // window_size[0], + H // window_size[1], + W // window_size[2], + window_size[0], + window_size[1], + window_size[2], + -1, + ) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1) + return x + + +def get_window_size(x_size, window_size, shift_size=None): + """Computing window size based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + + Args: + x_size: input size. + window_size: local window size. + shift_size: window shifting size. + """ + + use_window_size = list(window_size) + if shift_size is not None: + use_shift_size = list(shift_size) + for i in range(len(x_size)): + if x_size[i] <= window_size[i]: + use_window_size[i] = x_size[i] + if shift_size is not None: + use_shift_size[i] = 0 + + if shift_size is None: + return tuple(use_window_size) + else: + return tuple(use_window_size), tuple(use_shift_size) + + +class WindowAttention(nn.Module): + """ + Window based multi-head self attention module with relative position bias based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, + dim: int, + window_size: Tuple[int, int, int] = [7, 7, 7], + num_heads: Tuple[int, int, int, int] = [3, 6, 12, 24], + qkv_bias: bool = False, + qk_scale: type = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + """ + Args: + dim: number of feature channels. + window_size: local window size. + num_heads: number of attention heads. + qkv_bias: add a learnable bias to query, key, value. + qk_scale: override default qk scale of head_dim ** -0.5 if set. + attn_drop: attention dropout rate. + proj_drop: dropout rate of output. + """ + + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads) + ) + coords_d = torch.arange(self.window_size[0]) + coords_h = torch.arange(self.window_size[1]) + coords_w = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * self.scale + attn = q @ k.transpose(-2, -1) + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index[:N, :N].reshape(-1) + ].reshape(N, N, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attn = attn + relative_position_bias.unsqueeze(0) + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """ + Swin Transformer block based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, + dim: int, + num_heads: Tuple[int, int, int, int] = [3, 6, 12, 24], + window_size: Tuple[int, int, int] = [7, 7, 7], + shift_size: Tuple[int, int, int] = [0, 0, 0], + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: type = None, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + act_layer: type = nn.GELU, + norm_layer: type = nn.LayerNorm, + use_checkpoint: bool = False, + ) -> None: + """ + Args: + dim: number of feature channels. + num_heads: number of attention heads. + window_size: local window size. + shift_size: window shift size. + mlp_ratio: ratio of mlp hidden dim to embedding dim. + qkv_bias: add a learnable bias to query, key, value. + qk_scale: override default qk scale of head_dim ** -0.5 if set. + drop: dropout rate. + attn_drop: attention dropout rate. + drop_path: stochastic depth rate. + act_layer: activation layer. + norm_layer: normalization layer. + use_checkpoint: use gradient checkpointing for reduced memory usage. + """ + + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.use_checkpoint = use_checkpoint + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=self.window_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward_part1(self, x, mask_matrix): + B, D, H, W, C = x.shape + window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size) + x = self.norm1(x) + pad_l = pad_t = pad_d0 = 0 + pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0] + pad_b = (window_size[1] - H % window_size[1]) % window_size[1] + pad_r = (window_size[2] - W % window_size[2]) % window_size[2] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) + _, Dp, Hp, Wp, _ = x.shape + + if any(i > 0 for i in shift_size): + shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + x_windows = window_partition(shifted_x, window_size) + attn_windows = self.attn(x_windows, mask=attn_mask) + attn_windows = attn_windows.view(-1, *(window_size + (C,))) + shifted_x = window_reverse(attn_windows, window_size, B, Dp, Hp, Wp) + if any(i > 0 for i in shift_size): + x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) + else: + x = shifted_x + if pad_d1 > 0 or pad_r > 0 or pad_b > 0: + x = x[:, :D, :H, :W, :].contiguous() + return x + + def forward_part2(self, x): + return self.drop_path(self.mlp(self.norm2(x))) + + def load_from(self, weights, n_block, layer): + ROOT = f"module.{layer}.0.blocks.{n_block}." + block_names = [ + "norm1.weight", + "norm1.bias", + "attn.relative_position_bias_table", + "attn.relative_position_index", + "attn.qkv.weight", + "attn.qkv.bias", + "attn.proj.weight", + "attn.proj.bias", + "norm2.weight", + "norm2.bias", + "mlp.fc1.weight", + "mlp.fc1.bias", + "mlp.fc2.weight", + "mlp.fc2.bias", + ] + with torch.no_grad(): + self.norm1.weight.copy_(weights["state_dict"][ROOT + block_names[0]]) + self.norm1.bias.copy_(weights["state_dict"][ROOT + block_names[1]]) + self.attn.relative_position_bias_table.copy_(weights["state_dict"][ROOT + block_names[2]]) + self.attn.relative_position_index.copy_(weights["state_dict"][ROOT + block_names[3]]) + self.attn.qkv.weight.copy_(weights["state_dict"][ROOT + block_names[4]]) + self.attn.qkv.bias.copy_(weights["state_dict"][ROOT + block_names[5]]) + self.attn.proj.weight.copy_(weights["state_dict"][ROOT + block_names[6]]) + self.attn.proj.bias.copy_(weights["state_dict"][ROOT + block_names[7]]) + self.norm2.weight.copy_(weights["state_dict"][ROOT + block_names[8]]) + self.norm2.bias.copy_(weights["state_dict"][ROOT + block_names[9]]) + self.mlp.fc1.weight.copy_(weights["state_dict"][ROOT + block_names[10]]) + self.mlp.fc1.bias.copy_(weights["state_dict"][ROOT + block_names[11]]) + self.mlp.fc2.weight.copy_(weights["state_dict"][ROOT + block_names[12]]) + self.mlp.fc2.bias.copy_(weights["state_dict"][ROOT + block_names[13]]) + + def forward(self, x, mask_matrix): + shortcut = x + if self.use_checkpoint: + x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) + else: + x = self.forward_part1(x, mask_matrix) + x = shortcut + self.drop_path(x) + if self.use_checkpoint: + x = x + checkpoint.checkpoint(self.forward_part2, x) + else: + x = x + self.forward_part2(x) + return x + + +class PatchMerging(nn.Module): + """ + Patch merging layer based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__(self, dim: int, norm_layer: type = nn.LayerNorm) -> None: + """ + Args: + dim: number of feature channels. + norm_layer: normalization layer. + """ + + super().__init__() + self.dim = dim + self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False) + self.norm = norm_layer(8 * dim) + + def forward(self, x): + B, D, H, W, C = x.shape + pad_input = (H % 2 == 1) or (W % 2 == 1) or (D % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, D % 2, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, 0::2, :] + x1 = x[:, 1::2, 0::2, 0::2, :] + x2 = x[:, 0::2, 1::2, 0::2, :] + x3 = x[:, 0::2, 0::2, 1::2, :] + x4 = x[:, 1::2, 0::2, 1::2, :] + x5 = x[:, 0::2, 1::2, 0::2, :] + x6 = x[:, 0::2, 0::2, 1::2, :] + x7 = x[:, 1::2, 1::2, 1::2, :] + x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1) + x = self.norm(x) + x = self.reduction(x) + return x + + +def compute_mask(D, H, W, window_size, shift_size, device): + + """Computing region masks based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + + Args: + D: depth value. + H: height value. + W: height value. + window_size: local window size. + device: device. + """ + + img_mask = torch.zeros((1, D, H, W, 1), device=device) + cnt = 0 + for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None): + for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None): + for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None): + img_mask[:, d, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, window_size) + mask_windows = mask_windows.squeeze(-1) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + return attn_mask + + +class BasicLayer(nn.Module): + """ + Basic Swin Transformer layer in one stage based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, + dim: int, + depth: Tuple[int, int, int, int] = [2, 2, 2, 2], + num_heads: Tuple[int, int, int, int] = [3, 6, 12, 24], + window_size: Tuple[int, int, int] = [7, 7, 7], + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: type = None, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + norm_layer: type = nn.LayerNorm, + downsample: type = None, + use_checkpoint: bool = False, + ) -> None: + """ + Args: + dim: number of feature channels. + depths: number of layers in each stage. + num_heads: number of attention heads. + window_size: local window size. + mlp_ratio: ratio of mlp hidden dim to embedding dim. + qkv_bias: add a learnable bias to query, key, value. + qk_scale: override default qk scale of head_dim ** -0.5 if set. + drop: dropout rate. + attn_drop: attention dropout rate. + drop_path: stochastic depth rate. + norm_layer: normalization layer. + downsample: downsample layer at the end of the layer. + use_checkpoint: use gradient checkpointing for reduced memory usage. + """ + + super().__init__() + self.window_size = window_size + self.shift_size = tuple(i // 2 for i in window_size) + self.depth = depth + self.use_checkpoint = use_checkpoint + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0, 0, 0) if (i % 2 == 0) else self.shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + use_checkpoint=use_checkpoint, + ) + for i in range(depth) + ] + ) + self.downsample = downsample + if self.downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + + def forward(self, x): + B, C, D, H, W = x.shape + window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size) + x = rearrange(x, "b c d h w -> b d h w c") + Dp = int(np.ceil(D / window_size[0])) * window_size[0] + Hp = int(np.ceil(H / window_size[1])) * window_size[1] + Wp = int(np.ceil(W / window_size[2])) * window_size[2] + attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) + for blk in self.blocks: + x = blk(x, attn_mask) + x = x.view(B, D, H, W, -1) + if self.downsample is not None: + x = self.downsample(x) + x = rearrange(x, "b d h w c -> b c d h w") + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + + """Tensor initialization with truncated normal distribution. + Based on: + https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + https://github.com/rwightman/pytorch-image-models + + Args: + tensor: an n-dimensional `torch.Tensor`. + mean: the mean of the normal distribution. + std: the standard deviation of the normal distribution. + a: the minimum cutoff value. + b: the maximum cutoff value. + """ + + def norm_cdf(x): + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + with torch.no_grad(): + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + tensor.uniform_(2 * l - 1, 2 * u - 1) + tensor.erfinv_() + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + + """Tensor initialization with truncated normal distribution. + Based on: + https://github.com/rwightman/pytorch-image-models + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +class DropPath(nn.Module): + """Stochastic drop paths per sample for residual blocks. + Based on: + https://github.com/rwightman/pytorch-image-models + """ + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True) -> None: + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + """ + Args: + drop_prob: drop path probability. + scale_by_keep: scaling by non-dropped probability. + """ + + def drop_path(self, x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + def forward(self, x): + return self.drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + +class PatchEmbed(nn.Module): + """ + Patch embedding block based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, + patch_size: Tuple[int, int, int] = [2, 2, 2], + in_chans: int = 1, + embed_dim: int = 96, + norm_layer: type = None, + spatial_dims: int = 3, + ) -> None: + """ + Args: + patch_size: dimension of patch size. + in_chans: dimension of input channels. + embed_dim: number of linear projection output channels. + norm_layer: normalization layer. + spatial_dims: spatial dimension. + """ + + super().__init__() + self.patch_size = patch_size + self.embed_dim = embed_dim + self.proj = Conv[Conv.CONV, spatial_dims]( + in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size + ) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + _, _, D, H, W = x.size() + if W % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) + if H % self.patch_size[1] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) + if D % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) + x = self.proj(x) + if self.norm is not None: + D, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) + return x + + +class SwinTransformer(nn.Module): + """ + Swin Transformer based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, + in_chans: int = 1, + embed_dim: int = 96, + window_size: Tuple[int, int, int] = [7, 7, 7], + patch_size: Tuple[int, int, int] = [2, 2, 2], + depths: Tuple[int, int, int, int] = [2, 2, 2, 2], + num_heads: Tuple[int, int, int, int] = [3, 6, 12, 24], + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: type = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + norm_layer: type = nn.LayerNorm, + patch_norm: bool = False, + use_checkpoint: bool = False, + spatial_dims: int = 3, + ) -> None: + """ + Args: + in_chans: dimension of input channels. + embed_dim: number of linear projection output channels. + window_size: local window size. + patch_size: patch size. + depths: number of layers in each stage. + num_heads: number of attention heads. + mlp_ratio: ratio of mlp hidden dim to embedding dim. + qkv_bias: add a learnable bias to query, key, value. + qk_scale: override default qk scale of head_dim ** -0.5 if set. + drop_rate: dropout rate. + attn_drop_rate: attention dropout rate. + drop_path_rate: stochastic depth rate. + norm_layer: normalization layer. + patch_norm: add normalization after patch embedding. + use_checkpoint: use gradient checkpointing for reduced memory usage. + spatial_dims: spatial dimension. + """ + + super().__init__() + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.window_size = window_size + self.patch_size = patch_size + self.patch_embed = PatchEmbed( + patch_size=self.patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + spatial_dims=spatial_dims, + ) + self.pos_drop = nn.Dropout(p=drop_rate) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + self.layers1 = nn.ModuleList() + self.layers2 = nn.ModuleList() + self.layers3 = nn.ModuleList() + self.layers4 = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging, + use_checkpoint=use_checkpoint, + ) + if i_layer == 0: + self.layers1.append(layer) + elif i_layer == 1: + self.layers2.append(layer) + elif i_layer == 2: + self.layers3.append(layer) + elif i_layer == 3: + self.layers4.append(layer) + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.norm = norm_layer(self.num_features) + + def proj_out(self, x): + n, ch, d, h, w = x.size() + x = rearrange(x, "n c d h w -> n d h w c") + x = F.layer_norm(x, [ch]) + x = rearrange(x, "n d h w c -> n c d h w") + return x + + def forward(self, x): + x0 = self.patch_embed(x) + x0 = self.pos_drop(x0) + x0_out = self.proj_out(x0) + x1 = self.layers1[0](x0.contiguous()) + x1_out = self.proj_out(x1) + x2 = self.layers2[0](x1.contiguous()) + x2_out = self.proj_out(x2) + x3 = self.layers3[0](x2.contiguous()) + x3_out = self.proj_out(x3) + x4 = self.layers4[0](x3.contiguous()) + x4_out = self.proj_out(x4) + return [x0_out, x1_out, x2_out, x3_out, x4_out] diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py new file mode 100644 index 0000000000..d2c01d3c69 --- /dev/null +++ b/tests/test_swin_unetr.py @@ -0,0 +1,81 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.swin_unetr import SwinUNETR + +TEST_CASE_UNETR = [] +for attn_drop_rate in [0.4]: + for in_channels in [1]: + for depth in [[2, 2, 4, 2]]: + for out_channels in [2]: + for img_size in [96, 128]: + for feature_size in [48]: + for norm_name in ["instance"]: + test_case = [ + { + "in_channels": in_channels, + "out_channels": out_channels, + "img_size": (img_size,) * 3, + "feature_size": feature_size, + "depths": depth, + "norm_name": norm_name, + "attn_drop_rate": attn_drop_rate, + }, + (2, in_channels, *([img_size] * 3)), + (2, out_channels, *([img_size] * 3)), + ] + TEST_CASE_UNETR.append(test_case) + + +class TestPatchEmbeddingBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_UNETR) + def test_shape(self, input_param, input_shape, expected_shape): + net = SwinUNETR(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + SwinUNETR( + in_channels=1, + out_channels=3, + img_size=(128, 128, 128), + feature_size=24, + norm_name="instance", + attn_drop_rate=4, + ) + + with self.assertRaises(ValueError): + SwinUNETR(in_channels=1, out_channels=2, img_size=(96, 96), feature_size=48, norm_name="instance") + + with self.assertRaises(ValueError): + SwinUNETR(in_channels=1, out_channels=4, img_size=(96, 96, 96), feature_size=50, norm_name="instance") + + with self.assertRaises(ValueError): + SwinUNETR( + in_channels=1, + out_channels=3, + img_size=(85, 85, 85), + feature_size=24, + norm_name="instance", + drop_rate=0.4, + ) + + +if __name__ == "__main__": + unittest.main() From 34a1f9a70797d56fbd0f5c8a476c0a3d919f526d Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 4 Apr 2022 16:02:23 -0700 Subject: [PATCH 02/26] add swin_unetr model Signed-off-by: ahatamizadeh --- monai/networks/nets/swin_unetr.py | 119 +++++++++++++++--------------- 1 file changed, 60 insertions(+), 59 deletions(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 7bdb9ddb6b..9402dc41f5 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -21,7 +21,7 @@ import torch.utils.checkpoint as checkpoint from monai.networks.blocks.dynunet_block import UnetOutBlock -from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock +from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrUpBlock from monai.networks.layers import Conv from monai.utils import ensure_tuple_rep, optional_import @@ -350,22 +350,22 @@ def window_partition(x, window_size): window_size: local window size. """ - B, D, H, W, C = x.shape + b, d, h, w, c = x.shape x = x.view( - B, - D // window_size[0], + b, + d // window_size[0], window_size[0], - H // window_size[1], + h // window_size[1], window_size[1], - W // window_size[2], + w // window_size[2], window_size[2], - C, + c, ) - windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C) + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), c) return windows -def window_reverse(windows, window_size, B, D, H, W): +def window_reverse(windows, window_size, b, d, h, w): """window reverse operation based on: "Liu et al., Swin Transformer: Hierarchical Vision Transformer using Shifted Windows " @@ -374,23 +374,23 @@ def window_reverse(windows, window_size, B, D, H, W): Args: windows: windows tensor. window_size: local window size. - B: batch size. - D: depth. - H: height. - W: width. + b: batch size. + d: depth. + h: height. + w: width. """ x = windows.view( - B, - D // window_size[0], - H // window_size[1], - W // window_size[2], + b, + d // window_size[0], + h // window_size[1], + w // window_size[2], window_size[0], window_size[1], window_size[2], -1, ) - x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1) return x @@ -481,26 +481,26 @@ def __init__( self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + b, n, c = x.shape + qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scale attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index[:N, :N].reshape(-1) - ].reshape(N, N, -1) + self.relative_position_index[:n, :n].reshape(-1) + ].reshape(n, n, -1) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) + nw = mask.shape[0] + attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, n, n) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = (attn @ v).transpose(1, 2).reshape(b, n, c) x = self.proj(x) x = self.proj_drop(x) return x @@ -571,15 +571,15 @@ def __init__( self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward_part1(self, x, mask_matrix): - B, D, H, W, C = x.shape - window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size) + b, d, h, w, c = x.shape + window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size) x = self.norm1(x) pad_l = pad_t = pad_d0 = 0 - pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0] - pad_b = (window_size[1] - H % window_size[1]) % window_size[1] - pad_r = (window_size[2] - W % window_size[2]) % window_size[2] + pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0] + pad_b = (window_size[1] - h % window_size[1]) % window_size[1] + pad_r = (window_size[2] - w % window_size[2]) % window_size[2] x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) - _, Dp, Hp, Wp, _ = x.shape + _, dp, hp, wp, _ = x.shape if any(i > 0 for i in shift_size): shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) @@ -589,14 +589,14 @@ def forward_part1(self, x, mask_matrix): attn_mask = None x_windows = window_partition(shifted_x, window_size) attn_windows = self.attn(x_windows, mask=attn_mask) - attn_windows = attn_windows.view(-1, *(window_size + (C,))) - shifted_x = window_reverse(attn_windows, window_size, B, Dp, Hp, Wp) + attn_windows = attn_windows.view(-1, *(window_size + (c,))) + shifted_x = window_reverse(attn_windows, window_size, b, dp, hp, wp) if any(i > 0 for i in shift_size): x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) else: x = shifted_x if pad_d1 > 0 or pad_r > 0 or pad_b > 0: - x = x[:, :D, :H, :W, :].contiguous() + x = x[:, :d, :h, :w, :].contiguous() return x def forward_part2(self, x): @@ -671,10 +671,10 @@ def __init__(self, dim: int, norm_layer: type = nn.LayerNorm) -> None: self.norm = norm_layer(8 * dim) def forward(self, x): - B, D, H, W, C = x.shape - pad_input = (H % 2 == 1) or (W % 2 == 1) or (D % 2 == 1) + b, d, h, w, c = x.shape + pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1) if pad_input: - x = F.pad(x, (0, 0, 0, D % 2, 0, W % 2, 0, H % 2)) + x = F.pad(x, (0, 0, 0, d % 2, 0, w % 2, 0, h % 2)) x0 = x[:, 0::2, 0::2, 0::2, :] x1 = x[:, 1::2, 0::2, 0::2, :] @@ -690,7 +690,7 @@ def forward(self, x): return x -def compute_mask(D, H, W, window_size, shift_size, device): +def compute_mask(d, h, w, window_size, shift_size, device): """Computing region masks based on: "Liu et al., Swin Transformer: Hierarchical Vision Transformer using Shifted Windows @@ -698,14 +698,15 @@ def compute_mask(D, H, W, window_size, shift_size, device): https://github.com/microsoft/Swin-Transformer Args: - D: depth value. - H: height value. - W: height value. + d: depth value. + h: height value. + w: height value. window_size: local window size. + shift_size: shift size. device: device. """ - img_mask = torch.zeros((1, D, H, W, 1), device=device) + img_mask = torch.zeros((1, d, h, w, 1), device=device) cnt = 0 for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None): for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None): @@ -789,16 +790,16 @@ def __init__( self.downsample = downsample(dim=dim, norm_layer=norm_layer) def forward(self, x): - B, C, D, H, W = x.shape - window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size) + b, c, d, h, w = x.shape + window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size) x = rearrange(x, "b c d h w -> b d h w c") - Dp = int(np.ceil(D / window_size[0])) * window_size[0] - Hp = int(np.ceil(H / window_size[1])) * window_size[1] - Wp = int(np.ceil(W / window_size[2])) * window_size[2] - attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) + dp = int(np.ceil(d / window_size[0])) * window_size[0] + hp = int(np.ceil(h / window_size[1])) * window_size[1] + wp = int(np.ceil(w / window_size[2])) * window_size[2] + attn_mask = compute_mask(dp, hp, wp, window_size, shift_size, x.device) for blk in self.blocks: x = blk(x, attn_mask) - x = x.view(B, D, H, W, -1) + x = x.view(b, d, h, w, -1) if self.downsample is not None: x = self.downsample(x) x = rearrange(x, "b d h w c -> b c d h w") @@ -918,19 +919,19 @@ def __init__( self.norm = None def forward(self, x): - _, _, D, H, W = x.size() - if W % self.patch_size[2] != 0: - x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) - if H % self.patch_size[1] != 0: - x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) - if D % self.patch_size[0] != 0: - x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) + _, _, d, h, w = x.size() + if w % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - w % self.patch_size[2])) + if h % self.patch_size[1] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[1] - h % self.patch_size[1])) + if d % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - d % self.patch_size[0])) x = self.proj(x) if self.norm is not None: - D, Wh, Ww = x.size(2), x.size(3), x.size(4) + d, wh, ww = x.size(2), x.size(3), x.size(4) x = x.flatten(2).transpose(1, 2) x = self.norm(x) - x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) + x = x.transpose(1, 2).view(-1, self.embed_dim, d, wh, ww) return x From c292d2c5c60c1b79032c26c83d1de140fedf95bf Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 4 Apr 2022 16:08:01 -0700 Subject: [PATCH 03/26] add swin_unetr model Signed-off-by: ahatamizadeh --- monai/networks/nets/swin_unetr.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 9402dc41f5..a4dbbaf93d 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -603,7 +603,7 @@ def forward_part2(self, x): return self.drop_path(self.mlp(self.norm2(x))) def load_from(self, weights, n_block, layer): - ROOT = f"module.{layer}.0.blocks.{n_block}." + root = f"module.{layer}.0.blocks.{n_block}." block_names = [ "norm1.weight", "norm1.bias", @@ -621,20 +621,20 @@ def load_from(self, weights, n_block, layer): "mlp.fc2.bias", ] with torch.no_grad(): - self.norm1.weight.copy_(weights["state_dict"][ROOT + block_names[0]]) - self.norm1.bias.copy_(weights["state_dict"][ROOT + block_names[1]]) - self.attn.relative_position_bias_table.copy_(weights["state_dict"][ROOT + block_names[2]]) - self.attn.relative_position_index.copy_(weights["state_dict"][ROOT + block_names[3]]) - self.attn.qkv.weight.copy_(weights["state_dict"][ROOT + block_names[4]]) - self.attn.qkv.bias.copy_(weights["state_dict"][ROOT + block_names[5]]) - self.attn.proj.weight.copy_(weights["state_dict"][ROOT + block_names[6]]) - self.attn.proj.bias.copy_(weights["state_dict"][ROOT + block_names[7]]) - self.norm2.weight.copy_(weights["state_dict"][ROOT + block_names[8]]) - self.norm2.bias.copy_(weights["state_dict"][ROOT + block_names[9]]) - self.mlp.fc1.weight.copy_(weights["state_dict"][ROOT + block_names[10]]) - self.mlp.fc1.bias.copy_(weights["state_dict"][ROOT + block_names[11]]) - self.mlp.fc2.weight.copy_(weights["state_dict"][ROOT + block_names[12]]) - self.mlp.fc2.bias.copy_(weights["state_dict"][ROOT + block_names[13]]) + self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]]) + self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]]) + self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]]) + self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]]) + self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]]) + self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]]) + self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]]) + self.attn.proj.bias.copy_(weights["state_dict"][root + block_names[7]]) + self.norm2.weight.copy_(weights["state_dict"][root + block_names[8]]) + self.norm2.bias.copy_(weights["state_dict"][root + block_names[9]]) + self.mlp.fc1.weight.copy_(weights["state_dict"][root + block_names[10]]) + self.mlp.fc1.bias.copy_(weights["state_dict"][root + block_names[11]]) + self.mlp.fc2.weight.copy_(weights["state_dict"][root + block_names[12]]) + self.mlp.fc2.bias.copy_(weights["state_dict"][root + block_names[13]]) def forward(self, x, mask_matrix): shortcut = x From 83c43deb7d51d12823c8539131fc86e617bf1e3a Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 4 Apr 2022 16:39:57 -0700 Subject: [PATCH 04/26] add swin_unetr model Signed-off-by: ahatamizadeh --- monai/networks/nets/swin_unetr.py | 43 ++++++++++++++++++------------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index a4dbbaf93d..63d39f2bcc 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -41,8 +41,8 @@ def __init__( in_channels: int, out_channels: int, feature_size: int = 48, - depths: Tuple[int, int, int, int] = [2, 2, 2, 2], - num_heads: Tuple[int, int, int, int] = [3, 6, 12, 24], + depths: Sequence[int] = None, + num_heads: Sequence[int] = None, norm_name: Union[Tuple, str] = "instance", drop_rate: float = 0.0, attn_drop_rate: float = 0.0, @@ -85,6 +85,12 @@ def __init__( img_size = ensure_tuple_rep(img_size, spatial_dims) patch_size = ensure_tuple_rep(2, spatial_dims) window_size = ensure_tuple_rep(7, spatial_dims) + if depths is None: + depths = [2, 2, 2, 2] + + if num_heads is None: + num_heads = [3, 6, 12, 24] + for m, p in zip(img_size, patch_size): for i in range(5): if m % np.power(p, i + 1) != 0: @@ -308,7 +314,7 @@ class Mlp(nn.Module): def __init__( self, in_features: int, - hidden_features: type = None, + hidden_features: int, out_features: type = None, act_layer: type = nn.GELU, drop: float = 0.0, @@ -432,8 +438,8 @@ class WindowAttention(nn.Module): def __init__( self, dim: int, - window_size: Tuple[int, int, int] = [7, 7, 7], - num_heads: Tuple[int, int, int, int] = [3, 6, 12, 24], + num_heads: int, + window_size: Sequence[int] = None, qkv_bias: bool = False, qk_scale: type = None, attn_drop: float = 0.0, @@ -442,8 +448,8 @@ def __init__( """ Args: dim: number of feature channels. - window_size: local window size. num_heads: number of attention heads. + window_size: local window size. qkv_bias: add a learnable bias to query, key, value. qk_scale: override default qk scale of head_dim ** -0.5 if set. attn_drop: attention dropout rate. @@ -517,9 +523,9 @@ class SwinTransformerBlock(nn.Module): def __init__( self, dim: int, - num_heads: Tuple[int, int, int, int] = [3, 6, 12, 24], - window_size: Tuple[int, int, int] = [7, 7, 7], - shift_size: Tuple[int, int, int] = [0, 0, 0], + num_heads: int, + window_size: Sequence[int] = None, + shift_size: Sequence[int] = None, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_scale: type = None, @@ -731,15 +737,15 @@ class BasicLayer(nn.Module): def __init__( self, dim: int, - depth: Tuple[int, int, int, int] = [2, 2, 2, 2], - num_heads: Tuple[int, int, int, int] = [3, 6, 12, 24], - window_size: Tuple[int, int, int] = [7, 7, 7], + depth: int, + num_heads: int, + window_size: Sequence[int] = None, mlp_ratio: float = 4.0, qkv_bias: bool = False, qk_scale: type = None, drop: float = 0.0, attn_drop: float = 0.0, - drop_path: float = 0.0, + drop_path: list = None, norm_layer: type = nn.LayerNorm, downsample: type = None, use_checkpoint: bool = False, @@ -892,7 +898,7 @@ class PatchEmbed(nn.Module): def __init__( self, - patch_size: Tuple[int, int, int] = [2, 2, 2], + patch_size: Sequence[int], in_chans: int = 1, embed_dim: int = 96, norm_layer: type = None, @@ -947,10 +953,10 @@ def __init__( self, in_chans: int = 1, embed_dim: int = 96, - window_size: Tuple[int, int, int] = [7, 7, 7], - patch_size: Tuple[int, int, int] = [2, 2, 2], - depths: Tuple[int, int, int, int] = [2, 2, 2, 2], - num_heads: Tuple[int, int, int, int] = [3, 6, 12, 24], + window_size: Sequence[int] = None, + patch_size: Sequence[int] = None, + depths: Sequence[int] = None, + num_heads: Sequence[int] = None, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_scale: type = None, @@ -1006,6 +1012,7 @@ def __init__( dim=int(embed_dim * 2**i_layer), depth=depths[i_layer], num_heads=num_heads[i_layer], + window_size=self.window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, From 384d1e52613de516c7d766f62466c3e5a0b5e97f Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 4 Apr 2022 18:00:59 -0700 Subject: [PATCH 05/26] add swin_unetr model Signed-off-by: ahatamizadeh --- tests/test_swin_unetr.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py index d2c01d3c69..ba790fc54b 100644 --- a/tests/test_swin_unetr.py +++ b/tests/test_swin_unetr.py @@ -10,19 +10,23 @@ # limitations under the License. import unittest +from unittest import skipUnless import torch from parameterized import parameterized from monai.networks import eval_mode from monai.networks.nets.swin_unetr import SwinUNETR +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") TEST_CASE_UNETR = [] for attn_drop_rate in [0.4]: for in_channels in [1]: - for depth in [[2, 2, 4, 2]]: + for depth in [[2, 2, 4, 2], [1, 2, 1, 1]]: for out_channels in [2]: - for img_size in [96, 128]: + for img_size in [64]: for feature_size in [48]: for norm_name in ["instance"]: test_case = [ @@ -43,6 +47,7 @@ class TestPatchEmbeddingBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_UNETR) + @skipUnless(has_einops, "Requires einops") def test_shape(self, input_param, input_shape, expected_shape): net = SwinUNETR(**input_param) with eval_mode(net): From 177461004f8cbdf2edb6aa5eba2681fccff24cd5 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 4 Apr 2022 20:07:41 -0700 Subject: [PATCH 06/26] add swin_unetr model Signed-off-by: ahatamizadeh --- monai/networks/nets/swin_unetr.py | 68 ++++++++++--------------------- 1 file changed, 21 insertions(+), 47 deletions(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 63d39f2bcc..a736950923 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -40,9 +40,9 @@ def __init__( img_size: Union[Sequence[int], int], in_channels: int, out_channels: int, + depths: Sequence[int] = (2, 2, 2, 2), + num_heads: Sequence[int] = (3, 6, 12, 24), feature_size: int = 48, - depths: Sequence[int] = None, - num_heads: Sequence[int] = None, norm_name: Union[Tuple, str] = "instance", drop_rate: float = 0.0, attn_drop_rate: float = 0.0, @@ -85,11 +85,6 @@ def __init__( img_size = ensure_tuple_rep(img_size, spatial_dims) patch_size = ensure_tuple_rep(2, spatial_dims) window_size = ensure_tuple_rep(7, spatial_dims) - if depths is None: - depths = [2, 2, 2, 2] - - if num_heads is None: - num_heads = [3, 6, 12, 24] for m, p in zip(img_size, patch_size): for i in range(5): @@ -117,7 +112,6 @@ def __init__( num_heads=num_heads, mlp_ratio=4.0, qkv_bias=True, - qk_scale=None, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=dropout_path_rate, @@ -311,29 +305,20 @@ class Mlp(nn.Module): https://github.com/microsoft/Swin-Transformer """ - def __init__( - self, - in_features: int, - hidden_features: int, - out_features: type = None, - act_layer: type = nn.GELU, - drop: float = 0.0, - ) -> None: + def __init__(self, in_features: int, hidden_features: int, act_layer: type = nn.GELU, drop: float = 0.0) -> None: super().__init__() """ Args: in_features: number of input feature channels. hidden_features: number of hidden feature channels. - out_features: number of output feature channels. act_layer: activation layer. drop: dropout rate. """ - out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) + self.fc2 = nn.Linear(hidden_features, in_features) self.drop = nn.Dropout(drop) def forward(self, x): @@ -439,9 +424,8 @@ def __init__( self, dim: int, num_heads: int, - window_size: Sequence[int] = None, + window_size: Sequence[int], qkv_bias: bool = False, - qk_scale: type = None, attn_drop: float = 0.0, proj_drop: float = 0.0, ) -> None: @@ -451,7 +435,6 @@ def __init__( num_heads: number of attention heads. window_size: local window size. qkv_bias: add a learnable bias to query, key, value. - qk_scale: override default qk scale of head_dim ** -0.5 if set. attn_drop: attention dropout rate. proj_drop: dropout rate of output. """ @@ -461,7 +444,7 @@ def __init__( self.window_size = window_size self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim**-0.5 + self.scale = head_dim**-0.5 self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads) ) @@ -524,11 +507,10 @@ def __init__( self, dim: int, num_heads: int, - window_size: Sequence[int] = None, - shift_size: Sequence[int] = None, + window_size: Sequence[int], + shift_size: Sequence[int], mlp_ratio: float = 4.0, qkv_bias: bool = True, - qk_scale: type = None, drop: float = 0.0, attn_drop: float = 0.0, drop_path: float = 0.0, @@ -544,7 +526,6 @@ def __init__( shift_size: window shift size. mlp_ratio: ratio of mlp hidden dim to embedding dim. qkv_bias: add a learnable bias to query, key, value. - qk_scale: override default qk scale of head_dim ** -0.5 if set. drop: dropout rate. attn_drop: attention dropout rate. drop_path: stochastic depth rate. @@ -566,7 +547,6 @@ def __init__( window_size=self.window_size, num_heads=num_heads, qkv_bias=qkv_bias, - qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) @@ -739,15 +719,14 @@ def __init__( dim: int, depth: int, num_heads: int, - window_size: Sequence[int] = None, + window_size: Sequence[int], + drop_path: list, mlp_ratio: float = 4.0, qkv_bias: bool = False, - qk_scale: type = None, drop: float = 0.0, attn_drop: float = 0.0, - drop_path: list = None, - norm_layer: type = nn.LayerNorm, - downsample: type = None, + norm_layer: isinstance = nn.LayerNorm, + downsample: isinstance = None, use_checkpoint: bool = False, ) -> None: """ @@ -756,12 +735,11 @@ def __init__( depths: number of layers in each stage. num_heads: number of attention heads. window_size: local window size. + drop_path: stochastic depth rate. mlp_ratio: ratio of mlp hidden dim to embedding dim. qkv_bias: add a learnable bias to query, key, value. - qk_scale: override default qk scale of head_dim ** -0.5 if set. drop: dropout rate. attn_drop: attention dropout rate. - drop_path: stochastic depth rate. norm_layer: normalization layer. downsample: downsample layer at the end of the layer. use_checkpoint: use gradient checkpointing for reduced memory usage. @@ -781,7 +759,6 @@ def __init__( shift_size=(0, 0, 0) if (i % 2 == 0) else self.shift_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, - qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, @@ -951,19 +928,18 @@ class SwinTransformer(nn.Module): def __init__( self, - in_chans: int = 1, - embed_dim: int = 96, - window_size: Sequence[int] = None, - patch_size: Sequence[int] = None, - depths: Sequence[int] = None, - num_heads: Sequence[int] = None, + in_chans: int, + embed_dim: int, + window_size: Sequence[int], + patch_size: Sequence[int], + depths: Sequence[int], + num_heads: Sequence[int], mlp_ratio: float = 4.0, qkv_bias: bool = True, - qk_scale: type = None, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, - norm_layer: type = nn.LayerNorm, + norm_layer: isinstance = nn.LayerNorm, patch_norm: bool = False, use_checkpoint: bool = False, spatial_dims: int = 3, @@ -978,7 +954,6 @@ def __init__( num_heads: number of attention heads. mlp_ratio: ratio of mlp hidden dim to embedding dim. qkv_bias: add a learnable bias to query, key, value. - qk_scale: override default qk scale of head_dim ** -0.5 if set. drop_rate: dropout rate. attn_drop_rate: attention dropout rate. drop_path_rate: stochastic depth rate. @@ -1013,12 +988,11 @@ def __init__( depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=self.window_size, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, - qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging, use_checkpoint=use_checkpoint, From a6dec2d1882981dcafd8e7687a11eec53344da19 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 4 Apr 2022 22:04:31 -0700 Subject: [PATCH 07/26] add swin_unetr model Signed-off-by: ahatamizadeh --- monai/networks/nets/swin_unetr.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index a736950923..5244801db1 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -12,13 +12,14 @@ import math from functools import reduce from operator import mul -from typing import Sequence, Tuple, Union +from typing import Sequence, Tuple, Type, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint +from torch.nn import GELU, LayerNorm from monai.networks.blocks.dynunet_block import UnetOutBlock from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrUpBlock @@ -305,7 +306,9 @@ class Mlp(nn.Module): https://github.com/microsoft/Swin-Transformer """ - def __init__(self, in_features: int, hidden_features: int, act_layer: type = nn.GELU, drop: float = 0.0) -> None: + def __init__( + self, in_features: int, hidden_features: int, act_layer: Type[GELU] = nn.GELU, drop: float = 0.0 + ) -> None: super().__init__() """ Args: @@ -514,8 +517,8 @@ def __init__( drop: float = 0.0, attn_drop: float = 0.0, drop_path: float = 0.0, - act_layer: type = nn.GELU, - norm_layer: type = nn.LayerNorm, + act_layer: Type[GELU] = nn.GELU, + norm_layer: Type[LayerNorm] = nn.LayerNorm, use_checkpoint: bool = False, ) -> None: """ @@ -644,7 +647,7 @@ class PatchMerging(nn.Module): https://github.com/microsoft/Swin-Transformer """ - def __init__(self, dim: int, norm_layer: type = nn.LayerNorm) -> None: + def __init__(self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm) -> None: """ Args: dim: number of feature channels. @@ -725,8 +728,8 @@ def __init__( qkv_bias: bool = False, drop: float = 0.0, attn_drop: float = 0.0, - norm_layer: isinstance = nn.LayerNorm, - downsample: isinstance = None, + norm_layer: Type[LayerNorm] = nn.LayerNorm, + downsample: isinstance = None, # type: ignore use_checkpoint: bool = False, ) -> None: """ @@ -939,7 +942,7 @@ def __init__( drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, - norm_layer: isinstance = nn.LayerNorm, + norm_layer: Type[LayerNorm] = nn.LayerNorm, patch_norm: bool = False, use_checkpoint: bool = False, spatial_dims: int = 3, From ea2e6aa6ac5aba041fcd0cf7235670836891f898 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 4 Apr 2022 22:41:37 -0700 Subject: [PATCH 08/26] add swin_unetr model Signed-off-by: ahatamizadeh --- monai/networks/nets/swin_unetr.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 5244801db1..f804715c59 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -518,7 +518,7 @@ def __init__( attn_drop: float = 0.0, drop_path: float = 0.0, act_layer: Type[GELU] = nn.GELU, - norm_layer: Type[LayerNorm] = nn.LayerNorm, + norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore use_checkpoint: bool = False, ) -> None: """ @@ -647,7 +647,7 @@ class PatchMerging(nn.Module): https://github.com/microsoft/Swin-Transformer """ - def __init__(self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm) -> None: + def __init__(self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm) -> None: # type: ignore """ Args: dim: number of feature channels. @@ -728,7 +728,7 @@ def __init__( qkv_bias: bool = False, drop: float = 0.0, attn_drop: float = 0.0, - norm_layer: Type[LayerNorm] = nn.LayerNorm, + norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore downsample: isinstance = None, # type: ignore use_checkpoint: bool = False, ) -> None: @@ -881,7 +881,7 @@ def __init__( patch_size: Sequence[int], in_chans: int = 1, embed_dim: int = 96, - norm_layer: type = None, + norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore spatial_dims: int = 3, ) -> None: """ @@ -942,7 +942,7 @@ def __init__( drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, - norm_layer: Type[LayerNorm] = nn.LayerNorm, + norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore patch_norm: bool = False, use_checkpoint: bool = False, spatial_dims: int = 3, From 639b20c9c68f93f436d7d81015b5312fe32dd1fb Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Tue, 5 Apr 2022 06:54:00 -0700 Subject: [PATCH 09/26] add swin_unetr model Signed-off-by: ahatamizadeh --- monai/networks/nets/swin_unetr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index f804715c59..aca5cb51fa 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -976,7 +976,7 @@ def __init__( patch_size=self.patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None, + norm_layer=norm_layer if self.patch_norm else None, # type: ignore spatial_dims=spatial_dims, ) self.pos_drop = nn.Dropout(p=drop_rate) From 9f95464039bfdb1920c7f8884ffb0e6829aa2f22 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Tue, 5 Apr 2022 08:18:34 -0700 Subject: [PATCH 10/26] add swin_unetr model Signed-off-by: ahatamizadeh --- tests/test_swin_unetr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py index ba790fc54b..3e48804ebd 100644 --- a/tests/test_swin_unetr.py +++ b/tests/test_swin_unetr.py @@ -24,10 +24,10 @@ TEST_CASE_UNETR = [] for attn_drop_rate in [0.4]: for in_channels in [1]: - for depth in [[2, 2, 4, 2], [1, 2, 1, 1]]: + for depth in [[2, 2, 1, 1], [1, 2, 1, 1]]: for out_channels in [2]: for img_size in [64]: - for feature_size in [48]: + for feature_size in [12]: for norm_name in ["instance"]: test_case = [ { From 279f7e1f9672d1fed1caea339ae2594620a2e34f Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Tue, 5 Apr 2022 08:25:08 -0700 Subject: [PATCH 11/26] add swin_unetr model Signed-off-by: ahatamizadeh --- tests/min_tests.py | 1 + tests/test_swin_unetr.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/min_tests.py b/tests/min_tests.py index e0710a93ec..36b1d6f26b 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -159,6 +159,7 @@ def run_testsuit(): "test_zoomd", "test_prepare_batch_default_dist", "test_parallel_execution_dist", + "test_swin_unetr", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py index 3e48804ebd..1680567472 100644 --- a/tests/test_swin_unetr.py +++ b/tests/test_swin_unetr.py @@ -24,7 +24,7 @@ TEST_CASE_UNETR = [] for attn_drop_rate in [0.4]: for in_channels in [1]: - for depth in [[2, 2, 1, 1], [1, 2, 1, 1]]: + for depth in [[2, 1, 1, 1], [1, 2, 1, 1]]: for out_channels in [2]: for img_size in [64]: for feature_size in [12]: From 0e365adbdb29bedee65b6db1ba465ebcc5a35c6d Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Tue, 5 Apr 2022 08:27:04 -0700 Subject: [PATCH 12/26] add swin_unetr model Signed-off-by: ahatamizadeh --- tests/min_tests.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/min_tests.py b/tests/min_tests.py index 36b1d6f26b..377a1d8820 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -159,6 +159,9 @@ def run_testsuit(): "test_zoomd", "test_prepare_batch_default_dist", "test_parallel_execution_dist", + "test_bundle_verify_metadata", + "test_bundle_verify_net", + "test_bundle_ckpt_export", "test_swin_unetr", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" From 4c1653ae27aa455c10e5aabc6d9398470c698136 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Tue, 5 Apr 2022 08:29:29 -0700 Subject: [PATCH 13/26] add swin_unetr model Signed-off-by: ahatamizadeh --- tests/min_tests.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/min_tests.py b/tests/min_tests.py index 377a1d8820..e0710a93ec 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -159,10 +159,6 @@ def run_testsuit(): "test_zoomd", "test_prepare_batch_default_dist", "test_parallel_execution_dist", - "test_bundle_verify_metadata", - "test_bundle_verify_net", - "test_bundle_ckpt_export", - "test_swin_unetr", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" From 29a8603048967a2d2643a39c8ece108e9c6372a1 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Tue, 5 Apr 2022 08:48:23 -0700 Subject: [PATCH 14/26] add swin_unetr model Signed-off-by: ahatamizadeh --- tests/min_tests.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/min_tests.py b/tests/min_tests.py index e0710a93ec..b6ffc0c620 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -95,6 +95,7 @@ def run_testsuit(): "test_integration_unet_2d", "test_integration_workflows", "test_integration_workflows_gan", + "test_integration_bundle_run", "test_invertd", "test_iterable_dataset", "test_keep_largest_connected_component", @@ -159,6 +160,10 @@ def run_testsuit(): "test_zoomd", "test_prepare_batch_default_dist", "test_parallel_execution_dist", + "test_bundle_verify_metadata", + "test_bundle_verify_net", + "test_bundle_ckpt_export", + "test_swin_unetr", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" From b4153dfef56018fc2381b48ad61c25a295367de2 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Tue, 5 Apr 2022 08:49:39 -0700 Subject: [PATCH 15/26] add swin_unetr model Signed-off-by: ahatamizadeh --- tests/min_tests.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/min_tests.py b/tests/min_tests.py index b6ffc0c620..e0710a93ec 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -95,7 +95,6 @@ def run_testsuit(): "test_integration_unet_2d", "test_integration_workflows", "test_integration_workflows_gan", - "test_integration_bundle_run", "test_invertd", "test_iterable_dataset", "test_keep_largest_connected_component", @@ -160,10 +159,6 @@ def run_testsuit(): "test_zoomd", "test_prepare_batch_default_dist", "test_parallel_execution_dist", - "test_bundle_verify_metadata", - "test_bundle_verify_net", - "test_bundle_ckpt_export", - "test_swin_unetr", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" From d83e10d48bd09013ba60de0096e5f4bd09f3f327 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Sun, 10 Apr 2022 10:49:57 -0700 Subject: [PATCH 16/26] add swin_unetr model Signed-off-by: ahatamizadeh --- monai/networks/blocks/mlp.py | 37 +++++++++++++++++--- monai/networks/nets/swin_unetr.py | 56 ++++++------------------------- 2 files changed, 43 insertions(+), 50 deletions(-) diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py index a1728365cf..0feeb044f3 100644 --- a/monai/networks/blocks/mlp.py +++ b/monai/networks/blocks/mlp.py @@ -9,8 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple, Union + import torch.nn as nn +from monai.networks.layers import get_act_layer +from monai.utils import look_up_option + +SUPPORTED_DROPOUT_MODE = {"vit", "swin"} + class MLPBlock(nn.Module): """ @@ -18,12 +25,26 @@ class MLPBlock(nn.Module): An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " """ - def __init__(self, hidden_size: int, mlp_dim: int, dropout_rate: float = 0.0) -> None: + def __init__( + self, + hidden_size: int, + mlp_dim: int, + dropout_rate: float = 0.0, + act: Union[Tuple, str] = "GELU", + dropout_mode="vit", + ) -> None: """ Args: hidden_size: dimension of hidden layer. - mlp_dim: dimension of feedforward layer. + mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used. dropout_rate: faction of the input units to drop. + act: activation type and arguments. Defaults to GELU. + dropout_mode: dropout mode, can be "vit" or "swin". + "vit" mode uses two dropout instances as implemented in + https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87 + "swin" corresponds to one instance as implemented in + https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23 + """ @@ -31,12 +52,18 @@ def __init__(self, hidden_size: int, mlp_dim: int, dropout_rate: float = 0.0) -> if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") - + mlp_dim = mlp_dim or hidden_size self.linear1 = nn.Linear(hidden_size, mlp_dim) self.linear2 = nn.Linear(mlp_dim, hidden_size) - self.fn = nn.GELU() + self.fn = get_act_layer(act) self.drop1 = nn.Dropout(dropout_rate) - self.drop2 = nn.Dropout(dropout_rate) + dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE) + if dropout_opt == "vit": + self.drop2 = nn.Dropout(dropout_rate) + elif dropout_opt == "swin": + self.drop2 = self.drop1 + else: + raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}") def forward(self, x): x = self.fn(self.linear1(x)) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index aca5cb51fa..ecfc3cd5e1 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -22,6 +22,7 @@ from torch.nn import GELU, LayerNorm from monai.networks.blocks.dynunet_block import UnetOutBlock +from monai.networks.blocks.mlp import MLPBlock as Mlp from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrUpBlock from monai.networks.layers import Conv from monai.utils import ensure_tuple_rep, optional_import @@ -298,41 +299,6 @@ def forward(self, x_in): return logits -class Mlp(nn.Module): - """ - multi-layer perceptron based on: "Liu et al., - Swin Transformer: Hierarchical Vision Transformer using Shifted Windows - " - https://github.com/microsoft/Swin-Transformer - """ - - def __init__( - self, in_features: int, hidden_features: int, act_layer: Type[GELU] = nn.GELU, drop: float = 0.0 - ) -> None: - super().__init__() - """ - Args: - in_features: number of input feature channels. - hidden_features: number of hidden feature channels. - act_layer: activation layer. - drop: dropout rate. - """ - - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, in_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - def window_partition(x, window_size): """window partition operation based on: "Liu et al., Swin Transformer: Hierarchical Vision Transformer using Shifted Windows @@ -517,7 +483,7 @@ def __init__( drop: float = 0.0, attn_drop: float = 0.0, drop_path: float = 0.0, - act_layer: Type[GELU] = nn.GELU, + act_layer: str = "GELU", norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore use_checkpoint: bool = False, ) -> None: @@ -557,7 +523,7 @@ def __init__( self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = Mlp(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin") def forward_part1(self, x, mask_matrix): b, d, h, w, c = x.shape @@ -604,10 +570,10 @@ def load_from(self, weights, n_block, layer): "attn.proj.bias", "norm2.weight", "norm2.bias", - "mlp.fc1.weight", - "mlp.fc1.bias", - "mlp.fc2.weight", - "mlp.fc2.bias", + "mlp.linear1.weight", + "mlp.linear1.bias", + "mlp.linear2.weight", + "mlp.linear2.bias", ] with torch.no_grad(): self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]]) @@ -620,10 +586,10 @@ def load_from(self, weights, n_block, layer): self.attn.proj.bias.copy_(weights["state_dict"][root + block_names[7]]) self.norm2.weight.copy_(weights["state_dict"][root + block_names[8]]) self.norm2.bias.copy_(weights["state_dict"][root + block_names[9]]) - self.mlp.fc1.weight.copy_(weights["state_dict"][root + block_names[10]]) - self.mlp.fc1.bias.copy_(weights["state_dict"][root + block_names[11]]) - self.mlp.fc2.weight.copy_(weights["state_dict"][root + block_names[12]]) - self.mlp.fc2.bias.copy_(weights["state_dict"][root + block_names[13]]) + self.mlp.linear1.weight.copy_(weights["state_dict"][root + block_names[10]]) + self.mlp.linear1.bias.copy_(weights["state_dict"][root + block_names[11]]) + self.mlp.linear2.weight.copy_(weights["state_dict"][root + block_names[12]]) + self.mlp.linear2.bias.copy_(weights["state_dict"][root + block_names[13]]) def forward(self, x, mask_matrix): shortcut = x From 995250b9b1da9929d30bbfca9231412a3ff7ca3b Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Sun, 10 Apr 2022 10:58:14 -0700 Subject: [PATCH 17/26] add swin_unetr model Signed-off-by: ahatamizadeh --- monai/networks/nets/swin_unetr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index ecfc3cd5e1..8b4c07945e 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -19,7 +19,7 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint -from torch.nn import GELU, LayerNorm +from torch.nn import LayerNorm from monai.networks.blocks.dynunet_block import UnetOutBlock from monai.networks.blocks.mlp import MLPBlock as Mlp From ce9a374b9bf1319bcf6f33950d0b1b40e1e7ac13 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Sun, 10 Apr 2022 11:19:56 -0700 Subject: [PATCH 18/26] add swin_unetr model Signed-off-by: ahatamizadeh --- monai/networks/nets/swin_unetr.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 8b4c07945e..2544772e4a 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -49,6 +49,7 @@ def __init__( drop_rate: float = 0.0, attn_drop_rate: float = 0.0, dropout_path_rate: float = 0.0, + normalize: bool = False, use_checkpoint: bool = False, ) -> None: """ @@ -63,6 +64,7 @@ def __init__( drop_rate: dropout rate. attn_drop_rate: attention dropout rate. dropout_path_rate: drop path rate. + normalize: normalize output intermediate features in each stage. use_checkpoint: use gradient checkpointing for reduced memory usage. Examples:: @@ -105,6 +107,8 @@ def __init__( if feature_size % 12 != 0: raise ValueError("feature_size should be divisible by 12.") + self.normalize = normalize + self.swinViT = SwinTransformer( in_chans=in_channels, embed_dim=feature_size, @@ -284,7 +288,7 @@ def load_from(self, weights): self.swinViT.norm.bias.copy_(weights["state_dict"]["module.norm.bias"]) def forward(self, x_in): - hidden_states_out = self.swinViT(x_in) + hidden_states_out = self.swinViT(x_in, self.normalize) enc0 = self.encoder1(x_in) enc1 = self.encoder2(hidden_states_out[0]) enc2 = self.encoder3(hidden_states_out[1]) @@ -977,23 +981,24 @@ def __init__( self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.norm = norm_layer(self.num_features) - def proj_out(self, x): - n, ch, d, h, w = x.size() - x = rearrange(x, "n c d h w -> n d h w c") - x = F.layer_norm(x, [ch]) - x = rearrange(x, "n d h w c -> n c d h w") + def proj_out(self, x, normalize=False): + if normalize: + n, ch, d, h, w = x.size() + x = rearrange(x, "n c d h w -> n d h w c") + x = F.layer_norm(x, [ch]) + x = rearrange(x, "n d h w c -> n c d h w") return x - def forward(self, x): + def forward(self, x, normalize=False): x0 = self.patch_embed(x) x0 = self.pos_drop(x0) - x0_out = self.proj_out(x0) + x0_out = self.proj_out(x0, normalize) x1 = self.layers1[0](x0.contiguous()) - x1_out = self.proj_out(x1) + x1_out = self.proj_out(x1, normalize) x2 = self.layers2[0](x1.contiguous()) - x2_out = self.proj_out(x2) + x2_out = self.proj_out(x2, normalize) x3 = self.layers3[0](x2.contiguous()) - x3_out = self.proj_out(x3) + x3_out = self.proj_out(x3, normalize) x4 = self.layers4[0](x3.contiguous()) - x4_out = self.proj_out(x4) + x4_out = self.proj_out(x4, normalize) return [x0_out, x1_out, x2_out, x3_out, x4_out] From ec96fc3ae4ce2e0e44e70a39ba8a205453313c83 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 2 May 2022 01:25:05 -0700 Subject: [PATCH 19/26] add swin_unetr model Signed-off-by: ahatamizadeh --- monai/networks/blocks/__init__.py | 2 +- monai/networks/blocks/patchembedding.py | 109 ++++-- monai/networks/layers/__init__.py | 2 + monai/networks/layers/drop_path.py | 45 +++ monai/networks/layers/weight_init.py | 66 ++++ monai/networks/nets/__init__.py | 1 + monai/networks/nets/swin_unetr.py | 478 +++++++++++------------- tests/test_drop_path.py | 43 +++ tests/test_patchembedding.py | 37 +- tests/test_swin_unetr.py | 37 +- tests/test_weight_init.py | 47 +++ 11 files changed, 573 insertions(+), 294 deletions(-) create mode 100644 monai/networks/layers/drop_path.py create mode 100644 monai/networks/layers/weight_init.py create mode 100644 tests/test_drop_path.py create mode 100644 tests/test_weight_init.py diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index 0fdc944760..b6328734b0 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -20,7 +20,7 @@ from .fcn import FCN, GCN, MCFCN, Refine from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock from .mlp import MLPBlock -from .patchembedding import PatchEmbeddingBlock +from .patchembedding import PatchEmbed, PatchEmbeddingBlock from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock from .segresnet_block import ResBlock from .selfattention import SABlock diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 4c7263c6d5..f02f6342e8 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -9,15 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import math -from typing import Sequence, Union +from typing import Sequence, Type, Union import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm -from monai.networks.layers import Conv +from monai.networks.layers import Conv, trunc_normal_ from monai.utils import ensure_tuple_rep, optional_import from monai.utils.module import look_up_option @@ -98,34 +98,18 @@ def __init__( ) self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) self.dropout = nn.Dropout(dropout_rate) - self.trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) + trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): - self.trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0) + trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - def trunc_normal_(self, tensor, mean, std, a, b): - # From PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - with torch.no_grad(): - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - tensor.uniform_(2 * l - 1, 2 * u - 1) - tensor.erfinv_() - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - tensor.clamp_(min=a, max=b) - return tensor - def forward(self, x): x = self.patch_embeddings(x) if self.pos_embed == "conv": @@ -133,3 +117,84 @@ def forward(self, x): embeddings = x + self.position_embeddings embeddings = self.dropout(embeddings) return embeddings + + +class PatchEmbed(nn.Module): + """ + Patch embedding block based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + + Unlike ViT patch embedding block: (1) input is padded to satisfy window size requirements (2) normalized if + specified (3) position embedding is not used. + + Example:: + + >>> from monai.networks.blocks import PatchEmbed + >>> PatchEmbed(patch_size=2, in_chans=1, embed_dim=48, norm_layer=nn.LayerNorm, spatial_dims=3) + """ + + def __init__( + self, + patch_size: Union[Sequence[int], int] = 2, + in_chans: int = 1, + embed_dim: int = 48, + norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore + spatial_dims: int = 3, + ) -> None: + """ + Args: + patch_size: dimension of patch size. + in_chans: dimension of input channels. + embed_dim: number of linear projection output channels. + norm_layer: normalization layer. + spatial_dims: spatial dimension. + """ + + super().__init__() + + if not (spatial_dims == 2 or spatial_dims == 3): + raise ValueError("spatial dimension should be 2 or 3.") + + patch_size = ensure_tuple_rep(patch_size, spatial_dims) + self.patch_size = patch_size + self.embed_dim = embed_dim + self.proj = Conv[Conv.CONV, spatial_dims]( + in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size + ) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x_shape = x.size() + if len(x_shape) == 5: + _, _, d, h, w = x_shape + if w % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - w % self.patch_size[2])) + if h % self.patch_size[1] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[1] - h % self.patch_size[1])) + if d % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - d % self.patch_size[0])) + + elif len(x_shape) == 4: + _, _, h, w = x.size() + if w % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - w % self.patch_size[1])) + if h % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - h % self.patch_size[0])) + + x = self.proj(x) + if self.norm is not None: + x_shape = x.size() + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + if len(x_shape) == 5: + d, wh, ww = x_shape[2], x_shape[3], x_shape[4] + x = x.transpose(1, 2).view(-1, self.embed_dim, d, wh, ww) + elif len(x_shape) == 4: + wh, ww = x_shape[2], x_shape[3] + x = x.transpose(1, 2).view(-1, self.embed_dim, wh, ww) + return x diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index 5115c00af3..f122dccee6 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -10,6 +10,7 @@ # limitations under the License. from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding +from .drop_path import DropPath from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args from .filtering import BilateralFilter, PHLFilter from .gmm import GaussianMixtureModel @@ -27,3 +28,4 @@ ) from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer +from .weight_init import _no_grad_trunc_normal_, trunc_normal_ diff --git a/monai/networks/layers/drop_path.py b/monai/networks/layers/drop_path.py new file mode 100644 index 0000000000..f91f65d682 --- /dev/null +++ b/monai/networks/layers/drop_path.py @@ -0,0 +1,45 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn + + +class DropPath(nn.Module): + """Stochastic drop paths per sample for residual blocks. + Based on: + https://github.com/rwightman/pytorch-image-models + """ + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True) -> None: + """ + Args: + drop_prob: drop path probability. + scale_by_keep: scaling by non-dropped probability. + """ + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + if not (0 <= drop_prob <= 1): + raise ValueError("Drop path prob should be between 0 and 1.") + + def drop_path(self, x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + def forward(self, x): + return self.drop_path(x, self.drop_prob, self.training, self.scale_by_keep) diff --git a/monai/networks/layers/weight_init.py b/monai/networks/layers/weight_init.py new file mode 100644 index 0000000000..2217bdb42c --- /dev/null +++ b/monai/networks/layers/weight_init.py @@ -0,0 +1,66 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + + """Tensor initialization with truncated normal distribution. + Based on: + https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + https://github.com/rwightman/pytorch-image-models + + Args: + tensor: an n-dimensional `torch.Tensor`. + mean: the mean of the normal distribution. + std: the standard deviation of the normal distribution. + a: the minimum cutoff value. + b: the maximum cutoff value. + """ + + def norm_cdf(x): + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + with torch.no_grad(): + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + tensor.uniform_(2 * l - 1, 2 * u - 1) + tensor.erfinv_() + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + + """Tensor initialization with truncated normal distribution. + Based on: + https://github.com/rwightman/pytorch-image-models + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + + if not std > 0: + raise ValueError("the standard deviation should be greater than zero.") + + if a >= b: + raise ValueError("minimum cutoff value (a) should be greater than maximum cutoff value (b).") + + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 22fcef4903..d902ebff99 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -79,6 +79,7 @@ seresnext50, seresnext101, ) +from .swin_unetr import SwinUNETR from .torchvision_fc import TorchVisionFCModel, TorchVisionFullyConvModel from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex from .unet import UNet, Unet diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 2544772e4a..439009b57c 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -9,9 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math -from functools import reduce -from operator import mul from typing import Sequence, Tuple, Type, Union import numpy as np @@ -21,10 +18,9 @@ import torch.utils.checkpoint as checkpoint from torch.nn import LayerNorm -from monai.networks.blocks.dynunet_block import UnetOutBlock -from monai.networks.blocks.mlp import MLPBlock as Mlp -from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrUpBlock -from monai.networks.layers import Conv +from monai.networks.blocks import MLPBlock as Mlp +from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock +from monai.networks.layers import DropPath, trunc_normal_ from monai.utils import ensure_tuple_rep, optional_import rearrange, _ = optional_import("einops", name="rearrange") @@ -51,6 +47,7 @@ def __init__( dropout_path_rate: float = 0.0, normalize: bool = False, use_checkpoint: bool = False, + spatial_dims: int = 3, ) -> None: """ Args: @@ -66,6 +63,7 @@ def __init__( dropout_path_rate: drop path rate. normalize: normalize output intermediate features in each stage. use_checkpoint: use gradient checkpointing for reduced memory usage. + spatial_dims: number of spatial dims. Examples:: @@ -75,25 +73,24 @@ def __init__( # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage. >>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2)) - # for 3D single channel input with size (96,96,96), 2-channel output and gradient checkpointing. - >>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True) + # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing. + >>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2) """ super().__init__() - if isinstance(img_size, tuple) and len(img_size) == 2: - raise ValueError("3D Swin UNETR requires volumetric inputs.") - - spatial_dims = 3 img_size = ensure_tuple_rep(img_size, spatial_dims) patch_size = ensure_tuple_rep(2, spatial_dims) window_size = ensure_tuple_rep(7, spatial_dims) + if not (spatial_dims == 2 or spatial_dims == 3): + raise ValueError("spatial dimension should be 2 or 3.") + for m, p in zip(img_size, patch_size): for i in range(5): if m % np.power(p, i + 1) != 0: - raise ValueError("img_size should be divisible by stage-wise image resolution.") + raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.") if not (0 <= drop_rate <= 1): raise ValueError("dropout rate should be between 0 and 1.") @@ -229,12 +226,6 @@ def __init__( spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels ) # type: ignore - def proj_feat(self, x, hidden_size, feat_size): - - x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) - x = x.permute(0, 4, 1, 2, 3).contiguous() - return x - def load_from(self, weights): with torch.no_grad(): @@ -313,23 +304,30 @@ def window_partition(x, window_size): x: input tensor. window_size: local window size. """ - - b, d, h, w, c = x.shape - x = x.view( - b, - d // window_size[0], - window_size[0], - h // window_size[1], - window_size[1], - w // window_size[2], - window_size[2], - c, - ) - windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), c) + x_shape = x.size() + if len(x_shape) == 5: + b, d, h, w, c = x_shape + x = x.view( + b, + d // window_size[0], + window_size[0], + h // window_size[1], + window_size[1], + w // window_size[2], + window_size[2], + c, + ) + windows = ( + x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c) + ) + elif len(x_shape) == 4: + b, h, w, c = x.shape + x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c) return windows -def window_reverse(windows, window_size, b, d, h, w): +def window_reverse(windows, window_size, dims): """window reverse operation based on: "Liu et al., Swin Transformer: Hierarchical Vision Transformer using Shifted Windows " @@ -338,23 +336,26 @@ def window_reverse(windows, window_size, b, d, h, w): Args: windows: windows tensor. window_size: local window size. - b: batch size. - d: depth. - h: height. - w: width. + dims: dimension values. """ + if len(dims) == 4: + b, d, h, w = dims + x = windows.view( + b, + d // window_size[0], + h // window_size[1], + w // window_size[2], + window_size[0], + window_size[1], + window_size[2], + -1, + ) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1) - x = windows.view( - b, - d // window_size[0], - h // window_size[1], - w // window_size[2], - window_size[0], - window_size[1], - window_size[2], - -1, - ) - x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1) + elif len(dims) == 3: + b, h, w = dims + x = windows.view(b, h // window_size[0], w // window_size[0], window_size[0], window_size[1], -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1) return x @@ -418,21 +419,40 @@ def __init__( self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads) - ) - coords_d = torch.arange(self.window_size[0]) - coords_h = torch.arange(self.window_size[1]) - coords_w = torch.arange(self.window_size[2]) - coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) - coords_flatten = torch.flatten(coords, 1) - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] - relative_coords = relative_coords.permute(1, 2, 0).contiguous() - relative_coords[:, :, 0] += self.window_size[0] - 1 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 2] += self.window_size[2] - 1 - relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) - relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 + + if len(self.window_size) == 3: + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1), + num_heads, + ) + ) + coords_d = torch.arange(self.window_size[0]) + coords_h = torch.arange(self.window_size[1]) + coords_w = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 + elif len(self.window_size) == 2: + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) @@ -442,7 +462,7 @@ def __init__( trunc_normal_(self.relative_position_bias_table, std=0.02) self.softmax = nn.Softmax(dim=-1) - def forward(self, x, mask=None): + def forward(self, x, mask): b, n, c = x.shape qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] @@ -530,18 +550,34 @@ def __init__( self.mlp = Mlp(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin") def forward_part1(self, x, mask_matrix): - b, d, h, w, c = x.shape - window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size) + x_shape = x.size() x = self.norm1(x) - pad_l = pad_t = pad_d0 = 0 - pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0] - pad_b = (window_size[1] - h % window_size[1]) % window_size[1] - pad_r = (window_size[2] - w % window_size[2]) % window_size[2] - x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) - _, dp, hp, wp, _ = x.shape + if len(x_shape) == 5: + b, d, h, w, c = x.shape + window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size) + pad_l = pad_t = pad_d0 = 0 + pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0] + pad_b = (window_size[1] - h % window_size[1]) % window_size[1] + pad_r = (window_size[2] - w % window_size[2]) % window_size[2] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) + _, dp, hp, wp, _ = x.shape + dims = [b, dp, hp, wp] + + elif len(x_shape) == 4: + b, h, w, c = x.shape + window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size) + pad_l = pad_t = 0 + pad_r = (window_size[0] - h % window_size[0]) % window_size[0] + pad_b = (window_size[1] - w % window_size[1]) % window_size[1] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, hp, wp, _ = x.shape + dims = [b, hp, wp] if any(i > 0 for i in shift_size): - shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) + if len(x_shape) == 5: + shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) + elif len(x_shape) == 4: + shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) attn_mask = mask_matrix else: shifted_x = x @@ -549,13 +585,22 @@ def forward_part1(self, x, mask_matrix): x_windows = window_partition(shifted_x, window_size) attn_windows = self.attn(x_windows, mask=attn_mask) attn_windows = attn_windows.view(-1, *(window_size + (c,))) - shifted_x = window_reverse(attn_windows, window_size, b, dp, hp, wp) + shifted_x = window_reverse(attn_windows, window_size, dims) if any(i > 0 for i in shift_size): - x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) + if len(x_shape) == 5: + x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) + elif len(x_shape) == 4: + x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) else: x = shifted_x - if pad_d1 > 0 or pad_r > 0 or pad_b > 0: - x = x[:, :d, :h, :w, :].contiguous() + + if len(x_shape) == 5: + if pad_d1 > 0 or pad_r > 0 or pad_b > 0: + x = x[:, :d, :h, :w, :].contiguous() + elif len(x_shape) == 4: + if pad_r > 0 or pad_b > 0: + x = x[:, :h, :w, :].contiguous() + return x def forward_part2(self, x): @@ -617,39 +662,60 @@ class PatchMerging(nn.Module): https://github.com/microsoft/Swin-Transformer """ - def __init__(self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm) -> None: # type: ignore + def __init__( + self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3 + ) -> None: # type: ignore """ Args: dim: number of feature channels. norm_layer: normalization layer. + spatial_dims: number of spatial dims. """ super().__init__() self.dim = dim - self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False) - self.norm = norm_layer(8 * dim) + if spatial_dims == 3: + self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False) + self.norm = norm_layer(8 * dim) + elif spatial_dims == 2: + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) def forward(self, x): - b, d, h, w, c = x.shape - pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1) - if pad_input: - x = F.pad(x, (0, 0, 0, d % 2, 0, w % 2, 0, h % 2)) - - x0 = x[:, 0::2, 0::2, 0::2, :] - x1 = x[:, 1::2, 0::2, 0::2, :] - x2 = x[:, 0::2, 1::2, 0::2, :] - x3 = x[:, 0::2, 0::2, 1::2, :] - x4 = x[:, 1::2, 0::2, 1::2, :] - x5 = x[:, 0::2, 1::2, 0::2, :] - x6 = x[:, 0::2, 0::2, 1::2, :] - x7 = x[:, 1::2, 1::2, 1::2, :] - x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1) + + x_shape = x.size() + if len(x_shape) == 5: + b, d, h, w, c = x_shape + pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, d % 2, 0, w % 2, 0, h % 2)) + x0 = x[:, 0::2, 0::2, 0::2, :] + x1 = x[:, 1::2, 0::2, 0::2, :] + x2 = x[:, 0::2, 1::2, 0::2, :] + x3 = x[:, 0::2, 0::2, 1::2, :] + x4 = x[:, 1::2, 0::2, 1::2, :] + x5 = x[:, 0::2, 1::2, 0::2, :] + x6 = x[:, 0::2, 0::2, 1::2, :] + x7 = x[:, 1::2, 1::2, 1::2, :] + x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1) + + elif len(x_shape) == 4: + b, h, w, c = x_shape + pad_input = (h % 2 == 1) or (w % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2)) + x0 = x[:, 0::2, 0::2, :] + x1 = x[:, 1::2, 0::2, :] + x2 = x[:, 0::2, 1::2, :] + x3 = x[:, 1::2, 1::2, :] + x = torch.cat([x0, x1, x2, x3], -1) + x = self.norm(x) x = self.reduction(x) return x -def compute_mask(d, h, w, window_size, shift_size, device): +def compute_mask(dims, window_size, shift_size, device): """Computing region masks based on: "Liu et al., Swin Transformer: Hierarchical Vision Transformer using Shifted Windows @@ -657,25 +723,36 @@ def compute_mask(d, h, w, window_size, shift_size, device): https://github.com/microsoft/Swin-Transformer Args: - d: depth value. - h: height value. - w: height value. + dims: dimension values. window_size: local window size. shift_size: shift size. device: device. """ - img_mask = torch.zeros((1, d, h, w, 1), device=device) cnt = 0 - for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None): - for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None): - for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None): - img_mask[:, d, h, w, :] = cnt + + if len(dims) == 3: + d, h, w = dims + img_mask = torch.zeros((1, d, h, w, 1), device=device) + for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None): + for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None): + for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None): + img_mask[:, d, h, w, :] = cnt + cnt += 1 + + elif len(dims) == 2: + h, w = dims + img_mask = torch.zeros((1, h, w, 1), device=device) + for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None): + for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None): + img_mask[:, h, w, :] = cnt cnt += 1 + mask_windows = window_partition(img_mask, window_size) mask_windows = mask_windows.squeeze(-1) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + return attn_mask @@ -721,6 +798,7 @@ def __init__( super().__init__() self.window_size = window_size self.shift_size = tuple(i // 2 for i in window_size) + self.no_shift = tuple(0 for i in window_size) self.depth = depth self.use_checkpoint = use_checkpoint self.blocks = nn.ModuleList( @@ -728,8 +806,8 @@ def __init__( SwinTransformerBlock( dim=dim, num_heads=num_heads, - window_size=window_size, - shift_size=(0, 0, 0) if (i % 2 == 0) else self.shift_size, + window_size=self.window_size, + shift_size=self.no_shift if (i % 2 == 0) else self.shift_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, @@ -743,151 +821,38 @@ def __init__( ) self.downsample = downsample if self.downsample is not None: - self.downsample = downsample(dim=dim, norm_layer=norm_layer) + self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size)) def forward(self, x): - b, c, d, h, w = x.shape - window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size) - x = rearrange(x, "b c d h w -> b d h w c") - dp = int(np.ceil(d / window_size[0])) * window_size[0] - hp = int(np.ceil(h / window_size[1])) * window_size[1] - wp = int(np.ceil(w / window_size[2])) * window_size[2] - attn_mask = compute_mask(dp, hp, wp, window_size, shift_size, x.device) - for blk in self.blocks: - x = blk(x, attn_mask) - x = x.view(b, d, h, w, -1) - if self.downsample is not None: - x = self.downsample(x) - x = rearrange(x, "b d h w c -> b c d h w") - return x - - -def _no_grad_trunc_normal_(tensor, mean, std, a, b): - - """Tensor initialization with truncated normal distribution. - Based on: - https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - https://github.com/rwightman/pytorch-image-models - - Args: - tensor: an n-dimensional `torch.Tensor`. - mean: the mean of the normal distribution. - std: the standard deviation of the normal distribution. - a: the minimum cutoff value. - b: the maximum cutoff value. - """ - - def norm_cdf(x): - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - with torch.no_grad(): - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - tensor.uniform_(2 * l - 1, 2 * u - 1) - tensor.erfinv_() - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - tensor.clamp_(min=a, max=b) - return tensor - - -def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): - - """Tensor initialization with truncated normal distribution. - Based on: - https://github.com/rwightman/pytorch-image-models - - Args: - tensor: an n-dimensional `torch.Tensor` - mean: the mean of the normal distribution - std: the standard deviation of the normal distribution - a: the minimum cutoff value - b: the maximum cutoff value - """ - - return _no_grad_trunc_normal_(tensor, mean, std, a, b) - - -class DropPath(nn.Module): - """Stochastic drop paths per sample for residual blocks. - Based on: - https://github.com/rwightman/pytorch-image-models - """ - - def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True) -> None: - super(DropPath, self).__init__() - self.drop_prob = drop_prob - self.scale_by_keep = scale_by_keep - """ - Args: - drop_prob: drop path probability. - scale_by_keep: scaling by non-dropped probability. - """ - - def drop_path(self, x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True): - if drop_prob == 0.0 or not training: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) - random_tensor = x.new_empty(shape).bernoulli_(keep_prob) - if keep_prob > 0.0 and scale_by_keep: - random_tensor.div_(keep_prob) - return x * random_tensor - - def forward(self, x): - return self.drop_path(x, self.drop_prob, self.training, self.scale_by_keep) - - -class PatchEmbed(nn.Module): - """ - Patch embedding block based on: "Liu et al., - Swin Transformer: Hierarchical Vision Transformer using Shifted Windows - " - https://github.com/microsoft/Swin-Transformer - """ - - def __init__( - self, - patch_size: Sequence[int], - in_chans: int = 1, - embed_dim: int = 96, - norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore - spatial_dims: int = 3, - ) -> None: - """ - Args: - patch_size: dimension of patch size. - in_chans: dimension of input channels. - embed_dim: number of linear projection output channels. - norm_layer: normalization layer. - spatial_dims: spatial dimension. - """ - - super().__init__() - self.patch_size = patch_size - self.embed_dim = embed_dim - self.proj = Conv[Conv.CONV, spatial_dims]( - in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size - ) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - _, _, d, h, w = x.size() - if w % self.patch_size[2] != 0: - x = F.pad(x, (0, self.patch_size[2] - w % self.patch_size[2])) - if h % self.patch_size[1] != 0: - x = F.pad(x, (0, 0, 0, self.patch_size[1] - h % self.patch_size[1])) - if d % self.patch_size[0] != 0: - x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - d % self.patch_size[0])) - x = self.proj(x) - if self.norm is not None: - d, wh, ww = x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2) - x = self.norm(x) - x = x.transpose(1, 2).view(-1, self.embed_dim, d, wh, ww) + x_shape = x.size() + if len(x_shape) == 5: + b, c, d, h, w = x_shape + window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size) + x = rearrange(x, "b c d h w -> b d h w c") + dp = int(np.ceil(d / window_size[0])) * window_size[0] + hp = int(np.ceil(h / window_size[1])) * window_size[1] + wp = int(np.ceil(w / window_size[2])) * window_size[2] + attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device) + for blk in self.blocks: + x = blk(x, attn_mask) + x = x.view(b, d, h, w, -1) + if self.downsample is not None: + x = self.downsample(x) + x = rearrange(x, "b d h w c -> b c d h w") + + elif len(x_shape) == 4: + b, c, h, w = x_shape + window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size) + x = rearrange(x, "b c h w -> b h w c") + hp = int(np.ceil(h / window_size[0])) * window_size[0] + wp = int(np.ceil(w / window_size[1])) * window_size[1] + attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device) + for blk in self.blocks: + x = blk(x, attn_mask) + x = x.view(b, h, w, -1) + if self.downsample is not None: + x = self.downsample(x) + x = rearrange(x, "b h w c -> b c h w") return x @@ -983,10 +948,17 @@ def __init__( def proj_out(self, x, normalize=False): if normalize: - n, ch, d, h, w = x.size() - x = rearrange(x, "n c d h w -> n d h w c") - x = F.layer_norm(x, [ch]) - x = rearrange(x, "n d h w c -> n c d h w") + x_shape = x.size() + if len(x_shape) == 5: + n, ch, d, h, w = x_shape + x = rearrange(x, "n c d h w -> n d h w c") + x = F.layer_norm(x, [ch]) + x = rearrange(x, "n d h w c -> n c d h w") + elif len(x_shape) == 4: + n, ch, h, w = x_shape + x = rearrange(x, "n c h w -> n h w c") + x = F.layer_norm(x, [ch]) + x = rearrange(x, "n h w c -> n c h w") return x def forward(self, x, normalize=False): diff --git a/tests/test_drop_path.py b/tests/test_drop_path.py new file mode 100644 index 0000000000..f8ea454228 --- /dev/null +++ b/tests/test_drop_path.py @@ -0,0 +1,43 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.layers import DropPath + +TEST_CASES = [ + [{"drop_prob": 0.0, "scale_by_keep": True}, (1, 8, 8)], + [{"drop_prob": 0.7, "scale_by_keep": False}, (2, 16, 16, 16)], + [{"drop_prob": 0.3, "scale_by_keep": True}, (6, 16, 12)], +] + +TEST_ERRORS = [[{"drop_prob": 2, "scale_by_keep": False}, (1, 24, 6)]] + + +class TestDropPath(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_shape): + im = torch.rand(input_shape) + dr_path = DropPath(**input_param) + out = dr_path(im) + self.assertEqual(out.shape, input_shape) + + @parameterized.expand(TEST_ERRORS) + def test_ill_arg(self, input_param, input_shape): + with self.assertRaises(ValueError): + DropPath(**input_param) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py index 4af2b47ba5..6971eb0463 100644 --- a/tests/test_patchembedding.py +++ b/tests/test_patchembedding.py @@ -13,10 +13,11 @@ from unittest import skipUnless import torch +import torch.nn as nn from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.blocks.patchembedding import PatchEmbeddingBlock +from monai.networks.blocks.patchembedding import PatchEmbed, PatchEmbeddingBlock from monai.utils import optional_import einops, has_einops = optional_import("einops") @@ -48,6 +49,26 @@ test_case[0]["spatial_dims"] = 2 # type: ignore TEST_CASE_PATCHEMBEDDINGBLOCK.append(test_case) +TEST_CASE_PATCHEMBED = [] +for patch_size in [2]: + for in_chans in [1, 4]: + for img_size in [96]: + for embed_dim in [6, 12]: + for norm_layer in [nn.LayerNorm]: + for nd in [2, 3]: + test_case = [ + { + "patch_size": (patch_size,) * nd, + "in_chans": in_chans, + "embed_dim": embed_dim, + "norm_layer": norm_layer, + "spatial_dims": nd, + }, + (2, in_chans, *([img_size] * nd)), + (2, embed_dim, *([img_size // patch_size] * nd)), + ] + TEST_CASE_PATCHEMBED.append(test_case) + class TestPatchEmbeddingBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_PATCHEMBEDDINGBLOCK) @@ -115,5 +136,19 @@ def test_ill_arg(self): ) +class TestPatchEmbed(unittest.TestCase): + @parameterized.expand(TEST_CASE_PATCHEMBED) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, input_param, input_shape, expected_shape): + net = PatchEmbed(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + PatchEmbed(patch_size=(2, 2, 2), in_chans=1, embed_dim=24, norm_layer=nn.LayerNorm, spatial_dims=5) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py index 1680567472..0d48e99c44 100644 --- a/tests/test_swin_unetr.py +++ b/tests/test_swin_unetr.py @@ -21,7 +21,7 @@ einops, has_einops = optional_import("einops") -TEST_CASE_UNETR = [] +TEST_CASE_SWIN_UNETR = [] for attn_drop_rate in [0.4]: for in_channels in [1]: for depth in [[2, 1, 1, 1], [1, 2, 1, 1]]: @@ -29,24 +29,27 @@ for img_size in [64]: for feature_size in [12]: for norm_name in ["instance"]: - test_case = [ - { - "in_channels": in_channels, - "out_channels": out_channels, - "img_size": (img_size,) * 3, - "feature_size": feature_size, - "depths": depth, - "norm_name": norm_name, - "attn_drop_rate": attn_drop_rate, - }, - (2, in_channels, *([img_size] * 3)), - (2, out_channels, *([img_size] * 3)), - ] - TEST_CASE_UNETR.append(test_case) + for nd in (2, 3): + test_case = [ + { + "in_channels": in_channels, + "out_channels": out_channels, + "img_size": (img_size,) * nd, + "feature_size": feature_size, + "depths": depth, + "norm_name": norm_name, + "attn_drop_rate": attn_drop_rate, + }, + (2, in_channels, *([img_size] * nd)), + (2, out_channels, *([img_size] * nd)), + ] + if nd == 2: + test_case[0]["spatial_dims"] = 2 # type: ignore + TEST_CASE_SWIN_UNETR.append(test_case) -class TestPatchEmbeddingBlock(unittest.TestCase): - @parameterized.expand(TEST_CASE_UNETR) +class TestSWINUNETR(unittest.TestCase): + @parameterized.expand(TEST_CASE_SWIN_UNETR) @skipUnless(has_einops, "Requires einops") def test_shape(self, input_param, input_shape, expected_shape): net = SwinUNETR(**input_param) diff --git a/tests/test_weight_init.py b/tests/test_weight_init.py new file mode 100644 index 0000000000..c850ff4ce6 --- /dev/null +++ b/tests/test_weight_init.py @@ -0,0 +1,47 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.layers import trunc_normal_ + +TEST_CASES = [ + [{"mean": 0.0, "std": 1.0, "a": 2, "b": 4}, (6, 12, 3, 1, 7)], + [{"mean": 0.3, "std": 0.6, "a": -1.0, "b": 1.3}, (1, 4, 4, 4)], + [{"mean": 0.1, "std": 0.4, "a": 1.3, "b": 1.8}, (5, 7, 7, 8, 9)], +] + +TEST_ERRORS = [ + [{"mean": 0.0, "std": 1.0, "a": 5, "b": 1.1}, (1, 1, 8, 8, 8)], + [{"mean": 0.3, "std": -0.1, "a": 1.0, "b": 2.0}, (8, 5, 2, 6, 9)], + [{"mean": 0.7, "std": 0.0, "a": 0.1, "b": 2.0}, (4, 12, 23, 17)], +] + + +class TestWeightInit(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_shape): + im = torch.rand(input_shape) + trunc_normal_(im, **input_param) + self.assertEqual(im.shape, input_shape) + + @parameterized.expand(TEST_ERRORS) + def test_ill_arg(self, input_param, input_shape): + with self.assertRaises(ValueError): + im = torch.rand(input_shape) + trunc_normal_(im, **input_param) + + +if __name__ == "__main__": + unittest.main() From a7fa66bf60558273972b223793512397a2ec47af Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 2 May 2022 11:28:57 -0700 Subject: [PATCH 20/26] add swin_unetr model Signed-off-by: ahatamizadeh --- monai/networks/nets/swin_unetr.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 439009b57c..02266c4f3c 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -430,7 +430,10 @@ def __init__( coords_d = torch.arange(self.window_size[0]) coords_h = torch.arange(self.window_size[1]) coords_w = torch.arange(self.window_size[2]) - coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij")) + if "indexing" in torch.meshgrid.__kwdefaults__: + coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij")) + else: + coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) coords_flatten = torch.flatten(coords, 1) relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords = relative_coords.permute(1, 2, 0).contiguous() From 2631fb32b1ddf5b2af3d2d7fe923ac9e4d11199e Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 2 May 2022 11:39:39 -0700 Subject: [PATCH 21/26] add swin_unetr model Signed-off-by: ahatamizadeh --- monai/networks/nets/swin_unetr.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 02266c4f3c..bd2ad44583 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -448,7 +448,10 @@ def __init__( ) coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) + if "indexing" in torch.meshgrid.__kwdefaults__: + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) + else: + coords = torch.stack(torch.meshgrid(coords_h, coords_w)) coords_flatten = torch.flatten(coords, 1) relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords = relative_coords.permute(1, 2, 0).contiguous() From d1e081712de6be1ed1e432ceff0f9cf611a8545a Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 2 May 2022 13:39:14 -0700 Subject: [PATCH 22/26] add swin_unetr model Signed-off-by: ahatamizadeh --- monai/networks/nets/swin_unetr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index bd2ad44583..893ff45ebb 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -430,7 +430,7 @@ def __init__( coords_d = torch.arange(self.window_size[0]) coords_h = torch.arange(self.window_size[1]) coords_w = torch.arange(self.window_size[2]) - if "indexing" in torch.meshgrid.__kwdefaults__: + if torch.meshgrid.__kwdefaults__.__contains__("indexing"): coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij")) else: coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) @@ -448,7 +448,7 @@ def __init__( ) coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) - if "indexing" in torch.meshgrid.__kwdefaults__: + if torch.meshgrid.__kwdefaults__.__contains__("indexing"): coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) else: coords = torch.stack(torch.meshgrid(coords_h, coords_w)) From b6795e41f4dc49e741ce4997db9b7a78ed5ab609 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Mon, 2 May 2022 18:51:28 -0700 Subject: [PATCH 23/26] add swin_unetr model Signed-off-by: ahatamizadeh --- monai/networks/nets/swin_unetr.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 893ff45ebb..3d8e60a140 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -419,6 +419,7 @@ def __init__( self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 + mesh_args = torch.meshgrid.__kwdefaults__ if len(self.window_size) == 3: self.relative_position_bias_table = nn.Parameter( @@ -430,7 +431,7 @@ def __init__( coords_d = torch.arange(self.window_size[0]) coords_h = torch.arange(self.window_size[1]) coords_w = torch.arange(self.window_size[2]) - if torch.meshgrid.__kwdefaults__.__contains__("indexing"): + if mesh_args is not None: coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij")) else: coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) @@ -448,7 +449,7 @@ def __init__( ) coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) - if torch.meshgrid.__kwdefaults__.__contains__("indexing"): + if mesh_args is not None: coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) else: coords = torch.stack(torch.meshgrid(coords_h, coords_w)) From 14f14530bd34ee79a2df07adfc66c38a2652e6a6 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Tue, 3 May 2022 14:18:12 -0700 Subject: [PATCH 24/26] add swin_unetr model Signed-off-by: ahatamizadeh --- monai/networks/layers/weight_init.py | 2 +- monai/networks/nets/swin_unetr.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/layers/weight_init.py b/monai/networks/layers/weight_init.py index 2217bdb42c..516a314e8b 100644 --- a/monai/networks/layers/weight_init.py +++ b/monai/networks/layers/weight_init.py @@ -61,6 +61,6 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): raise ValueError("the standard deviation should be greater than zero.") if a >= b: - raise ValueError("minimum cutoff value (a) should be greater than maximum cutoff value (b).") + raise ValueError("minimum cutoff value (a) should be smaller than maximum cutoff value (b).") return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 3d8e60a140..2aa23c1c3e 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2022 MONAI Consortium +# Copyright 2020 - 2022 -> (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at From f5c64e1c2eb260ca5c457bc82d64ac52c395e171 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 May 2022 21:18:40 +0000 Subject: [PATCH 25/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/layers/drop_path.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/layers/drop_path.py b/monai/networks/layers/drop_path.py index f91f65d682..7bb209ed25 100644 --- a/monai/networks/layers/drop_path.py +++ b/monai/networks/layers/drop_path.py @@ -24,7 +24,7 @@ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True) -> None: drop_prob: drop path probability. scale_by_keep: scaling by non-dropped probability. """ - super(DropPath, self).__init__() + super().__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep From ee8ca910bb10b9a25ad33eb88064d4b47f6823a8 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 3 May 2022 21:47:08 +0000 Subject: [PATCH 26/26] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/networks/layers/weight_init.py | 2 -- monai/networks/nets/swin_unetr.py | 1 - 2 files changed, 3 deletions(-) diff --git a/monai/networks/layers/weight_init.py b/monai/networks/layers/weight_init.py index 516a314e8b..9b81ef17f8 100644 --- a/monai/networks/layers/weight_init.py +++ b/monai/networks/layers/weight_init.py @@ -15,7 +15,6 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b): - """Tensor initialization with truncated normal distribution. Based on: https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf @@ -44,7 +43,6 @@ def norm_cdf(x): def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): - """Tensor initialization with truncated normal distribution. Based on: https://github.com/rwightman/pytorch-image-models diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 2aa23c1c3e..d898da9884 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -723,7 +723,6 @@ def forward(self, x): def compute_mask(dims, window_size, shift_size, device): - """Computing region masks based on: "Liu et al., Swin Transformer: Hierarchical Vision Transformer using Shifted Windows "