In [None]:
import tensorflow as tf
import tensorflow_addons as tfa

In [None]:
def Upsampling(inputs, factor=2):
    x = tf.keras.layers.Conv2D(filters=256*(factor**2), kernel_size=3,
                               kernel_initializer='he_normal', padding="same")(inputs)
    x = tf.nn.depth_to_space(x, block_size=factor)
    return x

In [None]:
def ChannelAttention(inputs):
    apool = tfa.layers.AdaptiveAveragePooling2D(1)(inputs)
    
    mpool = tfa.layers.AdaptiveMaxPooling2D(1)(inputs)
    
    aconv = tf.keras.layers.Conv2D(filters=256//16, kernel_size=1, strides=1,
                                    kernel_initializer='he_normal', padding='same')(apool)
    aconv = tf.keras.layers.Conv2D(filters=256, kernel_size=1, strides=1, 
                                    kernel_initializer='he_normal', padding='same')(aconv)
    
    mconv = tf.keras.layers.Conv2D(filters=256//16, kernel_size=1, strides=1,
                                    kernel_initializer='he_normal', padding='same')(mpool)
    mconv = tf.keras.layers.Conv2D(filters=256, kernel_size=1, strides=1, 
                                    kernel_initializer='he_normal', padding='same')(mconv)
    
    add = tf.keras.layers.Add()([aconv, mconv])
    x = tf.keras.layers.Activation('sigmoid')(add)
    return x

In [None]:
def RCAB(inputs):
    
    conv1 = tf.keras.layers.Conv2D(filters=256, kernel_size=3, strides=1, padding='same', 
                                   kernel_initializer='he_normal', activation='relu')(inputs)
    ca = ChannelAttention(inputs=conv1)
    conv2 = tf.keras.layers.Conv2D(filters=256, kernel_size=3, strides=1, padding='same', 
                                   kernel_initializer='he_normal', activation='relu')(ca)
    pool = tf.math.reduce_max(conv2, axis=-1, keepdims=True)
    sigmoid = tf.keras.layers.Activation('sigmoid')(pool)
    
    mul = tf.keras.layers.Multiply()([ca, sigmoid])
    add = tf.keras.layers.Add()([inputs, mul])
    return add

In [None]:
def ResidualGroup(inputs):
    x = RCAB(inputs)
    x = RCAB(x)
    x = RCAB(x)
    x = tf.keras.layers.Conv2D(filters=256, kernel_size=3, strides=1, 
                               kernel_initializer='he_normal', padding='same')(x)
    x = tf.keras.layers.Add()([inputs, x])
    return x

In [None]:
def CAM(inputs):
    shape = [tf.shape(inputs)[k] for k in range(4)]
    x = tf.expand_dims(inputs, axis=1)
    x = tf.keras.layers.Conv3D(filters=256, kernel_size=3, strides=1, padding='same', 
                               kernel_initializer='he_normal', activation='sigmoid')(x)
    x = tf.keras.layers.Lambda(lambda t: t * 0.2)(x)
    x = tf.reshape(x, [shape[0], shape[1], shape[2], -1])
    m = tf.keras.layers.Multiply()([inputs, x])
    a = tf.keras.layers.Add()([inputs, m])
    return a

In [None]:
def LAM(inputs):

    shape = [tf.shape(inputs)[k] for k in range(5)]
    x = tf.reshape(inputs, [shape[0], shape[1], -1])
    mul1 = tf.linalg.matmul(x, x, transpose_b=True)
    softmax = tf.keras.layers.Activation('softmax')(mul1)
    softmax = tf.keras.layers.Lambda(lambda t: t * 0.2)(softmax)
    mul2 = tf.linalg.matmul(softmax, x)
    x_reshape = tf.reshape(mul2, [shape[0], shape[1], shape[2], shape[3], shape[4]])
    add = tf.keras.layers.Add()([inputs, x_reshape])
    feature = tf.reshape(add, [shape[0], shape[2], shape[3], shape[1]*shape[4]])
    conv = tf.keras.layers.Conv2D(filters=256, kernel_size=1, strides=1, 
                                  kernel_initializer='he_normal', padding='same')(feature)
        
    return conv

In [None]:
def model():
    inputs = tf.keras.layers.Input((None,None,3))
    x = tf.keras.layers.Rescaling(scale=1.0 / 255)(inputs)
    
    x = tf.keras.layers.Conv2D(filters=256, kernel_size=3, strides=1,
                               kernel_initializer='he_normal', padding='same')(x)
    
    x1 = ResidualGroup(inputs=x)
    x2 = ResidualGroup(inputs=x)
    x3 = ResidualGroup(inputs=x)
    x4 = ResidualGroup(inputs=x)
    x5 = ResidualGroup(inputs=x)
    x6 = ResidualGroup(inputs=x)
    x7 = ResidualGroup(inputs=x)
    x8 = ResidualGroup(inputs=x)
    x9 = ResidualGroup(inputs=x)
    x10 = ResidualGroup(inputs=x)
    
    conv = tf.keras.layers.Conv2D(filters=256, kernel_size=3, strides=1,
                                  kernel_initializer='he_normal', padding='same')(x3)
    cam = CAM(inputs=conv)
    
    rg1 = tf.expand_dims(x1, axis=1)
    rg2 = tf.expand_dims(x2, axis=1)
    rg3 = tf.expand_dims(x3, axis=1)
    rg4 = tf.expand_dims(x4, axis=1)
    rg5 = tf.expand_dims(x5, axis=1)
    rg6 = tf.expand_dims(x6, axis=1)
    rg7 = tf.expand_dims(x7, axis=1)
    rg8 = tf.expand_dims(x8, axis=1)
    rg9 = tf.expand_dims(x9, axis=1)
    rg10 = tf.expand_dims(x10, axis=1)
    concat = tf.keras.layers.Concatenate(axis=1)([rg1, rg2, rg3, rg4, rg5, rg6, rg7, rg8, rg9, rg10])

    lam = LAM(inputs=concat)
    
    x = tf.keras.layers.Add()([cam, lam, x]) 
    
    x = Upsampling(x)
    x = Upsampling(x)
    x = tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1,
                               kernel_initializer='he_normal', padding='same')(x)
    x = tf.keras.layers.Rescaling(scale=255)(x)
    return tf.keras.Model(inputs, x)