# Imports

In [None]:
from keras.optimizers import Adam
from keras.utils.vis_utils import plot_model
from keras.callbacks import ModelCheckpoint
from keras.utils.np_utils import to_categorical
import build_model
from util import mkdirs
from image_gen import ImageDataGenerator
import matplotlib.pyplot as plt
import numpy as np
import glob
import os
import sys
import h5py
import configparser
import ast

# Configuration

In [None]:
#TODO : maybe put this in the configuration file

# Choose the one of the UNet model
model_func = build_model.build_UNet_softmax

In [None]:
config = configparser.RawConfigParser(interpolation=configparser.ExtendedInterpolation())
config.read('cytonet.cfg')
section = 'training'

In [None]:
valid_portion      = config.getfloat(section, 'valid_portion')
batch_size         = config.getint(section, 'batch_size')
classes            = ast.literal_eval(config.get(section, 'classes') if config.has_option(section, 'classes') else config.get('general', 'classes'))
matrice_file       = config.get(section, 'matrice_file') if config.has_option(section, 'matrice_file') else config.get('saving', 'output_file')
experiment_folder  = config.get('general', 'experiment_folder')
patch_size         = config.getint('general', 'patch_size')
nb_classes         = len(classes)

#Data augmentation variables
data_augmentation  = config.getboolean(section, 'data_augmentation')
rotation_range     = config.getint(section, 'rotation_range')
width_shift_range  = config.getfloat(section, 'width_shift_range')
height_shift_range = config.getfloat(section, 'height_shift_range')
rescale            = config.getfloat(section, 'rescale')
zoom_range         = config.getfloat(section, 'zoom_range')

horizontal_flip    = config.getboolean(section, 'horizontal_flip')
vertical_flip      = config.getboolean(section, 'vertical_flip')

# Creating necessary folder

In [None]:
# creating folders
mkdirs(experiment_folder + "model/", 0o777)
mkdirs(experiment_folder + "matrices/", 0o777)
mkdirs(experiment_folder + "graphs/", 0o777)

# Functions definition

In [None]:
def getTimestamp():
    """
        Return the timestamp
    """
    import datetime
    return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

In [None]:
# TODO : Improve this and check if folder exists
def save_training_history(info, history):
    """
        Save the history of the model
    """
    plt.clf()
    # list all data in history
    print(history.history.keys())
    # summarize history for loss
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.gcf().savefig(info + '/loss_history.' + getTimestamp() + '.jpg')
    # plt.show()
    plt.clf()
    # summarize history for dice_coef
    plt.plot(history.history['categorical_accuracy'])
    plt.plot(history.history['val_categorical_accuracy'])
    plt.title('model acc')
    plt.ylabel('acc')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.gcf().savefig(info + '/acc_history.' + getTimestamp() + '.jpg')
    # plt.show()

    # history to json file
    import json
    with open(info + '/log.' + getTimestamp() + '.json', 'w') as fp:
        json.dump(history.history, fp, indent=True)

# Code execution

In [None]:
# Get the data from the h5 file created after extraction + saving
h5f = h5py.File(matrice_file,'r')

# Create the sets for training and validation
X_train, y_train = np.array([]).reshape((0,patch_size,patch_size,3)), np.array([]).reshape((0,patch_size,patch_size,nb_classes))
X_val, y_val = np.array([]).reshape((0,patch_size,patch_size,3)), np.array([]).reshape((0,patch_size,patch_size,nb_classes))

# Add the data in the training and validation sets
for key, val in classes.items():
    imgs = h5f[key+'_imgs'][:]
    masks = h5f[key+'_masks'][:]
    n_train = int(imgs.shape[0] * (1-valid_portion))
    
    X_train = np.concatenate((X_train,imgs[:n_train]))
    y_train_cat=to_categorical(masks[:n_train], num_classes=nb_classes) #Converting mask to one-hot encoded vectors 
    y_train_cat=y_train_cat.reshape((masks[:n_train].shape[0],masks[:n_train].shape[1],masks[:n_train].shape[2],y_train_cat.shape[1]))
    y_train = np.concatenate((y_train,y_train_cat))
    
    X_val = np.concatenate((X_val,imgs[n_train:]))
    y_val_cat=to_categorical(masks[n_train:], num_classes=nb_classes) #Converting mask to one-hot encoded vectors 
    y_val_cat=y_val_cat.reshape((masks[n_train:].shape[0],masks[n_train:].shape[1],masks[n_train:].shape[2],y_val_cat.shape[1]))
    y_val = np.concatenate((y_val,y_val_cat))
h5f.close()

In [None]:
# Generate the model
inp_shape = X_train.shape[1:]
UNet = model_func(inp_shape, nb_classes)
UNet.compile(optimizer='adam', loss="categorical_crossentropy", metrics=['categorical_accuracy'], sample_weight_mode="temporal")

plot_model(UNet, os.path.join(experiment_folder,"model","model.png"), show_shapes=True)

In [None]:
f = open(os.path.join(experiment_folder,'config_value'), 'w')
conf_model="loss : " + UNet.loss + '\n' + "metrics : " + str(UNet.metrics) + '\n' + "function :" + str(model_func) + '\n'
f.write(conf_model)
f.close()  # you can omit in most cases as the destructor will call it

In [None]:
# Data Augmentation
if data_augmentation:
    print("with data augmentation")
    train_gen = ImageDataGenerator(nb_classes=nb_classes,
                                   rotation_range=rotation_range,
                                    width_shift_range=width_shift_range,
                                    height_shift_range=height_shift_range,
                                    rescale=rescale,
                                    zoom_range=zoom_range,
                                    horizontal_flip = horizontal_flip,
                                    vertical_flip = vertical_flip,
                                    fill_mode='reflect',
                                    cval=0)
else:
    print("without data augmentation")

    train_gen = ImageDataGenerator(nb_classes=nb_classes,
                              rescale=1.)

#train_gen = ImageDataGenerator(rescale=1.)
test_gen = ImageDataGenerator(nb_classes=nb_classes,
                              rescale=1.)

In [None]:
#Saving summary into the model folder
orig_stdout = sys.stdout
f = open(os.path.join(experiment_folder,"model","summary.txt"), 'w')
sys.stdout = f
print(UNet.summary())
sys.stdout = orig_stdout
f.close()

In [None]:
# TODO : Put the parameters in config
model_file_format = os.path.join(experiment_folder, "matrices", 'model.{epoch:03d}.hdf5') 

checkpointer = ModelCheckpoint(model_file_format, period=10)

history = UNet.fit_generator(train_gen.flow(X_train, y_train, batch_size),
                            steps_per_epoch=(X_train.shape[0] + batch_size - 1) // batch_size,
                            epochs=100,
                            callbacks=[checkpointer],
                            validation_data=test_gen.flow(X_val, y_val),
                            validation_steps=(X_val.shape[0] + batch_size - 1) // batch_size)
    
save_training_history(os.path.join(experiment_folder,"graphs"), history)