**U-Net: Convolutional Networks for Biomedical Image Segmentation**   
*Olaf Ronneberger, Philipp Fischer, Thomas Brox*   
[[arXiv]] [arXiv]: https://arxiv.org/abs/1505.04597
MICCAI 2015 

In [2]:
import torch
import torch.nn as nn



  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class ConvBlock(nn.Module):
    def __init__(self, in_dim, out_dim, conv_class, dropout_class, drop_p=0.0) -> None:
        super(ConvBlock, self).__init__()

        self.conv_layer = conv_class(in_channels=in_dim, out_channels=out_dim, kernel_size=3, stride=1, padding=1, bias=False)
        self.norm = nn.LayerNorm(out_dim)
        self.act = nn.LeakyReLU()

        if self.drop_p > 0:
            self.drop = dropout_class(p=drop_p)
        else:
            self.drop = None

        self._init_layer_weights()
        
    def _init_layer_weights(self):
        for module in self.modules():
            nn.init.xavier_uniform_(module.weight)

    def forward(self, x):

        x = self.conv_layer(x)
        x = self.norm(x)
        
        if self.drop:
            x = self.norm(x)
        
        x = self.act(x)

        return x

In [None]:
from typing import Union, List, Tuple


class UNet_encoder(nn.Module):
    def __init__(self, in_dim, hidden_dims:Union[List, Tuple], spatial_dim, drop_p=0.0) -> None:
        super(UNet_encoder, self).__init__()

        if spatial_dim == 3:
            conv_class = nn.Conv3d
            dropout_class = nn.Dropout3d
            pooling_class = nn.MaxPool3d
        else:
            conv_class = nn.Conv2d
            dropout_class = nn.Dropout2d
            pooling_class = nn.MaxPool2d

        self.pool = pooling_class(kernel_size=2, stride=2)
        self.layer1 = nn.Sequential(
            ConvBlock(in_dim=in_dim        , out_dim=hidden_dims[0], conv_class=conv_class, dropout_class=dropout_class, drop_p=drop_p),
            ConvBlock(in_dim=hidden_dims[0], out_dim=hidden_dims[0], conv_class=conv_class, dropout_class=dropout_class, drop_p=drop_p)
        )

        self.layer2 = nn.Sequential(
            ConvBlock(in_dim=hidden_dims[0], out_dim=hidden_dims[1], conv_class=conv_class, dropout_class=dropout_class, drop_p=drop_p),
            ConvBlock(in_dim=hidden_dims[1], out_dim=hidden_dims[1], conv_class=conv_class, dropout_class=dropout_class, drop_p=drop_p)
        )

        self.layer3 = nn.Sequential(
            ConvBlock(in_dim=hidden_dims[1], out_dim=hidden_dims[2], conv_class=conv_class, dropout_class=dropout_class, drop_p=drop_p),
            ConvBlock(in_dim=hidden_dims[2], out_dim=hidden_dims[2], conv_class=conv_class, dropout_class=dropout_class, drop_p=drop_p)
        )

        self.layer4 = nn.Sequential(
            ConvBlock(in_dim=hidden_dims[2], out_dim=hidden_dims[3], conv_class=conv_class, dropout_class=dropout_class, drop_p=drop_p),
            ConvBlock(in_dim=hidden_dims[3], out_dim=hidden_dims[3], conv_class=conv_class, dropout_class=dropout_class, drop_p=drop_p)
        )

        self.layer5 = nn.Sequential(
            ConvBlock(in_dim=hidden_dims[3], out_dim=hidden_dims[4], conv_class=conv_class, dropout_class=dropout_class, drop_p=drop_p),
            ConvBlock(in_dim=hidden_dims[4], out_dim=hidden_dims[4], conv_class=conv_class, dropout_class=dropout_class, drop_p=drop_p)
        )

        

    def forward(self, x):

        stage_outputs = {}

        x = self.layer1(x)
        stage_outputs['stage1'] = x

        x = self.pool(x)
        x = self.layer2(x)
        stage_outputs['stage2'] = x

        x = self.pool(x)
        x = self.layer3(x)
        stage_outputs['stage3'] = x

        x = self.pool(x)
        x = self.layer4(x)
        stage_outputs['stage4'] = x

        x = self.pool(x)
        x = self.layer5(x)

        return x, stage_outputs

In [None]:
from typing import Union, List, Tuple


