# Model Architecture setup
This notebook will give a quick explaination of the archtecture. the final architecture will live in a pyhton script.

In [5]:
import torch
from torch import nn

The UNet architecture consist of 3 main parts:
- Double convolution
- downsampling
- upsampling

![architecture schematic](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png) 

We are going the 2d route and i believe our data is greyscale so the the dimensions should be:
- [N, 1, H, W]
- this step will be different from normal double conc because we are not doubling the amount of samples but just creating 64
- after double conv: [N, 64, H-4, W-4]
- after downsampling: [N, 64, H/2, W/2]

The general idea:
- each double conv: double samples and, due to convolutional kernal being 3*3, H-4 and W-4

In [24]:
# Class holding the double convolution
class DoubleConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        use_norm: bool = True,
    ) -> None:
        super().__init__()
        
        # For now we are just considering 2d input
        # Thus expected input dimensions is [batch_size, channels, height, width] 
        # Or when not using batches [channels, height, width]
        # Or in the convention of pytorch: (N, C, H, W) or (C, H, W)

        # nn.Identity just return it's input so it's used as a replacement for normalization if normalization is not used
        # TO DO: find out what batchnorm does exactly
        # TO DO: Find out how exactly relu works
        conv = nn.Conv2d
        norm = nn.BatchNorm2d if use_norm else nn.Identity 
        activation_function = nn.ReLU

        layers = [
            conv(in_channels, out_channels, 3),
            norm(out_channels),
            activation_function(inplace=True),
            conv(out_channels, out_channels, 3),
            norm(out_channels),
            activation_function(inplace=True)    
        ]

        self.double_conv = nn.Sequential(*layers)

    def forward(self, x):
        return self.double_conv(x)

In [29]:
# Class holding the downsampling
class DownSample(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int
    ) -> None:

        super().__init__()

        # with kernel_size 2 and stride 2 the dimensions will be halved
        self.downsample = nn.MaxPool2d(
            kernel_size=2,
            stride=2
        )

        def forward(self, x):
            return self.downsample(x)

In [33]:
# Class holding the upsampling
class UpSample(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int
    ) -> None:

        super().__init__()

        # with kernel_size 2 and stride 2 the dimensions will be doubled
        self.upsample = nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size=2,
            stride=2
        )

        def forward(self, x):
            return self.upsample(x)

In [39]:
# Class holding the entire model:
class SegmentUnet(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        depth: int = 3,
        first_layer_channel_count: int = 64,
        use_norm: bool = True
    ) -> None:
        super().__init__()
        
        # Define the input layer
        layers = [
            DoubleConv(
                in_channels, 
                first_layer_channel_count, 
                use_norm
            ),
            DownSample(
                in_channels=first_layer_channel_count,
                out_channels=first_layer_channel_count*2
            )
        ]

        current_channels = first_layer_channel_count*2
        # Encoder path
        for _ in range(depth-1): # minus one to account for input layer
            layers.append(DoubleConv(in_channels=current_channels, out_channels=current_channels, use_norm=use_norm))
            layers.append(DownSample(in_channels=current_channels, out_channels=out_channels*2))
            current_channels = current_channels*2 # double channel count each encoder block
        
        # Middle layer
        layers.append(DoubleConv(in_channels=current_channels, out_channels=current_channels, use_norm=use_norm))
        
        # Decoder path
        for _ in range(depth-1): # minus one to account for output layer
            # Keep in mind here that the double conv layers here gets both the output  
            # of upsample concatanated with the skip conncention.
            # So number of channels is doubled
            # We control this by concatanation in the second dimension 
            # the convention of pytorch: (N, C, H, W) or (C, H, W)
            
            layers.append(UpSample(in_channels=current_channels*2, out_channels=current_channels))
            layers.append(DoubleConv(in_channels=current_channels, out_channels=current_channels, use_norm=use_norm))
            current_channels = current_channels // 2

        # Output layer
        layers.append(UpSample(in_channels=current_channels*2, out_channels=current_channels))
        layers.append(DoubleConv(in_channels=current_channels, out_channels=out_channels, use_norm=use_norm))

        # Concatanate layers together
        self.layers = nn.ModuleList(layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        pass
        #for layer in self.layers[0:depth]

In [40]:
# Testing model creation:
model = SegmentUnet(
    in_channels=1,
    out_channels=2
)

model.layers

ModuleList(
  (0): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (1): DownSample(
    (downsample): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (2): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (3): DownSample(
