In [1]:
import cv2
import time
import torch
import numpy as np
from torch import nn

In [13]:
class TwoConvolutions(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.block = nn.Sequential(
            nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = (1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = (1,1)),
            nn.ReLU(),
        )

    def forward(self, input_):
        
        output = self.block(input_)
        
        return output

In [14]:
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.TwoConvolutions = TwoConvolutions(in_channels, out_channels)
        self.max_pooling = nn.MaxPool2d(kernel_size = 2, stride = 2)

    def forward(self, image):
        
        skip_features = self.TwoConvolutions(image)
        features = self.max_pooling(skip_features)
        
        return features, skip_features

In [15]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.upConvolution = nn.ConvTranspose2d(in_channels, out_channels, kernel_size = 2, stride = 2)
        self.TwoConvolutions = TwoConvolutions(out_channels * 2, out_channels)

    def forward(self, input_, skip_input):
        
        features = self.upConvolution(input_)
        features = torch.cat([features, skip_input], dim = 1)
        features = self.TwoConvolutions(features)
        
        return features

In [17]:
class U_Net(nn.Module):
    def __init__(self, in_channels, out_channels, depth):
        super().__init__()
        
        self.channels = [in_channels] + [64 * (2 ** i) for i in range(depth + 1)]

        self.Encoder = nn.ModuleList([
            EncoderBlock(self.channels[i], self.channels[i+1]) for i in range(depth)
        ])

        self.Bottleneck = TwoConvolutions(self.channels[depth], self.channels[depth + 1])

        self.channels.reverse()
        self.channels.pop()

        self.Decoder = nn.ModuleList([
            DecoderBlock(self.channels[i], self.channels[i+1]) for i in range(depth)
        ])

        self.FinalConvolution = nn.Conv2d(in_channels = self.channels[-1], out_channels = out_channels, kernel_size = 3, stride = 1, padding = (1,1))

    def forward(self, image):
        
        encoder_features = []
        features = image

        for block in self.Encoder:
            features, skip_features = block(features)
            encoder_features.append(skip_features)

        features = self.Bottleneck(features)
        encoder_features.reverse()

        for idx, block in enumerate(self.Decoder):
            features = block(features, encoder_features[idx])

        mask = self.FinalConvolution(features)

        return mask