In [None]:
class Conv2D_Block(tf.keras.Sequential):

    def __init__(self,
                 num_channels,
                 num_conv_layers=2,
                 kernel_size=(3,3),
                 nonlinearity='relu',
                 use_batchnorm = False,
                 use_dropout = False,
                 dropout_rate = 0.25, 
                 use_spatial_dropout = True,
                 data_format='channels_last',
                 **kwargs):

        super(Conv2D_Block, self).__init__(**kwargs)

        self.num_channels = num_channels
        self.num_conv_layers = num_conv_layers
        self.kernel_size = kernel_size
        self.nonlinearity = nonlinearity
        self.use_batchnorm = use_batchnorm
        self.use_dropout = use_dropout
        self.dropout_rate = dropout_rate
        self.use_spatial_dropout = use_spatial_dropout
        self.data_format = data_format

        for i in range(self.num_conv_layers):
            self.add(tf.keras.layers.Conv2D(self.num_channels, self.kernel_size, padding='same', data_format=self.data_format))
            if self.use_batchnorm:
              self.add(tf.keras.layers.BatchNormalization(axis=-1, momentum=0.95, epsilon=0.001))
            self.add(tf.keras.layers.Activation(self.nonlinearity))

        if self.use_dropout:
          if self.use_spatial_dropout:
            self.add(tf.keras.layers.SpatialDropout2D(rate=self.dropout_rate))
          else:
            self.add(tf.keras.layers.Dropout(rate=self.dropout_rate))

    def call(self, inputs, training=False):

        outputs = super(Conv2D_Block, self).call(inputs, training=training)

        return outputs

class Up_Conv2D(tf.keras.Sequential):

    def __init__(self, 
                 num_channels,
                 kernel_size=(2,2),
                 nonlinearity='relu',
                 use_batchnorm = False,
                 use_transpose = False,
                 strides=(2,2),
                 data_format='channels_last',
                 **kwargs):

        super(Up_Conv2D, self).__init__(**kwargs)

        self.num_channels = num_channels
        self.kernel_size = kernel_size
        self.nonlinearity = nonlinearity
        self.use_batchnorm = use_batchnorm
        self.use_transpose = use_transpose
        self.strides = strides
        self.data_format = data_format

        if self.use_transpose:
          self.add(tf.keras.layers.Conv2DTranspose(self.num_channels, self.kernel_size, padding='same', strides=self.strides, data_format=self.data_format))
        else:
          self.add(tf.keras.layers.UpSampling2D(size=self.strides))
          self.add(tf.keras.layers.Conv2D(self.num_channels, self.kernel_size, padding='same', data_format=self.data_format))
        if self.use_batchnorm:
          self.add(tf.keras.layers.BatchNormalization(axis=-1, momentum=0.95, epsilon=0.001))
        self.add(tf.keras.layers.Activation(self.nonlinearity))

    def call(self, inputs, training=False):
        
        outputs = super(Up_Conv2D, self).call(inputs, training=training)

        return outputs

class Recurrent_block(tf.keras.Model):
    def __init__(self,
                 num_channels,
                 kernel_size=(3,3),
                 strides=(1,1),
                 padding='same',
                 activation='relu',
                 t=2,
                 data_format='channels_last',
                 **kwargs):
        
        super(Recurrent_block,self).__init__(**kwargs)
        self.t = t
        self.num_channels = num_channels
        self.conv = Conv2D_Block(num_channels, num_conv_layers=1, kernel_size=kernel_size,nonlinearity=activation,use_batchnorm=True, use_dropout=False, dropout_rate=0.0, use_spatial_dropout=False,data_format=data_format)

    def call(self, x, training=False):
        for i in range(self.t):
            if i==0:
                x1 = self.conv(x, training=training)
            output = self.conv(x+x1, training=training)
        return output
    
        
class RRCNN_block(tf.keras.Model):
     def __init__(self,
                 num_channels,
                 kernel_size=(1,1),
                 strides=(1,1),
                 padding='same'
                 nonlinearity='relu',
                 t=2,
                 data_format='channels_last',
                 **kwargs):
        
        super(RRCNN_block,self).__init__(**kwargs)
        
        self.RCNN = tf.keras.Sequential([
            Recurrent_block(num_channels,t=t),
            Recurrent_block(num_channels,t=t)])
        
        self.Conv_1x1 = tf.keras.layers.Conv2D(num_channels,kernel_size,strides,padding,data_format)

    def call(self, x):
        x = self.Conv_1x1(x)
        x1 = self.RCNN(x)
        output = x+x1
        return output

    #--------------------------------------------------------------------------------------#

