diff --git a/monai/networks/nets/segresnet.py b/monai/networks/nets/segresnet.py index b722e5d70f..d7a4692f98 100644 --- a/monai/networks/nets/segresnet.py +++ b/monai/networks/nets/segresnet.py @@ -235,6 +235,7 @@ def __init__( in_channels=in_channels, out_channels=out_channels, dropout_prob=dropout_prob, + act=act, norm=norm, use_conv_final=use_conv_final, blocks_down=blocks_down, @@ -318,25 +319,11 @@ def _get_vae_loss(self, net_input: torch.Tensor, vae_input: torch.Tensor): def forward(self, x): net_input = x - x = self.convInit(x) - if self.dropout_prob is not None: - x = self.dropout(x) - - down_x = [] - for down in self.down_layers: - x = down(x) - down_x.append(x) - + x, down_x = self.encode(x) down_x.reverse() vae_input = x - - for i, (up, upl) in enumerate(zip(self.up_samples, self.up_layers)): - x = up(x) + down_x[i + 1] - x = upl(x) - - if self.use_conv_final: - x = self.conv_final(x) + x = self.decode(x, down_x) if self.training: vae_loss = self._get_vae_loss(net_input, vae_input) diff --git a/tests/test_segresnet.py b/tests/test_segresnet.py index ea6ca5b5dd..eada35ed66 100644 --- a/tests/test_segresnet.py +++ b/tests/test_segresnet.py @@ -70,6 +70,7 @@ "init_filters": init_filters, "out_channels": out_channels, "upsample_mode": upsample_mode, + "act": ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), "input_image_size": ([16] * spatial_dims), "vae_estimate_std": vae_estimate_std, },