<a href="https://colab.research.google.com/github/SonnetSaif/U-NET-from-scratch_PyTorch/blob/main/U_NET_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import torch
import torch.nn as nn
from torchvision import transforms

In [12]:
class DoubleConv(nn.Module):
  def __init__(self, in_c, out_c):
    super().__init__()
    self.conv = nn.Sequential(
      nn.Conv2d(in_c, out_c, kernel_size=3),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_c, out_c, kernel_size=3),
      nn.ReLU(inplace=True),
    )

  def forward(self, x):
    out = self.conv(x)
    return out

In [25]:
class UNet(nn.Module):
  def __init__(self):
    super().__init__()

    self.down_conv1 = DoubleConv(1, 64)
    self.down_conv2 = DoubleConv(64, 128)
    self.down_conv3 = DoubleConv(128, 256)
    self.down_conv4 = DoubleConv(256, 512)
    self.down_conv5 = DoubleConv(512, 1024)
    self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

    self.up_transpose1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)
    self.up_conv1 = DoubleConv(1024, 512)
    self.up_transpose2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
    self.up_conv2 = DoubleConv(512, 256)
    self.up_transpose3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
    self.up_conv3 = DoubleConv(256, 128)
    self.up_transpose4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
    self.up_conv4 = DoubleConv(128, 64)

    self.out = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=2)

  def forward(self, image):
    #encoder
    x1 = self.down_conv1(image)
    print(x1.size())
    x2 = self.max_pool(x1)
    # print(x2.size())
    x3 = self.down_conv2(x2)
    # print(x3.size())
    x4 = self.max_pool(x3)
    # print(x4.size())
    x5 = self.down_conv3(x4)
    # print(x5.size())
    x6 = self.max_pool(x5)
    # print(x6.size())
    x7 = self.down_conv4(x6)
    # print(x7.size())
    x8 = self.max_pool(x7)
    # print(x8.size())
    x9 = self.down_conv5(x8)
    print(x9.size())

    #decoder
    x10 = self.up_transpose1(x9)
    transform = transforms.CenterCrop(x10.size()[2])
    x7_trans = transform(x7)
    x11 = self.up_conv1(torch.cat([x10, x7_trans], 1))

    x12 = self.up_transpose2(x11)
    transform = transforms.CenterCrop(x12.size()[2])
    x5_trans = transform(x5)
    x13 = self.up_conv2(torch.cat([x12, x5_trans], 1))

    x14 = self.up_transpose3(x13)
    transform = transforms.CenterCrop(x14.size()[2])
    x3_trans = transform(x3)
    x15 = self.up_conv3(torch.cat([x14, x3_trans], 1))

    x16 = self.up_transpose4(x15)
    transform = transforms.CenterCrop(x16.size()[2])
    x1_trans = transform(x1)
    x17 = self.up_conv4(torch.cat([x16, x1_trans], 1))

    out = self.out(x17)
    print(out.size())
    return out

In [28]:
if __name__=="__main__":
  image = torch.rand(1, 1, 572, 572)
  model = UNet()
  model(image)

torch.Size([1, 64, 568, 568])
torch.Size([1, 1024, 28, 28])
torch.Size([1, 2, 387, 387])
