# Unet
source: https://amaarora.github.io/2020/09/13/unet.html

<img src="https://i.imgur.com/LQORH9i.png" alt="drawing" width="500"/>


In [37]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import Input, Model, Sequential, layers
# import tensorflow_addons as tfa

In [38]:
import warnings
warnings.filterwarnings("ignore")

In [39]:
BATCH_SIZE = 32
NUM_LABELS = 1
WIDTH = 512
HEIGHT = 512

## ConvBlock
- 加入Instance Norm.
- <img src="https://miro.medium.com/max/983/1*p84Hsn4-e60_nZPllkxGZQ.png" width="50%">

> 上圖為一整個batch的feature-map。輸入6張圖片，輸入6chs, 輸出也是6chs(C方向看進去是channel, N方向看進去是圖片)

In [41]:
class convBlock(layers.Layer):
    def __init__(self, out_ch, padding='same', kernel_size=3):
        super().__init__()
        kernel_size = kernel_size
        pad_size = lambda kernel_size: (kernel_size-1)//2
        if padding == 'same':
            self.padding = pad_size(kernel_size)
        else:
            self.padding = padding

        self.conv_1 = layers.Conv2D(out_ch, (3, 3),
                                    strides=(1, 1), padding='same')
        self.relu = layers.Activation('relu')

        self.conv_2 = layers.Conv2D(out_ch, (3, 3),
                                    strides=(1, 1), padding='same')

    def call(self, input, training=None):
        x = self.conv_1(input)
        x = self.relu(x)
        x = self.conv_2(x)
        x = self.relu(x)
        return x

9:9: E731 do not assign a lambda expression, use a def


In [42]:
block = convBlock(64)
inputs = np.zeros((1, HEIGHT, WIDTH, 3), dtype=np.float32)
block(inputs).shape

TensorShape([1, 512, 512, 64])

## Encoder(DownStream)
將影像進行編碼，過程中解析度會縮小(maxpooling、convolution)

In [8]:
class Encoder(layers.Layer):
    def __init__(self, chs=(32, 64, 128, 256, 512), padding='same'):
        super().__init__()
        self.FPN_enc_ftrs = [convBlock(chs[i]) for i in range(len(chs))]
        self.pool = layers.MaxPooling2D(pool_size=(2, 2),
                                        strides=(2, 2), padding=padding)

    def call(self, x, training=None):
        features = []
        for block in self.FPN_enc_ftrs:
            x = block(x)
            features.append(x)
            x = self.pool(x)
        return features

In [9]:
encoder = Encoder()
inputs = np.zeros((1, HEIGHT, WIDTH, 3), dtype=np.float32)
features = encoder(inputs)
for f in features:
    print(f.shape)

## Decoder(UpStream)
將編碼還原成影像，過程中解析度會放大直到回復成輸入影像解析度(transposed Convolution)。
- 將編碼還原成影像是因為影像分割是pixel-wise的精度進行預測，解析度被還原後，就可以知道指定pixel位置所對應的類別
- 類別資訊通常用feature-map的channels(chs)去劃分，一個channel代表一個class
- 有許多UNet模型架構會有輸入576x576，但輸出只有388x388的情況，是因為他們沒有對卷積過程做padding，導致解析度自然下降。最後只要把mask resize到388x388就能繼續計算loss。

### Transposed Conv and UpsampleConv
<img src="https://i.imgur.com/eIIJxre.png" alt="drawing" width="300"/>
<img src="https://i.imgur.com/uLo7icF.png" alt="drawing" width="300"/>

Transposed Conv 
- 透過上面的操作做轉置卷積，feature-map上的數值會作為常數與kernel相乘
- 會導致Gridding Effect(棋盤格效應)

UpsampleConv
- 先做上採樣(Upsample/ Unpooling)
- 然後作卷積(padding = same)
<!-- #### 替代方案 UpSampling(Unpooling)+Convolution -->


In [9]:
# ConvTranspose2d透過設定k=2, s=2, output_padding=0可以讓影像從28x28變成56x56

