**DeepSEED: 3D Squeeze-and-Excitation Encoder-Decoder Convolutional Neural Networks for Pulmonary Nodule Detection**    
*Yuemeng Li, Yong Fan*   
[[paper](https://arxiv.org/abs/1904.03501)]   
ISBI 2020   

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class SqueezeExcitation3D(nn.Module):
    def __init__(self, in_dim, reduction_ratio=4) -> None:
        super(SqueezeExcitation3D, self).__init__()

        self.squeeze = nn.AdaptiveAvgPool3d(1) # 1x1x1xC

        self.excitation = nn.Sequential(
            nn.Conv3d(in_channels=in_dim, out_channels=in_dim//reduction_ratio, kernel_size=1, stride=1, bias=False),
            nn.SiLU(),
            nn.Conv3d(in_channels=in_dim//reduction_ratio, out_channels=in_dim, kernel_size=1, stride=1, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):

        se_out = self.squeeze(x)
        se_out = self.excitation(se_out)

        return x * se_out

In [4]:
class ResBlock3D_SE(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size=3, stride=1, padding=1):
        super(ResBlock3D_SE, self).__init__()

        self.silu = nn.SiLU()
        self.conv = nn.Sequential(
            nn.Conv3d(in_dim, out_dim, kernel_size, stride, padding, bias=False),
            nn.BatchNorm3d(out_dim),
            nn.SiLU(inplace=True),
            nn.Conv3d(out_dim, out_dim, kernel_size, 1, padding, bias=False),
            nn.BatchNorm3d(out_dim),   
        )

        self.se = SqueezeExcitation3D(out_dim, reduction_ratio=4)
        

        if stride != 1 or in_dim != out_dim:
            self.downsample = nn.Sequential(
                nn.Conv3d(in_dim, out_dim, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm3d(out_dim)
            )
        else:
            self.downsample = nn.Identity()

    def forward(self, x):
        residual = self.downsample(x)

        out = self.conv(x)
        out += residual
        out = self.silu(out)

        out = self.se(out)
        out += residual
        out = self.silu(out)

        return out

In [5]:
class ResNet_Encoder(nn.Module):
    def __init__(self, in_dim, hidden_dim) -> None:
        super(ResNet_Encoder).__init__()

        self.init_conv = nn.Sequential(
            nn.Conv3d(in_channels=in_dim, out_channels=hidden_dim[0], kernel_size=7, stride=2, padding=3),
            nn.BatchNorm3d(hidden_dim[0]),
            nn.SiLU()
        )

        self.max_pool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)

        self.layer1 = nn.Sequential(
            ResBlock3D_SE(hidden_dim[0], hidden_dim[1], kernel_size=3, stride=2, padding=1),
            ResBlock3D_SE(hidden_dim[1], hidden_dim[1])
        )

        self.layer2 = nn.Sequential(
            ResBlock3D_SE(hidden_dim[1], hidden_dim[2], kernel_size=3, stride=2, padding=1),
            ResBlock3D_SE(hidden_dim[2], hidden_dim[2])
        )

        self.layer3 = nn.Sequential(
            ResBlock3D_SE(hidden_dim[2], hidden_dim[3], kernel_size=3, stride=2, padding=1),
            ResBlock3D_SE(hidden_dim[3], hidden_dim[3])
        )

        self.layer4 = nn.Sequential(
            ResBlock3D_SE(hidden_dim[3], hidden_dim[4], kernel_size=3, stride=2, padding=1),
            ResBlock3D_SE(hidden_dim[4], hidden_dim[4])
        )

    def forward(self, x):

        h1 = self.init_conv(x)
        h1 = self.max_pool(h1)

        h2 = self.layer1(h1)
        h3 = self.layer2(h2)
        h4 = self.layer3(h3)
        h5 = self.layer4(h4)

        stage_outputs = [h1, h2, h3, h4]

        return h5, stage_outputs

In [None]:
class ResNet_Decoder(nn.Module):
    def __init__(self, out_dim, hidden_dim) -> None:
        super(ResNet_Decoder, self).__init__()


        # h5(+)h4
        self.up1 = nn.Upsample(scale_factor=2, mode='trilinear')
        self.conv1 = nn.Sequential(
            ResBlock3D_SE(hidden_dim[4] + hidden_dim[3], hidden_dim[4]),
            ResBlock3D_SE(hidden_dim[4], hidden_dim[3])
        )

        # h4(+)h3
        self.up2 = nn.Upsample(scale_factor=2, mode='trilinear')
        self.conv2 = nn.Sequential(
            ResBlock3D_SE(hidden_dim[3] + hidden_dim[2], hidden_dim[3]),
            ResBlock3D_SE(hidden_dim[3], hidden_dim[2])
        )
        
        # h3(+)h2
        self.up3 = nn.Upsample(scale_factor=2, mode='trilinear')
        self.conv3 = nn.Sequential(
            ResBlock3D_SE(hidden_dim[2] + hidden_dim[1], hidden_dim[2]),
            ResBlock3D_SE(hidden_dim[2], hidden_dim[1])
        )

        # h2(+)h1
        self.up4 = nn.Upsample(scale_factor=2, mode='trilinear')
        self.conv4 = nn.Sequential(
            ResBlock3D_SE(hidden_dim[1] + hidden_dim[0], hidden_dim[1]),
            ResBlock3D_SE(hidden_dim[1], hidden_dim[0])
        )

        self.up5 = nn.Upsample(size=(240,240,155), mode='trilinear')
        self.conv5 = nn.Sequential(
            ResBlock3D_SE(hidden_dim[0], hidden_dim[0]),
            ResBlock3D_SE(hidden_dim[0], out_dim=out_dim)
        )

    def forward(self, enc_out, stage_outputs):

        h1 = self.up1(h1)
        h1 = torch.concat([h1, stage_outputs[-1]], dim=1)
        h1 = self.conv1(h1)

        h2 = self.up2(h1)
        h2 = torch.concat([h2, stage_outputs[-2]], dim=1)
        h2 = self.conv2(h2)

        h3 = self.up3(h2)
        h3 = torch.concat([h3, stage_outputs[-3]], dim=1)
        h3 = self.conv3(h3)

        h4 = self.up4(h3)
        h4 = torch.concat([h4, stage_outputs[-4]], dim=1)
        h4 = self.conv4(h4)

        h5 = self.up5(h4)
        h5 = self.conv5(h5)

        return h5

In [None]:
class DeepSEED(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim:list) -> None:
        super().__init__()

        self.encoder = ResNet_Encoder(in_dim, hidden_dim=hidden_dim)
        self.decoder = ResNet_Decoder(out_dim, hidden_dim=hidden_dim)

    def forward(self, x):

        enc_out, stage_outputs = self.encoder(x)
        out = self.decoder(enc_out, stage_outputs)

        return out