# Imports

In [None]:
import sys
#sys.path.append(r'D:\Programming\3rd_party\keras')

In [None]:
import sys
from imp import reload
import numpy as np
import keras

from keras.models import Model, load_model
from keras.layers import Input,Dropout,BatchNormalization,Activation,Add
from keras.layers.core import Lambda
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D
from keras.layers.merge import concatenate
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger
from keras import backend as K

import tensorflow as tf

# Load data

In [None]:
import load_data
load_data = reload(load_data)
import keras_unet_divrikwicky_model
keras_unet_divrikwicky_model = reload(keras_unet_divrikwicky_model)

In [None]:
DEV_MODE_RANGE = 0 # off

In [None]:
train_df = load_data.LoadData(train_data = True, DEV_MODE_RANGE = DEV_MODE_RANGE, to_gray = False)

In [None]:
train_df.images[0].shape

In [None]:
test_fold_no = 1

In [None]:
train_images, train_masks, validate_images, validate_masks = load_data.SplitTrainData(train_df, test_fold_no)
train_images.shape, train_masks.shape, validate_images.shape, validate_masks.shape

# Reproducability setup:

In [None]:
import random as rn

kSeed = 241075

import os
os.environ['PYTHONHASHSEED'] = '0'

np.random.seed(kSeed)
rn.seed(kSeed)

#session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
tf.set_random_seed(kSeed)
#sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
sess = tf.Session(graph=tf.get_default_graph())
K.set_session(sess)

# IOU metric

In [None]:
thresholds = np.array([0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95])

def iou(img_true, img_pred):
    assert (img_true.shape[-1]==1) and (len(img_true.shape)==3) or (img_true.shape[-1]!=1) and (len(img_true.shape)==2)
    i = np.sum((img_true*img_pred) >0)
    u = np.sum((img_true + img_pred) >0)
    if u == 0:
        return u
    return i/u

def iou_metric(img_true, img_pred):
    img_pred = img_pred > 0.5 # added by sgx 20180728
    if img_true.sum() == img_pred.sum() == 0:
        scores = 1
    else:
        scores = (thresholds <= iou(img_true, img_pred)).mean()
    return scores

def iou_metric_batch(y_true_in, y_pred_in):
    batch_size = len(y_true_in)
    metric = []
    for batch in range(batch_size):
        value = iou_metric(y_true_in[batch], y_pred_in[batch])
        metric.append(value)
    #print("metric = ",metric)
    return np.mean(metric)

# adapter for Keras
def my_iou_metric(label, pred):
    metric_value = tf.py_func(iou_metric_batch, [label, pred], tf.float64)
    return metric_value


# Data generator

In [None]:
mean_val = np.mean(train_images.apply(np.mean))
mean_std = np.mean(train_images.apply(np.std))
mean_val, mean_std 

In [None]:
sys.path.insert(1, '../3rd_party/albumentations')
sys.path.insert(1, '../3rd_party/imgaug')
import albumentations

In [None]:
augmented_image_size = 101
nn_image_size = 96

In [None]:
def basic_aug(p=1.):
    return albumentations.Compose([
        albumentations.Resize(augmented_image_size, augmented_image_size),
        albumentations.HorizontalFlip(),
        albumentations.RandomCrop(nn_image_size, nn_image_size),
        albumentations.Normalize(mean = mean_val, std = mean_std, max_pixel_value = 1.0),
    ], p=p)

In [None]:
def more_aug(p=1.):
    return albumentations.Compose([
        albumentations.Resize(augmented_image_size, augmented_image_size),
        albumentations.RandomScale(0.04),
        albumentations.HorizontalFlip(),
        albumentations.RandomCrop(nn_image_size, nn_image_size),
        albumentations.Normalize(mean = mean_val, std = mean_std, max_pixel_value = 1.0),

        albumentations.Blur(),
        albumentations.Rotate(limit=5),
        albumentations.RandomBrightness(),
        albumentations.RandomContrast(),
    ], p=p)

