# Unet Implementation

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class DoubleConv(nn.Module):
    """
    A double convolution block
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            # 1st conv layer
            nn.Conv2d(in_channels, out_channels, kernel_size=3),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            # 2nd conv layer
            nn.Conv2d(out_channels, out_channels, kernel_size=3),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

In [None]:
# center_crop (피처 , [h, w])
F.center_crop()

In [None]:
class UNet(nn.Module):
    """
    U-Net implementation
    1. encoder (contracting path)
    2. decoder (expansive path)
    3. skip connections
    """
    def __init__(self, in_channels=3, out_channels=1, init_features=64):
        super(UNet, self).__init__()

        features = init_features

        # Contracting path (Encoder)
        self.encoder1 = DoubleConv(in_channels, features)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.encoder2 = DoubleConv(features, features*2)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.encoder3 = DoubleConv(features*2, features*4)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.encoder4 = DoubleConv(features*4, features*8)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # bottleneck between enc & dec
        self.bottleneck = DoubleConv(features*8, features*16)

        # Expanding path (Decoder)
        self.upconv4 = nn.ConvTranspose2d(
            features*16, features*8, kernel_size=4, stride=2) # kernel_size=2로 가도 된다.(메모리 효율)
        self.decoder4 = DoubleConv(features*16, features*8)

        self.upconv3 = nn.ConvTranspose2d(
            features*8, features*4, kernel_size=4, stride=2
        )
        self.decoder3 = DoubleConv(features*8, features*4)

        self.upconv2 = nn.ConvTranspose2d(
            features*4, features*2, kernel_size=4, stride=2
        )
        self.decoder2 = DoubleConv(features * 4, features * 2)

        self.upconv1 = nn.ConvTranspose2d(
            features*2, features, kernel_size=4, stride=2
        )
        self.decoder1 = DoubleConv(features*2, features)

        # Final convolution to map to desired number of classes
        self.final_conv = nn.Conv2d(features, out_channels, kernel_size=1)

    def forward(self, x):
        # Contracting path (Encoder)
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        # bottleneck between enc & dec
        bottleneck = self.bottleneck(self.pool4(enc4))

        # Expanding path (Decoder)
        # 업샘플링 -> center crop -> concatenate -> double conv
        dec4 = self.upconv4(bottleneck)
        cropped_enc4 = F.center_crop(enc4, dec4.shape[2:]) # center_crop (피처 , [h, w])
        dec4 = self.decoder4(torch.cat([dec4, cropped_enc4], dim=1))

        dec3 = self.upconv3(dec4)
        cropped_enc3 = F.center_crop(enc3, dec3.shape[2:])
        dec3 = self.decoder3(torch.cat([dec3, cropped_enc3], dim=1))

        dec2 = self.upconv2(dec3)
        cropped_enc2 = F.center_crop(enc2, dec2.shape[2:])
        dec2 = self.decoder2(torch.cat([dec2, cropped_enc2], dim=1))

        dec1 = self.upconv1(dec2)
        cropped_enc1 = F.center_crop(enc1, dec1.shape[2:])
        dec1 = self.decoder1(torch.cat([dec1, cropped_enc1], dim=1))

        return self.final_conv(dec1)

## 직접 해보기
패딩을 주는 방식으로 unet을 설계해보자.(어떤 차이가 있을지 유념하자!)