## Training.ipynb

---

Trains a Fully Convolutional Network (FCN) with ResNet-50 as feature extractor.<br>
The neural network is trained on a labeled binary dataset of agricultural irrigation ponds.<br>
We used [Google Map Static API](https://developers.google.com/maps/documentation/maps-static/intro) 
as a source of high-resolution satellite imagery:<br>


Required Third-party libraries:
    
    * keras
    * numpy

Required custum modules
    
    * ./functions/image
    * ./functions/load_data
    * ./functions/utils
    * ./neural_network/resnet
    * ./neural_network/metrics


In [None]:
%matplotlib inline

# //-------------------------------------------------------------\\

# Required libraries
import os
import sys
import time
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from keras.callbacks import TensorBoard
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
from keras.callbacks import LearningRateScheduler
from keras.optimizers import RMSprop
from keras.optimizers import Adam

# Custom modules
sys.path.append('../functions')
from utils import need_time
from utils import save_pickle
from utils import files_paths
from load_data import DataGenerator
from image import display_image_and_labels
from image import display_image_labels_and_prediction

sys.path.append('../neural_network')
import resnet
from metrics import dice_coeff
from metrics import dice_loss
from metrics import crossentropy_dice_loss

# //-------------------------------------------------------------\\

seed = np.random.seed(1234)
sns.set_style('darkgrid')

# //-------------------------------------------------------------\\

In [None]:
# Load the files paths
# //-------------------------------------------------------------\\

# Input parameters

# Train and validation paths
train_images_path = '../Data/Images/Train_Images'
validation_images_path = '../Data/Images/Validation_Images'

data_suffix = '_data.png'
labels_suffix = '_labels.png'

# //-------------------------------------------------------------\\

# Set the train paths
train_ims_paths = files_paths(train_images_path, 
                              nested_carpets=True, exts=data_suffix)
train_labels_paths = [i.replace(data_suffix, 
                                labels_suffix) for i in train_ims_paths]

# Set the validation paths
validation_ims_paths = files_paths(validation_images_path, 
                                   nested_carpets=True, exts=data_suffix)
validation_labels_paths = [i.replace(data_suffix, 
                                     labels_suffix) for i in validation_ims_paths]

# //-------------------------------------------------------------\\

In [None]:
# Training parameters

# //-------------------------------------------------------------\\

# Input parameters

# Image pre-processing
im_prep_funcs = ['crop', 'scaling']
im_prep_params = {'crop_size': 20}

labels_prep_funcs = ['crop', 'pick channels']
labels_prep_params = {'crop_size': 20, 'idxs_channel': 0}

aug_funcs = ['flip', 'rotate']
aug_params = {'p' : 0.5}

# //-------------------------------------------------------------\\

# Training parameters

n_classes = 2
shuffle = True
epochs = 1
steps_per_epoch = 1
train_batch_size = 1
validation_batch_size = 1
validation_steps = 1
class_weights = {0: 0.5, 1: 1.5}

learning_rate = 1.0

# `RMSprop`, `Adam` 
optimizer = 'RMSprop'

# `crossentropy dice loss`, `dice loss`
loss = 'crossentropy dice loss'

# `dice_coeff`, `[]`
accuracy = '[]'

# Steps without any improve before stop
early_stop_steps = 4

# Load ResNet pretrained with ImageNet
use_pretreining_imagent = True

# Freeze pretrained layers
freeze = False

# //-------------------------------------------------------------\\

# Outputs, create a carpet with the current time
t = time.localtime()
base_name = '{}-{}-{}_{}_{}_{}'.format(t.tm_year,
                                       t.tm_mon,
                                       t.tm_mday,
                                       t.tm_hour,
                                       t.tm_min,
                                       t.tm_sec)

base_name = '../neural_network/Model/{}'.format(base_name)
if not os.path.exists(base_name):
    os.mkdir(base_name)
    os.mkdir(os.path.join(base_name, 'log'))

path_save_input_params = os.path.join(base_name, 'input_params.pkl')
path_save_weights = os.path.join(base_name, 'weights.h5')
path_tensorboard_log = os.path.join(base_name, 'log', 'log_out')

# //-------------------------------------------------------------\\

input_train_parameters = {'ims_paths' : train_ims_paths,
                          'labels_paths' : train_labels_paths,
                          'class_weights' : class_weights,
                          'batch_size' : train_batch_size,
                          'steps_per_epoch' : steps_per_epoch,
                          'shuffle' : shuffle,
                          'n_classes' : n_classes,
                          'im_prep_funcs' : im_prep_funcs,
                          'im_prep_params': im_prep_params,
                          'labels_prep_funcs' : labels_prep_funcs,
                          'labels_prep_params' : labels_prep_params,
                          'aug_funcs' : aug_funcs,
                          'aug_params' : aug_params}

input_validation_parameters = {'ims_paths' : validation_ims_paths,
                               'labels_paths' : validation_labels_paths,
                               'class_weights' : class_weights,
                               'batch_size' : validation_batch_size,
                               'steps_per_epoch' : validation_steps,
                               'shuffle' : shuffle,
                               'n_classes' : n_classes,
                               'im_prep_funcs' : im_prep_funcs,
                               'im_prep_params': im_prep_params,
                               'labels_prep_funcs' : labels_prep_funcs,
                               'labels_prep_params' : labels_prep_params}

# //-------------------------------------------------------------\\

In [None]:
# //-------------------------------------------------------------\\

# Fit the train and validation generators
train_generator = DataGenerator(**input_train_parameters)
validation_generator = DataGenerator(**input_validation_parameters)

print('\nTrain set information:')
print(train_generator)
print('\nValidation set information:')
print(validation_generator)

# //-------------------------------------------------------------\\

In [None]:
# Display one batch of samples from the training set

# //-------------------------------------------------------------\\

if class_weights:
    X, y, sample_weights = train_generator[0]
else:
    X, y = train_generator[0]

for i in range(X.shape[0]):
    display_image_and_labels(X[i,...,:3], 
                             y.reshape(train_generator.batch_size,
                                       train_generator.h,
                                       train_generator.w,
                                       train_generator.n_classes)[i,...,1], 
                             colormap='gray')

# //-------------------------------------------------------------\\

In [None]:
# Create the FCN model

# //-------------------------------------------------------------\\

input_shape = [train_generator.h, 
               train_generator.w, 
               train_generator.channels]    
    
# Build the model
model = resnet.ResNet50_FCN(input_shape, n_classes, 
                            use_pretreining_imagent, freeze)

# Optimizer
if optimizer == 'RMSprop':
    optimizer_function = RMSprop(lr=learning_rate) 
elif optimizer == 'Adam':
    optimizer_function = Adam(lr=learning_rate)

# Loss function
if loss == 'crossentropy dice loss':
    loss_function = crossentropy_dice_loss 
elif loss == 'dice loss':
    loss_function = dice_loss

# Accuracy
if accuracy == 'dice_coeff':
    accuracy_function = [dice_coeff]

else:
    accuracy_function = []

# Compile the model
if class_weights is not None:    
    model.compile(optimizer=optimizer_function, 
                  loss=loss_function, 
                  metrics=accuracy_function,
                  sample_weight_mode='temporal')
else:
    model.compile(optimizer=optimizer_function, 
                  loss=loss_function, 
                  metrics=accuracy_function)

# Training callbacks
callbacks = [EarlyStopping(monitor='val_loss',
                           patience=early_stop_steps,
                           verbose=1,
                           min_delta=1e-4),
             
             LearningRateScheduler(resnet.lr_scheduler_function),
             
             ModelCheckpoint(monitor='val_loss',
                             filepath=path_save_weights,
                             save_best_only=True,
                             save_weights_only=True),
             
             TensorBoard(log_dir=path_tensorboard_log)]

model.summary()

# //-------------------------------------------------------------\\

In [None]:
# Save training parameters

# //-------------------------------------------------------------\\

params = {'seed' : seed,

          # Image directory
          'train_images_path' : train_images_path,
          'validation_images_path' : validation_images_path,
          'data_suffix' : data_suffix,
          'labels_suffix' : labels_suffix,

          # Image pre-processing
          'im_prep_params': im_prep_params,
          'im_prep_funcs' : im_prep_funcs,
          'labels_prep_params' : labels_prep_params,
          'labels_prep_funcs' : labels_prep_funcs,
          'aug_params' : aug_params,
          'aug_funcs' : aug_funcs,

          # Training parameters
          'n_train_files' : train_generator.n_files,
          'n_validation_files' : validation_generator.n_files,
          'input_shape' : input_shape,
          'n_classes' : n_classes,

          'epochs' : epochs,
          'shuffle' : shuffle,
          'steps_per_epoch' : steps_per_epoch,
          'train_batch_size' : train_batch_size,

          'validation_batch_size' : validation_batch_size,
          'validation_steps' : validation_steps,

          'learning_rate' : learning_rate,
          'optimizer' : optimizer,
          'loss' : loss,
          'accuracy' : accuracy,
          'class_weights' : class_weights,

          'early_stop_steps' : early_stop_steps,
          'use_pretreining_imagent' : use_pretreining_imagent,
          'freeze' : freeze,

          # Outputs
          'output_identifier' : base_name,
          'weights_output' : path_save_weights
          }

save_pickle(params, path_save_input_params)

print(f'Input parameters saved at: {path_save_input_params}')

# //-------------------------------------------------------------\\

In [None]:
# Training the model

# //-------------------------------------------------------------\\
init = time.time()

hist = model.fit_generator(generator=train_generator,
                           validation_data=validation_generator,
                           epochs=epochs,                               
                           validation_steps=validation_steps,
                           callbacks = callbacks)
endt = time.time()
need_time(init, endt)
# //-------------------------------------------------------------\\

In [None]:
# Plotting the metrics from training process

# //-------------------------------------------------------------\\

fig, ax = plt.subplots(1,2)
fig.set_size_inches(12,6)

ax[0].plot(hist.history['loss'], color='blue', label='Training')
ax[0].plot(hist.history['val_loss'], color='red', label='Validation')
ax[1].plot(hist.history['dice_coeff'], color='blue', label='Training')
ax[1].plot(hist.history['val_dice_coeff'], color='red', label='Validation')

ax[0].set_ylabel('Binary Cross Entropy')
ax[0].set_xlabel('Step')
ax[0].legend(loc='best')
ax[1].set_ylabel('Dice coefficient')
ax[1].set_xlabel('Step')
ax[1].legend(loc='best')

plt.show()

# //-------------------------------------------------------------\\