**U-Net: Convolutional Networks for Biomedical Image Segmentation**   
*Olaf Ronneberger, Philipp Fischer, Thomas Brox*   
[[arXiv]] [arXiv]: https://arxiv.org/abs/1505.04597
MICCAI 2015 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import CIFAR10
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [7]:
conv1 = nn.Sequential(
                nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=0), # the paper model used zero padding in the contracting Path
                nn.ReLU(),
                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
                nn.ReLU()
        )



Sequential(
  (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (3): ReLU()
)


In [50]:
# encoder
class ContractingPath(nn.Module):
    def __init__(self, args=None) -> None:
        super(ContractingPath, self).__init__()

        # input dim = 1 : gray scale image 
        # 572x572 input size
        self.conv1 = nn.Sequential(
                nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=0), # the paper model used zero padding in the contracting Path
                nn.ReLU(),
                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
                nn.ReLU()
        )

        self.conv2 = nn.Sequential(
                nn.MaxPool2d(kernel_size=2, stride=2), # 2x2 max pooling
                nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0),
                nn.ReLU()
        )

        self.conv3 = nn.Sequential(
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0),
                nn.ReLU()
        )

        self.conv4 = nn.Sequential(
                nn.MaxPool2d(2,2),
                nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
                nn.ReLU()
        )

        self.conv5 = nn.Sequential(
                nn.MaxPool2d(2,2),
                nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0),
                nn.ReLU()
        )
        # 28x28 output size

    def forward(self, x):


                h1 = self.conv1(x)

                h2 = self.conv2(h1)

                h3 = self.conv3(h2)

                h4 = self.conv4(h3)

                h5 = self.conv5(h4)

                layer_outputs = [h1, h2, h3, h4] # U-net uses the outputs of each two consecutive convolution+ReLU, excepts 5th one. 

                return h5, layer_outputs


In [52]:
# deconder

class ExpansivePath(nn.Module):
    def __init__(self) -> None:
        super(ExpansivePath, self).__init__()

        # they used "up-convolutions", 
        # -> Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2 convolution ("up-convolution")
        # -> up-convolution halves the # of feature channels. 
        self.upConv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.conv1 = nn.Sequential(
                nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
                nn.ReLU()
        )

        self.upConv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv2 = nn.Sequential(
                nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0),
                nn.ReLU()
        )

        self.upConv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv3 = nn.Sequential(
                nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0),
                nn.ReLU()
        )

        self.upConv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv4 = nn.Sequential(
                nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
                nn.ReLU()
        )

        self.outConv = nn.Conv2d(64, 2, kernel_size=1, stride=1)
    
def forward(self, enc_out, layer_outputs):

        
        h = self.upConv1(enc_out)
        cropped = layer_outputs[-1][..., 4:h.shape[-2], 4:h.shape[-1]]
        h = torch.cat((h, cropped), dim=1) 
        h = self.conv1(h)

        h = self.upConv2(h)
        cropped = layer_outputs[-2][..., 16:h.shape[-2], 16:h.shape[-1]]
        h = torch.cat((h, cropped), dim=1)
        h = self.conv2(h)

        h = self.upConv3(h)
        cropped = layer_outputs[-3][..., 40:h.shape[-2], 40:h.shape[-1]]
        h = torch.cat((h, cropped), dim=1)
        h = self.conv3(h)

        h = self.upConv4(h)
        cropped = layer_outputs[-4][..., 88:h.shape[-2], 88:h.shape[-1]]
        h = torch.cat((h, cropped), dim=1)
        h = self.conv4(h)

        output = self.outConv(h)

        return output



In [53]:
class Unet(nn.Module):
    def __init__(self) -> None:
        super(Unet, self).__init__()

        self.encoder = ContractingPath()
        self.decoder = ExpansivePath()

    def forward(self, x):

        enc_out, layer_outputs = self.encoder(x)
        output = self.decoder(enc_out, layer_outputs)

        return output