In [None]:
# definite data generator from liver box (overlap cube) with weighted dice loss 


# choose one specific GPU

import sys
sys.path.append('/data/AlgProj/liuzhsh/deeplearning/')
from gpu_allocation import set_gpu

num_gpu = 1
set_gpu(num_gpu, gpu_list = [0,1,2,3,4,5,6,7])


# import libs and initialization

from keras.layers import Input, merge, Conv3D, MaxPooling3D, UpSampling3D, Dropout, BatchNormalization, Activation
from keras.layers.merge import concatenate
from keras.models import Model
from keras.callbacks import ModelCheckpoint
from keras.optimizers import Adam
from keras import backend as K

import matplotlib.pyplot as plt
import SimpleITK as sitk
import tensorflow as tf
from glob import glob
import numpy as np
import argparse
import json
import os

image_rows = 64
image_cols = 64
image_deps = 64

image_vmin = 0.
image_vmax = 400.


# define loss funtion

def dice_coef(y_true, y_pred, smooth=1e-3):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(y_true, y_pred):
    return 1.-dice_coef(y_true, y_pred)

def focal_loss_v1(y_true, y_pred, gamma=2.0, alpha=0.25):
    # Define epsilon so that the backpropagation will not result in NaN
    # for 0 divisor case
    epsilon = K.epsilon()
    # Add the epsilon to prediction value
    #y_pred = y_pred + epsilon
    # Clip the prediciton value
    y_pred = K.clip(y_pred, epsilon, 1.0-epsilon)
    # Calculate p_t
    p_t = tf.where(K.equal(y_true, 1), y_pred, 1-y_pred)
    # Calculate alpha_t
    alpha_factor = K.ones_like(y_true)*alpha
    alpha_t = tf.where(K.equal(y_true, 1), alpha_factor, 1-alpha_factor)
    # Calculate cross entropy
    cross_entropy = -K.log(p_t)
    weight = alpha_t * K.pow((1-p_t), gamma)
    # Calculate focal loss
    loss = weight * cross_entropy
    # Sum the losses in mini_batch
    loss = K.sum(loss, axis=1)
    return loss

def focal_loss_v2(y_true, y_pred):    
    gamma=0.75    
    alpha=0.25    
    
    pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))    
    pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))  
    
    pt_1 = K.clip(pt_1, 1e-3, .999)    
    pt_0 = K.clip(pt_0, 1e-3, .999)     
    
    return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1))-K.sum((1-alpha) * K.pow( pt_0, gamma) * K.log(1. - pt_0))

#  build elegant unet3d model

def unet3d(input_shape, lr=1e-5, depth=5, n_base_filters=32, pool_size=(2, 2, 2), batch_normalization=True):
    inputs = Input(input_shape)
    current_layer = inputs
    levels = list()

    for layer_depth in range(depth):
        layer1 = conv_block(input_layer=current_layer, n_filters=n_base_filters*(2**layer_depth),
                            batch_normalization=batch_normalization)
        layer2 = conv_block(input_layer=layer1, n_filters=n_base_filters*(2**layer_depth),
                            batch_normalization=batch_normalization)
        if layer_depth < depth-1:
            current_layer = MaxPooling3D(pool_size=pool_size)(layer2)
            levels.append([layer1, layer2, current_layer])
        else:
            current_layer = layer2
            levels.append([layer1, layer2])

    for layer_depth in range(depth-2, -1, -1):
        up_layer = up_conv(input_layer=current_layer, n_filters=levels[layer_depth][1]._keras_shape[-1], pool_size=pool_size)
        concat = concatenate([up_layer, levels[layer_depth][1]], axis=-1)
        current_layer = conv_block(input_layer=concat, n_filters=levels[layer_depth][1]._keras_shape[-1],
                                   batch_normalization=batch_normalization)
        current_layer = conv_block(input_layer=current_layer, n_filters=levels[layer_depth][1]._keras_shape[-1],
                                   batch_normalization=batch_normalization)

    final_layer = Conv3D(1, 1, activation='sigmoid')(current_layer)
    model = Model(inputs=inputs, outputs=final_layer)
    model.compile(optimizer=Adam(lr=lr), loss=focal_loss_v1, metrics=[dice_coef])
    return model

def conv_block(input_layer, n_filters, batch_normalization=True, kernel=3, padding='same', strides=(1, 1, 1)):
    layer = Conv3D(n_filters, kernel, padding=padding, strides=strides)(input_layer)
    if batch_normalization:
        layer = BatchNormalization(axis=-1)(layer)
    layer = Activation('relu')(layer)
    return layer

