In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from enum import Enum


class UNet(nn.Module):
    """UNet implementation from https://arxiv.org/abs/1505.04597
    
    Args:
        input_channels: Number of input channels to the network.
        n_classes: Number of output channels to the network.
        depth: The depth of the network.
    """
    def __init__(self, n_classes: int, depth: int):
        super().__init__()
        
        
class DoubleConvBlock(nn.Module):
    def __init__(self, 
                 in_channels:int, 
                 out_channels:int, 
                 kernel_size: int = 3,
                 padding: int = 1,
                 has_relu: bool = True, 
                 has_batch_norm: bool = True):
        super().__init__()
        conv_layers = []
        
        def addConvLayer():
            conv_layers.append(nn.Conv2d(in_channels, out_channels, kernel_size))
            if has_relu:
                conv_layers.append(nn.ReLU())

            if has_batch_norm:
                conv_layers.append(nn.BatchNorm2d(output_channels))
        
        addConvLayer()
        addConvLayer()
        self.conv_layers = nn.Sequential(*layers)
            
    def forward(self, x):
        return self.conv_layers(x)

class DownLayer(nn.Module):
    def __init__(self,
                 in_channels:int, 
                 out_channels:int, 
                 kernel_size: int = 3,
                 padding: int = 1,
                 has_relu: bool = True, 
                 has_batch_norm: bool = True
                 pooling_kernel_size: int = 2):
        super().__init__()
        self.layers = nn.Sequential(
            DoubleConvBlock(in_channels, 
                            out_channels, 
                            kernel_size,
                            padding,
                            has_relu, 
                            has_batch_norm),
            nn.MaxPool2d(pooling_kernel_size),
        )

    def forward(self, x):
        return self.layers(x)


class UpConvMode(Enum):
    CONV_TRANSPOSE = 1
    UPSAMPLE = 2
    
class InvalidUpConvModeError(Exception)
    pass

class UpLayer(nn.Module):
    def __init__(self, 
                 in_channels: int, 
                 out_channels: int, 
                 kernel_size: int,
                 padding: int,
                 has_relu: bool,
                 has_batch_norm: bool,
                 stride: int,
                 scale_factor: int,
                 up_conv_mode: bool):
        super().__init__()
        layers = []

        if up_conv_mode == UpConvMode.CONV_TRANSPOSE:
            layers.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride))
                          
        elif up_conv_mode == UpConvMode.UPSAMPLE:
            layers.append(nn.Upsample(mode = 'bilinear', scale_factor=scale_factor))
            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size))
                          
        else:
            raise InvalidUpConvModeError
            
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)
 

