# Imports

In [None]:
from image_gen import ImageDataGenerator
from load_data import loadData
from build_model import build_UNet2D_4L
from keras.optimizers import Adam
from keras.utils.vis_utils import plot_model
from keras.callbacks import ModelCheckpoint
import glob
import os
import h5py
import util
import numpy as np

# Configuration

In [None]:
valid_portion=0.2 # portion of validation set, should be between 0.0 and 1.0
batch_size = 16 # size of the batch
classes = {'neg': 0, 'pos' : 1}
matrice_file = '/root/workspace/data/matrice_train.h5'
mean_file = '/root/workspace/data/values.npy'

# Functions definition

In [None]:
def getTimestamp():
    """
        Return the timestamp
    """
    import datetime
    return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

In [None]:
# TODO : Improve this and check if folder exists
def save_training_history(info, history):
    """
        Save the history of the model
    """
    import matplotlib.pyplot as plt
    # list all data in history
    print(history.history.keys())
    # summarize history for loss
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.gcf().savefig('./' + info + '/loss_history.' + getTimestamp() + '.jpg')
    # plt.show()

    # summarize history for dice_coef
    plt.plot(history.history['acc'])
    plt.plot(history.history['val_acc'])
    plt.title('model acc')
    plt.ylabel('acc')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.gcf().savefig('./' + info + '/acc_history.' + getTimestamp() + '.jpg')
    # plt.show()

    # history to json file
    import json
    with open('./' + info + '/log.' + getTimestamp() + '.json', 'w') as fp:
        json.dump(history.history, fp, indent=True)

# Code execution

In [None]:
# Get the data
h5f = h5py.File(matrice_file,'r')

# Create the sets for training and validation
X_train, y_train = np.array([]).reshape((0,256,256,3)), np.array([]).reshape((0,256,256,1))
X_val, y_val = np.array([]).reshape((0,256,256,3)), np.array([]).reshape((0,256,256,1))

# Add the data in the training and validation sets
for key, val in classes.items():
    imgs = h5f[key+'_imgs'][:]
    masks = h5f[key+'_masks'][:]
    n_train = int(imgs.shape[0] * (1-valid_portion))
    X_train = np.concatenate((X_train,imgs[:n_train]))
    y_train = np.concatenate((y_train,masks[:n_train]))
    X_val = np.concatenate((X_val,imgs[n_train:]))
    y_val = np.concatenate((y_val,masks[n_train:]))
h5f.close()

In [None]:
# Generate the model
inp_shape = X_train.shape[1:]
UNet = build_UNet2D_4L(inp_shape)
UNet.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

plot_model(UNet, 'model.png', show_shapes=True)

In [None]:
# Data Augmentation
# TODO : Put the parameters in config
train_gen = ImageDataGenerator(rotation_range=180,
                                width_shift_range=0.3,
                                height_shift_range=0.3,
                                rescale=1.,
                                zoom_range=0.2,
                                horizontal_flip = True,
                                vertical_flip = True,
                                fill_mode='reflect',
                                cval=0)

test_gen = ImageDataGenerator(rescale=1.)

In [None]:
# TODO : Put the parameters in config
model_file_format = 'model.{epoch:03d}.hdf5'

checkpointer = ModelCheckpoint(model_file_format, period=10)

history = UNet.fit_generator(train_gen.flow(X_train, y_train, batch_size),
                            steps_per_epoch=(X_train.shape[0] + batch_size - 1) // batch_size,
                            epochs=100,
                            callbacks=[checkpointer],
                            validation_data=test_gen.flow(X_val, y_val),
                            validation_steps=(X_val.shape[0] + batch_size - 1) // batch_size)
    
save_training_history("graphs", history)