diff --git a/monai/networks/blocks/dynunet_block.py b/monai/networks/blocks/dynunet_block.py index 248bc5341c..43d3c46cc9 100644 --- a/monai/networks/blocks/dynunet_block.py +++ b/monai/networks/blocks/dynunet_block.py @@ -33,7 +33,8 @@ class UnetResBlock(nn.Module): kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. - dropout: dropout probability + act_name: activation layer type and arguments. + dropout: dropout probability. """ @@ -45,6 +46,7 @@ def __init__( kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), dropout: Optional[Union[Tuple, str, float]] = None, ): super().__init__() @@ -63,7 +65,7 @@ def __init__( self.conv3 = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=1, stride=stride, dropout=dropout, conv_only=True ) - self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01})) + self.lrelu = get_act_layer(name=act_name) self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) self.norm3 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) @@ -100,7 +102,8 @@ class UnetBasicBlock(nn.Module): kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. - dropout: dropout probability + act_name: activation layer type and arguments. + dropout: dropout probability. """ @@ -112,6 +115,7 @@ def __init__( kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), dropout: Optional[Union[Tuple, str, float]] = None, ): super().__init__() @@ -127,7 +131,7 @@ def __init__( self.conv2 = get_conv_layer( spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True ) - self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01})) + self.lrelu = get_act_layer(name=act_name) self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) @@ -155,7 +159,9 @@ class UnetUpBlock(nn.Module): stride: convolution stride. upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: feature normalization type and arguments. - dropout: dropout probability + act_name: activation layer type and arguments. + dropout: dropout probability. + trans_bias: transposed convolution bias. """ @@ -168,7 +174,9 @@ def __init__( stride: Union[Sequence[int], int], upsample_kernel_size: Union[Sequence[int], int], norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), dropout: Optional[Union[Tuple, str, float]] = None, + trans_bias: bool = False, ): super().__init__() upsample_stride = upsample_kernel_size @@ -179,6 +187,7 @@ def __init__( kernel_size=upsample_kernel_size, stride=upsample_stride, dropout=dropout, + bias=trans_bias, conv_only=True, is_transposed=True, ) @@ -190,6 +199,7 @@ def __init__( stride=1, dropout=dropout, norm_name=norm_name, + act_name=act_name, ) def forward(self, inp, skip): diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index 3e846b9b7b..696c9d25dc 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -57,6 +57,7 @@ class DynUNet(nn.Module): This reimplementation of a dynamic UNet (DynUNet) is based on: `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. + `Optimized U-Net for Brain Tumor Segmentation `_. This model is more flexible compared with ``monai.networks.nets.UNet`` in three places: @@ -89,8 +90,15 @@ class DynUNet(nn.Module): strides: convolution strides for each blocks. upsample_kernel_size: convolution kernel size for transposed convolution layers. The values should equal to strides[1:]. + filters: number of output channels for each blocks. Different from nnU-Net, in this implementation we add + this argument to make the network more flexible. As shown in the third reference, one way to determine + this argument is like: + ``[64, 96, 128, 192, 256, 384, 512, 768, 1024][: len(strides)]``. + The above way is used in the network that wins task 1 in the BraTS21 Challenge. + If not specified, the way which nnUNet used will be employed. Defaults to ``None``. dropout: dropout ratio. Defaults to no dropout. norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``. + act_name: activation layer type and arguments. Defaults to ``leakyrelu``. deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. If ``True``, in training mode, the forward function will output not only the last feature map, but also the previous feature maps that come from the intermediate up sample layers. @@ -109,6 +117,7 @@ class DynUNet(nn.Module): Defaults to 1. res_block: whether to use residual connection based convolution blocks during the network. Defaults to ``False``. + trans_bias: whether to set the bias parameter in transposed convolution layers. Defaults to ``False``. """ def __init__( @@ -119,11 +128,14 @@ def __init__( kernel_size: Sequence[Union[Sequence[int], int]], strides: Sequence[Union[Sequence[int], int]], upsample_kernel_size: Sequence[Union[Sequence[int], int]], + filters: Optional[Sequence[int]] = None, dropout: Optional[Union[Tuple, str, float]] = None, norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), deep_supervision: bool = False, deep_supr_num: int = 1, res_block: bool = False, + trans_bias: bool = False, ): super().__init__() self.spatial_dims = spatial_dims @@ -133,9 +145,15 @@ def __init__( self.strides = strides self.upsample_kernel_size = upsample_kernel_size self.norm_name = norm_name + self.act_name = act_name self.dropout = dropout self.conv_block = UnetResBlock if res_block else UnetBasicBlock - self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] + self.trans_bias = trans_bias + if filters is not None: + self.filters = filters + self.check_filters() + else: + self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] self.input_block = self.get_input_block() self.downsamples = self.get_downsamples() self.bottleneck = self.get_bottleneck() @@ -161,9 +179,9 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck): """ if len(downsamples) != len(upsamples): - raise AssertionError(f"{len(downsamples)} != {len(upsamples)}") + raise ValueError(f"{len(downsamples)} != {len(upsamples)}") if (len(downsamples) - len(superheads)) not in (1, 0): - raise AssertionError(f"{len(downsamples)}-(0,1) != {len(superheads)}") + raise ValueError(f"{len(downsamples)}-(0,1) != {len(superheads)}") if len(downsamples) == 0: # bottom of the network, pass the bottleneck block return bottleneck @@ -191,26 +209,33 @@ def check_kernel_stride(self): kernels, strides = self.kernel_size, self.strides error_msg = "length of kernel_size and strides should be the same, and no less than 3." if len(kernels) != len(strides) or len(kernels) < 3: - raise AssertionError(error_msg) + raise ValueError(error_msg) for idx, k_i in enumerate(kernels): kernel, stride = k_i, strides[idx] if not isinstance(kernel, int): error_msg = f"length of kernel_size in block {idx} should be the same as spatial_dims." if len(kernel) != self.spatial_dims: - raise AssertionError(error_msg) + raise ValueError(error_msg) if not isinstance(stride, int): error_msg = f"length of stride in block {idx} should be the same as spatial_dims." if len(stride) != self.spatial_dims: - raise AssertionError(error_msg) + raise ValueError(error_msg) def check_deep_supr_num(self): deep_supr_num, strides = self.deep_supr_num, self.strides num_up_layers = len(strides) - 1 if deep_supr_num >= num_up_layers: - raise AssertionError("deep_supr_num should be less than the number of up sample layers.") + raise ValueError("deep_supr_num should be less than the number of up sample layers.") if deep_supr_num < 1: - raise AssertionError("deep_supr_num should be larger than 0.") + raise ValueError("deep_supr_num should be larger than 0.") + + def check_filters(self): + filters = self.filters + if len(filters) < len(self.strides): + raise ValueError("length of filters should be no less than the length of strides.") + else: + self.filters = filters[: len(self.strides)] def forward(self, x): out = self.skip_layers(x) @@ -231,6 +256,7 @@ def get_input_block(self): self.kernel_size[0], self.strides[0], self.norm_name, + self.act_name, dropout=self.dropout, ) @@ -242,6 +268,7 @@ def get_bottleneck(self): self.kernel_size[-1], self.strides[-1], self.norm_name, + self.act_name, dropout=self.dropout, ) @@ -257,7 +284,9 @@ def get_upsamples(self): inp, out = self.filters[1:][::-1], self.filters[:-1][::-1] strides, kernel_size = self.strides[1:][::-1], self.kernel_size[1:][::-1] upsample_kernel_size = self.upsample_kernel_size[::-1] - return self.get_module_list(inp, out, kernel_size, strides, UnetUpBlock, upsample_kernel_size) + return self.get_module_list( + inp, out, kernel_size, strides, UnetUpBlock, upsample_kernel_size, trans_bias=self.trans_bias + ) def get_module_list( self, @@ -267,6 +296,7 @@ def get_module_list( strides: Sequence[Union[Sequence[int], int]], conv_block: nn.Module, upsample_kernel_size: Optional[Sequence[Union[Sequence[int], int]]] = None, + trans_bias: bool = False, ): layers = [] if upsample_kernel_size is not None: @@ -280,8 +310,10 @@ def get_module_list( "kernel_size": kernel, "stride": stride, "norm_name": self.norm_name, + "act_name": self.act_name, "dropout": self.dropout, "upsample_kernel_size": up_kernel, + "trans_bias": trans_bias, } layer = conv_block(**params) layers.append(layer) @@ -294,6 +326,7 @@ def get_module_list( "kernel_size": kernel, "stride": stride, "norm_name": self.norm_name, + "act_name": self.act_name, "dropout": self.dropout, } layer = conv_block(**params) diff --git a/monai/networks/nets/dynunet_v1.py b/monai/networks/nets/dynunet_v1.py deleted file mode 100644 index 4c910157c9..0000000000 --- a/monai/networks/nets/dynunet_v1.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import List, Sequence, Union - -import torch -import torch.nn as nn - -from monai.networks.blocks.dynunet_block_v1 import _UnetBasicBlockV1, _UnetResBlockV1, _UnetUpBlockV1 -from monai.networks.nets.dynunet import DynUNet, DynUNetSkipLayer -from monai.utils import deprecated - -__all__ = ["DynUNetV1", "DynUnetV1", "DynunetV1"] - - -@deprecated( - since="0.6.0", - removed="0.8.0", - msg_suffix="This module is for backward compatibility purpose only. Please use `DynUNet` instead.", -) -class DynUNetV1(DynUNet): - """ - This a deprecated reimplementation of a dynamic UNet (DynUNet), please use `monai.networks.nets.DynUNet` instead. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - kernel_size: convolution kernel size. - strides: convolution strides for each blocks. - upsample_kernel_size: convolution kernel size for transposed convolution layers. - dropout: dropout ratio. Defaults to no dropout. - norm_name: [``"batch"``, ``"instance"``, ``"group"``]. Defaults to "instance". - deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. - deep_supr_num: number of feature maps that will output during deep supervision head. Defaults to 1. - res_block: whether to use residual connection based convolution blocks during the network. - Defaults to ``False``. - - .. deprecated:: 0.6.0 - Use :class:`monai.networks.nets.DynUNet` instead. - - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Sequence[Union[Sequence[int], int]], - strides: Sequence[Union[Sequence[int], int]], - upsample_kernel_size: Sequence[Union[Sequence[int], int]], - dropout: float = 0.0, - norm_name: str = "instance", - deep_supervision: bool = False, - deep_supr_num: int = 1, - res_block: bool = False, - ): - nn.Module.__init__(self) - self.spatial_dims = spatial_dims - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.strides = strides - self.upsample_kernel_size = upsample_kernel_size - self.norm_name = norm_name - self.dropout = dropout - self.conv_block = _UnetResBlockV1 if res_block else _UnetBasicBlockV1 # type: ignore - self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] - self.input_block = self.get_input_block() - self.downsamples = self.get_downsamples() - self.bottleneck = self.get_bottleneck() - self.upsamples = self.get_upsamples() - self.output_block = self.get_output_block(0) - self.deep_supervision = deep_supervision - self.deep_supervision_heads = self.get_deep_supervision_heads() - self.deep_supr_num = deep_supr_num - self.apply(self.initialize_weights) - self.check_kernel_stride() - self.check_deep_supr_num() - - # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on - self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1) - - def create_skips(index, downsamples, upsamples, superheads, bottleneck): - """ - Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is - done recursively from the top down since a recursive nn.Module subclass is being used to be compatible - with Torchscript. Initially the length of `downsamples` will be one more than that of `superheads` - since the `input_block` is passed to this function as the first item in `downsamples`, however this - shouldn't be associated with a supervision head. - """ - - if len(downsamples) != len(upsamples): - raise AssertionError(f"{len(downsamples)} != {len(upsamples)}") - if (len(downsamples) - len(superheads)) not in (1, 0): - raise AssertionError(f"{len(downsamples)}-(0,1) != {len(superheads)}") - - if len(downsamples) == 0: # bottom of the network, pass the bottleneck block - return bottleneck - if index == 0: # don't associate a supervision head with self.input_block - current_head, rest_heads = nn.Identity(), superheads - elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one - current_head, rest_heads = nn.Identity(), superheads[1:] - else: - current_head, rest_heads = superheads[0], superheads[1:] - - # create the next layer down, this will stop at the bottleneck layer - next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck) - - return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer) - - self.skip_layers = create_skips( - 0, - [self.input_block] + list(self.downsamples), - self.upsamples[::-1], - self.deep_supervision_heads, - self.bottleneck, - ) - - def get_upsamples(self): - inp, out = self.filters[1:][::-1], self.filters[:-1][::-1] - strides, kernel_size = self.strides[1:][::-1], self.kernel_size[1:][::-1] - upsample_kernel_size = self.upsample_kernel_size[::-1] - return self.get_module_list(inp, out, kernel_size, strides, _UnetUpBlockV1, upsample_kernel_size) - - @staticmethod - def initialize_weights(module): - name = module.__class__.__name__.lower() - if "conv3d" in name or "conv2d" in name: - nn.init.kaiming_normal_(module.weight, a=0.01) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - elif "norm" in name: - nn.init.normal_(module.weight, 1.0, 0.02) - nn.init.zeros_(module.bias) - - -DynUnetV1 = DynunetV1 = DynUNetV1 diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index 18fe146a40..ca19ea2b47 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -43,6 +43,7 @@ "strides": strides, "upsample_kernel_size": strides[1:], "norm_name": "batch", + "act_name": ("leakyrelu", {"inplace": True, "negative_slope": 0.2}), "deep_supervision": False, "res_block": res_block, "dropout": None, @@ -66,8 +67,9 @@ "kernel_size": (3, (1, 1, 3), 3, 3), "strides": ((1, 2, 1), 2, 2, 1), "upsample_kernel_size": (2, 2, 1), + "filters": (64, 96, 128, 192), "norm_name": ("INSTANCE", {"affine": True}), - "deep_supervision": False, + "deep_supervision": True, "res_block": res_block, "dropout": ("alphadropout", {"p": 0.25}), }, diff --git a/tests/test_dynunet_block.py b/tests/test_dynunet_block.py index 7e832f6d81..de3c018d78 100644 --- a/tests/test_dynunet_block.py +++ b/tests/test_dynunet_block.py @@ -35,6 +35,7 @@ "out_channels": 16, "kernel_size": kernel_size, "norm_name": norm_name, + "act_name": ("leakyrelu", {"inplace": True, "negative_slope": 0.1}), "stride": stride, }, (1, 16, *([in_size] * spatial_dims)), @@ -49,22 +50,24 @@ for stride in [1, 2]: for norm_name in ["batch", "instance"]: for in_size in [15, 16]: - out_size = in_size * stride - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": in_channels, - "out_channels": out_channels, - "kernel_size": kernel_size, - "norm_name": norm_name, - "stride": stride, - "upsample_kernel_size": stride, - }, - (1, in_channels, *([in_size] * spatial_dims)), - (1, out_channels, *([out_size] * spatial_dims)), - (1, out_channels, *([in_size * stride] * spatial_dims)), - ] - TEST_UP_BLOCK.append(test_case) + for trans_bias in [True, False]: + out_size = in_size * stride + test_case = [ + { + "spatial_dims": spatial_dims, + "in_channels": in_channels, + "out_channels": out_channels, + "kernel_size": kernel_size, + "norm_name": norm_name, + "stride": stride, + "upsample_kernel_size": stride, + "trans_bias": trans_bias, + }, + (1, in_channels, *([in_size] * spatial_dims)), + (1, out_channels, *([out_size] * spatial_dims)), + (1, out_channels, *([in_size * stride] * spatial_dims)), + ] + TEST_UP_BLOCK.append(test_case) class TestResBasicBlock(unittest.TestCase): diff --git a/tests/test_dynunet_v1.py b/tests/test_dynunet_v1.py deleted file mode 100644 index fc216c145b..0000000000 --- a/tests/test_dynunet_v1.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from typing import Any, Sequence, Union - -import torch -from parameterized import parameterized - -from monai.networks import eval_mode -from monai.networks.nets.dynunet_v1 import DynUNetV1 -from tests.utils import skip_if_quick, test_script_save - -device = "cuda" if torch.cuda.is_available() else "cpu" - -strides: Sequence[Union[Sequence[int], int]] -kernel_size: Sequence[Any] -expected_shape: Sequence[Any] - -TEST_CASE_DYNUNET_2D = [] -for kernel_size in [(3, 3, 3, 1), ((3, 1), 1, (3, 3), (1, 1))]: - for strides in [(1, 1, 1, 1), (2, 2, 2, 1)]: - for in_channels in [2, 3]: - for res_block in [True, False]: - out_channels = 2 - in_size = 64 - spatial_dims = 2 - expected_shape = (1, out_channels, *[in_size // strides[0]] * spatial_dims) - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": in_channels, - "out_channels": out_channels, - "kernel_size": kernel_size, - "strides": strides, - "upsample_kernel_size": strides[1:], - "norm_name": "batch", - "deep_supervision": False, - "res_block": res_block, - }, - (1, in_channels, in_size, in_size), - expected_shape, - ] - TEST_CASE_DYNUNET_2D.append(test_case) - -TEST_CASE_DYNUNET_3D = [] # in 3d cases, also test anisotropic kernel/strides -for out_channels in [2, 3]: - for res_block in [True, False]: - in_channels = 1 - in_size = 64 - expected_shape = (1, out_channels, 64, 32, 64) - test_case = [ - { - "spatial_dims": 3, - "in_channels": in_channels, - "out_channels": out_channels, - "kernel_size": (3, (1, 1, 3), 3, 3), - "strides": ((1, 2, 1), 2, 2, 1), - "upsample_kernel_size": (2, 2, 1), - "norm_name": "instance", - "deep_supervision": False, - "res_block": res_block, - }, - (1, in_channels, in_size, in_size, in_size), - expected_shape, - ] - TEST_CASE_DYNUNET_3D.append(test_case) - -TEST_CASE_DEEP_SUPERVISION = [] -for spatial_dims in [2, 3]: - for res_block in [True, False]: - for deep_supr_num in [1, 2]: - for strides in [(1, 2, 1, 2, 1), (2, 2, 2, 1), (2, 1, 1, 2, 2)]: - scale = strides[0] - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": 1, - "out_channels": 2, - "kernel_size": [3] * len(strides), - "strides": strides, - "upsample_kernel_size": strides[1:], - "norm_name": "group", - "deep_supervision": True, - "deep_supr_num": deep_supr_num, - "res_block": res_block, - }, - (1, 1, *[in_size] * spatial_dims), - (1, 1 + deep_supr_num, 2, *[in_size // scale] * spatial_dims), - ] - TEST_CASE_DEEP_SUPERVISION.append(test_case) - - -@skip_if_quick -class TestDynUNet(unittest.TestCase): - @parameterized.expand(TEST_CASE_DYNUNET_2D + TEST_CASE_DYNUNET_3D) - def test_shape(self, input_param, input_shape, expected_shape): - net = DynUNetV1(**input_param).to(device) - with eval_mode(net): - result = net(torch.randn(input_shape).to(device)) - self.assertEqual(result.shape, expected_shape) - - def test_script(self): - input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] - net = DynUNetV1(**input_param) - test_data = torch.randn(input_shape) - test_script_save(net, test_data) - - -class TestDynUNetDeepSupervision(unittest.TestCase): - @parameterized.expand(TEST_CASE_DEEP_SUPERVISION) - def test_shape(self, input_param, input_shape, expected_shape): - net = DynUNetV1(**input_param).to(device) - with torch.no_grad(): - results = net(torch.randn(input_shape).to(device)) - self.assertEqual(results.shape, expected_shape) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_network_consistency.py b/tests/test_network_consistency.py index 9698a40116..ccccd9e7f0 100644 --- a/tests/test_network_consistency.py +++ b/tests/test_network_consistency.py @@ -20,6 +20,7 @@ from parameterized.parameterized import parameterized import monai.networks.nets as nets +from monai.utils import set_determinism extra_test_data_dir = os.environ.get("MONAI_EXTRA_TEST_DATA", None) @@ -33,6 +34,12 @@ class TestNetworkConsistency(unittest.TestCase): + def setUp(self): + set_determinism(0) + + def tearDown(self): + set_determinism(None) + @skipIf( len(TESTS) == 0, "To run these tests, clone https://github.com/Project-MONAI/MONAI-extra-test-data and set MONAI_EXTRA_TEST_DATA",