In [1]:
# import the necessary packages
import import_ipynb
import config
from torch.nn import ConvTranspose2d
from torch.nn import Conv2d
from torch.nn import MaxPool2d
from torch.nn import Module
from torch.nn import ModuleList
from torch.nn import ReLU
from torch.nn import BatchNorm2d
from torchvision.transforms import CenterCrop
from torch.nn import functional as F
import torch

importing Jupyter notebook from config.ipynb


In [2]:
class conv(Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = BatchNorm2d(out_channels)
        self.conv2 = Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = BatchNorm2d(out_channels)
        self.relu = ReLU()
        
    def forward(self, images):
        x = self.conv1(images)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        
        return x

In [3]:
class Encoder(Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = conv(in_channels, out_channels)
        self.pool = MaxPool2d(2,2)
        
    def forward(self, images):
        x = self.conv(images)
        p = self.pool(x)
        
        return x, p

In [10]:
class Decoder(Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upconv = ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0)
        self.conv = conv(in_channels, out_channels)
        
    def forward(self, images, prev):
        x = self.upconv(images)
        # torch.size()[idk , channels, width, height]
        x = torch.cat((x, prev), 1)
        x  = self.conv(x)
        
        return x

In [11]:
class UNet(Module):
    def __init__(self):
        super().__init__()
        
        self.e1 = Encoder(3, 64)
        self.e2 = Encoder(64, 128)
        self.e3 = Encoder(128, 256)
        self.e4 = Encoder(256, 512)
        
        self.x5 = conv(512, 1024)
        
        self.d1 = Decoder(1024, 512)
        self.d2 = Decoder(512, 256)
        self.d3 = Decoder(256, 128)
        self.d4 = Decoder(128, 64)
        
        self.output = Conv2d(64, 1, kernel_size=1, padding=0)
        
    def forward(self, images):
        x1, p1 = self.e1(images)
        x2, p2 = self.e2(p1)
        x3, p3 = self.e3(p2)
        x4, p4 = self.e4(p3)

        x5 = self.x5(p4)

        d1 = self.d1(x5, x4)
        d2 = self.d2(d1, x3)
        d3 = self.d3(d2, x2)
        d4 = self.d4(d3, x1)

        output_mask = self.output(d4)

        return output_mask

In [12]:
if __name__ == '__main__':
    from torchsummary import summary
    
    net = UNet().cuda()
    summary(net,(3, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]           1,792
       BatchNorm2d-2         [-1, 64, 256, 256]             128
              ReLU-3         [-1, 64, 256, 256]               0
            Conv2d-4         [-1, 64, 256, 256]          36,928
       BatchNorm2d-5         [-1, 64, 256, 256]             128
              ReLU-6         [-1, 64, 256, 256]               0
              conv-7         [-1, 64, 256, 256]               0
         MaxPool2d-8         [-1, 64, 128, 128]               0
           Encoder-9  [[-1, 64, 256, 256], [-1, 64, 128, 128]]               0
           Conv2d-10        [-1, 128, 128, 128]          73,856
      BatchNorm2d-11        [-1, 128, 128, 128]             256
             ReLU-12        [-1, 128, 128, 128]               0
           Conv2d-13        [-1, 128, 128, 128]         147,584
      BatchNorm2d-14    