In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from enum import Enum

class UNet(nn.Module):
    """UNet implementation from https://arxiv.org/abs/1505.04597
    
    Args:
        input_channels: Number of input channels to the network.
        n_classes: Number of output channels to the network.
        depth: The depth of the network.
    """
    def __init__(self, 
                 in_channels:int,
                 n_classes: int, 
                 depth: int, 
                 start_channels:int= 64, 
                 channel_scale_factor:int = 2):
        super().__init__()
        self.down_layers = nn.ModuleList()
        self.up_layers = nn.ModuleList()

        prev_layer_channels = in_channels
        out_channels = 64
        for i in range(depth):
            out_channels = 
            self.down_layers.append(DownLayer(prev_layer_channels,
                                              out_channels, 
                                              scale_factor = channel_scale_factor))
            prev_layer_channels = out_channels
            out_channels = prev_layer_channels * channel_scale_factor

        for i in range(depth):
            out_channels = 
            self.up_layers.append(UpLayer(prev_layer_channels
                 out_channels: int, 
                 scale_factor = channel_scale_factor))

            prev_layer_channels = out_channels
            out_channels = prev_layer_channels * channel_scale_factor


        self.conv1d = nn.Conv2d(prev_channels, n_classes, kernel_size=1)
        
    def forward(self, x):
        out = x 
        for down in self.down_layers:
            out = down(out)
        
        for down in self.up_layers:
            out = up(out)
        
        return self.conv1d(out)

        
class DoubleConvBlock(nn.Module):
    def __init__(self, 
                 in_channels:int, 
                 out_channels:int, 
                 kernel_size: int,
                 padding: int,
                 has_relu: bool, 
                 has_batch_norm: bool):
        super().__init__()
        conv_layers = []
        
        def addConvLayer(in_c, out_c):
            conv_layers.append(nn.Conv2d(in_c, out_c, kernel_size))
            if has_relu:
                conv_layers.append(nn.ReLU())

            if has_batch_norm:
                conv_layers.append(nn.BatchNorm2d(out_c))
        
        addConvLayer(input_channels, out_channels)
        addConvLayer(out_channels, out_channels)
        self.conv_layers = nn.Sequential(*layers)
            
    def forward(self, x):
        return self.conv_layers(x)

class DownLayer(nn.Module):
    def __init__(self,
                 in_channels:int, 
                 out_channels:int, 
                 kernel_size: int = 3,
                 padding: int = 1,
                 has_relu: bool = True, 
                 has_batch_norm: bool = True,
                 scale_factor: int = 2):
        super().__init__()
        self.layers = nn.Sequential(
            DoubleConvBlock(in_channels, 
                            out_channels, 
                            kernel_size,
                            padding,
                            has_relu, 
                            has_batch_norm),
            nn.MaxPool2d(scale_factor),
        )

    def forward(self, x):
        return self.layers(x)



class UpLayer(nn.Module):
    def __init__(self, 
                 in_channels: int, 
                 out_channels: int, 
                 kernel_size: int = 3,
                 padding: int = 1,
                 has_relu: bool = True, 
                 has_batch_norm: bool = True,
                 scale_factor: int = 2,
                 stride: int = 2):
        super().__init__()
        up_layers = []

        self.up_layers = nn.Sequential(nn.Upsample(mode = 'bilinear', scale_factor=scale_factor),
                                       nn.Conv2d(in_channels, out_channels, kernel_size))
                          
        self.conv_layers = DoubleConvBlock(in_channels, 
                                           out_channels / scale_factor, 
                                           kernel_size,
                                           padding,
                                           has_relu, 
                                           has_batch_norm)

    def forward(self, up_x, down_x):
        up_out = self.up_layers(up_x)
        h_up, w_up = up_out.size()[2:]
        h_down, w_down = down_x.size()[2:]
        diff_h = (h_down - h_up) / 2.0
        diff_w = (w_down - w_up) / 2.0
        crop_out = down_x[:, :, diff_h:(diff_h + h_up), diff_w:(diff_w + w_up)]
        concat_out = torch.cat([up_out, crop_out])
        return self.conv_layers(concat_out)