# Artículo Científico - Deep Residual U-Net

In [5]:
import tensorflow as tf

def self_attention(input_features, guide_features, inter_channels):
    input_shapes = input_features.shape
    x = tf.keras.layers.Conv2D(inter_channels, 1, 2, padding="same")(input_features)
    guide = tf.keras.layers.Conv2D(inter_channels, 1, padding="same")(guide_features)
    combined = tf.keras.layers.add([x, guide])
    relu_activation = tf.keras.layers.Activation('relu')(combined)
    psi = tf.keras.layers.Conv2D(1, 1, padding="same")(relu_activation)
    sigmoid_activation = tf.keras.layers.Activation('sigmoid')(psi)
    upsampled = tf.keras.layers.UpSampling2D(size=(2, 2))(sigmoid_activation)
    attention = tf.keras.layers.multiply([upsampled, input_features])
    output_features = tf.keras.layers.Conv2D(input_shapes[3], 1, padding="same")(attention)

    return tf.keras.layers.BatchNormalization()(output_features)

def convolution_block(input_features, num_filters, dropout_rate=0.5, use_batch_norm=True):
    x_save = tf.keras.layers.Conv2D(num_filters, 3, activation="relu", padding="same")(input_features)
    if use_batch_norm:
        x = tf.keras.layers.BatchNormalization()(x_save)
    
    x = tf.keras.layers.Conv2D(num_filters, 3, activation="relu", padding="same")(x)
    if use_batch_norm:
        x = tf.keras.layers.BatchNormalization()(x)
    
    if dropout_rate:
        x = tf.keras.layers.Dropout(dropout_rate)(x)
            
    x = tf.keras.layers.add([x, x_save])
    x = tf.keras.layers.Activation("relu")(x)

    return x

def downsample_block(x, num_filters, dropout_rate=0.5, use_batch_norm=True):
    residual_connection = convolution_block(x, num_filters, dropout_rate=dropout_rate, use_batch_norm=use_batch_norm)
    x = tf.keras.layers.MaxPool2D((2, 2), strides=(2, 2))(residual_connection)
    
    return x, residual_connection

def upsample_block(x, num_filters, skip_connection, dropout_rate=0.5, use_batch_norm=True):
    attention = self_attention(skip_connection, x, num_filters)
    x = tf.keras.layers.Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same", activation="relu")(x)
    x = tf.keras.layers.Concatenate()([x, attention])
    x = convolution_block(x, num_filters)
    if dropout_rate:
        x = tf.keras.layers.Dropout(dropout_rate)(x)

    if use_batch_norm:
        x = tf.keras.layers.BatchNormalization()(x)
        
    return x
    

inputs = tf.keras.layers.Input(shape=(256, 256, 3))
    
x, skip_connection1 = downsample_block(inputs, 32)
x, skip_connection2 = downsample_block(x, 32 * 2)
x, skip_connection3 = downsample_block(x, 32 * 4)
x, skip_connection4 = downsample_block(x, 32 * 8)
    
x = convolution_block(x, 32 * 16)

x = upsample_block(x, 32 * 8, skip_connection4)
x = upsample_block(x, 32 * 4, skip_connection3)    
x = upsample_block(x, 32 * 2, skip_connection2)    
x = upsample_block(x, 32, skip_connection1)
    
outputs = tf.keras.layers.Conv2D(1, 3, activation="sigmoid", padding="same")(x)    

model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.summary()