### U-Net
https://arxiv.org/pdf/1505.04597

In [30]:
import torch
import torch.nn as nn
import torchvision

In [31]:
def doubleConv(in_channels, out_channels):
    conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3),
        nn.ReLU(inplace=True)
    )
    return conv

def crop(tensor, target_tensor):
    _, _, h, w = target_tensor.size()
    tensor = torchvision.transforms.functional.center_crop(tensor, [h, w])
    return tensor


In [34]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.maxpool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1 = doubleConv(1, 64)
        self.down_conv_2 = doubleConv(64, 128)
        self.down_conv_3 = doubleConv(128, 256)
        self.down_conv_4 = doubleConv(256, 512)
        self.down_conv_5 = doubleConv(512, 1024)

        self.up_conv_1 = doubleConv(1024, 512)
        self.up_conv_2 = doubleConv(512, 256)
        self.up_conv_3 = doubleConv(256, 128)
        self.up_conv_4 = doubleConv(128, 64)

        self.up_trans_1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.up_trans_2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up_trans_3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up_trans_4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.out = nn.ConvTranspose2d(64, 2, kernel_size=1)


    def forward(self, x):
            # Encoder
            x1 = self.down_conv_1(x)
            x2 = self.maxpool_2x2(x1)

            x3 = self.down_conv_2(x2)
            x4 = self.maxpool_2x2(x3)

            x5 = self.down_conv_3(x4)
            x6 = self.maxpool_2x2(x5)

            x7 = self.down_conv_4(x6)
            x8 = self.maxpool_2x2(x7)

            x9 = self.down_conv_5(x8)

            # Decoder
            x10 = self.up_trans_1(x9)
            y10 = crop(x7, x10)
            x11 = self.up_conv_1(torch.cat([x10, y10], 1))

            x12 = self.up_trans_2(x11)
            y12 = crop(x5, x12)
            x13 = self.up_conv_2(torch.cat([x12, y12], 1))

            x14 = self.up_trans_3(x13)
            y14 = crop(x3, x14)
            x15 = self.up_conv_3(torch.cat([x14, y14], 1))

            x16 = self.up_trans_4(x15)
            y16 = crop(x1, x16)
            x17 = self.up_conv_4(torch.cat([x16, y16], 1))

            x_out = self.out(x17)
            
            return x_out

In [35]:
image = torch.rand((1, 1, 572, 572))

model = UNet()

preds = model(image)

print(preds.shape)

torch.Size([1, 2, 388, 388])
