# build_unet_mask_graph

<img src="https://i.imgur.com/1yQnLbK.jpg" width="100%">

In [None]:
def build_unet_mask_graph(rois, feature_maps, image_meta,

                          pool_size, num_classes, train_bn=True):
    # ROI Pooling
    # Shape: [batch, num_rois, MASK_POOL_SIZE, MASK_POOL_SIZE, channels]
    # num_rois: number of regions of interest to be used
    x = PyramidROIAlign([pool_size, pool_size],
                        name="roi_align_mask")([rois, image_meta] + feature_maps)

    # Conv layers
    # (1, 100, 14, 14, 256)

    # Downsampling
    x = KL.Conv2D(256, (3, 3), padding='same', name='layer11', activation='relu', kernel_initializer='he_normal')(x)
    x = BatchNorm()(x)
    # (1, 100, 14, 14, 256)
    x = KL.Conv2D(256, (3, 3), padding='same', name='layer12', activation='relu', kernel_initializer='he_normal')(x)
    x = BatchNorm()(x)
    # (1, 100, 14, 14, 256)
    skip_connection = x  # for skip connection
    x = KL.Maxpooling2D()(x)
    # (1, 100, 7, 7, 256)

    # Bottleneck
    x = KL.Conv2D(256, (3, 3), padding='same', name='layer21', activation='relu', kernel_initializer='he_normal')(x)
    x = BatchNorm()(x)
    # (1, 100, 7, 7, 256)
    x = KL.Conv2D(512, (3, 3), padding='same', name='layer22', activation='relu', kernel_initializer='he_normal')(x)
    x = BatchNorm()(x)
    # (1, 100, 7, 7, 512)
    x = KL.Conv2D(256, (3, 3), padding='same', name='layer23', activation='relu', kernel_initializer='he_normal')(x)
    x = BatchNorm()(x)
    # (1, 100, 7, 7, 256)

    # Upsampling
    x = KL.UpSampling2D()(x)
    # (1, 100, 14, 14, 256)
    x = KL.Concatenate(axis=-1)([x, skip_connection])
    # (1, 100, 14, 14, 512)
    x = KL.Conv2D(256, (3, 3), padding='same', name='layer31', activation='relu', kernel_initializer='he_normal')(x)
    x = BatchNorm()(x)
    # (1, 100, 14, 14, 256)
    x = KL.Conv2D(128, (3, 3), padding='same', name='layer32', activation='relu', kernel_initializer='he_normal')(x)
    x = BatchNorm()(x)
    # (1, 100, 14, 14, 128)
    x = KL.UpSampling2D()(x)
    # (1, 100, 28, 28, 256)
    x = KL.Conv2D(256, (3, 3), padding='same', name='layer41', activation='relu', kernel_initializer='he_normal')(x)
    x = BatchNorm()(x)
    # (1, 100, 28, 28, 256)
    x = KL.Conv2D(256, (3, 3), padding='same', name='layer42', activation='relu', kernel_initializer='he_normal')(x)
    x = BatchNorm()(x)
    # (1, 100, 28, 28, 256)
    x = KL.Conv2D(81, (1, 1), padding='same', name='layer43', activation='sigmoid')(x)
    # (1, 100, 28, 28, 81)


    return x