In [1]:
import torch
import torch.nn as nn
from torch import optim
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor, Compose, Normalize, RandomRotation, Resize
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import math

In [2]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size), 
            nn.ReLU(inplace=True), 
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, kernel_size), 
            nn.ReLU(inplace=True), 
            nn.BatchNorm2d(out_channels),
        )
    def forward(self, x):
        return self.conv(x)

In [3]:
def crop_and_concat(tens1, tens2):
    tens1_size = tens1.size()[2]
    tens2_size = tens2.size()[2]
    diff = (tens1_size - tens2_size) // 2
    tens1 = tens1[:, :, diff:tens1_size-diff, diff:tens1_size-diff]
    return torch.cat([tens1, tens2], dim=1)

In [4]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()
        self.max_pool = nn.MaxPool2d(2, 2)
        # encoder
        self.down_1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=174), 
            nn.ReLU(inplace=True), 
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, kernel_size=3), 
            nn.ReLU(inplace=True), 
            nn.BatchNorm2d(64),
        )
        self.down_2 = DoubleConv(64, 128)
        self.down_3 = DoubleConv(128, 256)
        self.down_4 = DoubleConv(256, 512)
        # bottle_neck
        self.conv_1 = nn.Conv2d(512, 1024, 3)
        self.conv_2 = nn.Conv2d(1024, 1024, 3)
        # decoder
        self.up_conv_1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.up_1 = DoubleConv(1024, 512)
        self.up_conv_2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up_2 = DoubleConv(512, 256)
        self.up_conv_3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up_3 = DoubleConv(256, 128)
        self.up_conv_4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up_4 = DoubleConv(128, 64)
        
        #output
        self.one_by_one_conv = nn.Conv2d(64,1, 1)

    def forward(self, x):
        #decoder
        x1 = self.down_1(x)
        x2 = self.max_pool(x1)

        x3 = self.down_2(x2)
        x4 = self.max_pool(x3)

        x5 = self.down_3(x4)
        x6 = self.max_pool(x5)

        x7 = self.down_4(x6)
        x8 = self.max_pool(x7)

        #bottle_neck
        x9 = self.conv_1(x8)
        x10 = self.conv_2(x9)

        #encoder
        x11 = self.up_conv_1(x10)
        x12 = crop_and_concat(x7, x11)
        x13 = self.up_1(x12)

        x14 = self.up_conv_2(x13)
        x15 = crop_and_concat(x5, x14)
        x16 = self.up_2(x15)

        x17 = self.up_conv_3(x16)
        x18 = crop_and_concat(x3, x17)
        x19 = self.up_3(x18)

        x20 = self.up_conv_4(x19)
        x21 = crop_and_concat(x1, x20)
        x22 = self.up_4(x21)
        
        #output
        x22 = self.one_by_one_conv(x22)
        return x22

In [9]:

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet().to(DEVICE)
x = torch.randn((1, 3, 224, 224)).to(DEVICE)
output = model(x)
print(output.shape)


torch.Size([1, 1, 388, 388])


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
