### Model

In [3]:
import tensorflow as tf
from tensorflow import keras
from keras.layers import (
    Conv2D,
    Conv3D,
    GlobalAveragePooling3D,
    GlobalMaxPooling3D,
    BatchNormalization, 
    Activation, 
    add,
    Reshape,
    multiply,
    Lambda,
    concatenate,
    Input,
    Permute,
    Dropout
)
from keras.models import Model
import keras.backend as K
from keras import layers, regularizers

physical_devices = tf.config.list_physical_devices('GPU')

tf.config.experimental.set_visible_devices(devices=physical_devices[0], device_type='GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)

In [4]:
keras.backend.set_image_data_format('channels_last')

#### Attention Block

In [6]:
# channel last
# for 3D model
class ChannelAttention(layers.Layer):
    """
    inputs: (None,height,width,time,channel)
    channel_attention
    """
    def __init__(self,channels,ratio=16,**kwargs):
        super(ChannelAttention,self).__init__(**kwargs)
        self.avg = GlobalAveragePooling3D()
        self.max = GlobalMaxPooling3D()
        self.channels = channels
        self.ratio = ratio
        self.conv1 = Conv3D(self.channels//self.ratio,kernel_size=1,strides=1,padding='same',
                            activation='relu',use_bias=True,bias_initializer='zeros',
                            kernel_initializer='he_normal',name = 'CA_Conv1')
        self.conv2 = Conv3D(self.channels,kernel_size=1,strides=1,padding='same',use_bias=True,
                            bias_initializer='zeros',kernel_initializer='he_normal',name = 'CA_Conv2')

    def call(self,inputs):
        #channel = inputs.get_shape().as_list()[-1]
        avg = self.avg(inputs)
        max = self.max(inputs)
        avg = Reshape((1,1,1,avg.shape[1]))(avg) # shape (None,1,1,1,channel)
        max = Reshape((1,1,1,max.shape[1]))(max)
        avg_out = self.conv2(self.conv1(avg))
        max_out = self.conv2(self.conv1(max))
        out = add([avg_out,max_out])
        out = Activation('sigmoid')(out)
        return multiply([inputs,out]) #out #
    
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'channels': self.channels,
            'ratio': self.ratio
        })
        return config



In [7]:
class SpatialAttention(layers.Layer):
    """
    spatial_attention
    kernel_size: kernel size for spatial domain
    """
    def __init__(self,kernel_size = 5,**kwargs): #之前为7
        super(SpatialAttention,self).__init__(**kwargs)
        self.kernel_size = kernel_size
        #卷积核数应为(5,5,1) 在时间维度上是1  
        self.conv1 = Conv3D(filters= 1,kernel_size=(self.kernel_size,self.kernel_size,1),strides=1,padding='same',
                            activation='sigmoid',kernel_initializer='he_normal',use_bias=False,name = 'SA_Conv1')
    
    def call(self,inputs):
        #输出shape:(None,H,W,T,1) axis = -1
        #输出应该为:(None,H,W,1,1) axis = [3,4]
        avg_out = Lambda(lambda x: K.mean(x, axis = [3,4], keepdims=True))(inputs) 
        max_out = Lambda(lambda x: K.max(x, axis = [3,4], keepdims=True))(inputs)
        concat = concatenate([avg_out,max_out],axis = -1)
        out = self.conv1(concat)
        return multiply([inputs,out]) #out #
    
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'kernel_size': self.kernel_size
        })
        return config


In [8]:
class TemporalAttention(layers.Layer):
    """
    channels: 输入的Conv中的通道数
    temporal_attention
    注意Time是第几维
    input shape: (None,H,W,T,C)
    channels:输入的time维度的通道数
    """
    def __init__(self,channels,**kwargs):
        super(TemporalAttention,self).__init__(**kwargs)
        self.avg = GlobalAveragePooling3D()
        self.max = GlobalMaxPooling3D()
        self.channels = channels
        self.conv1 = Conv3D(self.channels,kernel_size=1,strides=1,padding='same',
                            activation='relu',use_bias=True,bias_initializer='zeros',
                            kernel_initializer='he_normal',name = 'TA_Conv1')
        self.conv2 = Conv3D(self.channels,kernel_size=1,strides=1,padding='same',use_bias=True,
                            bias_initializer='zeros',kernel_initializer='he_normal',name = 'TA_Conv2')

    def call(self,inputs):
        x = Permute((1,2,4,3))(inputs)
        avg = self.avg(x)
        max = self.max(x)
        avg = Reshape((1,1,1,avg.shape[1]))(avg) # shape (None,1,1,1,time)
        max = Reshape((1,1,1,max.shape[1]))(max)
        avg_out = self.conv2(self.conv1(avg)) # 这里卷积核为(1,1,1)所以time还是最后一个维度
        max_out = self.conv2(self.conv1(max)) # 输出为 （None,1,1,1,64)
        avg_out = Permute((1,2,4,3))(avg_out) # change shape to (None,1,1,time,1)
        max_out = Permute((1,2,4,3))(max_out)
        out = add([avg_out,max_out])
        out = Activation('sigmoid')(out)
        return multiply([inputs,out]) #out #

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'channels': self.channels
        })
        return config
        