class UNet_decoder(nn.Module):
    def __init__(self, out_dim, hidden_dims:Union[List, Tuple], spatial_dim, drop_p=0.0) -> None:
        super(UNet_encoder, self).__init__()

        if spatial_dim == 3:
            conv_class = nn.Conv3d
            dropout_class = nn.Dropout3d
            upconv_class = nn.ConvTranspose3d
        else:
            conv_class = nn.Conv2d
            dropout_class = nn.Dropout2d
            upconv_class = nn.ConvTranspose2d

        self.upconv1 = upconv_class(in_channels=hidden_dims[4], out_channels=hidden_dims[3], kernel_size=2, stride=2)
        self.layer1 = nn.Sequential(
            ConvBlock(in_dim=hidden_dims[3]*2, out_dim=hidden_dims[3], conv_class=conv_class, dropout_class=dropout_class, drop_p=drop_p),
            ConvBlock(in_dim=hidden_dims[3], out_dim=hidden_dims[3], conv_class=conv_class, dropout_class=dropout_class, drop_p=drop_p)
        )

        self.upconv2 = upconv_class(in_channels=hidden_dims[3], out_channels=hidden_dims[2], kernel_size=2, stride=2)
        self.layer2 = nn.Sequential(
            ConvBlock(in_dim=hidden_dims[2]*2, out_dim=hidden_dims[2], conv_class=conv_class, dropout_class=dropout_class, drop_p=drop_p),
            ConvBlock(in_dim=hidden_dims[2], out_dim=hidden_dims[2], conv_class=conv_class, dropout_class=dropout_class, drop_p=drop_p)
        )

        self.upconv3 = upconv_class(in_channels=hidden_dims[2], out_channels=hidden_dims[1], kernel_size=2, stride=2)
        self.layer3 = nn.Sequential(
            ConvBlock(in_dim=hidden_dims[1]*2, out_dim=hidden_dims[1], conv_class=conv_class, dropout_class=dropout_class, drop_p=drop_p),
            ConvBlock(in_dim=hidden_dims[1], out_dim=hidden_dims[1], conv_class=conv_class, dropout_class=dropout_class, drop_p=drop_p)
        )

        self.upconv4 = upconv_class(in_channels=hidden_dims[1], out_channels=hidden_dims[0], kernel_size=2, stride=2)
        self.layer4 = nn.Sequential(
            ConvBlock(in_dim=hidden_dims[0]*2, out_dim=hidden_dims[0], conv_class=conv_class, dropout_class=dropout_class, drop_p=drop_p),
            ConvBlock(in_dim=hidden_dims[0], out_dim=hidden_dims[0], conv_class=conv_class, dropout_class=dropout_class, drop_p=drop_p)
        )

        self.fc = conv_class(in_channels=hidden_dims[0], out_channels=out_dim, kernel_size=1, stride=1)

        

    def forward(self, h, stage_outputs):

        h = self.upconv1(h)
        h = torch.concat([h, stage_outputs['stage4']], dim=1) 
        h = self.layer1(h)

        h = self.upconv2(h)
        h = torch.concat([h, stage_outputs['stage3']], dim=1) 
        h = self.layer2(h)

        h = self.upconv3(h)
        h = torch.concat([h, stage_outputs['stage2']], dim=1) 
        h = self.layer3(h)

        h = self.upconv4(h)
        h = torch.concat([h, stage_outputs['stage1']], dim=1) 
        h = self.layer4(h)

        h = self.fc(h)

        return h

In [None]:
class UNet(nn.Module):
    def __init__(self, input_dim, out_dim, hidden_dims:Union[Tuple, List], spatial_dim, dropout_p=0.0) -> None:
        super(UNet, self).__init__()
        assert spatial_dim in [2,3] and hidden_dims

        self.encoder = UNet_encoder(in_dim=input_dim, hidden_dims=hidden_dims, spatial_dim=spatial_dim, drop_p=dropout_p)
        self.decoder = UNet_decoder(in_dim=input_dim, hidden_dims=hidden_dims, spatial_dim=spatial_dim, drop_p=dropout_p)

    
    def forward(self, x):

        enc_out, stage_outputs = self.encoder(x)
        out = self.decoder(enc_out, stage_outputs)

        return out

