# Model Architecture setup
This notebook will give a quick explaination of the archtecture. the final architecture will live in a pyhton script.

In [1]:
import torch
from torch import nn

The UNet architecture consist of 3 main parts:
- Double convolution
- downsampling
- upsampling

![architecture schematic](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png) 

We are going the 2d route and i believe our data is greyscale so the the dimensions should be:
- [N, 1, H, W]
- this step will be different from normal double conc because we are not doubling the amount of samples but just creating 64
- after double conv: [N, 64, H-4, W-4]
- after downsampling: [N, 64, H/2, W/2]

The general idea:
- each double conv: double samples and, due to convolutional kernal being 3*3, H-4 and W-4

In [2]:
# Class holding the double convolution
class DoubleConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        use_norm: bool = True,
    ) -> None:
        super().__init__()
        
        # For now we are just considering 2d input
        # Thus expected input dimensions is [batch_size, channels, height, width] 
        # Or when not using batches [channels, height, width]
        # Or in the convention of pytorch: (N, C, H, W) or (C, H, W)

        # nn.Identity just return it's input so it's used as a replacement for normalization if normalization is not used
        # TO DO: find out what batchnorm does exactly
        # TO DO: Find out how exactly relu works
        conv = nn.Conv2d
        norm = nn.BatchNorm2d if use_norm else nn.Identity 
        activation_function = nn.ReLU

        layers = [
            conv(self.in_channels, self.out_channels),
            norm(self.out_channels),
            activation_function(inplace=True),
            conv(self.out_channels, self.out_channels),
            norm(self.out_channels),
            activation_function(inplace=True)    
        ]

        self.double_conv = nn.Sequential(*layers)

    def forward(self, x):
        return self.double_conv(x)

In [3]:
# Class holding the downsampling
class DownSample(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int
    ) -> None:

        super.__init__()

        # with kernel_size 2 and stride 2 the dimensions will be halved
        self.downsample = nn.MaxPool2d(
            kernel_size=2,
            stride=2
        )

        def forward(self, x):
            return self.downsample(x)

In [4]:
# Class holding the upsampling
class UpSample(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int
    ) -> None:

        super.__init__()

        # with kernel_size 2 and stride 2 the dimensions will be doubled
        self.upsample = nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size=2,
            stride=2
        )

        def forward(self, x):
            return self.upsample(x)

In [6]:
# Class holding the entire model:
class SegmentUnet(nn.Module):
    def __init__(
        self,
    ) -> None:
        super().__init__()

        # Define the input layer
        layers = [DoubleConv(self.n_dims, self.in_channels , self.in_channels*64, use_normalization)]

        # Encoder path

        # Decoder path