#### Notes

This notebook is meant to demonstrate how to train Fully Convolutional Network (for more tutorials please see Qubvel's [repo](https://github.com/qubvel/segmentation_models). 

"Although the dense labels created by Fast-MSS could have been used to classify the 3-D reconstructed model directly, they were also used as training data with a deep learning semantic segmentation algorithm to produce a FCN. The major advantage of a FCN is its ability to generalize to images collected from domains that are similar to those on which it was trained. A researcher could obtain dense labels from an FCN given images collected from the same or similar habitats that it was previously trained on without having to perform any of the previous steps in the workflow (steps B-G). Thus, the objective of this workflow was not just to obtain a set of dense labels for every still image, but rather to acquire a deep learning semantic segmentation model that could create dense labels automatically for datasets collected in the future.

This study experimented with five different FCNs to understand how the size of the network affected the classification accuracy. Each FCN used an encoder from the EfficientNet series (Tan and Le, 2019) and was used to create an additional set of dense labels for every image in the dataset; these and the set created by Fast-MSS were validated and compared against the ground-truth dense labels that were manually created for the test set."

...

"For the task of semantic segmentation this study experimented with five different FCNs, all of which used the U-Net architecture and were equipped with one of the five smallest encoders within the EfficientNet family (i.e., B0 through B4, see Supplementary Information 4 for more information). All models were implemented in Python using the Segmentation Models library (Yakubovskiy, 2019).

When training the FCNs, the error was calculated using the soft-Jaccard loss function, which acted as a differentiable proxy that attempted to maximize the Intersection-over-Union metric (Berman et al., 2018). Parameters were updated via backpropagation using the Adam optimizer with an initial learning rate of 10–4, which decreased using the same settings as described before. After 20 epochs, the weights from the epoch with the lowest validation loss were archived. All deep learning models were trained on a PC equipped with a NVIDIA GTX 1080 Ti GPU and an Intel i7-8700 CPU, using the Keras deep learning framework and the Tensorflow numerical computational library; for more information see Supplementary Information 4."

![alt text](../Figures/getting_dense_labels.png)

In [None]:
import os
import glob

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

import segmentation_models as sm
from segmentation_models.losses import *
from segmentation_models.metrics import *
sm.set_framework('keras')

import keras
import keras.backend as K
from keras.callbacks import *
from keras.utils import to_categorical
from keras.optimizers import Adam
keras.backend.set_image_data_format('channels_last')
os.environ['CUDA_VISIBLE_DEVICES'] = '0'


import imgaug.augmenters as iaa

In [None]:
# labels for each class category of interest used for Pierce et al., 2021

class_categories = {'Branching' : 0, 
                      'Fish' : 1, 
                      'Massive' : 2,
                      'Not Massive' : 3,
                      'Substrate' : 4,
                      'Target' : 5,
                      'Water' : 6}

In [None]:
path = "Data\\"

images = sorted(glob.glob(path + "images\\*.png"))
masks = sorted(glob.glob(path + "dense\\*.png"))
points = sorted(glob.glob(path + "sparse\\*.csv"))

data = pd.DataFrame(list(zip(images, masks, points)), columns = ['Images', 'Masks', 'Points'])

In [None]:
train, valid = train_test_split(data, test_size = .1)

train.reset_index(drop = True, inplace = True)
valid.reset_index(drop = True, inplace = True)

len(train), len(valid)

In [None]:
def colorize_prediction(pred):
   
    colored_mask = np.zeros(shape = (pred.shape[0], pred.shape[1], 3))

    for _ in np.unique(pred):
           
            colored_mask[pred == _] = cp[_]/255.0
        
    return colored_mask


# Image data generator class
class DataGenerator(keras.utils.Sequence):
    
    def __init__(self, dataframe, batch_size, augment, n_classes = 8):
        self.dataframe = dataframe
        self.batch_size = batch_size
        self.n_classes = n_classes
        self.augment = augment
          
        
    # Steps per epoch    
    def __len__(self):
        return len(self.dataframe) // self.batch_size
    
    # Shuffles and resets the index at the end of training epoch
    def on_epoch_end(self):
        self.dataframe = self.dataframe.reset_index(drop = True)
    
    
    # Generates data, feeds to training
    def __getitem__(self, index):
        
        global preprocess_input
        
        processed_images = []
        processed_masks = []
        
        for _ in range(self.batch_size):

            the_image = plt.imread(self.dataframe['Images'][index])
            the_mask = plt.imread(self.dataframe['Masks'][index]).astype('uint8');
            one_hot_mask = to_categorical(the_mask, len(list(class_categories)))
            
            if(self.augment):
                
                processed_image = augs_for_images(image = the_image)
                processed_mask = augs_for_masks(image = one_hot_mask)
         
            else:
                # Still resizing and then random cropping, but no augmentations   
                processed_image = resize_for_images(image = the_image)
                processed_mask = resize_for_masks(image = one_hot_mask)

            processed_images.append(preprocess_input(processed_image))
            processed_masks.append(processed_mask)

                
        batch_x = np.array( processed_images )
        batch_y = np.array( processed_masks )
        
        return (batch_x, batch_y)


In [None]:
# Parameters for training      
batch_size = 1
num_epochs = 100

steps_per_epoch_train = len(train) // batch_size; print(steps_per_epoch_train)
steps_per_epoch_valid = len(valid) // batch_size; print(steps_per_epoch_valid)

train_gen = DataGenerator(train, batch_size = batch_size, augment = True) 
valid_gen = DataGenerator(valid, batch_size = batch_size, augment = False)

height, width = 736, 1280 

In [None]:
# Augmentation methods
augs_for_images = iaa.Sequential([iaa.Resize(size = {'height' : height, 'width' : width}, interpolation = 'linear',
                                            random_state = 5),
                                  iaa.Fliplr(0.25, random_state = 1),
                                  iaa.Flipud(0.25, random_state = 2),
                                  iaa.Rot90([1, 2, 3, 4], True, random_state = 3)
                       ])


augs_for_masks = iaa.Sequential([iaa.Resize(size = {'height' : height, 'width' : width}, interpolation = 'nearest',
                                           random_state = 5),
                                  iaa.Fliplr(0.25, random_state = 1),
                                  iaa.Flipud(0.25, random_state = 2),
                                  iaa.Rot90([1, 2, 3, 4], True, random_state = 3)
                                ])



resize_for_images = iaa.Sequential([
     iaa.Resize(size = {'height' : height, 'width' : width}, interpolation = 'linear', random_state = 1),
])

resize_for_masks = iaa.Sequential([
     iaa.Resize(size = {'height' : height, 'width' : width}, interpolation = 'nearest', random_state = 1),
])

In [None]:
BACKBONE = 'efficientnetb0'
preprocess_input = sm.get_preprocessing(BACKBONE) 

model = sm.Unet(input_shape = (None, None, 3), 
                backbone_name = BACKBONE, 
                encoder_weights = 'noisy-student', # 'imagenet'
                activation = 'softmax', 
                classes = len(list(class_categories)),
                encoder_freeze = True,
                decoder_use_batchnorm = True)


metrics = ['accuracy', iou_score, precision, recall]

model.compile(optimizer = Adam(lr = .001), 
              loss = [cce_jaccard_loss], 
              metrics = metrics)

In [None]:
os.makedirs("weights\\", exist_ok=False) 

hollabackgirl = [
                 ReduceLROnPlateau(monitor = 'val_loss', factor = .65, patience = 2, verbose = 1),
                 ModelCheckpoint(filepath = 'weights\\model-{epoch:03d}-{acc:03f}-{val_acc:03f}.h5', 
                                 monitor='val_loss', save_weights_only = True, 
                                 save_best_only = True, verbose = 1),
                ]

In [None]:
history = model.fit_generator(generator = train_gen, 
                              steps_per_epoch = steps_per_epoch_train, 
                              epochs = num_epochs, 
                              validation_data = valid_gen,
                              validation_steps = steps_per_epoch_valid,
                              verbose = 1,
                              callbacks = holla)

In [None]:
print(history.history.keys())

plt.figure(figsize= (10, 5))

plt.plot(history.history["loss"], label="train_loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.plot(np.argmin(history.history["val_loss"]), 
         np.min(history.history["val_loss"]), 
         marker = "x", color = "b", label = "best model")
plt.title("Training Loss")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="upper right")
plt.show()

plt.figure(figsize= (10, 5))
plt.plot(history.history["precision"], label="precision")
plt.plot(history.history["val_precision"], label="val_precision")
plt.title("Training Precision")
plt.xlabel("Epoch #")
plt.ylabel("Precision")
plt.legend(loc="upper right")
plt.show()

plt.figure(figsize= (10, 5))
plt.plot(history.history["recall"], label="recall")
plt.plot(history.history["val_recall"], label="val_recall")
plt.title("Training Recall")
plt.xlabel("Epoch #")
plt.ylabel("Recall")
plt.legend(loc="upper right")

plt.show()

In [None]:
model.load_weights('weights\\path_to_best_weights.h5')

In [None]:
# Making predictions with the trained model

for _ in range(5):
    
    image, mask = valid_gen.__getitem__(_)
    prediction = model.predict(image)
    prediction = np.argmax(prediction, axis = 2).astype("uint8")
    
    plt.figure(figsize=(20, 20))
    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.subplot(1, 3, 2)
    plt.imshow(colorize_prediction(mask))
    plt.subplot(1, 3, 3)
    plt.imshow(colorize_prediction(prediction))
    plt.show()