## Import the Packages

In [None]:
from keras.models import Sequential, Model, load_model
from keras.layers import Dropout, Reshape, Activation, Conv3D, Input, MaxPooling3D, BatchNormalization, Flatten, Dense, Lambda
from keras.layers.advanced_activations import LeakyReLU
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard
from keras.optimizers import SGD, Adam, RMSprop
from keras.layers.merge import concatenate
import matplotlib.pyplot as plt
import keras.backend as K
import tensorflow as tf
from tqdm import tqdm
import numpy as np
import pickle
import os, cv2
from preprocessing3D import parse_annotation, BatchGenerator

from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 8.0 
set_session(tf.Session(config=config))

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["TF_CUDA_HOST_MEM_LIMIT_IN_MB"] = "120000"
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "TRUE"


tf.logging.set_verbosity(tf.logging.INFO)

## Define the Parameters

Change the parameters according to your dataset!

In [None]:
LABELS = ['NGpair']

IMAGE_Z, IMAGE_H, IMAGE_W = 64, 416, 416    # when changing the 96 change as well in preproceessing3d and utils3d
GRID_Z,  GRID_H, GRID_W   = 2, 13, 13
BOX              = 5
CLASS            = len(LABELS)
CLASS_WEIGHTS    = np.ones(CLASS, dtype='float32')
OBJ_THRESHOLD    = 0.5 #0.5
NMS_THRESHOLD    = 0.45 #0.45
ANCHORS          = [0.56,0.61, 0.05, 0.71,1.17, 0.06, 1.07,1.50, 0.1, 1.08,0.82, 0.08, 1.65,1.05, 0.1]
#ANCHORS          = [0.4, 0.4, 0.4,    0.6, 0.6, 0.6]

NO_OBJECT_SCALE  = 1.0
OBJECT_SCALE     = 5.0
COORD_SCALE      = 1.0
CLASS_SCALE      = 1.0

BATCH_SIZE       = 1  # 16
WARM_UP_BATCHES  = 0 
TRUE_BOX_BUFFER  = 50

## Define the Training and Validation Directories

Change the training and validation directories!

In [None]:
train_image_folder = '/dev/shm/datasetyolo/train/images/'
train_annot_folder = '/dev/shm/datasetyolo/train/annot/'
valid_image_folder = '/dev/shm/datasetyolo/val/images/'
valid_annot_folder = '/dev/shm/datasetyolo/val/annot/'

## Construct the Network

In [None]:
input_image = Input(shape=(IMAGE_Z, IMAGE_H, IMAGE_W, 3))
true_boxes  = Input(shape=(1, 1, 1, 1, TRUE_BOX_BUFFER , 6))  # ?,?,?,  #, x,y,z,h,w,d 
dropout_rate = 0.0

# Layer 1
x = Conv3D(16, (3,3,3), strides=(1,1,1), padding='same', name='conv_1', use_bias=False)(input_image)
x = BatchNormalization(name='norm_1')(x)
x = LeakyReLU(alpha=0.1)(x)
x = MaxPooling3D(pool_size=(2, 2, 2))(x)

# Layer 2 - 5
for i in range(0,4):
    x = Conv3D(32*(2**i), (3,3,3), strides=(1,1,1), padding='same', name='conv_' + str(i+2), use_bias=False)(x)
    x = BatchNormalization(name='norm_' + str(i+2))(x)
    x = LeakyReLU(alpha=0.1)(x)
    x = MaxPooling3D(pool_size=(2, 2, 2))(x)

# Layer 6
x = Conv3D(512, (3,3, 3), strides=(1,1,1), padding='same', name='conv_6', use_bias=False)(x)
x = BatchNormalization(name='norm_6')(x)
x = LeakyReLU(alpha=0.1)(x)
x = MaxPooling3D(pool_size=(2, 2,2), strides=(1,1,1), padding='same')(x)

# Layer 7 - 8
for i in range(0,2):
    x = Conv3D(1024, (3,3,3), strides=(1,1,1), padding='same', name='conv_' + str(i+7), use_bias=False)(x)
    x = BatchNormalization(name='norm_' + str(i+7))(x)
    x = LeakyReLU(alpha=0.1)(x)

# make the object detection layer
output = Conv3D(BOX * (6 + 1 + CLASS), 
                (1,1,1), strides=(1,1,1), 
                padding='same', 
                name='DetectionLayer', 
                kernel_initializer='lecun_normal')(x)
output = Reshape((GRID_Z, GRID_H, GRID_W, BOX, 6 + 1 + CLASS))(output)

