In [1]:
# """
#    Author: Aaron Liu
#    Email: tl254@duke.edu
#    Created on: June 16 2021
#    Code structure reference: https://github.com/milesial/Pytorch-UNet
# """

# import torch
# import torch.nn as nn
# import torch.nn.functional as F


# class LevelBlock(nn.Module):
#     """(BN ==> ReLU ==> Conv) x 2"""

#     def __init__(
#         self,
#         in_channels,
#         out_channels,
#         cardinality=16,
#         stride=(1, 1),
#     ):
#         super().__init__()

#         self.stride = stride
#         self.quarter = in_channels // 4
#         self.max_pool = nn.MaxPool2d((2, 2))
#         self.activation = nn.ReLU(inplace=True)
#         self.use_maxpool = in_channels < out_channels
#         self.cardinality = cardinality
#         self.bn1 = nn.BatchNorm2d(3 * self.quarter)
#         self.stacked_blocks1 = nn.Conv2d(
#             3 * self.quarter,
#             out_channels - self.quarter,
#             kernel_size=3,
#             padding=1,
#             groups=self.cardinality,
#             stride=self.stride[0],
#         )
#         self.bn2 = nn.BatchNorm2d(out_channels - self.quarter)

#         self.stacked_blocks2 = nn.Conv2d(
#             out_channels - self.quarter,
#             out_channels - self.quarter,
#             kernel_size=3,
#             padding=1,
#             groups=self.cardinality,
#             stride=self.stride[1],
#         )

#     def forward(self, x):
#         # SCP design
#         part1, part2 = x[:, : self.quarter], x[:, self.quarter :]
#         x = self.stacked_blocks1(self.activation(self.bn1(part2)))
#         x = self.stacked_blocks2(self.activation(self.bn2(x)))
#         if self.use_maxpool:
#             part1 = self.max_pool(part1)
#         x = torch.cat([part1, x], dim=1)

#         return x


# class UpSamplingConcatenate(nn.Module):
#     """Upscaling"""

#     def __init__(self, in_channels, out_channels):
#         super().__init__()

#         self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

#     def forward(self, x1, x2):
#         x1 = self.up(x1)
#         diffY = x2.size()[2] - x1.size()[2]
#         diffX = x2.size()[3] - x1.size()[3]
#         x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])

#         x = torch.cat([x2, x1], dim=1)

#         return x

In [2]:
# class CSPResXUNet(nn.Module):
#     def __init__(self, n_channels, n_classes):
#         super(CSPResXUNet, self).__init__()
#         self.n_channels = n_channels
#         self.n_classes = n_classes

#         # Encoding
#         self.level1 = nn.Sequential(
#             nn.Conv2d(n_channels, 64, kernel_size=3, padding=1),
#             nn.BatchNorm2d(64),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(64, 64, kernel_size=3, padding=1),
#         )
#         self.level2 = LevelBlock(64, 128, stride=(2, 1))
#         self.level3 = LevelBlock(128, 256, stride=(2, 1))
#         self.level4 = LevelBlock(256, 512, stride=(2, 1))
#         self.level5 = LevelBlock(512, 256, stride=(1, 1))
#         self.level6 = LevelBlock(256, 128, stride=(1, 1))
#         self.level7 = LevelBlock(128, 64, stride=(1, 1))

#         self.up1 = UpSamplingConcatenate(512, 256)
#         self.up2 = UpSamplingConcatenate(256, 128)
#         self.up3 = UpSamplingConcatenate(128, 64)

#         self.shortcut1 = nn.Conv2d(n_channels, 64, kernel_size=1)
#         self.shortcut2 = nn.Conv2d(64, 128, kernel_size=1, stride=2)
#         self.shortcut3 = nn.Conv2d(128, 256, kernel_size=1, stride=2)
#         self.shortcut5 = nn.Conv2d(512, 256, kernel_size=1)
#         self.shortcut6 = nn.Conv2d(256, 128, kernel_size=1)
#         self.shortcut7 = nn.Conv2d(128, 64, kernel_size=1)

#         self.outconv = nn.Conv2d(64, n_classes, kernel_size=1)
#         self.sigmoid = nn.Sigmoid()

#     def forward(self, x):
#         # Encoding
#         x1 = self.level1(x)
#         x2_in = x1 + self.shortcut1(x)
#         x2 = self.level2(x2_in)
#         x3_in = x2 + self.shortcut2(x1)
#         x3 = self.level3(x3_in)
#         x4_in = x3 + self.shortcut3(x2)

#         # Bridge
#         x4 = self.level4(x4_in)

#         # Decoding
#         x_cat = self.up1(x4, x4_in)
#         x5 = self.level5(x_cat)
#         x_cat = self.up2(x5 + self.shortcut5(x_cat), x3_in)
#         x6 = self.level6(x_cat)
#         x_cat = self.up3(x6 + self.shortcut6(x_cat), x2_in)
#         x7 = self.level7(x_cat)
#         x = self.outconv(x7 + self.shortcut7(x_cat))
#         x = self.sigmoid(x)

#         return x

In [6]:
from cspresxunet import CSPResXUNet
from torchsummary import summary
model = CSPResXUNet(1, 1)

In [7]:
# x = torch.rand(1, 1, 112, 112)

In [8]:
# from resunet import ResUNet

# from unet import UNet

In [9]:
# model = ResUNet(1,1)
summary(model, input_size=(1, 112, 112))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]             640
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
            Conv2d-4         [-1, 64, 112, 112]          36,928
            Conv2d-5         [-1, 64, 112, 112]             128
       BatchNorm2d-6         [-1, 48, 112, 112]              96
              ReLU-7         [-1, 48, 112, 112]               0
            Conv2d-8          [-1, 112, 56, 56]           3,136
       BatchNorm2d-9          [-1, 112, 56, 56]             224
             ReLU-10          [-1, 112, 56, 56]               0
           Conv2d-11          [-1, 112, 56, 56]           7,168
        MaxPool2d-12           [-1, 16, 56, 56]               0
       LevelBlock-13          [-1, 128, 56, 56]               0
           Conv2d-14          [-1, 128,