In [5]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [6]:
def _recurrent_conv_layer(x,filters,kernel_size=3,padding="SAME",strides=1,time_steps=2):
    conv_base = layers.Conv2D(filters=filters,kernel_size=kernel_size,padding=padding)
    x = conv_base(x)
    norm = layers.BatchNormalization()
    x = norm(x)
    conv_t = layers.Conv2D(filters=filters,kernel_size=kernel_size,padding=padding)
    
    add = layers.Add()
    x_t = tf.nn.relu(conv_t(x))#t0
    x_t = norm(x_t)
    for step in range(time_steps):
        x_t = norm(tf.nn.relu(add([x,conv_t(x_t)])))
    
    return x_t

In [7]:
def residual_recurrent_down_block(x,filters,kernel_size=3,padding="SAME",strides=1,time_steps=2):
    pool = layers.MaxPool2D((2,2),(2,2))
    res = layers.Conv2D(filters=filters,kernel_size=1,padding="SAME",strides=1,kernel_initializer=tf.constant_initializer())
    add = layers.Add()
    r = res(x)
    x = _recurrent_conv_layer(x,filters,kernel_size,padding,strides,time_steps)
    x = _recurrent_conv_layer(x,filters,kernel_size,padding,strides,time_steps)
    
    c = add([r,x])
    p = pool(c)
    return c,p

def residual_recurrent_plateau_block(x,filters,kernel_size=3,padding="SAME",strides=1,time_steps=2):
    conv_transpose = layers.Conv2DTranspose(filters//2,kernel_size=(2,2),strides=(2,2))
    res = layers.Conv2D(filters,kernel_size=1,padding="SAME",strides=1,kernel_initializer=tf.constant_initializer())
    add = layers.Add()
    r = res(x)
    x = _recurrent_conv_layer(x,filters,kernel_size,padding,strides,time_steps)
    x = _recurrent_conv_layer(x,filters,kernel_size,padding,strides,time_steps)
    x = add([r,x])
    up = conv_transpose(x)
    return up

def residual_recurrent_up_block(x,p,filters,kernel_size=3,padding="SAME",strides=1,time_steps=2):
    res = layers.Conv2D(filters,kernel_size=1,padding="SAME",strides=1,kernel_initializer=tf.constant_initializer())
    add = layers.Add()
    concat = layers.Concatenate()([x,p])
    r = res(concat)
    conv_transpose = layers.Conv2DTranspose(filters,kernel_size=(2,2),strides=(2,2))
    x = _recurrent_conv_layer(concat,filters,kernel_size,padding,strides,time_steps)
    x = _recurrent_conv_layer(x,filters,kernel_size,padding,strides,time_steps)
    x = add([r,x])
    up = conv_transpose(x)
    return up

def residual_recurrent_output_layer(x,p,filters,kernel_size=3,padding="SAME",strides=1,time_steps=2):
    concat = layers.Concatenate()([x,p])
    out_conv = keras.layers.Conv2D(1, (1, 1), padding="same", activation="relu")
    x = _recurrent_conv_layer(concat,filters,kernel_size,padding,strides,time_steps)
    x = _recurrent_conv_layer(x,filters,kernel_size,padding,strides,time_steps)
    outs = out_conv(x)
    return outs

In [8]:
def R2UNet():
    inputs = None
    #downsampling
    p0 = inputs
    c1, p1 = residual_recurrent_down_block(p0, f[0]) #16
    c2, p2 = residual_recurrent_down_block(p1, f[1]) #32
    c3, p3 = residual_recurrent_down_block(p2, f[2]) #64
    c4, p4 = residual_recurrent_down_block(p3, f[3]) #128
    #convolution_block before upsampling
    cn = residual_recurrent_plateau_block(p4, f[4]) #128
    #upsampling
    u1 = residual_recurrent_up_block(cn, c4, f[3]) #128
    u2 = residual_recurrent_up_block(u1, c3, f[2]) #64
    u3 = residual_recurrent_up_block(u2, c2, f[1]) #32
    outputs = residual_recurrent_output_layer(u3, c1, f[0]) #16
    
    model = None#keras.models.Model(inputs, outputs)
    return model