# small hack to allow true_boxes to be registered when Keras build the model 
# for more information: https://github.com/fchollet/keras/issues/2790
output = Lambda(lambda args: args[0])([output, true_boxes])

model = Model([input_image, true_boxes], output)

model.summary(positions=[0.2,0.5,0.6,0.8,1.0])

## Define the Loss Function

In [None]:
def custom_loss(y_true, y_pred):
    mask_shape = tf.shape(y_true)[:5]
#    mask_shape = tf.shape(y_true)[:4]

    cell_x = tf.to_float(tf.reshape(tf.tile(tf.range(GRID_W), [GRID_H*GRID_Z]),(1, GRID_Z, GRID_H, GRID_W, 1, 1)))
    cell_y = tf.to_float(tf.reshape(tf.tile(tf.range(GRID_H), [GRID_W*GRID_Z]),(1, GRID_Z, GRID_W, GRID_H, 1, 1)))
    cell_y = tf.transpose(cell_y,(0,1,3,2,4,5))
    cell_z = tf.to_float(tf.reshape(tf.tile(tf.range(GRID_Z), [GRID_H*GRID_W]),(1, GRID_W, GRID_H, GRID_Z, 1, 1)))
    cell_z = tf.transpose(cell_z,(0,3,2,1,4,5))
    
#    cell_y = tf.transpose(cell_x, (0,2,1,3,4,5))
#    cell_z = tf.transpose(cell_x, (0,3,2,1,4,5))

    cell_grid = tf.tile(tf.concat([cell_x,cell_y,cell_z], -1), [BATCH_SIZE, 1, 1, 1, BOX , 1])
 #   cell_grid = tf.tile(tf.concat([cell_x,cell_y,cell_z], -1), [BATCH_SIZE, 1, 1, 5, 1])
    
    coord_mask = tf.zeros(mask_shape)
    conf_mask  = tf.zeros(mask_shape)
    class_mask = tf.zeros(mask_shape)
    
    seen = tf.Variable(0.)
    total_recall = tf.Variable(0.)
    
    """
    Adjust prediction
    """
    ### adjust x and y      
    pred_box_xy = tf.sigmoid(y_pred[..., :3]) + cell_grid
    
    ### adjust w and h
    pred_box_wh = tf.exp(y_pred[..., 3:6]) * np.reshape(ANCHORS, [1,1,1,1,BOX,3])
    ### adjust confidence
    pred_box_conf = tf.sigmoid(y_pred[..., 6])
    
    ### adjust class probabilities
    pred_box_class = y_pred[..., 7:]
    
    """
    Adjust ground truth
    """
    ### adjust x and y
    true_box_xy = y_true[..., 0:3] # relative position to the containing cell
   
    ### adjust w and h
    true_box_wh = y_true[..., 3:6] # number of cells accross, horizontally and vertically
    
    ### adjust confidence
    true_wh_half = true_box_wh / 2.
    true_mins    = true_box_xy - true_wh_half
    true_maxes   = true_box_xy + true_wh_half
    
    pred_wh_half = pred_box_wh / 2.
    pred_mins    = pred_box_xy - pred_wh_half
    pred_maxes   = pred_box_xy + pred_wh_half       
    
    intersect_mins  = tf.maximum(pred_mins,  true_mins)
    intersect_maxes = tf.minimum(pred_maxes, true_maxes)
    intersect_wh    = tf.maximum(intersect_maxes - intersect_mins, 0.)
    intersect_areas = intersect_wh[..., 0] * intersect_wh[..., 1] * intersect_wh[..., 2]
    
    true_areas = true_box_wh[..., 0] * true_box_wh[..., 1] * true_box_wh[..., 2] 
    pred_areas = pred_box_wh[..., 0] * pred_box_wh[..., 1] * pred_box_wh[..., 2] 
#    true_areas = true_box_wh[..., 0] * true_box_wh[..., 1] 
#    pred_areas = pred_box_wh[..., 0] * pred_box_wh[..., 1]

    union_areas = pred_areas + true_areas - intersect_areas
    iou_scores  = tf.truediv(intersect_areas, union_areas)
    user5 = iou_scores[0,0,6,4,...]
    user6 = y_pred[0,0,6,4,...,6]
    true_box_conf = iou_scores * y_true[..., 6]

    ### adjust class probabilities
    true_box_class = tf.to_int64(0 *  y_true[..., 6])   # was int32