In [50]:
# encoder
class ContractingPath(nn.Module):
    def __init__(self, args=None) -> None:
        super(ContractingPath, self).__init__()

        if args is None:
            args = edict()
            args.input_dim = 1
            args.net_dim   = [64, 128, 256, 512, 1024]

        # input dim = 1 : gray scale image 
        # 572x572 input size
        self.conv1 = nn.Sequential(
                nn.Conv2d(args.input_dim, args.net_dim[0], kernel_size=3, stride=1, padding=0), # the paper model used zero padding in the contracting Path
                nn.ReLU(),
                nn.Conv2d(args.net_dim[0], args.net_dim[0], kernel_size=3, stride=1, padding=0),
                nn.ReLU()
            )         

        self.conv2 = nn.Sequential(
                nn.MaxPool2d(kernel_size=2, stride=2), # 2x2 max pooling
                nn.Conv2d(args.net_dim[0], args.net_dim[1], kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(args.net_dim[1], args.net_dim[1], kernel_size=3, stride=1, padding=0),
                nn.ReLU()
            )

        self.conv3 = nn.Sequential(
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(args.net_dim[1], args.net_dim[2], kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(args.net_dim[2], args.net_dim[2], kernel_size=3, stride=1, padding=0),
                nn.ReLU()
            )

        self.conv4 = nn.Sequential(
                nn.MaxPool2d(2,2),
                nn.Conv2d(args.net_dim[2], args.net_dim[3], kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(args.net_dim[3], args.net_dim[3], kernel_size=3, stride=1, padding=0),
                nn.ReLU()
            )

        self.conv5 = nn.Sequential(
                nn.MaxPool2d(2,2),
                nn.Conv2d(args.net_dim[3], args.net_dim[4], kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(args.net_dim[4], args.net_dim[4], kernel_size=3, stride=1, padding=0),
                nn.ReLU()
            )
        # 28x28 output size

    def forward(self, x):

        h1 = self.conv1(x)
        h2 = self.conv2(h1)
        h3 = self.conv3(h2)
        h4 = self.conv4(h3)
        h5 = self.conv5(h4)

        layer_outputs = [h1, h2, h3, h4] # U-net uses the outputs of each two consecutive convolution+ReLU, excepts 5th one. 

        return h5, layer_outputs


In [52]:
# deconder

class ExpansivePath(nn.Module):
    def __init__(self, args=None) -> None:
        super(ExpansivePath, self).__init__()

        if args is None:
            args = edict()
            args.out_dim = 2
            args.net_dim    = [64, 128, 256, 512, 1024]


        # they used "up-convolutions", 
        # -> Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2 convolution ("up-convolution")
        # -> up-convolution halves the # of feature channels. 
        self.upConv1 = nn.ConvTranspose2d(args.net_dim[-1], args.net_dim[-2], kernel_size=2, stride=2)
        self.conv1 = nn.Sequential(
                nn.Conv2d(args.net_dim[-1], args.net_dim[-2], kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(args.net_dim[-2], args.net_dim[-2], kernel_size=3, stride=1, padding=0),
                nn.ReLU()
            )

        self.upConv2 = nn.ConvTranspose2d(args.net_dim[-2], args.net_dim[-3], kernel_size=2, stride=2)
        self.conv2 = nn.Sequential(
                nn.Conv2d(args.net_dim[-2], args.net_dim[-3], kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(args.net_dim[-3], args.net_dim[-3], kernel_size=3, stride=1, padding=0),
                nn.ReLU()
            )

        self.upConv3 = nn.ConvTranspose2d(args.net_dim[-3], 128, kernel_size=2, stride=2)
        self.conv3 = nn.Sequential(
                nn.Conv2d(args.net_dim[-3], args.net_dim[-4], kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(args.net_dim[-4], args.net_dim[-4], kernel_size=3, stride=1, padding=0),
                nn.ReLU()
            )

        self.upConv4 = nn.ConvTranspose2d(args.net_dim[-4], args.net_dim[-5], kernel_size=2, stride=2)
        self.conv4 = nn.Sequential(
                nn.Conv2d(args.net_dim[-4], args.net_dim[-5], kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(args.net_dim[-5], args.net_dim[-5], kernel_size=3, stride=1, padding=0),
                nn.ReLU()
            )

        self.outConv = nn.Conv2d(args.net_dim[-5], args.out_dim, kernel_size=1, stride=1)
    
def forward(self, enc_out, layer_outputs):

        
        h = self.upConv1(enc_out)
        cropped = layer_outputs[-1][..., 4:h.shape[-2], 4:h.shape[-1]]
        h = torch.cat((h, cropped), dim=1) 
        h = self.conv1(h)

        h = self.upConv2(h)
        cropped = layer_outputs[-2][..., 16:h.shape[-2], 16:h.shape[-1]]
        h = torch.cat((h, cropped), dim=1)
        h = self.conv2(h)

        h = self.upConv3(h)
        cropped = layer_outputs[-3][..., 40:h.shape[-2], 40:h.shape[-1]]
        h = torch.cat((h, cropped), dim=1)
        h = self.conv3(h)

        h = self.upConv4(h)
        cropped = layer_outputs[-4][..., 88:h.shape[-2], 88:h.shape[-1]]
        h = torch.cat((h, cropped), dim=1)
        h = self.conv4(h)

        output = self.outConv(h)

        return output



In [53]:
class Unet(nn.Module):
    def __init__(self, args=None) -> None:
        super(Unet, self).__init__()

        self.encoder = ContractingPath(args)
        self.decoder = ExpansivePath(args)

    def forward(self, x):

        enc_out, layer_outputs = self.encoder(x)
        output = self.decoder(enc_out, layer_outputs)

        return output