def up_conv(input_layer, n_filters, pool_size=(2, 2, 2), batch_normalization=True, kernel=2, padding='same', strides=(1, 1, 1)):
    layer = UpSampling3D(size=pool_size)(input_layer)
    layer = Conv3D(n_filters, kernel, padding=padding, strides=strides)(layer)
    if batch_normalization:
        layer = BatchNormalization(axis=-1)(layer)
    layer = Activation('relu')(layer)
    return layer


# create data generator to yeild definite batch

def data_generator(batch_size, image_path_list, mask_path_list, is_train):

    # image_path_list, mask_path_list: path list of cubic image and mask (npy)

    while True:
        image_group = []
        mask_group = []
        data_list = np.arange(len(image_path_list))

        if is_train:
            np.random.shuffle(data_list)

        for data_i, data_list_i in enumerate(data_list):
            image = np.load(image_path_list[data_list_i])
            mask = np.load(mask_path_list[data_list_i])

            image = (image-image_vmin)/(image_vmax-image_vmin)
            image[image>1] = 1
            image[image<0] = 0

            image_group.append(image)
            mask_group.append(mask)

            if (data_i+1) % batch_size == 0:
                image_generated = np.ndarray([batch_size, image_deps, image_rows, image_cols, 1], dtype=np.float32)
                mask_generated = np.ndarray([batch_size, image_deps, image_rows, image_cols, 1], dtype=np.float32)

                for cube_i in range(batch_size):
                    image_generated[cube_i, :, :, :, 0] = image_group[cube_i]
                    mask_generated[cube_i, :, :, :, 0] = mask_group[cube_i]

                yield image_generated, mask_generated

                image_group = []
                mask_group = []


# train model

def main():

    ap = argparse.ArgumentParser()
    ap.add_argument('-l', '--lr', required=True, type=float, default=1e-5, help='learning rate')
    ap.add_argument('-e', '--epochs', required=True, type=int, default=100, help='epochs')
    ap.add_argument('-b', '--batchsize', required=True, type=int, default=4, help='batch size')
    ap.add_argument('-s', '--study', required=True, type=int, default=0, help='study number')
    args = vars(ap.parse_args())

    lr = args['lr']
    epochs = args['epochs']
    batch_size = args['batchsize']
    study = args['study']

    print('unet3d_v10 study:%d, lr:%s, epochs:%d, batchsize:%d' % (study, lr, epochs, batch_size))

    input_shape = (image_rows, image_cols, image_deps, 1)

    model = unet3d(input_shape, lr, depth=5, n_base_filters=32, pool_size=(2, 2, 2), batch_normalization=True)

    model_checkpoint = ModelCheckpoint('/data/AlgProj/guxl/model_unet3d/unet3d_study%d_lr%s_epochs%d_batchsize%d_v10.hdf5' % (study, lr, epochs, batch_size),
                                       monitor='val_loss',verbose=1, save_best_only=True)

    image_train_path_list = np.load('/data/AlgProj/guxl/3Dircadb1_test/cubic_data_path_list_overlap/image_train_path_list.npy')
    mask_train_path_list = np.load('/data/AlgProj/guxl/3Dircadb1_test/cubic_data_path_list_overlap/mask_train_path_list.npy')

    image_val_path_list = np.load('/data/AlgProj/guxl/3Dircadb1_test/cubic_data_path_list_overlap/image_val_path_list.npy')
    mask_val_path_list = np.load('/data/AlgProj/guxl/3Dircadb1_test/cubic_data_path_list_overlap/mask_val_path_list.npy')

    train_steps = len(image_train_path_list)/batch_size
    val_steps = len(image_val_path_list)/batch_size

    train_generator = data_generator(batch_size=batch_size,
                                     image_path_list=image_train_path_list,
                                     mask_path_list=mask_train_path_list,
                                     is_train=True)

    val_generator = data_generator(batch_size=batch_size,
                                   image_path_list=image_val_path_list,
                                   mask_path_list=mask_val_path_list,
                                   is_train=False)

    H = model.fit_generator(train_generator,
                            steps_per_epoch=train_steps,
                            epochs=epochs,
                            verbose=1,
                            callbacks=[model_checkpoint],
                            validation_data=val_generator,
                            validation_steps=val_steps)

    with open('history_unet3d_study%d_lr%s_epochs%d_batchsize%d.json' % (study, lr, epochs, batch_size), 'w')  as f:
        json.dump(H.history, f)

    print('unet3d_v10 study:%d, lr:%s, epochs:%d, batchsize:%d' % (study, lr, epochs, batch_size))


# execute main function

if __name__ == '__main__':
    main()
