<a href="https://colab.research.google.com/github/Howl06/classify_project_final/blob/main/02_ResNet%2BUNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from tensorflow import keras 

### Config

In [None]:
IMG_SIZE = (224, 224, 3)

In [None]:
encoder = keras.applications.ResNet50(
    include_top=False,
    weights="imagenet",
    input_shape=IMG_SIZE,
)

In [None]:
encoder.summary()

Model: "resnet50"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv1_pad (ZeroPadding2D)      (None, 230, 230, 3)  0           ['input_3[0][0]']                
                                                                                                  
 conv1_conv (Conv2D)            (None, 112, 112, 64  9472        ['conv1_pad[0][0]']              
                                )                                                                 
                                                                                           

In [None]:
"""
(112, 112) conv1_relu 
(56, 56)   conv2_block3_out
(28, 28)   conv3_block4_out
(14, 14)   conv4_block6_out
(7, 7)     conv5_block3_out
"""

'\n(112, 112) conv1_relu \n(56, 56)   conv2_block3_out\n(28, 28)   conv3_block4_out\n(14, 14)   conv4_block6_out\n(7, 7)     conv5_block3_out\n'

In [None]:
def decoder_block(x, skip, filters):
    l = keras.layers.concatenate([
            keras.layers.Conv2DTranspose(filters=filters,
                                         kernel_size=3, 
                                         strides=2, 
                                         padding='same', 
                                         activation='relu')(x),
            skip],
        axis=-1)
    return l

In [None]:
def build_unet(encoder):
    """ Encoder """
    s0 = encoder.get_layer("conv1_relu").output # 112
    s1 = encoder.get_layer("conv2_block3_out").output # 56
    s2 = encoder.get_layer("conv3_block4_out").output # 28
    s3 = encoder.get_layer("conv4_block6_out").output # 14
    s4 = encoder.get_layer("conv5_block3_out").output # 7

    """ Decoder """
    d1 = decoder_block(s4, s3, 128) # 14
    d2 = decoder_block(d1, s2, 64) # 28
    d3 = decoder_block(d2, s1, 32) # 56
    d4 = decoder_block(d3, s0, 16) # 112

    d5 = keras.layers.Conv2DTranspose(filters=8,
                                      kernel_size=3, 
                                      strides=2, 
                                      padding='same', 
                                      activation='relu')(d4) # 224
    """ Output """
    outputs = keras.layers.Conv2D(1, 1, 
                                  padding="same", 
                                  activation="sigmoid")(d5)

    model = keras.models.Model(encoder.inputs, 
                               outputs, 
                               name="ResNet50_U-Net")
    return model

In [None]:
model = build_unet(encoder)

In [None]:
import numpy as np
# check output
imgs = np.random.normal(0, 1, (1, 224, 224, 3))
output = model(imgs)

print(output.shape)

(1, 224, 224, 1)


In [None]:
model.summary()

Model: "ResNet50_U-Net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv1_pad (ZeroPadding2D)      (None, 230, 230, 3)  0           ['input_3[0][0]']                
                                                                                                  
 conv1_conv (Conv2D)            (None, 112, 112, 64  9472        ['conv1_pad[0][0]']              
                                )                                                                 
                                                                                     