In [9]:
class RDAB(Model):
    """
    
    """
    def __init__(self,channels,t_dims,**kwargs):
        super(RDAB,self).__init__(**kwargs)
        
        self.channels = channels
        self.t_dims = t_dims
        self.c1 = Conv3D(filters=self.channels,kernel_initializer='he_normal',kernel_size=(3,3,3),activation='relu',padding='same',name = 'RDAB_Conv1')
        self.c2 = Conv3D(filters=self.channels,kernel_initializer='he_normal',kernel_size=(3,3,3),activation='relu',padding='same',name = 'RDAB_Conv2')
        self.c3 = Conv3D(filters=self.channels,kernel_initializer='he_normal',kernel_size=(3,3,3),activation='relu',padding='same',name = 'RDAB_Conv3')
        self.c4 = Conv3D(filters=self.channels,kernel_initializer='he_normal',kernel_size=(3,3,3),activation='relu',padding='same',name = 'RDAB_Conv4')
        self.c5 = Conv3D(filters=self.channels,kernel_initializer='he_normal',kernel_size=(3,3,3),activation='relu',padding='same',name = 'RDAB_Conv5')
        self.c6 = Conv3D(filters=self.channels,kernel_initializer='he_normal',kernel_size=(3,3,3),activation='relu',padding='same',name = 'RDAB_Conv6')
        self.c = Conv3D(filters=64,kernel_initializer='he_normal',kernel_size=(1,1,1),padding='same',name = 'RDAB_Conv7')
        #ChannelAttention中的channel要和前面最后一个的输出对应
        self.ca = ChannelAttention(self.channels)
        self.ta = TemporalAttention(self.t_dims)
        self.sa = SpatialAttention()


    def call(self,inputs):
        x1 = self.c1(inputs)
        #在channel维进行concatenate
        y1 = tf.concat([inputs,x1],axis=-1)

        x2 = self.c2(y1)
        y2 = tf.concat([inputs,x1,x2],axis=-1)

        x3 = self.c3(y2)
        y3 = tf.concat([inputs,x1,x2,x3],axis=-1)

        x4 = self.c4(y3)
        y4 = tf.concat([inputs,x1,x2,x3,x4],axis=-1)
        
        x5 = self.c5(y4)
        y5 = tf.concat([inputs,x1,x2,x3,x4,x5],axis=-1)

        x6 = self.c6(y5)
        y6 = tf.concat([inputs,x1,x2,x3,x4,x5,x6],axis=-1)

        y = self.c(y6)

        #注意力是串行的
        ca_out = self.ca(y)
        sa_out = self.sa(ca_out)
        ta_out = self.ta(sa_out)

        out = add([ta_out,inputs])
        return out

  


#### STARN

