<a href="https://colab.research.google.com/github/DaeSeokSong/image-processing/blob/feature%2FUnet-scar/Unet_Scar.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ※ Precautions
# 1. RawDataset_Processor 실행 후 UNet 실행

<br>

---

<br>

# 『Reference』

* *Paper*
>   [U-net](https://paperswithcode.com/paper/u-net-convolutional-networks-for-biomedical)
>
> 기존 CNN은 Single classification task에 사용되었지만,
> 
> biomedical image processing 분야에서는 한 이미지 내의 모든 pixel을 classification 하는 Semantic segmentation task가 중요하게 사용되었다.
>
>  sliding window 방식을 사용하는 CNN 구조와 달리 검증된 patch는 넘기기 때문에 보다 빠른 처리가 가능한 구조이다.
> 
> 적은 양의 데이터로도 dataset argumentation을 통해 잘 학습시킬 수 있다.
>   * [U-net++](https://paperswithcode.com/paper/unet-a-nested-u-net-architecture-for-medical)
>   * [ResUNet++](https://paperswithcode.com/paper/resunet-an-advanced-architecture-for-medical)

<br>

* *Lecture*
> * [UNet architecture by pytorch](https://89douner.tistory.com/300)
> * [Train method of UNet](https://toitoitoi79.tistory.com/97)



# 1.Development enviroment

## 1) Import

### 1-1) Library

In [None]:
# U-net
import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from torchvision import transforms

from sklearn.model_selection import train_test_split

# Image processing
import cv2
import numpy as np
import matplotlib.pyplot as plt

from google.colab.patches import cv2_imshow
from google.colab import output

# ETC
import os
import time

from PIL import Image

### 1-2) Mount google drive

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
%cd /content/gdrive/MyDrive/Models/GAN_Scar
!ls -al

/content/gdrive/MyDrive/Models/GAN_Scar
total 144
drwx------ 2 root root  4096 Aug  9 13:27  Dataset
-rw------- 1 root root 86402 Aug 13 09:16  Image_segmentation-Scar.ipynb
-rw------- 1 root root 39995 Aug 15 11:40 'UNet architecture.PNG'
-rw------- 1 root root 15687 Aug 16 07:36  Unet-Scar.ipynb


# 2.Train U-Net

## 1) Grobal variable

In [None]:
# Path
MODEL_PATH = "/content/gdrive/MyDrive/Models/GAN_Scar"

TRAIN_SET_PATH = "/Dataset/train"
TEST_SET_PATH = "/Dataset/test"

IMAGES_PATH = '/images'
LABELS_PATH = '/labels'

# Train hyperparameter
LR = 1e-3
BATCH_SIZE = 4
EPOCH = 100

"""
GPU 사용이 가능하면 cuda 사용
아니면 CPU를 이용하여 학습
"""
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 2) Class

### 2-1) U-Net

#### 2-1-1) Architecture

<img src = "https://drive.google.com/uc?id=14CzAAaKv5v7pVfvugBRbD1xI4IuhmoyT"  width = 640>

#### 2-1-2) Build network

In [None]:
# torch.nn의 Module 클래스를 상속한, 커스텀 UNet 클래스
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # kernel size, stride, padding, bias는 거의 고정 >> predefine
        def ConvBatchReLU_2d(in_ch, out_ch, k_size=3, stride=1, padding=1, bias=True):
            layers = []

            # Add Conv layer
            layers += [nn.Conv2d(in_channels=in_ch,
                                 out_channels=out_ch,
                                 kernel_size=k_size,
                                 stride=stride,
                                 padding=padding,
                                 bias=bias
                                 )]

            # Add batch normalization layer
            layers += [nn.BatchNorm2d(num_features=out_ch)]

            # Add ReLU
            layers += [nn.ReLU()]

            # Define conv, ReLU step in contracting path
            CBR = nn.Sequential(*layers)

            return CBR

        """
        [Contracting path]
        >> 입력 이미지의 context 포착이 목적
        """
        # enc == encoder / n_m == n번째 stage(step)의 m번째 레이어
        self.enc1_1 = ConvBatchReLU_2d(in_ch=1, out_ch=64)
        self.enc1_2 = ConvBatchReLU_2d(in_ch=64, out_ch=64)

        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.enc2_1 = ConvBatchReLU_2d(in_ch=64, out_ch=128)
        self.enc2_2 = ConvBatchReLU_2d(in_ch=128, out_ch=128)

        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.enc3_1 = ConvBatchReLU_2d(in_ch=128, out_ch=256)
        self.enc3_2 = ConvBatchReLU_2d(in_ch=256, out_ch=256)

        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.enc4_1 = ConvBatchReLU_2d(in_ch=256, out_ch=512)
        self.enc4_2 = ConvBatchReLU_2d(in_ch=512, out_ch=512)

        self.pool4 = nn.MaxPool2d(kernel_size=2)

        self.enc5_1 = ConvBatchReLU_2d(in_ch=512, out_ch=1024)

        """
        [Expansive path]
        >> 세밀한 Localization을 위한 높은 차원의 채널을 갖는 Upsampling
        >> 얕은 레이어의 특집 맵을 결합
        """
        # dec == decoder
        self.dec5_1 = ConvBatchReLU_2d(in_ch=1024, out_ch=512)

        # up-conv 레이어는 채널을 복원을 해야하기 때문에 kernel size를
        # 대칭되는 MaxPool layer의 kernel size와 같도록 설정한다.
        self.unpool4 = nn.ConvTranspose2d(in_channels=512,
                                          out_channels=512,
                                          kernel_size=2,
                                          stride=2,
                                          padding=0,
                                          bias=True)
        
        # input channel은 up-conv와 대칭되는 enc, 두 레이어에서
        # 같은 크기의 채널로 오기 때문에 대칭 enc 레이어보다 input이 두 배 많다.
        self.dec4_2 = ConvBatchReLU_2d(in_ch=2 * 512, out_ch=512)
        self.dec4_1 = ConvBatchReLU_2d(in_ch=512, out_ch=256)

        self.unpool3 = nn.ConvTranspose2d(in_channels=256,
                                          out_channels=256,
                                          kernel_size=2,
                                          stride=2,
                                          padding=0,
                                          bias=True)
        
        self.dec3_2 = ConvBatchReLU_2d(in_ch=2 * 256, out_ch=256)
        self.dec3_1 = ConvBatchReLU_2d(in_ch=256, out_ch=128)

        self.unpool2 = nn.ConvTranspose2d(in_channels=128,
                                          out_channels=128,
                                          kernel_size=2,
                                          stride=2,
                                          padding=0,
                                          bias=True)
        
        self.dec2_2 = ConvBatchReLU_2d(in_ch=2 * 128, out_ch=128)
        self.dec2_1 = ConvBatchReLU_2d(in_ch=128, out_ch=64)

        self.unpool1 = nn.ConvTranspose2d(in_channels=64,
                                          out_channels=64,
                                          kernel_size=2,
                                          stride=2,
                                          padding=0,
                                          bias=True)
        
        self.dec1_2 = ConvBatchReLU_2d(in_ch=2 * 64, out_ch=64)
        self.dec1_1 = ConvBatchReLU_2d(in_ch=64, out_ch=64)

        # conv 1*1, N class for segmentation
        self.conv = nn.Conv2d(in_channels=64,
                              out_channels=2,
                              kernel_size=1,
                              stride=1,
                              padding=0,
                              bias=True)
        
    # x == input_image
    def forward(self, x):
        # encoder part
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        enc2_1 = self.enc1_1(pool1)
        enc2_2 = self.enc1_2(enc2_1)
        pool2 = self.pool1(enc2_2)
        
        enc3_1 = self.enc1_1(pool2)
        enc3_2 = self.enc1_2(enc3_1)
        pool3 = self.pool1(enc3_2)

        enc4_1 = self.enc1_1(pool3)
        enc4_2 = self.enc1_2(enc4_1)
        pool4 = self.pool1(enc4_2)

        enc5_1 = self.enc5_1(pool4)

        # decoder part
        dec5_1 = self.dec5_1(enc5_1)

        unpool4 = self.unpool4(dec5_1)
        """
        [Skip connection]
        >> Semantic segmentation에서는 위치정보가 중요하기에
        >> 이에 대한 소실 방지 차원에서 이전 연산했던 값을 더해준다.
        """
        # concatenate, 두 차원을 결합해준다.
        # dim=[0:batch, 1:channel, 2:height, 3:width]
        cat4 = torch.cat((unpool4, enc4_2), dim=1)
        dec4_2 = self.dec4_2(cat4)
        dec4_1 = self.dec4_1(dec4_2)

        unpool3 = self.unpool4(dec4_1)
        cat3 = torch.cat((unpool3, enc3_2), dim=1)
        dec3_2 = self.dec4_2(cat3)
        dec3_1 = self.dec4_1(dec3_2)

        unpool2 = self.unpool4(dec3_1)
        cat2 = torch.cat((unpool2, enc2_2), dim=1)
        dec2_2 = self.dec4_2(cat2)
        dec2_1 = self.dec4_1(dec2_2)

        unpool1 = self.unpool4(dec2_1)
        cat1 = torch.cat((unpool1, enc1_2), dim=1)
        dec1_2 = self.dec4_2(cat1)
        dec1_1 = self.dec4_1(dec1_2)

        x = self.conv(dec1_1)

        return x

### 2-2) Pytorch

#### 2-2-1) Dataset

In [None]:
class ScarDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform

        lst_data = os.listdir(self.data_dir)

        lst_scar = [f for f in lst_data if f.startswith('scar')]
        lst_label = [f for f in lst_data if f.startswith('label')]

        lst_scar.sort()
        lst_label.sort()

        self.lst_scar = lst_scar
        self.lst_label = lst_label

    def __len__(self):
        return len(self.lst_label)

    def __getitem__(self, index):
        scar = np.load(os.path.join(self.data_dir, self.lst_scar[index]))
        label = np.load(os.path.join(self.data_dir, self.lst_label[index]))

        scar = scar/255.0
        label = label/255.0

        if input.ndim == 2:
            scar = scar[:, :, np.newaxis]
        if label.ndim == 2:
            label = label[:, :, np.newaxis]

        data = {'scar': scar, 'label': label}

        if self.transform:
            data = self.transform(data)

        return data

#### 2-2-2) Transform

In [None]:
class ToTensor(object):
    def __call__(self, data):
        scar, label = data['scar'], data['label']
        scar = scar.transpose((2, 0, 1)).astype(np.float32)
        label = label.transpose((2, 0, 1)).astype(np.float32)

        data = {'scar': torch.from_numpy(scar), 'label': torch.from_numpy(label)}

        return data

class Normalization(object):
    def __init__(self, mean=0.5, std=0.5):
        self.mean = mean
        self.std = std

    def __call__(self, data):
        scar, label = data['scar'], data['label']

        scar = (scar - self.mean) / self.std

        data = {'scar': scar, 'label': label}

        return data

class RandomFlip(object):
    def __call__(self, data):
        scar, label = data['scar'], data['label']

        if np.random.rand() > 0.5:
            scar = np.fliplr(scar)
            label = np.fliplr(label)

        if np.random.rand() > 0.5:
            scar = np.flipud(scar)
            label = np.flipud(label)

        data = {'scar': scar, 'label': label}

        return data

## 3) Funtion

### 3-1) Convenience func

In [None]:
def imshow_waitkey_enter(image):
    cv2_imshow(image)

    time.sleep(0.5)
    
    input("Please press the Enter key to proceed\n")
    output.clear()

    pass

## 4) Run

### 4-1) Prepare dataset

### 4-2) Model