**UNet++: A Nested U-Net Architecture for Medical Image Segmentation**    
*Zongwei Zhou, Md Mahfuzur Rahman Siddiquee, Nima Tajbakhsh, Jianming Liang*   
[[paper](https://arxiv.org/abs/1807.10165)]   
DLMIA 2018   

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


from easydict import EasyDict as edict

args = edict() 

# net dim 
args.in_dim     = 1 
args.init_dim   = 64
args.enc_depth  = 5
args.net_dim    = [args.init_dim*2**x for x in range(args.enc_depth)] # [64, 128, 256, 512, 1024]
args.out_dim    = 2

# net operator
conv_kwargs = edict()
conv_kwargs.kernel_size = (3,3,3)
conv_kwargs.stride      = (1,1,1)
conv_kwargs.padding     = 1

down_kwargs = edict()
down_kwargs.kernel_size = (2,2,2)
down_kwargs.stride      = (2,2,2)
down_kwargs.padding     = 0

up_kwargs = edict()
up_kwargs.kernel_size = (2,2,2)
up_kwargs.stride      = (2,2,2)
up_kwargs.padding     = 0
up_kwargs.scale       = 2

args.conv_kwargs = conv_kwargs
args.down_kwargs = down_kwargs
args.up_kwargs   = up_kwargs

In [None]:
class ConvLayer(nn.Module):
    def __init__(self, in_dim, out_dim,
                    conv_type=nn.Conv2d, conv_kwargs=None,
                    norm_type=nn.BatchNorm2d, 
                    act_type=nn.LeakyReLU) -> None:
        super(ConvLayer, self).__init__()

        if conv_kwargs is None:
            # net operator
            conv_kwargs = edict()
            conv_kwargs.kernel_size = 3
            conv_kwargs.stride      = 1
            conv_kwargs.padding     = 1


        self.conv = nn.Sequential(
                conv_type(in_dim, out_dim, **conv_kwargs),
                norm_type(out_dim),
                act_type()
            )

    def forward(self, inputs):
        return self.conv(inputs)

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=None, 
                    num_conv=2,
                    conv_type=nn.Conv2d, conv_kwargs=None,
                    norm_type=nn.BatchNorm2d, 
                    act_type=nn.LeakyReLU) -> None:
        assert num_conv > 0

        super(ConvBlock, self).__init__()

        if hidden_dim is None:
            hidden_dim = out_dim

        if conv_kwargs is None:
            # net operator
            conv_kwargs = edict()
            conv_kwargs.kernel_size = 3
            conv_kwargs.stride      = 1
            conv_kwargs.padding     = 1

        if num_conv == 1:
            self.blocks = ConvLayer(in_dim, out_dim,
                                    conv_type=conv_type, norm_type=norm_type, act_type=nn.LeakyReLU)
        else:
            self.blocks = nn.Sequential(
                    *(
                        [ConvLayer(in_dim, hidden_dim,
                                    conv_type=conv_type, conv_kwargs=conv_kwargs, norm_type=norm_type, act_type=act_type)]
                        + [ConvLayer(hidden_dim, hidden_dim, 
                                     conv_type=conv_type, conv_kwargs=conv_kwargs, norm_type=norm_type, act_type=act_type) for _ in range(num_conv - 2)]
                        + [ConvLayer(hidden_dim, out_dim,
                                    conv_type=conv_type, conv_kwargs=conv_kwargs, norm_type=norm_type, act_type=act_type)]
                    )
                )

    def forward(self, inputs):
        return self.blocks(inputs)

In [None]:
class UpsamplingLayer(nn.Module):
    def __init__(self, in_dim, out_dim, is_deconv=True, mode='bilinear', up_kwargs=None) -> None:
        super(UpsamplingLayer, self).__init__()

        if up_kwargs is None:
            up_kwargs = edict()
            up_kwargs.kernel_size = (2,2,2)
            up_kwargs.stride      = (2,2,2)
            up_kwargs.padding     = 0
            up_kwargs.scale       = 2

        if is_deconv:
            self.upsampler = nn.ConvTranspose2d(in_dim, out_dim, **up_kwargs)
        else:
            self.upsampler = nn.Upsample(scale_factor=up_kwargs.scale, mode=mode)

    def forward(self, x):
        return self.upsampler(x)

In [None]:
class Encoder(nn.Module):
    def __init__(self, args=None) -> None:
        super(Encoder, self).__init__()

        if args is None:
            args = edict() 
            # net dim 
            args.in_dim     = 1 
            args.net_dim    = [64, 128, 256, 512, 1024]

            # net operator
            conv_kwargs = edict()
            conv_kwargs.kernel_size = 3
            conv_kwargs.stride      = 1
            conv_kwargs.padding     = 1

            down_kwargs = edict()
            down_kwargs.kernel_size = 2
            down_kwargs.stride      = 2
            down_kwargs.padding     = 0

            args.conv_kwargs = conv_kwargs
            args.down_kwargs = down_kwargs

            

        self.conv1 = ConvBlock(args.in_dim, args.net_dim[0], num_conv=2, conv_kwargs=args.conv_kwargs)
        self.conv2 = ConvBlock(args.net_dim[0], args.net_dim[1], num_conv=2, conv_kwargs=args.conv_kwargs)
        self.conv3 = ConvBlock(args.net_dim[1], args.net_dim[2], num_conv=2, conv_kwargs=args.conv_kwargs)
        self.conv4 = ConvBlock(args.net_dim[2], args.net_dim[3], num_conv=2, conv_kwargs=args.conv_kwargs)
        self.conv5 = ConvBlock(args.net_dim[3], args.net_dim[4],num_conv=2, conv_kwargs=args.conv_kwargs)

        self.pool  = nn.MaxPool2d(**args.down_kwargs)

    def forward(self, inputs):

        conv1_out = self.conv1(inputs)
        h1 = self.pool(conv1_out)

        conv2_out = self.conv2(h1)
        h2 = self.pool(conv2_out)

        conv3_out = self.conv3(h2)
        h3 = self.pool(conv3_out)

        conv4_out = self.conv4(h3)
        h4 = self.pool(conv4_out)

        conv5_out = self.conv5(h4)
        h5 = self.pool(conv5_out)

        stage_outputs = [conv1_out, conv2_out, conv3_out, conv4_out]

        return h5, stage_outputs

In [None]:
class SkipPathways(nn.Module):
    def __init__(self, in_dim, out_dim, lower_dim, path_length=1) -> None:
        super(SkipPathways, self).__init__()

        self.path_length = path_length

        self.conv_path = nn.ModuleList([])
        self.upSamplers = nn.ModuleList([])

        self.upSamplers.append() 
        for idx in range(path_length): 
            self.conv_path.append(ConvLayer(in_dim*idx, out_dim,
                             conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU))
            self.upSamplers.append(UpsamplingLayer(lower_dim, out_dim, is_deconv=True))
        
    def forward(self, xi_0, lowers=[]):
        '''
        xi_0: the same level encoder output (Xi_0)
        lowers: the 1 lower level encoder and pathways outputs (X_{i-1}j)
        '''
        assert self.path_length == (len(lowers))

        xi_j = []
        xi_j.append(xi_0)

        id_x = xi_0 
        xi_inter = xi_0
        for idx in range(self.path_length):
            upsampled_lower = self.upSamplers[idx](lowers[idx])
            concated = torch.concat((xi_inter, upsampled_lower), dim=1)
            xi_inter = self.conv_path[idx](concated)
            xi_j.append(xi_inter)
            id_x = torch.concat(id_x, xi_inter) #xi_1 + xi_0

        return id_x, xi_j 

In [None]:
class Decoder(nn.Module):
    def __init__(self, args=None) -> None:
        super(Decoder, self).__init__()

        if args is None:
            args = edict() 

            # net dim 
            args.net_dim    = [64, 128, 256, 512, 1024]
            args.out_dim    = 2

            # net operator
            conv_kwargs = edict()
            conv_kwargs.kernel_size = (3,3,3)
            conv_kwargs.stride      = (1,1,1)
            conv_kwargs.padding     = 1

            up_kwargs = edict()
            up_kwargs.kernel_size = (2,2,2)
            up_kwargs.stride      = (2,2,2)
            up_kwargs.padding     = 0
            up_kwargs.scale       = 2

            args.conv_kwargs = conv_kwargs
            args.up_kwargs   = up_kwargs
            


        # X4_0 -> X3_1
        self.up1    = UpsamplingLayer(args.net_dim[-1], args.net_dim[-2], is_deconv=True)
        self.conv1  = ConvBlock(args.net_dim[-1], args.net_dim[-2], num_conv=2)

        # X3_1 && X2_0 -> X2_2
        self.up2    = UpsamplingLayer(args.net_dim[-2], args.net_dim[-3], is_deconv=True)
        self.skip2  = SkipPathways(in_dim=args.net_dim[-3] + args.net_dim[-2], out_dim=args.net_dim[-3], lower_dim=args.net_dim[-2], path_length=1)
        self.conv2  = ConvBlock(args.net_dim[-2], args.net_dim[-3], num_conv=2)

        # X2_2 && X1_0 -> X1_3
        self.up3    = UpsamplingLayer(args.net_dim[-3], args.net_dim[-4], is_deconv=True)
        self.skip3  = SkipPathways(in_dim=args.net_dim[-4] + args.net_dim[-3], out_dim=args.net_dim[-4], lower_dim=args.net_dim[-3], path_length=2)
        self.conv3  = ConvBlock(args.net_dim[-3], args.net_dim[-4], num_conv=2)

        # X1_3 && X0_0 -> X0_4
        self.up4    = UpsamplingLayer(args.net_dim[-4], args.net_dim[-5], is_deconv=True)
        self.skip4  = SkipPathways(in_dim=args.net_dim[-5] + args.net_dim[-4], out_dim=args.net_dim[-5], lower_dim=args.net_dim[-4], path_length=3)
        self.conv4  = ConvBlock(args.net_dim[-4], args.net_dim[-5], num_conv=2)

        # for deep supervision
        self.dsp_l1 = nn.Conv2d(args.net_dim[-5], 2, kernel_size=1, stride=1, padding=0)
        self.dsp_l2 = nn.Conv2d(args.net_dim[-5], 2, kernel_size=1, stride=1, padding=0)
        self.dsp_l3 = nn.Conv2d(args.net_dim[-5], 2, kernel_size=1, stride=1, padding=0)
        self.dsp_l4 = nn.Conv2d(args.net_dim[-5], 2, kernel_size=1, stride=1, padding=0)

    def forward(self, enc_out, stage_outputs):

        x3_0 = stage_outputs[-1]
        x3_1 = self.up1(enc_out)
        x3_1 = torch.concat((x3_1, x3_0), dim=1) # x3_0, x4_0
        x3_1 = self.conv1(x3_1)

        x2_0 = stage_outputs[-2]
        x2_2 = self.up2(x3_1)
        x2_1, x2_j = self.skip2(x2_0, [x3_0])
        x2_2 = torch.concat((x2_2, x2_1), dim=1)
        x2_2 = self.conv2(x2_2)

        x1_0 = stage_outputs[-3]
        x1_3 = self.up3(x2_2)
        x1_2, x1_j = self.skip3(x1_0, x2_j)
        x1_3 = torch.concat((x1_3, x1_2), dim=1)
        x1_3 = self.conv3(x1_3)

        x0_0 = stage_outputs[-4]
        x0_4 = self.up4(x1_3)
        x0_3, x0_j = self.skip4(x0_0, x1_j)
        x0_4 = torch.concat((x0_4, x0_3), dim=1)
        x0_4 = self.conv4(x0_4)

        out_l1 = self.dsp_l1(x0_j[0])
        out_l2 = self.dsp_l2(x0_j[1])
        out_l3 = self.dsp_l3(x0_j[2])
        out_l4 = self.dsp_l4(x0_4)

        return out_l1, out_l2, out_l3, out_l4

In [None]:
class UNetpp(nn.modules):
    def __init__(self) -> None:
        super(UNetpp, self).__init__()

        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, inputs):
        
        enc_out, stage_outputs = self.encoder(inputs)

        out_l1, out_l2, out_l3, out_l4 = self.decoder(enc_out, stage_outputs)

        # for deep supervision
        # out = (out_l1 + out_l2 + out_l3 + out_l4) / 4

        return out_l1, out_l2, out_l3, out_l4
