In [14]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import ReLU, Conv2D, Conv2DTranspose, Input, MaxPooling2D, add
from tensorflow.keras.utils import plot_model

In [None]:
def get_model(size: tuple):
    x = Input(shape=size)
    def _conv_block(_input, filters, kernel_size=(3, 3)):
        _y = Conv2D(filters=filters, kernel_size=kernel_size, padding="same")(_input)
        _y = ReLU()(_y)
        _y = MaxPooling2D(pool_size=(2, 2))(_y)
        return _y
    # 1
    y = _conv_block(x, 8)
    # 2
    y = _conv_block(y, 16)
    # 3
    y = _conv_block(y, 32)
    y_add1 = Conv2D(filters=3, kernel_size=(3, 3))(y)
    y_add1 = ReLU()(y_add1)
    # 4
    y = _conv_block(y, 64)
    y_add2 = Conv2D(filters=3, kernel_size=(3, 3))(y)
    y_add2 = ReLU()(y_add2)
    # 5
    y = _conv_block(y, 64)
    y_add3 = Conv2D(filters=3, kernel_size=(3, 3))(y)
    y_add3 = ReLU()(y_add3)
    y_add3 = Conv2DTranspose(filters=3, kernel_size=(4, 4), strides=2)(y_add3)
    y_add3 = ReLU()(y_add3)

    # 融合
    y_add_3_2 = add([y_add3, y_add2])
    y_add_3_2 = Conv2DTranspose(filters=3, kernel_size=(4, 4), strides=2)(y_add_3_2)
    y_add_3_2 = ReLU()(y_add_3_2)
    # 融合
    y_add_3_2_1 = add([y_add_3_2, y_add1])
    y_add_3_2_1 = Conv2DTranspose(filters=3, kernel_size=(16, 16), strides=8)(y_add_3_2_1)

    return Model(x, y_add_3_2_1)

In [16]:
model = get_model((256, 256, 3))
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 256, 256, 8)  224         input_4[0][0]                    
__________________________________________________________________________________________________
re_lu_20 (ReLU)                 (None, 256, 256, 8)  0           conv2d_17[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_11 (MaxPooling2D) (None, 128, 128, 8)  0           re_lu_20[0][0]                   
____________________________________________________________________________________________

In [None]:
plot_model(model, to_file="change_fcn.png", show_shapes=True)