Skip to content

Commit

Permalink
Merge branch '813-AHNet' of github.com:yiheng-wang-nv/MONAI into 813-…
Browse files Browse the repository at this point in the history
…AHNet
  • Loading branch information
yiheng-wang-nv committed Aug 4, 2020
2 parents a0981cd + f6790a2 commit 46902bd
Showing 1 changed file with 30 additions and 15 deletions.
45 changes: 30 additions & 15 deletions monai/networks/nets/ahnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.blocks.convolutions import Convolution
from monai.networks.layers.factories import Act, Conv, Norm, Pool


Expand All @@ -33,17 +34,38 @@ def __init__(

super(Bottleneck3x3x1, self).__init__()

conv3d_type: Type[nn.Conv3d] = Conv[Conv.CONV, 3]
norm3d_type: Type[nn.BatchNorm3d] = Norm[Norm.BATCH, 3]
relu_type: Type[nn.ReLU] = Act[Act.RELU]
pool3d_type: Type[nn.MaxPool3d] = Pool[Pool.MAX, 3]

self.conv1 = conv3d_type(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = norm3d_type(planes)
self.conv2 = conv3d_type(planes, planes, kernel_size=(3, 3, 1), stride=stride, padding=(1, 1, 0), bias=False,)
self.bn2 = norm3d_type(planes)
self.conv3 = conv3d_type(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = norm3d_type(planes * 4)
self.conv1 = Convolution(
dimensions=3,
in_channels=inplanes,
out_channels=planes,
kernel_size=1,
act=("relu", {"inplace": True}),
norm=Norm.BATCH,
bias=False,
)
self.conv2 = Convolution(
dimensions=3,
in_channels=planes,
out_channels=planes,
strides=stride,
kernel_size=(3, 3, 1),
act=("relu", {"inplace": True}),
norm=Norm.BATCH,
bias=False,
)
self.conv3 = Convolution(
dimensions=3,
in_channels=planes,
out_channels=planes * 4,
kernel_size=1,
act=None,
norm=Norm.BATCH,
bias=False,
)

self.relu = relu_type(inplace=True)
self.downsample = downsample
self.stride = stride
Expand All @@ -53,15 +75,8 @@ def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
residual = self.downsample(x)
Expand Down

0 comments on commit 46902bd

Please sign in to comment.