# Imports

In [None]:
from image_gen import ImageDataGenerator
from load_data import loadData
from build_model import build_UNet2D_4L
from keras.optimizers import Adam
from keras.utils.vis_utils import plot_model
from keras.callbacks import ModelCheckpoint
import glob
import os
import h5py

# Configuration

In [None]:
path="/root/workspace/data/mylungrgb/" # path containing the images
mask_format="_mask" # suffix of the masks without the extension
img_format=".png" # extension of the images
valid_portion=0.2 # portion of validation set, should be between 0.0 and 1.0
batch_size = 16 # size of the batch

# Functions definition

In [None]:
def getTimestamp():
    """
        Return the timestamp
    """
    import datetime
    return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

In [None]:
# TODO : Improve this and check if folder exists
def save_training_history(info, history):
    """
        Save the history of the model
    """
    import matplotlib.pyplot as plt
    # list all data in history
    print(history.history.keys())
    # summarize history for loss
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.gcf().savefig('./' + info + '/loss_history.' + getTimestamp() + '.jpg')
    # plt.show()

    # summarize history for dice_coef
    plt.plot(history.history['acc'])
    plt.plot(history.history['val_acc'])
    plt.title('model acc')
    plt.ylabel('acc')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.gcf().savefig('./' + info + '/acc_history.' + getTimestamp() + '.jpg')
    # plt.show()

    # history to json file
    import json
    with open('./' + info + '/log.' + getTimestamp() + '.json', 'w') as fp:
        json.dump(history.history, fp, indent=True)

# Code execution

In [None]:
# Get the data
h5f = h5py.File('/root/workspace/data/mylungrgb/matrice_train.h5','r')
X = h5f['imgs'][:]
print(X.shape)
y = h5f['masks'][:]
print(y.shape)
h5f.close()

n_train = int(int(X.shape[0])*(1-valid_portion))
X_train, y_train = X[:n_train], y[:n_train]
X_val, y_val = X[n_train:], y[n_train:]

In [None]:
# Generate the model
inp_shape = X_train.shape[1:]
UNet = build_UNet2D_4L(inp_shape)
UNet.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

plot_model(UNet, 'model.png', show_shapes=True)

In [None]:
# Data Augmentation
# TODO : Put the parameters in config
train_gen = ImageDataGenerator(rotation_range=180,
                                width_shift_range=0.3,
                                height_shift_range=0.3,
                                rescale=1.,
                                zoom_range=0.2,
                                horizontal_flip = True,
                                vertical_flip = True,
                                fill_mode='reflect',
                                cval=0)

test_gen = ImageDataGenerator(rescale=1.)

In [None]:
# Train the model
# TODO : Put the parameters in config
model_file_format = 'model.{epoch:03d}.hdf5'

checkpointer = ModelCheckpoint(model_file_format, period=10)

history = UNet.fit_generator(train_gen.flow(X_train, y_train, batch_size),
                            steps_per_epoch=(X_train.shape[0] + batch_size - 1) // batch_size,
                            epochs=100,
                            callbacks=[checkpointer],
                            validation_data=test_gen.flow(X_val, y_val),
                            validation_steps=(X_val.shape[0] + batch_size - 1) // batch_size)
    
save_training_history("graphs", history)