In [None]:
# load all necessary packages
import os
import random
import sys

import cv2
import numpy as np
from matplotlib import pyplot as plt
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical

sys.path.append('../modeling')
import segmentation_models as sm
import tensorflow
from train import build_unet

sm.set_framework('tf.keras')
sm.framework()


In [None]:
#Define a function to perform additional preprocessing after datagen.
#For example, scale images, convert masks to categorical, etc. 
def preprocess_data(img, mask, num_class):
    #Scale images
    img = img / 255. #This can be done in ImageDataGenerator but showing it outside as an example
    #Convert mask to one-hot
    labelencoder = LabelEncoder()
    n, h, w, c = mask.shape  
    mask = mask.reshape(-1,1)
    mask = labelencoder.fit_transform(mask)
    mask = mask.reshape(n, h, w, c)
    mask = to_categorical(mask, num_class)
      
    return (img, mask)

#Define the generator.
#We are not doing any rotation or zoom to make sure mask values are not interpolated.
#It is important to keep pixel values in mask as 0, 1, 2, 3, .....
def trainGenerator(train_img_path, train_mask_path, num_class):
    
    img_data_gen_args = dict(horizontal_flip=True,
                      vertical_flip=True,
                      fill_mode='reflect')
    
    image_datagen = ImageDataGenerator(**img_data_gen_args)
    mask_datagen = ImageDataGenerator(**img_data_gen_args)
    
    image_generator = image_datagen.flow_from_directory(
        train_img_path,
        class_mode = None,
        color_mode = 'grayscale',
        target_size=(512,512),
        batch_size = batch_size,
        seed = seed)
    
    mask_generator = mask_datagen.flow_from_directory(
        train_mask_path,
        class_mode = None,
        color_mode = 'grayscale',
        target_size=(512,512),
        batch_size = batch_size,
        seed = seed)
    
    train_generator = zip(image_generator, mask_generator)
    
    for (img, mask) in train_generator:
        img, mask = preprocess_data(img, mask, num_class)
        yield (img, mask)

In [None]:
train_img_path = '../data/data_train/train/images/'
train_mask_path = '../data/data_train/train/masks/'
train_img_gen = trainGenerator(train_img_path, train_mask_path, num_class=4)

val_img_path = '../data/data_train/val/images/'
val_mask_path = '../data/data_train/val/masks/'
val_img_gen = trainGenerator(val_img_path, val_mask_path, num_class=4)

In [None]:
x, y = train_img_gen.__next__()

In [None]:
x, y = val_img_gen.__next__()

In [None]:
print(x.shape)
print(y.shape)

In [None]:
#Make sure the generator is working and that images and masks are indeed lined up. 
x, y = train_img_gen.__next__()

for i in range(0,1):
    image = x[i,:,:,0]
    mask = np.argmax(y[i], axis=2)
    plt.subplot(1,2,1)
    plt.imshow(image, cmap='gray')
    plt.subplot(1,2,2)
    plt.imshow(mask, cmap='gray')
    plt.show()

In [None]:
x_val, y_val = val_img_gen.__next__()

for i in range(0,1):
    image = x_val[i,:,:,0]
    mask = np.argmax(y_val[i], axis=2)
    plt.subplot(1,2,1)
    plt.imshow(image, cmap='gray')
    plt.subplot(1,2,2)
    plt.imshow(mask, cmap='gray')
    plt.show()

In [None]:
#Define the model metrics and load model. 
seed=24
batch_size= 2
n_classes=4
epochs = 5
LR = 0.005 #default value: 0.001

num_train_imgs = len(os.listdir('../data/data_train/train/images/train'))
num_val_images = len(os.listdir('../data/data_train/val/images/val/'))
steps_per_epoch = num_train_imgs//batch_size
val_steps_per_epoch = num_val_images//batch_size

IMG_HEIGHT = x.shape[1]
IMG_WIDTH  = x.shape[2]
IMG_CHANNELS = x.shape[3]
input_shape = (IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)

optim = tensorflow.keras.optimizers.Adam(LR)
model_type = 'StdUnet'
# Segmentation models losses
metrics = [tensorflow.keras.metrics.MeanIoU(num_classes=n_classes)]
# Segmentation models losses can be combined together by '+' and scaled by integer or float factor
dice_loss = sm.losses.DiceLoss() 
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)
metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]

loss_name = 'diceplusfocal'

In [None]:
model = build_unet(input_shape, n_classes) # loss here can be replaced with total loss, the optimizer can be tuned with LR
model.compile(optimizer=optim, loss=total_loss, metrics=['accuracy', metrics])
model.summary()

In [None]:
# Define early stopping criteria
early_stopping = EarlyStopping(monitor='val_iou_score', # Quantity to monitor
                patience = 5, # Number of epochs with no improvement. 0 means the training is terminated as soon as the performance measure gets worse from one epoch to the next.
                min_delta = 0.0001,  # Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement. 
                mode = 'max',
                baseline = 0.5,
                verbose = 1
)

# Define mode checkpoints
model_checkpoint = ModelCheckpoint(
    filepath = '../models/checkpoints/',
    monitor='val_iou_score',
    verbose=1,
    save_best_only=True,
    save_weights_only=False,
    mode='max',
    save_freq='epoch',
    options=None,
    initial_value_threshold=None,
)

In [None]:
history=model.fit(train_img_gen,
          steps_per_epoch=steps_per_epoch,
          epochs=epochs,
          verbose=1,
          validation_data=val_img_gen,
          validation_steps=val_steps_per_epoch,
          callbacks=[early_stopping, model_checkpoint])

In [None]:
#Save the model for future use
# Follow this scheme modeltype_lossfunctions_nrofepochs_batchsize_lr.hdf5
model_name = f'{model_type}_{loss_name}_epochs{epochs}_batchsize{batch_size}_learningrate{LR}'
model.save(f'../models/{model_name}.hdf5')

In [None]:
#plot the training and validation accuracy and loss at each epoch
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, 'y', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

plt.plot(epochs, acc, 'y', label='Training Accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation Accuracy')
plt.title('Training and validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
plt.savefig(f'../models/{model_name}.jpg', dpi=150)