In [None]:
class Up_Conv2D(tf.keras.layers.Layer):

    def __init__(self, 
                 num_channels,
                 kernel_size=(3,3),
                 nonlinearity='relu',
                 use_batchnorm = False,
                 use_transpose = False,
                 strides=(2,2),
                 data_format='channels_last',
                 name='upsampling_convolution_block'):

        super(Up_Conv2D, self).__init__(name=name)

        self.use_batchnorm = use_batchnorm
        self.upsample = tf.keras.layers.UpSampling2D(size=(2,2))
        self.conv = tf.keras.layers.Conv2D(num_channels, kernel_size, padding='same', data_format=data_format)
        self.batch_norm = tf.keras.layers.BatchNormalization(axis=-1)
        self.activation = tf.keras.layers.Activation(nonlinearity)
        self.use_transpose = use_transpose
        self.conv_transpose = tf.keras.layers.Conv2DTranspose(num_channels, kernel_size, padding='same', strides=strides, data_format=data_format)
        
    def call(self, inputs):
        
        x = inputs
        if self.use_transpose:
            x = self.conv_transpose(x)
        else:
            x = self.upsample(x)
            x = self.conv(x)
        if self.use_batchnorm:
            x = self.batch_norm(x)
        outputs = self.activation(x)

        return outputs

class Recurrent_block(tf.keras.layers.Layer):
    def __init__(self,
                 ch_out,
                 kernel_size=(3,3),
                 strides=(1,1),
                 padding='same'
                 nonlinearity='relu',
                 t=2,
                 data_format='channels_last',
                 name="Recurrent_block"):
        
        super(Recurrent_block,self).__init__()
        self.t = t
        self.ch_out = ch_out
        self.conv = tf.keras.Sequential([
            tf.keras.layers.Conv2D(ch_out,kernel_size,strides,padding,data_format),
            tf.keras.layers.BatchNormalization(axis=-1)
            tf.keras.layers.Activation(nonlinearity)])

    def call(self, x):
        for i in range(self.t):
            if i==0:
                x1 = self.conv(x)
            output = self.conv(x+x1)
        return output
    
        
class RRCNN_block(tf.keras.layers.Layer):
     def __init__(self,
                 ch_out,
                 kernel_size=(1,1),
                 strides=(1,1),
                 padding='same'
                 nonlinearity='relu',
                 t=2,
                 data_format='channels_last',
                 name="RRCNN_block"):
        
        super(RRCNN_block,self).__init__(name=name)
        
        self.RCNN = tf.keras.Sequential([
            Recurrent_block(ch_out,t=t),
            Recurrent_block(ch_out,t=t)])
        
        self.Conv_1x1 = tf.keras.layers.Conv2D(ch_out,kernel_size,strides,padding,data_format)

    def call(self, x):
        x = self.Conv_1x1(x)
        x1 = self.RCNN(x)
        output = x+x1
        return output

    #--------------------------------------------------------------------------------------#