x = np.zeros((1, 28, 28, 3), dtype=np.float32)
x = layers.Conv2DTranspose(30, kernel_size=(2, 2),
                           strides=(2, 2), padding='valid')(x)
x.shape

TensorShape([1, 56, 56, 30])

In [10]:
class UpSampleConvs(layers.Layer):
    def __init__(self, out_ch, padding='same'):
        super().__init__()
        self.conv = layers.Conv2D(out_ch, (3, 3),
                                  strides=(1, 1), padding=padding)
        self.relu = layers.Activation('relu')
        self.upSample = layers.UpSampling2D(size=2)
#         self.INorm = tfa.layers.InstanceNormalization(axis=3,
#                                                       center=True,
#                                                       scale=True)

    def call(self, x):
        x = self.upSample(x)
        x = self.conv(x)
        # x = self.INorm(x)
        x = self.relu(x)
        return x

In [11]:
x = np.zeros((1, 28, 28, 3), dtype=np.float32)
x = UpSampleConvs(30)(x)
print(x.shape)

(1, 56, 56, 30)


### decoder(上採樣) module

In [12]:
class Decoder(layers.Layer):
    def __init__(self, chs=(256, 128, 64, 32), padding='same'):
        super().__init__()

        self.chs = chs
        self.padding = padding
        # 上採樣後卷積
        self.upconvs = [UpSampleConvs(chs[i], padding=padding)
                        for i in range(len(chs))]
        self.FPN_dec_ftrs = [convBlock(chs[i], padding=padding)
                             for i in range(len(chs))]

    def call(self, x, encoder_features):
        for i in range(len(self.chs)):
            enc_ftrs = encoder_features[i]
            x = self.upconvs[i](x)

            # enc_ftrs = self.crop(encoder_features[i], x)
            x = layers.Concatenate(axis=-1)([x, enc_ftrs])
            x = self.FPN_dec_ftrs[i](x)
        return x

    def crop(self, enc_ftrs, x):
        _, H, W, _ = x.shape
        enc_ftrs = layers.CenterCrop(H, W)(enc_ftrs)
        return enc_ftrs

In [13]:
for i in features:
    print(i.shape)

(1, 512, 512, 32)
(1, 256, 256, 64)
(1, 128, 128, 128)
(1, 64, 64, 256)
(1, 32, 32, 512)


In [14]:
decoder = Decoder()
decoder
x = np.zeros((1, HEIGHT//16, WIDTH//16, 512), dtype=np.float32)
print(decoder(x, features[::-1][1:]).shape)

(1, 512, 512, 32)


## Unet構建
結合encoder和decoder組成Unet。
- 在輸出層如果用softmax做多元分類問題預測的話，類別數量要+1(num_classes+background)

In [15]:
class UNet(Model):
    def __init__(self, enc_chs=(64, 128, 256, 512, 1024),
                 dec_chs=(512, 256, 128, 64),
                 num_class=1, padding='same',
                 retain_dim=None, activation=None):
        super().__init__()
        self.encoder = Encoder(enc_chs, padding=padding)
        self.decoder = Decoder(dec_chs, padding=padding)
        self.head = layers.Conv2D(num_class, (1, 1),
                                  strides=(1, 1), padding=padding)
        self.retain_dim = retain_dim
        self.activation = activation

    def call(self, inputs):
        enc_ftrs = self.encoder(inputs)
        # 把不同尺度的所有featuremap都輸入decoder，我們在decoder需要做featuremap的拼接
        outputs = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        outputs = self.head(outputs)

        if self.retain_dim:
            outputs = tf.image.resize(outputs,
                                      self.retain_dim,
                                      method='nearest')

        if self.activation:
            outputs = self.activation(outputs)

        return outputs

In [16]:
unet = UNet(num_class=2, padding='same', retain_dim=(WIDTH, HEIGHT))
x = np.zeros((1, WIDTH, HEIGHT, 3), dtype=np.float32)
y_pred = unet(x)
print(y_pred.shape)

(1, 512, 512, 2)