#    true_box_class = tf.argmax(y_true[..., 7:], -1)     # original: get index of maximal value over all classes
    
    """
    Determine the masks
    """
    ### coordinate mask: simply the position of the ground truth boxes (the predictors)
    coord_mask = tf.expand_dims(y_true[..., 6], axis=-1) * COORD_SCALE
    
    ### confidence mask: penelize predictors + penalize boxes with low IOU
    # penalize the confidence of the boxes, which have IOU with some ground truth box < 0.6
    true_xy = true_boxes[..., 0:3]
    true_wh = true_boxes[..., 3:6]
    
    true_wh_half = true_wh / 2.
    true_mins    = true_xy - true_wh_half
    true_maxes   = true_xy + true_wh_half
    
    pred_xy = tf.expand_dims(pred_box_xy, 5)  
    pred_wh = tf.expand_dims(pred_box_wh, 5)  
    
    pred_wh_half = pred_wh / 2.
    pred_mins    = pred_xy - pred_wh_half
    pred_maxes   = pred_xy + pred_wh_half    
    
    intersect_mins  = tf.maximum(pred_mins,  true_mins)
    intersect_maxes = tf.minimum(pred_maxes, true_maxes)
    intersect_wh    = tf.maximum(intersect_maxes - intersect_mins, 0.)
    intersect_areas = intersect_wh[..., 0] * intersect_wh[..., 1] * intersect_wh[..., 2] 
    
    true_areas = true_wh[..., 0] * true_wh[..., 1] * true_wh[..., 2]
    pred_areas = pred_wh[..., 0] * pred_wh[..., 1] * pred_wh[..., 2]

    union_areas = pred_areas + true_areas - intersect_areas
    iou_scores  = tf.truediv(intersect_areas, union_areas)

    best_ious = tf.reduce_max(iou_scores, axis=5) 
    conf_mask = conf_mask + tf.to_float(best_ious < 0.6) * (1 - y_true[..., 6]) * NO_OBJECT_SCALE          ###### was best_ious < 0.6     -------
    
    # penalize the confidence of the boxes, which are reponsible for corresponding ground truth box
    conf_mask = conf_mask + y_true[..., 6] * OBJECT_SCALE
    
    ### class mask: simply the position of the ground truth boxes (the predictors)
    class_mask = y_true[..., 6] * tf.gather(CLASS_WEIGHTS, true_box_class) * CLASS_SCALE       
    
    """
    Warm-up training
    """
    no_boxes_mask = tf.to_float(coord_mask < COORD_SCALE/2.)
    seen = tf.assign_add(seen, 1.)

    

    true_box_xy, true_box_wh, coord_mask = tf.cond(tf.less(seen, WARM_UP_BATCHES), 
                          lambda: [true_box_xy + (0.5 + cell_grid) * no_boxes_mask, 
                                   true_box_wh + tf.ones_like(true_box_wh) * np.reshape(ANCHORS, [1,1,1,1,BOX,3]) * no_boxes_mask, 
                                   tf.ones_like(coord_mask)],
                          lambda: [true_box_xy, 
                                   true_box_wh,
                                   coord_mask])
   
    """
    Finalize the loss
    """
    nb_coord_box = tf.reduce_sum(tf.to_float(coord_mask > 0.0))
    nb_conf_box  = tf.reduce_sum(tf.to_float(conf_mask  > 0.0))
    nb_class_box = tf.reduce_sum(tf.to_float(class_mask > 0.0))
    


    loss_xy    = tf.reduce_sum(tf.square(true_box_xy-pred_box_xy)     * coord_mask) / (nb_coord_box + 1e-6) / 2.
    loss_wh    = tf.reduce_sum(tf.square(true_box_wh-pred_box_wh)     * coord_mask) / (nb_coord_box + 1e-6) / 2.
    loss_wh_pred    = tf.reduce_sum(pred_box_wh     * coord_mask) / (nb_coord_box + 1e-6) / 2.
    loss_wh_true    = tf.reduce_sum(true_box_wh     * coord_mask) / (nb_coord_box + 1e-6) / 2.
    loss_conf  = tf.reduce_sum(tf.square(true_box_conf-pred_box_conf) * conf_mask)  / (nb_conf_box  + 1e-6) / 2.
    loss_class = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=true_box_class, logits=pred_box_class)
    loss_class = tf.reduce_sum(loss_class * class_mask) / (nb_class_box + 1e-6)
    
    loss = loss_xy + loss_wh + loss_conf + loss_class
    
    nb_true_box = tf.reduce_sum(y_true[..., 6])
    user1 = y_true[0,0,6,4,...,2]
    user2 = pred_box_xy[0,0,6,4,...,2]
    user3 = nb_coord_box
    user4 = nb_conf_box   

    nb_pred_box = tf.reduce_sum(tf.to_float(true_box_conf > 0.5) * tf.to_float(pred_box_conf > 0.3))

    """
    Debugging code
    """   
