diff --git a/monai/networks/nets/ahnet.py b/monai/networks/nets/ahnet.py index 294d1c50b41..655ad68d38a 100644 --- a/monai/networks/nets/ahnet.py +++ b/monai/networks/nets/ahnet.py @@ -16,6 +16,7 @@ import torch.nn as nn import torch.nn.functional as F +from monai.networks.blocks.convolutions import Convolution from monai.networks.layers.factories import Act, Conv, Norm, Pool @@ -33,17 +34,38 @@ def __init__( super(Bottleneck3x3x1, self).__init__() - conv3d_type: Type[nn.Conv3d] = Conv[Conv.CONV, 3] - norm3d_type: Type[nn.BatchNorm3d] = Norm[Norm.BATCH, 3] relu_type: Type[nn.ReLU] = Act[Act.RELU] pool3d_type: Type[nn.MaxPool3d] = Pool[Pool.MAX, 3] - self.conv1 = conv3d_type(inplanes, planes, kernel_size=1, bias=False) - self.bn1 = norm3d_type(planes) - self.conv2 = conv3d_type(planes, planes, kernel_size=(3, 3, 1), stride=stride, padding=(1, 1, 0), bias=False,) - self.bn2 = norm3d_type(planes) - self.conv3 = conv3d_type(planes, planes * 4, kernel_size=1, bias=False) - self.bn3 = norm3d_type(planes * 4) + self.conv1 = Convolution( + dimensions=3, + in_channels=inplanes, + out_channels=planes, + kernel_size=1, + act=("relu", {"inplace": True}), + norm=Norm.BATCH, + bias=False, + ) + self.conv2 = Convolution( + dimensions=3, + in_channels=planes, + out_channels=planes, + strides=stride, + kernel_size=(3, 3, 1), + act=("relu", {"inplace": True}), + norm=Norm.BATCH, + bias=False, + ) + self.conv3 = Convolution( + dimensions=3, + in_channels=planes, + out_channels=planes * 4, + kernel_size=1, + act=None, + norm=Norm.BATCH, + bias=False, + ) + self.relu = relu_type(inplace=True) self.downsample = downsample self.stride = stride @@ -53,15 +75,8 @@ def forward(self, x): residual = x out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - out = self.conv3(out) - out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x)