In [49]:
import keras
import tensorflow as tf

### Variables

In [41]:
CLASSES = 2
SAMPLE_SIZE = (384, 384)
OUTPUT_SIZE = (768, 768)

### Define input neural network model

In [42]:
def input_layer():
    return keras.layers.Input(shape=SAMPLE_SIZE +(3,))

### Describe bloks enkoder format

In [43]:
def downsample_block(filters, size, batch_norm = True):
    initializer = keras.initializers.GlorotNormal()

    result = keras.Sequential()
    result.add(keras.layers.Conv2D(filters, size, strides = 2, padding = 'same', 
                                   kernel_initializer=initializer, use_bias=False))
    
    if batch_norm:
        result.add(keras.layers.BatchNormalization())

    result.add(keras.layers.LeakyReLU())
    
    return result

### Format decoder block for neural network

In [44]:
def upsample_block(filters, size, dropout = False):
    initializer = keras.initializers.GlorotNormal()

    result = keras.Sequential()
    result.add(keras.layers.Conv2DTranspose(filters, size, strides = 2, padding = 'same', 
                                   kernel_initializer=initializer, use_bias=False))
    
    result.add(keras.layers.BatchNormalization())

    if dropout:
        result.add(keras.layers.Dropout(0.25))
        
    result.add(keras.layers.ReLU())
    
    return result

### Define output neural network model

In [45]:
def output_layer(size):
    initializer = keras.initializers.GlorotNormal()
    return keras.layers.Conv2DTranspose(CLASSES, size, strides=2, padding='same', kernel_initializer=initializer,
                                        activation= 'sigmoid')

### Create stack layers

In [46]:
inp_layer = input_layer()

downsample_stack = [
    downsample_block(64, 4, batch_norm = False),
    downsample_block(128, 4),
    downsample_block(256, 4),
    downsample_block(512, 4),
    downsample_block(512, 4),
    downsample_block(512, 4),
    downsample_block(512, 4),
]

upsample_stack = [
    upsample_block(512, 4, dropout = True),
    upsample_block(512, 4, dropout = True),
    upsample_block(512, 4, dropout = True),
    upsample_block(256, 4),
    upsample_block(128, 4),
    upsample_block(64, 4),
]

out_layer = output_layer(4)

### Skip connection realization

In [47]:
x = inp_layer

downsample_skips = []

for block in downsample_stack:
    x = block(x)
    downsample_skips.append(x)

downsample_skips = reversed(downsample_skips[:-1])

for up_block, down_block in zip(upsample_stack, downsample_skips):
    x = up_block(x)
    x = keras.layers.Concatenate()([x, down_block])

out_layer = out_layer(x)

unet_like = keras.Model(inputs = inp_layer, outputs = out_layer)

### Metrics and loss functions

In [48]:
def dice_mc_metric(a, b):
    a = tf.unstack(a, axis=3)
    b = tf.unstack(b, axis=3)
    
    dice_summ = 0
    
    for i, (aa, bb) in enumerate(zip(a, b)):
        numenator = 2 * tf.math.reduce_sum(aa * bb) + 1
        denomerator = tf.math.reduce_sum(aa + bb) + 1
        dice_summ += numenator / denomerator
        
    avg_dice = dice_summ / CLASSES
    
    return avg_dice

def dice_mc_loss(a, b):
    return 1 - dice_mc_metric(a, b)

def dice_bce_mc_loss(a, b):
    return 0.3 * dice_mc_loss(a, b) + tf.keras.losses.binary_crossentropy(a, b)