In [None]:
import tensorflow as tf
import tensorflow_addons as tfa

In [None]:
def Upsampling(inputs, factor=2):
    x = tf.keras.layers.Conv2D(filters=64*(factor**2), kernel_size=3, 
                               padding="same", kernel_initializer='he_normal')(inputs)
    x = tf.nn.depth_to_space(x, block_size=factor)
    return x

In [None]:
def drlm(inputs):
    #residual1
    r1 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1,
                                padding='same', kernel_initializer='he_normal')(inputs)
    r1 = tf.keras.layers.Activation('relu')(r1)
    r1 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1,
                                padding='same', kernel_initializer='he_normal')(r1)
    a1 = inputs + r1
    a1 = tf.keras.layers.Activation('relu')(a1)
    c1 = tf.keras.layers.Concatenate(axis=-1)([inputs, a1])
    
    #residual2
    r2 = tf.keras.layers.Conv2D(filters=64 * 2, kernel_size=3, strides=1,
                                padding='same', kernel_initializer='he_normal')(c1)
    r2 = tf.keras.layers.Activation('relu')(r2)
    r2 = tf.keras.layers.Conv2D(filters=64 * 2, kernel_size=3, strides=1,
                                padding='same', kernel_initializer='he_normal')(r2)
    a2 = c1 + r2
    a2 = tf.keras.layers.Activation('relu')(a2)
    c2 = tf.keras.layers.Concatenate(axis=-1)([c1, a2])
    
    #residual3
    r3 = tf.keras.layers.Conv2D(filters=64 * 4, kernel_size=3, strides=1,
                                padding='same', kernel_initializer='he_normal')(c2)
    r3 = tf.keras.layers.Activation('relu')(r3)
    r3 = tf.keras.layers.Conv2D(filters=64 * 4, kernel_size=3, strides=1,
                                padding='same', kernel_initializer='he_normal')(r3)
    a3 = c2 + r3
    a3 = tf.keras.layers.Activation('relu')(a3)
    c3 = tf.keras.layers.Concatenate(axis=-1)([c2, a3])
    
    #ca
    compression = tf.keras.layers.Conv2D(filters=64, kernel_size=1, strides=1, 
                                         padding='same', kernel_initializer='he_normal')(c3)
    pool = tfa.layers.AdaptiveAveragePooling2D(1)(compression)
    d3 = tf.keras.layers.Conv2D(filters=64//4, kernel_size=3, strides=1, dilation_rate=3, 
                                padding='same', kernel_initializer='he_normal', activation='relu')(pool)
    d5 = tf.keras.layers.Conv2D(filters=64//4, kernel_size=3, strides=1, dilation_rate=5, 
                                padding='same', kernel_initializer='he_normal', activation='relu')(pool)
    d7 = tf.keras.layers.Conv2D(filters=64//4, kernel_size=3, strides=1, dilation_rate=7, 
                                padding='same', kernel_initializer='he_normal', activation='relu')(pool)
    concat = tf.keras.layers.Concatenate(axis=-1)([d3,d5,d7])
    conv = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                  padding='same', kernel_initializer='he_normal')(concat)
    conv = tf.keras.layers.Activation('sigmoid')(conv)
    x = compression * conv
    return x

