In [1]:
# From https://idiotdeveloper.com/unet-segmentation-with-pretrained-mobilenetv2-as-encoder/
%load_ext autoreload
%autoreload 2
import os
os.environ["CUDA_VISIBLE_DEVICES"]="-1"    
import sys
sys.path.insert(0,'../')

from src.utils.kerasDataLoader import DataGenerator
import src.utils.keras_losses as Loss

import numpy as np
from glob import glob
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow.keras.layers import Conv2D, Activation, BatchNormalization
from tensorflow.keras.layers import UpSampling2D, Input, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.applications import MobileNetV2, mobilenet_v2
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.metrics import Recall, Precision

from tensorflow.keras import backend as K
import numpy as np

In [6]:
if os.path.isfile('models/weights/mobilnetweights.h5'):
    MODELPATH = 'models/weights/mobilnetweights.h5'
else:
    if not os.path.isdir('models/weights/'):
        os.mkdir('models/weights/')
    mobilenet = tf.keras.applications.mobilenet_v2.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
    mobilenet.save_weights('models/weights/mobilenetweights.h5')
    MODELPATH = 'models/weights/mobilnetweights.h5'

In [7]:
try:
    from albumentations import (
        HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
        Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
        IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine,
        IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose
    )

    def augmentations(p=0.5):
        return Compose([
            Flip(),
            OneOf([
                GaussNoise(),
            ], p=0.2),
            OneOf([
                MotionBlur(p=0.2),
                MedianBlur(blur_limit=3, p=0.1),
                Blur(blur_limit=3, p=0.1),
            ], p=0.2),
            OneOf([
                CLAHE(clip_limit=2),
                IAASharpen(),
                IAAEmboss(),
                RandomBrightnessContrast(),
            ], p=0.3),
            HueSaturationValue(p=0.3),
        ], p=p)
    
except:
        augmentations = None


In [10]:
def make_model(image_size, n_classes = 21, MODELPATH=MODELPATH):
    inputs = Input(shape=(*image_size, 3), name=MODELPATH)
    preproc_layer = tf.keras.layers.Lambda(mobilenet_v2.preprocess_input, name="input_image")(inputs) # Preprocessing function

    encoder = MobileNetV2(input_tensor=preproc_layer, weights=None, include_top=False, alpha=1.0)
    skip_connection_names = ["input_image", 
                             "block_1_expand_relu", 
                             "block_3_expand_relu", 
                             "block_6_expand_relu"]
    
    encoder_output = encoder.get_layer("block_13_expand_relu").output
    
    f = [16, 32, 48, 64]
    x = encoder_output
    
    for i in range(1, len(skip_connection_names)+1, 1):
        x_skip = encoder.get_layer(skip_connection_names[-i]).output
        x = UpSampling2D((2, 2))(x)
        x = Concatenate()([x, x_skip])
        
        x = Conv2D(f[-i], (3, 3), padding="same")(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        
        x = Conv2D(f[-i], (3, 3), padding="same")(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        
    x = Conv2D(n_classes, (1, 1), padding="same")(x)
    x = tf.keras.layers.Softmax(axis=-1)(x)
    
    model = Model(inputs, x)
    return model

In [11]:
model = make_model(image_size=(224,224))

opt = tf.keras.optimizers.Adam(0.001)

metrics = [Loss.dice_coef, Recall(), Precision()]
model.compile(loss=Loss.FocalLoss, optimizer=opt, metrics=metrics)

callbacks = [
    ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=4),
    EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=False),
    ModelCheckpoint('models/kerasUnet', monitor='val_loss', verbose=1, save_best_only=True)
    
]

In [12]:
train_dataset = DataGenerator(batch_size=16, augmentation=augmentations, preprocessing=None)
valid_dataset = DataGenerator(step='valid', shuffle=False, preprocessing=None)

HBox(children=(HTML(value='Loading images'), FloatProgress(value=0.0, max=663.0), HTML(value='')))




HBox(children=(HTML(value='Loading images'), FloatProgress(value=0.0, max=83.0), HTML(value='')))




In [None]:
model.fit(
    train_dataset,
    validation_data=valid_dataset,
    epochs=20,
    callbacks=callbacks
)

Epoch 1/20
Epoch 00001: val_loss improved from inf to 0.03272, saving model to models/kerasUnet
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: models/kerasUnet/assets
Epoch 2/20