**UNet 3+: A Full-Scale Connected UNet for Medical Image Segmentation**    
*Huimin Huang, Lanfen Lin, Ruofeng Tong, Hongjie Hu, Qiaowei Zhang, Yutaro Iwamoto, Xianhua Han, Yen-Wei Chen, Jian Wu*   
[[paper](https://arxiv.org/abs/2004.08790)]   
ICASSP(IEEE) 2020   

In [1]:
import torch
import torch.nn as nn

In [3]:
class ConvLayer(nn.Module):
    def __init__(self, in_dim, out_dim,
                    conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.ReLU) -> None:
        super(ConvLayer, self).__init__()

        if norm_type is None:
            self.conv = nn.Sequential(
                    conv_type(in_dim, out_dim, kernel_size=3, stride=1, padding=1),
                    act_type()
                )
        else:
            self.conv = nn.Sequential(
                    conv_type(in_dim, out_dim, kernel_size=3, stride=1, padding=1),
                    norm_type(out_dim),
                    act_type()
                )

        
    def forward(self, inputs):
        return self.conv(inputs)

In [4]:
class ConvBlock(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=None, num_conv=2,
                    conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.ReLU) -> None:
        assert num_conv > 0

        super(ConvBlock, self).__init__()

        if hidden_dim is None:
            hidden_dim = out_dim

        if num_conv == 1:
            self.blocks = ConvLayer(in_dim, out_dim,
                                    conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU)
        else:
            self.blocks = nn.Sequential(
                    *(
                        [ConvLayer(in_dim, hidden_dim,
                                    conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU)]
                        + [ConvLayer(hidden_dim, hidden_dim, 
                                     conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU) for _ in range(num_conv - 2)]
                        + [ConvLayer(hidden_dim, out_dim,
                                    conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU)]
                    )
                )

    def forward(self, inputs):
        return self.blocks(inputs)

In [None]:
class UpsamplingLayer(nn.Module):
    def __init__(self, in_dim, out_dim, is_deconv=True, mode='bilinear') -> None:
        super(UpsamplingLayer, self).__init__()

        if is_deconv:
            self.upsampler = nn.ConvTranspose2d(in_dim, out_dim, kernel_size=2, stride=2, padding=0)
        else:
            self.upsampler = nn.Upsample(scale_factor=2, mode=mode)

    def forward(self, x):
        return self.upsampler(x)

In [5]:
class Encoder(nn.Module):
    def __init__(self) -> None:
        super(Encoder, self).__init__()

        self.dim = 32

        self.conv1 = ConvBlock(1, self.dim, num_conv=2)
        self.conv2 = ConvBlock(self.dim, self.dim*2, num_conv=2)
        self.conv3 = ConvBlock(self.dim*2, self.dim*4, num_conv=2)
        self.conv4 = ConvBlock(self.dim*4, self.dim*8, num_conv=2)
        self.conv5 = ConvBlock(self.dim*8, self.dim*16,num_conv=2)

        self.pool  = nn.MaxPool2d(2,2)

    def forward(self, inputs):

        conv1_out = self.conv1(inputs)
        h1 = self.pool(conv1_out)

        conv2_out = self.conv2(h1)
        h2 = self.pool(conv2_out)

        conv3_out = self.conv3(h2)
        h3 = self.pool(conv3_out)

        conv4_out = self.conv4(h3)
        h4 = self.pool(conv4_out)

        conv5_out = self.conv5(h4)
        h5 = self.pool(conv5_out)

        stage_outputs = [conv1_out, conv2_out, conv3_out, conv4_out]

        return h5, stage_outputs

In [None]:
class VGG16_backbone(nn.modules):
    def __init__(self) -> None:
        super(VGG16_backbone, self).__init__()
        import torchvision.models as models

        self.CNN_encoder = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features # weights=VGG16_Weights.IMAGENET1K_V1

    def vgg_layer_forward(self, x, indices):
        output = x
        start_idx, end_idx = indices
        for idx in range(start_idx, end_idx):
            if idx == (end_idx-1):
                pooling = self.CNN_encoder[idx](output)
            else:
                output = self.CNN_encoder[idx](output)
        return pooling, output

    def vgg_forward(self, x):
        out = {}
        depth = 5
        layer_indices = [0, 5, 10, 15, 20, 24] # 
        for layer_num in range(len(depth)-1):
            pooling, output = self.vgg_layer_forward(x, layer_indices[layer_num:layer_num+2])
            out[f'pool{layer_num+1}'] = pooling
            out[f'conv{layer_num+1}'] = output
        return out

    def forward(self, inputs):

        vgg_enc_out = self.CNN_encoder(inputs)

        vgg_conv1 = vgg_enc_out['conv1'].detach()
        vgg_conv2 = vgg_enc_out['conv2'].detach()
        vgg_conv3 = vgg_enc_out['conv3'].detach()
        vgg_conv4 = vgg_enc_out['conv4'].detach()
        vgg_conv5 = vgg_enc_out['conv5'].detach()

        stage_outputs = [vgg_conv4, vgg_conv3, vgg_conv2, vgg_conv1]

        return vgg_conv5, stage_outputs

In [None]:
class ResNet101_backbone(nn.modules):
    def __init__(self) -> None:
        super(ResNet101_backbone, self).__init__()
        import torchvision.models as models

        self.ResNet101 = models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
        self.conv1     = nn.Sequential(
                            self.ResNet101.conv1, # (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
                            self.ResNet101.bn1,   # (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                            self.ResNet101.relu   # (relu): ReLU(inplace=True)
                        )
        self.init_pool = self.ResNet101.maxpool # (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

        self.conv2     = self.ResNet101.layer1
        self.conv3     = self.ResNet101.layer2
        self.conv4     = self.ResNet101.layer3
        self.conv5     = self.ResNet101.layer4

    def forward(self, inputs):

        h1 = self.conv1(inputs)
        p1 = self.init_pool(h1)

        h2 = self.conv2(p1)
        h3 = self.conv3(h2)
        h4 = self.conv4(h3)
        h5 = self.conv5(h4)

        stage_outputs = [h4, h3, h2, h1]

        return h5, stage_outputs

In [None]:
class FullScale_SkipConnection(nn.Module):
    def __init__(self, enc_init_dim=64, skip_hidden_dim=64, decoding_dim=320, dec_depth=4, level=0) -> None:
        assert level > 0
        super(FullScale_SkipConnection, self).__init__()

        
        self.feature_aggregator = nn.ModuleList([])


        # X5_en (dim=enc_init_dim* # pooling(=dec_depth))
        self.feature_aggregator.append(
            nn.Sequential(
                UpsamplingLayer(-1, -1, is_deconv=False, mode='bilinear'),
                ConvLayer(enc_init_dim*2**dec_depth, skip_hidden_dim, norm_type=None)
            )
        )
        for L in range(dec_depth, 0, -1): # l, l-1, ... , 1
            
            if L > level: # lower level (needed to up scale)
                self.feature_aggregator.append(
                    nn.Sequential(
                        UpsamplingLayer(-1, -1, is_deconv=False, mode='bilinear'),
                        ConvLayer(decoding_dim, skip_hidden_dim, norm_type=None)
                    )
                )
            elif L == level: # same level
                self.feature_aggregator.append(
                    ConvLayer(enc_init_dim*2**(L-1), skip_hidden_dim, norm_type=None) # norm_type=None -> weight-ReLU
                )
            elif L < level: # upper level (needed to down sampling)
                self.feature_aggregator.append(
                    nn.Sequential(
                        nn.MaxPool2d(kernel_size=dec_depth//L, stride=dec_depth//L), # if level=1 -> Maxpooling(4,4)
                        ConvLayer(enc_init_dim*2**(L-1), skip_hidden_dim, norm_type=None)
                    )
                )

    def forward(self, stage_outputs):

        
        skip_out1 = self.feature_aggregator[0](stage_outputs[0]) # x5_de
        skip_out2 = self.feature_aggregator[1](stage_outputs[1])
        skip_out3 = self.feature_aggregator[2](stage_outputs[2])
        skip_out4 = self.feature_aggregator[3](stage_outputs[3])
        skip_out5 = self.feature_aggregator[4](stage_outputs[4]) # x1_en

        skip_out = [skip_out1, skip_out2, skip_out3, skip_out4, skip_out5]
        skip_out = torch.concat(skip_out, dim=1)

        return skip_out


In [None]:
class Decoder(nn.Module):
    def __init__(self) -> None:
        super(Decoder, self).__init__()

        self.decoding_dim = 320

        self.skip1 = FullScale_SkipConnection(dec_depth=4, level=4)
        self.conv1 = ConvBlock(self.decoding_dim, self.decoding_dim, num_conv=1)

        self.skip2 = FullScale_SkipConnection(dec_depth=4, level=3)
        self.conv2 = ConvBlock(self.decoding_dim, self.decoding_dim, num_conv=1)

        self.skip3 = FullScale_SkipConnection(dec_depth=4, level=2)
        self.conv3 = ConvBlock(self.decoding_dim, self.decoding_dim, num_conv=1)

        self.skip4 = FullScale_SkipConnection(dec_depth=4, level=1)
        self.conv4 = ConvBlock(self.decoding_dim, self.decoding_dim, num_conv=1)

        # deep supervision 구현 안 함    
        self.dsv1  = nn.Sequential(
                        nn.Upsample(scale_factor=16, mode='bilinear'),
                        nn.Conv2d(1024, 2, kernel_size=1, stride=1, padding=0)
                    )
        self.dsv2  = nn.Sequential(
                        nn.Upsample(scale_factor=8, mode='bilinear'),
                        nn.Conv2d(self.decoding_dim, 2, kernel_size=1, stride=1, padding=0)
                    )
        self.dsv3  = nn.Sequential(
                        nn.Upsample(scale_factor=4, mode='bilinear'),
                        nn.Conv2d(self.decoding_dim, 2, kernel_size=1, stride=1, padding=0)
                    )
        self.dsv4  = nn.Sequential(
                        nn.Upsample(scale_factor=2, mode='bilinear'),
                        nn.Conv2d(self.decoding_dim, 2, kernel_size=1, stride=1, padding=0)
                    )
        self.dsv5  = nn.Conv2d(self.decoding_dim, 2, kernel_size=1, stride=1, padding=0)

        
    def forward(self, enc_out, stage_outputs, organ_flag=True):

        stage_outputs = [enc_out] + stage_outputs
        skip_out = self.skip1(stage_outputs) 
        x4_de    = self.conv1(skip_out) # x4_de

        stage_outputs[1] = x4_de
        skip_out = self.skip2(stage_outputs)
        x3_de    = self.conv2(skip_out)

        stage_outputs[2] = x3_de
        skip_out = self.skip3(stage_outputs)
        x2_de    = self.conv3(skip_out)

        stage_outputs[3] = x2_de
        skip_out = self.skip4(x2_de, stage_outputs)
        x1_de    = self.conv4(skip_out)
        
        x4_de    *= organ_flag
        x3_de    *= organ_flag
        x2_de    *= organ_flag
        x1_de    *= organ_flag

        dsv1_out = self.dsv1(enc_out)
        dsv2_out = self.dsv2(x4_de)
        dsv3_out = self.dsv3(x3_de)
        dsv4_out = self.dsv4(x2_de)
        out = self.dsv5(x1_de)

        return out, dsv4_out, dsv3_out, dsv2_out, dsv1_out 


In [None]:
class UNet3p(nn.Module):
    def __init__(self) -> None:
        super(UNet3p, self).__init__()

        self.encoder = Encoder()
        # self.encoder = VGG16_backbone()
        # self.encoder = ResNet101_backbone()
        self.decoder = Decoder()
        self.classification_guide = nn.Sequential(
                                        nn.Dropout2d(),
                                        nn.Conv2d(1024, 2, kernel_size=1, stride=1, padding=0),
                                        nn.AdaptiveAvgPool2d(1),
                                        nn.Sigmoid()
                                    )

    def forward(self, inputs):

        enc_out, stage_outputs = self.encoder(inputs)
        organ_flag = torch.argmax(self.classification_guide(enc_out), dim=1)
        out, dsv4_out, dsv3_out, dsv2_out, dsv1_out = self.Decoder(enc_out, stage_outputs, organ_flag)

        return out, dsv4_out, dsv3_out, dsv2_out, dsv1_out