def model():
    inputs = tf.keras.layers.Input((None,None,3))
    inputs = tf.keras.layers.Rescaling(scale=(1.0/255))(inputs)
    
    x = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                               padding='same', kernel_initializer='he_normal')(inputs)
    
    b1 = drlm(x)
    c1 = tf.keras.layers.Concatenate(axis=-1)([x, b1])
    conv1 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                   padding='same', kernel_initializer='he_normal')(c1)
    
    b2 = drlm(conv1)
    c2 = tf.keras.layers.Concatenate(axis=-1)([c1, b2])
    conv2 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                   padding='same', kernel_initializer='he_normal')(c2)
    
    b3 = drlm(conv2)
    c3 = tf.keras.layers.Concatenate(axis=-1)([c2, b3])
    conv3 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                   padding='same', kernel_initializer='he_normal')(c3)
    a1 = conv3 + x
    
    b4 = drlm(a1)
    c4 = tf.keras.layers.Concatenate(axis=-1)([conv3, b4])
    conv4 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                   padding='same', kernel_initializer='he_normal')(c4)
    
    b5 = drlm(conv4)
    c5 = tf.keras.layers.Concatenate(axis=-1)([c4, b5])
    conv5 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                   padding='same', kernel_initializer='he_normal')(c5)
    
    b6 = drlm(conv5)
    c6 = tf.keras.layers.Concatenate(axis=-1)([c5, b6])
    conv6 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                   padding='same', kernel_initializer='he_normal')(c6)
    a2 = conv6 + a1
    
    b7 = drlm(a2)
    c7 = tf.keras.layers.Concatenate(axis=-1)([conv6, b7])
    conv7 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                   padding='same', kernel_initializer='he_normal')(c7)
    
    b8 = drlm(conv4)
    c8 = tf.keras.layers.Concatenate(axis=-1)([c7, b8])
    conv8 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                   padding='same', kernel_initializer='he_normal')(c8)
    
    b9 = drlm(conv5)
    c9 = tf.keras.layers.Concatenate(axis=-1)([c8, b9])
    conv9 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                   padding='same', kernel_initializer='he_normal')(c9)
    a3 = conv9 + a2
    
    b10 = drlm(a3)
    c10 = tf.keras.layers.Concatenate(axis=-1)([conv9, b10])
    conv10 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                   padding='same', kernel_initializer='he_normal')(c10)
    
    b11 = drlm(conv4)
    c11 = tf.keras.layers.Concatenate(axis=-1)([c10, b11])
    conv11 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                   padding='same', kernel_initializer='he_normal')(c11)
    
    b12 = drlm(conv5)
    c12 = tf.keras.layers.Concatenate(axis=-1)([c11, b12])
    conv12 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                   padding='same', kernel_initializer='he_normal')(c12)
    a4 = conv12 + a3
    
    b13 = drlm(a4)
    c13 = tf.keras.layers.Concatenate(axis=-1)([conv12, b13])
    conv13 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                   padding='same', kernel_initializer='he_normal')(c13)
    
    b14 = drlm(conv4)
    c14 = tf.keras.layers.Concatenate(axis=-1)([c13, b14])
    conv14 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                   padding='same', kernel_initializer='he_normal')(c14)
    
    b15 = drlm(conv5)
    c15 = tf.keras.layers.Concatenate(axis=-1)([c14, b15])
    conv15 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                   padding='same', kernel_initializer='he_normal')(c15)
    
    b16 = drlm(conv5)
    c16 = tf.keras.layers.Concatenate(axis=-1)([c15, b16])
    conv16 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                   padding='same', kernel_initializer='he_normal')(c16)
    a5 = conv16 + a4
    
    b17 = drlm(a5)
    c17 = tf.keras.layers.Concatenate(axis=-1)([conv16, b17])
    conv17 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                   padding='same', kernel_initializer='he_normal')(c17)
    
    b18 = drlm(conv4)
    c18 = tf.keras.layers.Concatenate(axis=-1)([c17, b18])
    conv18 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                   padding='same', kernel_initializer='he_normal')(c18)
    
    b19 = drlm(conv5)
    c19 = tf.keras.layers.Concatenate(axis=-1)([c18, b19])
    conv19 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                   padding='same', kernel_initializer='he_normal')(c19)
    
    b20 = drlm(conv5)
    c20 = tf.keras.layers.Concatenate(axis=-1)([c19, b20])
    conv20 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, 
                                   padding='same', kernel_initializer='he_normal')(c20)
    a6 = conv20 + a5
    a7 = a6 + x
    x = Upsampling(inputs=a7)
    x = Upsampling(inputs=x)
    
    x = tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1, 
                               padding='same', kernel_initializer='he_normal')(x)
    x = tf.keras.layers.Rescaling(scale=255)(x)
    
    return tf.keras.Model(inputs=inputs, outputs=x)