In [65]:
from cspunet import CSPUNet
from torchsummary import summary

In [314]:
"""
   Author: Aaron Liu
   Email: tl254@duke.edu
   Created on: July 2 2021
   Code structure reference: https://github.com/milesial/Pytorch-UNet
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class FusedMBConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, se=True):
        super().__init__()
        self.se = se
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.activation = nn.ReLU(inplace=True)
        self.skip_connection = (
            self.stride == 1 and self.in_channels == self.out_channels
        )
        self.half = 0
        if self.skip_connection:
            self.half = self.in_channels // 2

        self.conv = nn.Conv2d(
            self.in_channels - self.half,
            4 * (self.in_channels - self.half),
            kernel_size=3,
            bias=False,
            padding=self.kernel_size // 2,
        )
        self.bn1 = nn.BatchNorm2d(4 * (self.in_channels - self.half))
        if self.se:
            squeezed_channels = 4 * (self.in_channels - self.half)
            self.reduce = nn.Conv2d(
                in_channels=4 * (self.in_channels - self.half),
                out_channels=squeezed_channels,
                kernel_size=1,
            )
            self.expand = nn.Conv2d(
                in_channels=squeezed_channels,
                out_channels=4 * (self.in_channels - self.half),
                kernel_size=1,
            )

        self.projection = nn.Conv2d(
            4 * (self.in_channels - self.half),
            self.out_channels - self.half,
            kernel_size=1,
            bias=False,
        )
        self.bn2 = nn.BatchNorm2d(self.out_channels - self.half)

    def forward(self, x):
        part1, part2 = x[:, : self.half], x[:, self.half :]
        out = self.activation(self.bn1(self.conv(part2)))

        # squeeze and excitation block
        if self.se:
            out_squeezed = F.adaptive_avg_pool2d(out, 1)
            out_squeezed = F.relu(self.reduce(out_squeezed))
            out_squeezed = self.expand(out_squeezed)
            out = torch.sigmoid(out_squeezed) * out

        out = self.bn2(self.projection(out))

        if self.skip_connection:
            out = torch.cat([out + part2, part1], dim=1)
        return out


class MBConv(nn.Module):
    def __init__(
        self, in_channels, out_channels, t=4, kernel_size=3, stride=1, se=True
    ):
        super().__init__()
        self.se = se
        self.t = t
        self.in_channels = in_channels
        self.mid_channels = self.in_channels * self.t
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.activation = nn.ReLU(inplace=True)
        self.skip_connection = (
            self.stride == 1 and self.in_channels == self.out_channels
        )
        self.half = 0
        if self.skip_connection:
            self.half = self.in_channels // 2
            self.expansion = nn.Conv2d(
                self.half, self.mid_channels - self.half, kernel_size=1, bias=False
            )
        else:
            self.expansion = nn.Conv2d(
                self.in_channels, self.mid_channels, kernel_size=1, bias=False
            )
        self.bn1 = nn.BatchNorm2d(self.mid_channels - self.half)
        self.dwise = nn.Conv2d(
            self.mid_channels - self.half,
            self.mid_channels - self.half,
            groups=self.mid_channels - self.half,
            kernel_size=self.kernel_size,
            bias=False,
            stride=self.stride,
            padding=self.kernel_size // 2,
        )
        self.bn2 = nn.BatchNorm2d(self.mid_channels - self.half)
        self.projection = nn.Conv2d(
            self.mid_channels - self.half,
            self.out_channels - self.half,
            kernel_size=1,
            bias=False,
        )
        self.bn3 = nn.BatchNorm2d(self.out_channels - self.half)
        if self.se:
            se_ratio = 0.25
            squeezed_channels = int((self.mid_channels - self.half) * se_ratio)
            self.reduce = nn.Conv2d(
                in_channels=self.mid_channels - self.half,
                out_channels=squeezed_channels,
                kernel_size=1,
            )
            self.expand = nn.Conv2d(
                in_channels=squeezed_channels,
                out_channels=self.mid_channels - self.half,
                kernel_size=1,
            )

    def forward(self, x):

        part1, part2 = x[:, : self.half], x[:, self.half :]
        out = self.activation(self.bn1(self.expansion(part2)))
        out = self.activation(self.bn2(self.dwise(out)))

        # squeeze and excitation block
        if self.se:
            out_squeezed = F.adaptive_avg_pool2d(out, 1)
            out_squeezed = F.relu(self.reduce(out_squeezed))
            out_squeezed = self.expand(out_squeezed)
            out = torch.sigmoid(out_squeezed) * out

        out = self.bn3(self.projection(out))

        if self.skip_connection:
            out = torch.cat([out + part2, part1], dim=1)
        return out


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(mid_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.shortcut = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1)
        )

    def forward(self, x):
        
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        return out + self.shortcut(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()

        #         self.half = out_channels
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = DoubleConv(2 * out_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)

        return self.conv(x)

In [315]:
class CSPUNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(CSPUNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        # Encoding
        self.stem = nn.Sequential(
            nn.Conv2d(n_channels, 24, kernel_size=3, padding=1, bias=False, stride=2),
            nn.BatchNorm2d(24),
            nn.ReLU(inplace=True),
        )
        self.down1 = nn.Sequential(*(2 * [FusedMBConv(24, 24, stride=1)]))
        self.down2 = nn.Sequential(
            *(3 * [FusedMBConv(24, 24)] + [FusedMBConv(24, 48, stride=2)])
        )
        self.down3 = nn.Sequential(
            *(3 * [FusedMBConv(48, 48)] + [FusedMBConv(48, 64, stride=2)])
        )
        self.down4 = nn.Sequential(
            *(5 * [MBConv(64, 64, t=4)] + [MBConv(64, 128, stride=2, t=4)])
        )
        self.down5 = nn.Sequential(
            *(8 * [MBConv(128, 128, t=6)] + [MBConv(128, 160, stride=1, t=6)])
        )
        self.down6 = nn.Sequential(
            *(14 * [MBConv(160, 160, t=6)] + [MBConv(160, 256, stride=2, t=6)])
        )

        self.up1 = Up(256, 160)
        self.up2 = Up(160, 128)
        self.up3 = Up(128, 64)
        self.up4 = Up(64, 48)
        self.up5 = Up(48, 24)
        self.up6 = Up(24, 24)

        self.outconv = nn.ConvTranspose2d(24, n_classes, kernel_size=2, stride=2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Encoding
        x0 = self.stem(x)
        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)
        x6 = self.down6(x5)

        # Decoding
        x = self.up1(x6, x5)
        x = self.up2(x, x4)
        x = self.up3(x, x3)
        x = self.up4(x, x2)
        x = self.up5(x, x1)
        x = self.up6(x, x0)

        # Head
        x = self.sigmoid(self.outconv(x))

        return x

In [316]:
import torch

model = CSPUNet(1, 1)
x = torch.rand(2, 1, 112, 112)
out = model(x)

In [317]:
out.shape

torch.Size([2, 1, 112, 112])

In [318]:
summary(model, input_size=(1, 112, 112))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 24, 56, 56]             216
       BatchNorm2d-2           [-1, 24, 56, 56]              48
              ReLU-3           [-1, 24, 56, 56]               0
            Conv2d-4           [-1, 48, 56, 56]           5,184
       BatchNorm2d-5           [-1, 48, 56, 56]              96
              ReLU-6           [-1, 48, 56, 56]               0
            Conv2d-7             [-1, 48, 1, 1]           2,352
            Conv2d-8             [-1, 48, 1, 1]           2,352
            Conv2d-9           [-1, 12, 56, 56]             576
      BatchNorm2d-10           [-1, 12, 56, 56]              24
      FusedMBConv-11           [-1, 24, 56, 56]               0
           Conv2d-12           [-1, 48, 56, 56]           5,184
      BatchNorm2d-13           [-1, 48, 56, 56]              96
             ReLU-14           [-1, 48,