# Segmenter_Granules. Core model definition only. April 17, 2024.

In [19]:
def unet_model(input_size=(1024, 1024, 1), n_filters=4, n_classes=2):
    """
    Returns: 
        model -- tf.keras.Model
    """
    inputs = Input(input_size)
    # Contracting Path (encoding)
    cblock0 = conv_block(inputs, 4)
    cblock1 = conv_block(cblock0[0], 8)
    cblock2 = conv_block(cblock1[0], 16)
    cblock3 = conv_block(cblock2[0], 32)
    cblock4 = conv_block(cblock3[0], 64)
    cblock5 = conv_block(cblock4[0], 128)
    cblock6 = conv_block(cblock5[0], 256, 0.3) # Include a dropout_prob of 0.3 for this layer
    cblock7 = conv_block(cblock6[0], 512, 0.3, max_pooling=False) 
    # Expanding Path (decoding)
    ublock8 = upsampling_block(cblock7[0], cblock6[1], n_filters * 64)
    ublock9 = upsampling_block(ublock8, cblock5[1],  128)
    ublock10 = upsampling_block(ublock9, cblock4[1],  64)
    ublock11 = upsampling_block(ublock10, cblock3[1],  32)
    ublock12 = upsampling_block(ublock11, cblock2[1],  16)
    ublock13 = upsampling_block(ublock12, cblock1[1],  8)
    ublock14 = upsampling_block(ublock13, cblock0[1],  n_filters)
    conv9 = Conv2D(n_filters, 3, activation='relu', padding='same', kernel_initializer='he_normal')(ublock14)
    conv10 = Conv2D(n_classes, 1, padding='same')(conv9)
    
    model = tf.keras.Model(inputs=inputs, outputs=conv10)
    return model