In [None]:
%matplotlib inline
# import brine
# import cv2
import numpy as np
# from model.augmentations import randomHueSaturationValue, randomShiftScaleRotate, randomHorizontalFlip
# import model.u_net as unet
from keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, TensorBoard
import matplotlib.pyplot as plt
from keras.preprocessing.image import img_to_array

from keras.losses import binary_crossentropy
import keras.backend as K

from keras.models import Model
from keras.layers import Input, concatenate, Conv3D, MaxPooling3D, Activation, UpSampling3D, BatchNormalization
from keras.optimizers import RMSprop

In [None]:
def dice_loss(y_true, y_pred):
    loss = 1 - dice_coeff(y_true, y_pred)
    return loss

In [None]:
def bce_dice_loss(y_true, y_pred):
    loss = binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
    return loss

In [None]:
def dice_coeff(y_true, y_pred):
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    return score

In [None]:
def get_unet_128(input_shape=(128, 128, 128, 1),
                 num_classes=1):
    
    conv_kern_size = (3, 3, 3)
    pool_kern_size = (2, 2, 2)
    sample_kern_size = (2, 2, 2)
    stride = (2, 2, 2)
    
    inputs = Input(shape=input_shape)
    # 128

    down1 = Conv3D(64, conv_kern_size, padding='same')(inputs)
    down1 = BatchNormalization()(down1)
    down1 = Activation('relu')(down1)
    down1 = Conv3D(64, conv_kern_size, padding='same')(down1)
    down1 = BatchNormalization()(down1)
    down1 = Activation('relu')(down1)
    down1_pool = MaxPooling3D(pool_kern_size, strides=stride)(down1)
    # 64

    down2 = Conv3D(128, conv_kern_size, padding='same')(down1_pool)
    down2 = BatchNormalization()(down2)
    down2 = Activation('relu')(down2)
    down2 = Conv3D(128, conv_kern_size, padding='same')(down2)
    down2 = BatchNormalization()(down2)
    down2 = Activation('relu')(down2)
    down2_pool = MaxPooling3D(pool_kern_size, strides=stride)(down2)
    # 32

    down3 = Conv3D(256, conv_kern_size, padding='same')(down2_pool)
    down3 = BatchNormalization()(down3)
    down3 = Activation('relu')(down3)
    down3 = Conv3D(256, conv_kern_size, padding='same')(down3)
    down3 = BatchNormalization()(down3)
    down3 = Activation('relu')(down3)
    down3_pool = MaxPooling3D(pool_kern_size, strides=stride)(down3)
    # 16

    down4 = Conv3D(512, conv_kern_size, padding='same')(down3_pool)
    down4 = BatchNormalization()(down4)
    down4 = Activation('relu')(down4)
    down4 = Conv3D(512, conv_kern_size, padding='same')(down4)
    down4 = BatchNormalization()(down4)
    down4 = Activation('relu')(down4)
    down4_pool = MaxPooling3D(pool_kern_size, strides=stride)(down4)
    # 8

    center = Conv3D(1024, conv_kern_size, padding='same')(down4_pool)
    center = BatchNormalization()(center)
    center = Activation('relu')(center)
    center = Conv3D(1024, conv_kern_size, padding='same')(center)
    center = BatchNormalization()(center)
    center = Activation('relu')(center)
    # center

    up4 = UpSampling3D(sample_kern_size)(center)
    up4 = concatenate([down4, up4], axis=4)
    up4 = Conv3D(512, conv_kern_size, padding='same')(up4)
    up4 = BatchNormalization()(up4)
    up4 = Activation('relu')(up4)
    up4 = Conv3D(512, conv_kern_size, padding='same')(up4)
    up4 = BatchNormalization()(up4)
    up4 = Activation('relu')(up4)
    up4 = Conv3D(512, conv_kern_size, padding='same')(up4)
    up4 = BatchNormalization()(up4)
    up4 = Activation('relu')(up4)
    # 16

    up3 = UpSampling3D(sample_kern_size)(up4)
    up3 = concatenate([down3, up3], axis=4)
    up3 = Conv3D(256, conv_kern_size, padding='same')(up3)
    up3 = BatchNormalization()(up3)
    up3 = Activation('relu')(up3)
    up3 = Conv3D(256, conv_kern_size, padding='same')(up3)
    up3 = BatchNormalization()(up3)
    up3 = Activation('relu')(up3)
    up3 = Conv3D(256, conv_kern_size, padding='same')(up3)
    up3 = BatchNormalization()(up3)
    up3 = Activation('relu')(up3)
    # 32

    up2 = UpSampling3D(sample_kern_size)(up3)
    up2 = concatenate([down2, up2], axis=4)
    up2 = Conv3D(128, conv_kern_size, padding='same')(up2)
    up2 = BatchNormalization()(up2)
    up2 = Activation('relu')(up2)
    up2 = Conv3D(128, conv_kern_size, padding='same')(up2)
    up2 = BatchNormalization()(up2)
    up2 = Activation('relu')(up2)
    up2 = Conv3D(128, conv_kern_size, padding='same')(up2)
    up2 = BatchNormalization()(up2)
    up2 = Activation('relu')(up2)
    # 64

    up1 = UpSampling3D(sample_kern_size)(up2)
    up1 = concatenate([down1, up1], axis=4)
    up1 = Conv3D(64, conv_kern_size, padding='same')(up1)
    up1 = BatchNormalization()(up1)
    up1 = Activation('relu')(up1)
    up1 = Conv3D(64, conv_kern_size, padding='same')(up1)
    up1 = BatchNormalization()(up1)
    up1 = Activation('relu')(up1)
    up1 = Conv3D(64, conv_kern_size, padding='same')(up1)
    up1 = BatchNormalization()(up1)
    up1 = Activation('relu')(up1)
    # 128

    classify = Conv3D(num_classes, (1, 1, 1), activation='sigmoid')(up1)

    model = Model(inputs=inputs, outputs=classify)

    model.compile(optimizer=RMSprop(lr=0.0001), loss=bce_dice_loss, metrics=[dice_coeff])

    return model

In [None]:
model = get_unet_128()
model.summary()