In [None]:
import threading
class AlbuDataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, images, masks, batch_size=32, nn_image_size = 96, shuffle=True, aug_func = basic_aug):
        'Initialization'
        self.images = images
        self.masks = masks
        self.batch_size = batch_size
        self.nn_image_size = nn_image_size
        self.shuffle = shuffle
        self.indexes = np.arange(len(self.images))
        self.on_epoch_end()
        self.augmentation = aug_func()
        
    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.ceil(len(self.images) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        X, y = self.__data_generation(indexes)
        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, indexes):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((self.batch_size, self.nn_image_size,self.nn_image_size, 3), dtype=np.float32)
        y = np.empty((self.batch_size, self.nn_image_size,self.nn_image_size, 1), dtype=np.float32)

        # Generate data
        for i, index in enumerate(indexes):
            image = self.images[index]
            mask = None if self.masks is None else self.masks[index]
            aug_res = self.augmentation(image = image, mask = mask)
            image = aug_res['image']
            X[i, ...] = image
            mask = aug_res['mask']
            y[i, ...] = mask.reshape(mask.shape[0], mask.shape[1], 1)

        return X, y

# model

In [None]:
sys.path.append('../3rd_party/segmentation_models')
import segmentation_models
segmentation_models = reload(segmentation_models)
from segmentation_models.utils import set_trainable


In [None]:
model_name = 'divrikwicky'
backbone_name='inceptionv3'
test_batch_size = 50
channels = 3

In [None]:
model = None
if model_name == 'FNN':
    model = segmentation_models.FPN(backbone_name=backbone_name, input_shape=(None, None, channels), encoder_weights='imagenet', freeze_encoder=True)
if model_name == 'Unet':
    model = segmentation_models.Unet(backbone_name=backbone_name, input_shape=(None, None, channels), encoder_weights='imagenet', freeze_encoder=True)
if model_name == 'divrikwicky':
    model = keras_unet_divrikwicky_model.CreateModel(nn_image_size)
    backbone_name = ''


In [None]:
model1_file = 'models_1/{model_name}_{backbone_name}_{test_fold_no}_x96.model'.format(model_name=model_name, backbone_name=backbone_name, test_fold_no=test_fold_no)
model_out_file = 'models_2/{model_name}_{backbone_name}_{test_fold_no}.model'.format(model_name=model_name, backbone_name=backbone_name, test_fold_no=test_fold_no)
model_out_file

In [None]:
#model = load_model(model1_file, custom_objects={'my_iou_metric': my_iou_metric}) #, 'lavazs_loss': lavazs_loss

# Train

In [None]:
model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["acc", my_iou_metric]) #, my_iou_metric

In [None]:
epochs_warmup = 2
epochs = 250
batch_size = 20

In [None]:
train_gen = AlbuDataGenerator(train_images, train_masks, batch_size=batch_size, nn_image_size = nn_image_size,
                              shuffle=True, aug_func = more_aug)
val_gen = AlbuDataGenerator(validate_images, validate_masks, batch_size=test_batch_size, nn_image_size = nn_image_size,
                            shuffle=False, aug_func = basic_aug)

In [None]:
sys.path.append('../3rd_party/keras-tqdm')
from keras_tqdm import TQDMCallback, TQDMNotebookCallback

In [None]:
early_stopping = EarlyStopping(monitor='val_my_iou_metric', mode = 'max',patience=50, verbose=1)
model_checkpoint = ModelCheckpoint(model_out_file, monitor='val_my_iou_metric',
                                   mode = 'max', save_best_only=True, verbose=1)
reduce_lr = ReduceLROnPlateau(monitor='val_my_iou_metric', mode = 'max',factor=0.2, patience=20, min_lr=0.00001, verbose=1)

'''
def get_callbacks(filepath, patience=2):
    es = EarlyStopping('val_loss', patience=patience, mode="min")
    msave = ModelCheckpoint(filepath + '.hdf5', save_best_only=True)
    csv_logger = CSVLogger(filepath+'_log.csv', separator=',', append=False)
    return [es, msave, csv_logger]
'''
    


history = model.fit_generator(train_gen,
                    validation_data=val_gen, 
                    epochs=epochs_warmup,
                    callbacks=[early_stopping, model_checkpoint, reduce_lr, TQDMNotebookCallback(leave_inner=True),
                              CSVLogger(model_out_file+'_log.csv', separator=',', append=False)],
                    validation_steps=len(val_gen)*3,
                    workers=1,
                    use_multiprocessing=False,
                    verbose=0)

set_trainable(model)

history = model.fit_generator(train_gen,
                    validation_data=val_gen, 
                    epochs=epochs,
                    initial_epoch = epochs_warmup,
                    callbacks=[early_stopping, model_checkpoint, reduce_lr, TQDMNotebookCallback(leave_inner=True),
                              CSVLogger(model_out_file+'_log.csv', separator=',', append=True)],
                    validation_steps=len(val_gen)*3,
                    workers=1,
                    use_multiprocessing=False,
                    verbose=0)