In [9]:
import torch
import torch.nn as nn

In [14]:
class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size = 3, stride = 2):
    super(ConvBlock, self).__init__()
    self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size),
                              nn.ReLU(inplace = True),
                              nn.Conv2d(out_channels, out_channels, kernel_size),
                              nn.ReLU(inplace = True))

  def forward(self, x):
    return self.conv(x)

In [78]:
def crop_img(tensor, target_tensor):
  target_size = target_tensor.size()[2]
  tensor_size = tensor.size()[2]
  delta = tensor_size - target_size
  delta = delta // 2
  return tensor[:, :, delta:tensor_size - delta, delta:tensor_size - delta]

In [99]:
class Unet(nn.Module):
  def __init__ (self):
    super(Unet, self).__init__()

    self.maxpool = nn.MaxPool2d(kernel_size = 2, stride = 2)

    self.down_conv_1 = ConvBlock(1, 64)

    self.down_conv_2 = ConvBlock(64, 128)

    self.down_conv_3 = ConvBlock(128, 256)

    self.down_conv_4 = ConvBlock(256, 512)

    self.down_conv_5 = ConvBlock(512, 1024)

    self.up_trans_1 = nn.ConvTranspose2d(1024, 512, kernel_size = 2, stride = 2)

    self.up_conv_1 = ConvBlock(1024, 512)

    self.up_trans_2 = nn.ConvTranspose2d(512, 256, kernel_size = 2, stride = 2)

    self.up_conv_2 = ConvBlock(512, 256)

    self.up_trans_3 = nn.ConvTranspose2d(256, 128, kernel_size = 2, stride = 2)

    self.up_conv_3 = ConvBlock(256, 128)

    self.up_trans_4 = nn.ConvTranspose2d(128, 64, kernel_size = 2, stride = 2)

    self.up_conv_4 = ConvBlock(128, 64)

    self.one_conv = nn.Conv2d(64, 2, kernel_size = 1)

  def forward(self, images):
    #encoder
    x1 = self.down_conv_1(images)
    print(x1.shape)
    x2 = self.maxpool(x1)
    x3 = self.down_conv_2(x2)
    x4 = self.maxpool(x3)
    x5 = self.down_conv_3(x4)
    x6 = self.maxpool(x5)
    x7 = self.down_conv_4(x6)
    x8 = self.maxpool(x7)
    x9 = self.down_conv_5(x8)
    print(x8.shape)

    #decoder

    x = self.up_trans_1(x9)
    print(x.shape)
    y = crop_img(x7, x)
    x = self.up_conv_1(torch.cat([x, y], 1))
    print(x.shape)
    x = self.up_trans_2(x)
    y = crop_img(x5, x)
    x = self.up_conv_2(torch.cat([x, y], 1))
    x = self.up_trans_3(x)
    y = crop_img(x3, x)
    x = self.up_conv_3(torch.cat([x, y], 1))
    x = self.up_trans_4(x)
    y = crop_img(x1, x)
    x = self.up_conv_4(torch.cat([x, y], 1))
    print(x.shape)
    x = self.one_conv(x)
    print(x.shape)
    return x

In [100]:
model = Unet()

In [101]:
image = torch.rand((1, 1, 572, 572))

In [102]:
 m = model(image)

torch.Size([1, 64, 568, 568])
torch.Size([1, 512, 32, 32])
torch.Size([1, 512, 56, 56])
torch.Size([1, 512, 52, 52])
torch.Size([1, 64, 388, 388])
torch.Size([1, 2, 388, 388])