class R2U_Net(tf.keras.Model):
    def __init__(self,
                 num_channels,
                 output_ch=1,
                 t=2,
                 **kwargs):
        
        super(R2U_Net,self).__init__(**kwargs)
        
        self.RRCNN1 = RRCNN_block(num_channels=num_channels,t=t)
        self.RRCNN2 = RRCNN_block(num_channels=num_channels*2,t=t)
        self.RRCNN3 = RRCNN_block(num_channels=num_channels*4,t=t)
        self.RRCNN4 = RRCNN_block(num_channels=num_channels*8,t=t)
        self.RRCNN5 = RRCNN_block(num_channels=nm_channels*16,t=t)

        self.Up5 = Up_Conv2D(num_channels=num_channels*8)
        self.Up_RRCNN5 = RRCNN_block(num_channels=num_channels*8,t=t)
        
        self.Up4 = Up_Conv2D(num_channels=num_channels*4)
        self.Up_RRCNN4 = RRCNN_block(num_channels=num_channels*4,t=t)
        
        self.Up3 = Up_Conv2D(num_channels=num_channels*2)
        self.Up_RRCNN3 = RRCNN_block(num_channels=num_channels*2,t=t)
        
        self.Up2 = Up_Conv2D(num_channels=num_channels)
        self.Up_RRCNN2 = RRCNN_block(num_channels=num_channels,t=t)

        self.Conv_1x1 = tf.keras.layers.Conv2D(output_ch,kernel_size=(1,1),strides=(1,1),padding='same')

        
    def call(self,x, training=False):
        # encoding path
        x1 = self.RRCNN1(x, training=training)

        x2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x1)
        x2 = self.RRCNN2(x2, training=training)
        
        x3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x2)
        x3 = self.RRCNN3(x3, training=training)

        x4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x3)
        x4 = self.RRCNN4(x4, training=training)

        x5 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x4)
        x5 = self.RRCNN5(x5, training=training)

        # decoding + concat path
        d5 = self.Up5(x5, training=training)
        d5 = tf.keras.layers.concatenatet([x4,d5])
        d5 = self.Up_RRCNN5(d5, training=training)
        
        d4 = self.Up4(d5, training=training)
        d4 = tf.keras.layers.concatenate([x3,d4])
        d4 = self.Up_RRCNN4(d4, training=training)

        d3 = self.Up3(d4, training=training)
        d3 = tf.keras.layers.concatenate([x2,d3])
        d3 = self.Up_RRCNN3(d3, training=training)

        d2 = self.Up2(d3, training=training)
        d2 = tf.keras.layers.concatenate([x1,d2])
        d2 = self.Up_RRCNN2(d2, training=training)

        output = self.Conv_1x1(d2)

        return output

    #--------------------------------------------------------------------------------------#
    
    class R2AttU_Net(tf.keras.Model):
    """Tensorflow 2 Implementation of 'U-Net: Convolutional Networks for Biomedical Image Segmentation'
    https://arxiv.org/pdf/1804.03999.pdf. """

    def __init__(self, 
                 num_channels,
                 num_classes,
                 num_conv_layers=1,
                 kernel_size=(3,3),
                 strides=(1,1),
                 pool_size=(2,2),
                 use_bias=False,
                 padding='same',
                 nonlinearity='relu',
                 use_batchnorm = True,
                 use_transpose = True,
                 data_format='channels_last',
                 **kwargs):

        super(R2AttU_Net, self).__init__(**kwargs)

        self.RRCNN1 = RRCNN_block(num_channels=num_channels,t=t)
        self.RRCNN2 = RRCNN_block(num_channels=num_channels*2,t=t)
        self.RRCNN3 = RRCNN_block(num_channels=num_channels*4,t=t)
        self.RRCNN4 = RRCNN_block(num_channels=num_channels*8,t=t)
        self.RRCNN5 = RRCNN_block(num_channels=nm_channels*16,t=t)

        self.up_conv_1 = Up_Conv2D(num_channels*8,(3,3),nonlinearity,use_batchnorm=True,data_format=data_format)
        self.up_conv_2 = Up_Conv2D(num_channels*4,(3,3),nonlinearity,use_batchnorm=True,data_format=data_format)
        self.up_conv_3 = Up_Conv2D(num_channels*2,(3,3),nonlinearity,use_batchnorm=True,data_format=data_format)
        self.up_conv_4 = Up_Conv2D(num_channels,(3,3),nonlinearity,use_batchnorm=True,data_format=data_format)

        self.a1 = Attention_Gate(num_channels*8,(1,1),nonlinearity,padding,strides,use_bias,data_format)
        self.a2 = Attention_Gate(num_channels*4,(1,1),nonlinearity,padding,strides,use_bias,data_format)
        self.a3 = Attention_Gate(num_channels*2,(1,1),nonlinearity,padding,strides,use_bias,data_format)
        self.a4 = Attention_Gate(num_channels,(1,1),nonlinearity,padding,strides,use_bias,data_format)

        self.u1 = RRCNN_block(num_channels=num_channels*8,t=t)
        self.u2 = RRCNN_block(num_channels=num_channels*4,t=t)
        self.u3 = RRCNN_block(num_channels=num_channels*2,t=t)
        self.u4 = RRCNN_block(num_channels=num_channels,t=t)
        
    def call(self, inputs):

        #ENCODER PATH
        x1 = self.RRCNN1(inputs)
        
        pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x1)
        x2 = self.RRCNN2(pool1)
        
        pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x2)
        x3 = self.RRCNN3(pool2)
        
        pool3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x3)
        x4 = self.RRCNN4(pool3)
        
        pool4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x4)
        x5 = self.RRCNN5(pool4)

        #DECODER PATH
        up4 = self.up_conv_1(x5)
        a1 = self.a1(up4, x4)
        y1 = tf.keras.layers.concatenate([a1, up4])
        y1 = self.u1(y1)

        up5 = self.up_conv_2(y1)
        a2 = self.a2(up5, x3)
        y2 = tf.keras.layers.concatenate([a2, up5])
        y2 = self.u2(y2)

        up6 = self.up_conv_3(y2)
        a3 = self.a3(up6, x2)
        y3 = tf.keras.layers.concatenate([a3, up6])
        y3 = self.u3(y3)

        up7 = self.up_conv_4(y3)
        a4 = self.a4(up7, x1)
        y4 = tf.keras.layers.concatenate([a4, up7])
        y4 = self.u4(y4)

        output = self.conv_1x1(y4)

        return output

