**U-Net: Convolutional Networks for Biomedical Image Segmentation**   
*Olaf Ronneberger, Philipp Fischer, Thomas Brox*   
[[arXiv]] [arXiv]: https://arxiv.org/abs/1505.04597
MICCAI 2015 

In [54]:
import torch
import torch.nn as nn

from easydict import EasyDict as edict

args = edict() 

# net dim 
args.in_dim     = 1 
args.init_dim   = 64
args.enc_depth  = 5
args.net_dim    = [args.init_dim*2**x for x in range(args.enc_depth)] # [64, 128, 256, 512, 1024]
args.out_dim    = 2


In [50]:
# encoder
class ContractingPath(nn.Module):
    def __init__(self, args=None) -> None:
        super(ContractingPath, self).__init__()

        if args is None:
            args = edict()
            args.input_dim = 1
            args.net_dim   = [64, 128, 256, 512, 1024]

        # input dim = 1 : gray scale image 
        # 572x572 input size
        self.conv1 = nn.Sequential(
                nn.Conv2d(args.input_dim, args.net_dim[0], kernel_size=3, stride=1, padding=0), # the paper model used zero padding in the contracting Path
                nn.ReLU(),
                nn.Conv2d(args.net_dim[0], args.net_dim[0], kernel_size=3, stride=1, padding=0),
                nn.ReLU()
            )         

        self.conv2 = nn.Sequential(
                nn.MaxPool2d(kernel_size=2, stride=2), # 2x2 max pooling
                nn.Conv2d(args.net_dim[0], args.net_dim[1], kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(args.net_dim[1], args.net_dim[1], kernel_size=3, stride=1, padding=0),
                nn.ReLU()
            )

        self.conv3 = nn.Sequential(
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(args.net_dim[1], args.net_dim[2], kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(args.net_dim[2], args.net_dim[2], kernel_size=3, stride=1, padding=0),
                nn.ReLU()
            )

        self.conv4 = nn.Sequential(
                nn.MaxPool2d(2,2),
                nn.Conv2d(args.net_dim[2], args.net_dim[3], kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(args.net_dim[3], args.net_dim[3], kernel_size=3, stride=1, padding=0),
                nn.ReLU()
            )

        self.conv5 = nn.Sequential(
                nn.MaxPool2d(2,2),
                nn.Conv2d(args.net_dim[3], args.net_dim[4], kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(args.net_dim[4], args.net_dim[4], kernel_size=3, stride=1, padding=0),
                nn.ReLU()
            )
        # 28x28 output size

    def forward(self, x):

        h1 = self.conv1(x)
        h2 = self.conv2(h1)
        h3 = self.conv3(h2)
        h4 = self.conv4(h3)
        h5 = self.conv5(h4)

        layer_outputs = [h1, h2, h3, h4] # U-net uses the outputs of each two consecutive convolution+ReLU, excepts 5th one. 

        return h5, layer_outputs


In [52]:
# deconder

class ExpansivePath(nn.Module):
    def __init__(self, args=None) -> None:
        super(ExpansivePath, self).__init__()

        if args is None:
            args = edict()
            args.out_dim = 2
            args.net_dim    = [64, 128, 256, 512, 1024]


        # they used "up-convolutions", 
        # -> Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2 convolution ("up-convolution")
        # -> up-convolution halves the # of feature channels. 
        self.upConv1 = nn.ConvTranspose2d(args.net_dim[-1], args.net_dim[-2], kernel_size=2, stride=2)
        self.conv1 = nn.Sequential(
                nn.Conv2d(args.net_dim[-1], args.net_dim[-2], kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(args.net_dim[-2], args.net_dim[-2], kernel_size=3, stride=1, padding=0),
                nn.ReLU()
            )

        self.upConv2 = nn.ConvTranspose2d(args.net_dim[-2], args.net_dim[-3], kernel_size=2, stride=2)
        self.conv2 = nn.Sequential(
                nn.Conv2d(args.net_dim[-2], args.net_dim[-3], kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(args.net_dim[-3], args.net_dim[-3], kernel_size=3, stride=1, padding=0),
                nn.ReLU()
            )

        self.upConv3 = nn.ConvTranspose2d(args.net_dim[-3], 128, kernel_size=2, stride=2)
        self.conv3 = nn.Sequential(
                nn.Conv2d(args.net_dim[-3], args.net_dim[-4], kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(args.net_dim[-4], args.net_dim[-4], kernel_size=3, stride=1, padding=0),
                nn.ReLU()
            )

        self.upConv4 = nn.ConvTranspose2d(args.net_dim[-4], args.net_dim[-5], kernel_size=2, stride=2)
        self.conv4 = nn.Sequential(
                nn.Conv2d(args.net_dim[-4], args.net_dim[-5], kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(args.net_dim[-5], args.net_dim[-5], kernel_size=3, stride=1, padding=0),
                nn.ReLU()
            )

        self.outConv = nn.Conv2d(args.net_dim[-5], args.out_dim, kernel_size=1, stride=1)
    
def forward(self, enc_out, layer_outputs):

        
        h = self.upConv1(enc_out)
        cropped = layer_outputs[-1][..., 4:h.shape[-2], 4:h.shape[-1]]
        h = torch.cat((h, cropped), dim=1) 
        h = self.conv1(h)

        h = self.upConv2(h)
        cropped = layer_outputs[-2][..., 16:h.shape[-2], 16:h.shape[-1]]
        h = torch.cat((h, cropped), dim=1)
        h = self.conv2(h)

        h = self.upConv3(h)
        cropped = layer_outputs[-3][..., 40:h.shape[-2], 40:h.shape[-1]]
        h = torch.cat((h, cropped), dim=1)
        h = self.conv3(h)

        h = self.upConv4(h)
        cropped = layer_outputs[-4][..., 88:h.shape[-2], 88:h.shape[-1]]
        h = torch.cat((h, cropped), dim=1)
        h = self.conv4(h)

        output = self.outConv(h)

        return output



In [53]:
class Unet(nn.Module):
    def __init__(self, args=None) -> None:
        super(Unet, self).__init__()

        self.encoder = ContractingPath(args)
        self.decoder = ExpansivePath(args)

    def forward(self, x):

        enc_out, layer_outputs = self.encoder(x)
        output = self.decoder(enc_out, layer_outputs)

        return output