In [10]:
class STARN(Model):
    """
    endocer-decoder RDN with CBAM
    """
    def __init__(self,input_shape,num_G=64,**kwargs):
        super(STARN,self).__init__(**kwargs)
        self.num_G = num_G
        self.c_dims = input_shape[3]
        self.t_dims = input_shape[2]
        #self.channels = channels
        self.input_layer = Input(input_shape)
        self.conv1 = Conv3D(filters=64,kernel_size=(3,3,3),kernel_regularizer = regularizers.l2(5e-4),padding='same',name = 'STARN_Conv1')
        #,kernel_initializer='he_normal'
        # encoder
        self.Econv1 = Conv3D(filters=64,kernel_size=(3,3,3),padding='same',name = 'STARN_Econv1')
        self.Ebn1 = BatchNormalization(momentum=0.8,name = 'STARN_Ebn1')
        self.Econv2 = Conv3D(filters=64,kernel_size=(3,3,3),padding='same',name = 'STARN_Econv2')
        self.Ebn2 = BatchNormalization(momentum=0.8,name = 'STARN_Ebn2')
        self.Econv3 = Conv3D(filters=64,kernel_size=(3,3,3),padding='same',name = 'STARN_Econv3')
        self.Ebn3 = BatchNormalization(momentum=0.8,name = 'STARN_Ebn3')
        self.Econv4 = Conv3D(filters=64,kernel_size=(3,3,3),padding='same',name = 'STARN_Econv4')
        self.Ebn4 = BatchNormalization(momentum=0.8,name = 'STARN_Ebn4')
        
        #decoder
        self.Dconv1 = Conv3D(filters=64,kernel_size=(3,3,3),padding='same',name = 'STARN_Dconv1')
        self.Dbn1 = BatchNormalization(momentum=0.8,name = 'STARN_Dbn1')
        self.Dconv2 = Conv3D(filters=64,kernel_size=(3,3,3),padding='same',name = 'STARN_Dconv2')
        self.Dbn2 = BatchNormalization(momentum=0.8,name = 'STARN_Dbn2')
        self.Dconv3 = Conv3D(filters=64,kernel_size=(3,3,3),padding='same',name = 'STARN_Dconv3')
        self.Dbn3 = BatchNormalization(momentum=0.8,name = 'STARN_Dbn3')
        self.Dconv4 = Conv3D(filters=64,kernel_size=(3,3,3),padding='same',name = 'STARN_Dconv4')
        self.Dbn4 = BatchNormalization(momentum=0.8,name = 'STARN_Dbn4')

        #RDN_CBAM Block
        self.Block1 = RDAB(self.num_G,self.t_dims)
        self.Block2 = RDAB(self.num_G,self.t_dims)
        self.Block3 = RDAB(self.num_G,self.t_dims)
        self.Block4 = RDAB(self.num_G,self.t_dims)
        #self.Block5 = RDAB(self.num_G,self.t_dims)
        #self.Block6 = RDAB(self.num_G,self.t_dims)
        
        #activation sigmoid or linear
        self.conv2 = Conv3D(filters=1,kernel_size=(3,3,3),padding='same',name = 'STARN_Conv2')
        self.conv3 = Conv2D(filters=1,kernel_size=(2,2),padding='same',name = 'STARN_Conv3')
        self.ta = TemporalAttention(self.t_dims)
        self.bn1 = BatchNormalization(momentum=0.8,name = 'STARN_Bn1')
        #self.drop1 = Dropout(0.2)
        #self.drop2 = Dropout(0.2)
        #self.drop3 = Dropout(0.2)
        self.out = self.call(self.input_layer)
        super(STARN,self).__init__(
            inputs = self.input_layer,
            outputs = self.out,
            **kwargs
        )

    def call(self,inputs,training = False):
        x = self.conv1(inputs)
        x = Activation(activation='relu')(x)
        residual_1 = x
        #encoder1
        out = self.Econv1(x)
        out = self.Ebn1(out)
        out = Activation(activation='relu')(out)
        out = self.Econv2(out)
        out = self.Ebn2(out)
        out = Activation(activation='relu')(out)
        residual_2 = out

        #encoder2
        out = self.Econv3(out)
        out = self.Ebn3(out)
        out = Activation(activation='relu')(out)
        out = self.Econv4(out)
        out = self.Ebn4(out)
        out = Activation(activation='relu')(out)
        residual_3 = out
        
        #RDN_CBAM block
        # without concate
        out = self.Block1(out)
        out = self.Block2(out)
        out = self.Block3(out)
        out = self.Block4(out)
        #out = self.Block5(out)
        #out = self.Block6(out)

        #decoder2
        out = add([out,residual_3])
        out = self.Dconv1(out)
        out = self.Dbn1(out)
        out = Activation(activation='relu')(out)
        out = self.Dconv2(out)
        out = self.Dbn2(out)
        out = Activation(activation='relu')(out)

        #decoder1
        out = add([out,residual_2])
        out = self.Dconv3(out)
        out = self.Dbn3(out)
        out = Activation(activation='relu')(out)
        out = self.Dconv4(out)
        out = self.Dbn4(out)
        out = Activation(activation='relu')(out)

        #final
        out = add([out,residual_1])
        out = self.conv2(out)
        out = self.bn1(out)
        out = Activation(activation='relu')(out)
        #shape: (None,H,W,T,1)
        out = self.ta(out)
        out = Lambda(lambda x: K.mean(x, axis = [3], keepdims=False))(out) 
        out = self.conv3(out)
        out = Activation(activation='sigmoid')(out)
        #shape: (None,H,W,1)
        
        return out
    
    def build(self):
         # Initialize the graph
        self._is_graph_network = True
        self._init_graph_network(
            inputs=self.input_layer,
            outputs=self.out)
    

In [9]:
StarN2 = STARN(input_shape=(55, 125, 9, 14)) #shape （H,W,T,feature)

In [11]:
StarN2.summary()

Model: "starn_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 55, 125, 9,  0           []                               
                                 14)]                                                             
                                                                                                  
 STARN_Conv1 (Conv3D)           (None, 55, 125, 9,   24256       ['input_1[0][0]']                
                                64)                                                               
                                                                                                  
 activation (Activation)        (None, 55, 125, 9,   0           ['STARN_Conv1[0][0]']            
                                64)                                                         