class R2U_Net(tf.keras.Model):
    def __init__(self,num_channels,output_ch=1,t=2):
        
        super(R2U_Net,self).__init__()
        
        self.RRCNN1 = RRCNN_block(ch_out=num_channels,t=t)
        self.RRCNN2 = RRCNN_block(ch_out=num_channels*2,t=t)
        self.RRCNN3 = RRCNN_block(ch_out=num_channels*4,t=t)
        self.RRCNN4 = RRCNN_block(ch_out=num_channels*8,t=t)
        self.RRCNN5 = RRCNN_block(ch_out=nm_channels*16,t=t)

        self.Up5 = Up_Conv2D(ch_out=num_channels*8)
        self.Up_RRCNN5 = RRCNN_block(ch_out=num_channels*8,t=t)
        
        self.Up4 = Up_Conv2D(ch_out=num_channels*4)
        self.Up_RRCNN4 = RRCNN_block(ch_out=num_channels*4,t=t)
        
        self.Up3 = Up_Conv2D(ch_out=num_channels*2)
        self.Up_RRCNN3 = RRCNN_block(ch_out=num_channels*2,t=t)
        
        self.Up2 = Up_Conv2D(ch_out=num_channels)
        self.Up_RRCNN2 = RRCNN_block(ch_out=num_channels,t=t)

        self.Conv_1x1 = tf.keras.layers.Conv2D(output_ch,kernel_size=(1,1),strides=(1,1),padding='same')

        
    def call(self,x):
        # encoding path
        x1 = self.RRCNN1(x)

        x2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x1)
        x2 = self.RRCNN2(x2)
        
        x3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x2)
        x3 = self.RRCNN3(x3)

        x4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x3)
        x4 = self.RRCNN4(x4)

        x5 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x4)
        x5 = self.RRCNN5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = tf.keras.layers.concatenatet([x4,d5])
        d5 = self.Up_RRCNN5(d5)
        
        d4 = self.Up4(d5)
        d4 = tf.keras.layers.concatenate([x3,d4])
        d4 = self.Up_RRCNN4(d4)

        d3 = self.Up3(d4)
        d3 = tf.keras.layers.concatenate([x2,d3])
        d3 = self.Up_RRCNN3(d3)

        d2 = self.Up2(d3)
        d2 = tf.keras.layers.concatenate([x1,d2])
        d2 = self.Up_RRCNN2(d2)

        output = self.Conv_1x1(d2)

        return output

    #--------------------------------------------------------------------------------------#
    
    class R2AttU_Net(tf.keras.Model):
    """Tensorflow 2 Implementation of 'U-Net: Convolutional Networks for Biomedical Image Segmentation'
    https://arxiv.org/pdf/1804.03999.pdf. """

    def __init__(self, 
                 num_channels,
                 num_classes,
                 num_conv_layers=1,
                 kernel_size=(3,3),
                 strides=(1,1),
                 pool_size=(2,2),
                 use_bias=False,
                 padding='same',
                 nonlinearity='relu',
                 use_batchnorm = True,
                 use_transpose = True,
                 data_format='channels_last',
                 name="R2AttentionUNet"):

        super(R2AttU_Net, self).__init__(name=name)

        self.RRCNN1 = RRCNN_block(ch_out=num_channels,t=t)
        self.RRCNN2 = RRCNN_block(ch_out=num_channels*2,t=t)
        self.RRCNN3 = RRCNN_block(ch_out=num_channels*4,t=t)
        self.RRCNN4 = RRCNN_block(ch_out=num_channels*8,t=t)
        self.RRCNN5 = RRCNN_block(ch_out=nm_channels*16,t=t)

        self.up_conv_1 = Up_Conv2D(num_channels*8,(3,3),nonlinearity,use_batchnorm=True,data_format=data_format)
        self.up_conv_2 = Up_Conv2D(num_channels*4,(3,3),nonlinearity,use_batchnorm=True,data_format=data_format)
        self.up_conv_3 = Up_Conv2D(num_channels*2,(3,3),nonlinearity,use_batchnorm=True,data_format=data_format)
        self.up_conv_4 = Up_Conv2D(num_channels,(3,3),nonlinearity,use_batchnorm=True,data_format=data_format)

        self.a1 = Attention_Gate(num_channels*8,(1,1),nonlinearity,padding,strides,use_bias,data_format)
        self.a2 = Attention_Gate(num_channels*4,(1,1),nonlinearity,padding,strides,use_bias,data_format)
        self.a3 = Attention_Gate(num_channels*2,(1,1),nonlinearity,padding,strides,use_bias,data_format)
        self.a4 = Attention_Gate(num_channels,(1,1),nonlinearity,padding,strides,use_bias,data_format)

        self.u1 = RRCNN_block(ch_out=num_channels*8,t=t)
        self.u2 = RRCNN_block(ch_out=num_channels*4,t=t)
        self.u3 = RRCNN_block(ch_out=num_channels*2,t=t)
        self.u4 = RRCNN_block(ch_out=num_channels,t=t)
        
    def call(self, inputs):

        #ENCODER PATH
        x1 = self.RRCNN1(inputs)
        
        pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x1)
        x2 = self.RRCNN2(pool1)
        
        pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x2)
        x3 = self.RRCNN3(pool2)
        
        pool3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x3)
        x4 = self.RRCNN4(pool3)
        
        pool4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x4)
        x5 = self.RRCNN5(pool4)

        #DECODER PATH
        up4 = self.up_conv_1(x5)
        a1 = self.a1(up4, x4)
        y1 = tf.keras.layers.concatenate([a1, up4])
        y1 = self.u1(y1)

        up5 = self.up_conv_2(y1)
        a2 = self.a2(up5, x3)
        y2 = tf.keras.layers.concatenate([a2, up5])
        y2 = self.u2(y2)

        up6 = self.up_conv_3(y2)
        a3 = self.a3(up6, x2)
        y3 = tf.keras.layers.concatenate([a3, up6])
        y3 = self.u3(y3)

        up7 = self.up_conv_4(y3)
        a4 = self.a4(up7, x1)
        y4 = tf.keras.layers.concatenate([a4, up7])
        y4 = self.u4(y4)

        output = self.conv_1x1(y4)

        return output