#    sess = K.get_session()
#    sess.run(tf.contrib.memory_stats.BytesInUse())
    user7 = tf.contrib.memory_stats.BytesInUse()

    current_recall = nb_pred_box/(nb_true_box + 1e-6)
    total_recall = tf.assign_add(total_recall, current_recall) 

    loss = tf.Print(loss, [tf.zeros((1))], message='Dummy Line \t', summarize=1000)
    loss = tf.Print(loss, [loss_xy], message='Loss XY \t', summarize=1000)
    loss = tf.Print(loss, [loss_wh], message='Loss WH \t', summarize=1000)
    loss = tf.Print(loss, [loss_wh_pred], message='Loss WH pred\t', summarize=1000)
    loss = tf.Print(loss, [loss_wh_true], message='Loss WH true\t', summarize=1000)
    loss = tf.Print(loss, [loss_conf], message='Loss Conf \t', summarize=1000)
    loss = tf.Print(loss, [loss_class], message='Loss Class \t', summarize=1000)
    loss = tf.Print(loss, [loss], message='Total Loss \t', summarize=1000)
    loss = tf.Print(loss, [current_recall], message='Current Recall \t', summarize=1000)
    loss = tf.Print(loss, [total_recall/seen], message='Average Recall \t', summarize=1000)
    return loss

## Define Training Parameters

In [None]:
generator_config = {
    'IMAGE_Z'         : IMAGE_Z, 
    'IMAGE_H'         : IMAGE_H,
    'IMAGE_W'         : IMAGE_W,
    'GRID_Z'          : GRID_Z,  
    'GRID_H'          : GRID_H,
    'GRID_W'          : GRID_W,
    'BOX'             : BOX,
    'LABELS'          : LABELS,
    'CLASS'           : len(LABELS),
    'ANCHORS'         : ANCHORS,
    'BATCH_SIZE'      : BATCH_SIZE,
    'TRUE_BOX_BUFFER' : 50,
}

def normalize(image):
    return image / 255.

train_imgs, seen_train_labels = parse_annotation(train_annot_folder, train_image_folder, labels=LABELS)
train_batch = BatchGenerator(train_imgs, generator_config, norm=normalize, jitter=False,shuffle=True)

valid_imgs, seen_valid_labels = parse_annotation(valid_annot_folder, valid_image_folder, labels=LABELS)
valid_batch = BatchGenerator(valid_imgs, generator_config, norm=normalize, jitter=False,shuffle=False)

#Setup a few callbacks and start the training
early_stop = EarlyStopping(monitor='val_loss', 
                           min_delta=0.001, 
                           patience=10000000, 
                           mode='min', 
                           verbose=1)

checkpoint = ModelCheckpoint('weights_NGPAIRS_3D.h5', 
                             monitor='val_loss', 
                             verbose=1, 
                             save_best_only=True, 
                             mode='min', 
                             period=1)


tb_counter  = len([log for log in os.listdir(os.path.expanduser('/logs/')) if 'ngpair_' in log]) + 1
tensorboard = TensorBoard(log_dir=os.path.expanduser('/logs/') + 'ngpair_' + '_' + str(tb_counter), 
                          histogram_freq=0, 
                          write_graph=False, 
                          write_images=False)


optimizer = Adam(lr=1.0e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
model.compile(loss=custom_loss, optimizer=optimizer)

## Perform Training

Change the number of epochs, and uncomment the first lines if you want to use a pre-trained model.

In [None]:
#uncomment these lines after training the first model
#pretrained_weights = load_model('weights3D'+ str(tb_counter-1) +'.h5', custom_objects={'custom_loss': custom_loss, 'tf': tf})
#pretrained_weights = pretrained_weights.get_weights()
#model.set_weights(pretrained_weights)


model.fit_generator(generator        = train_batch, 
                    steps_per_epoch  = len(train_batch), 
                    epochs           = 100, 
                    verbose          = 1,
                    validation_data  = valid_batch,
                    validation_steps = len(valid_batch),
                    callbacks        = [early_stop, checkpoint, tensorboard], 
                    max_queue_size   = 3)


model.save('weights3D' + str(tb_counter) + '.h5')   