In [0]:
import tensorflow as tf
import numpy as np
import random
import cv2
import math
import matplotlib.pyplot as plt

In [0]:
from tensorflow.keras.models import Model
from tensorflow.keras.utils import plot_model
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.layers import Conv2D, ZeroPadding2D, Input, BatchNormalization, Activation, Add
from tensorflow.keras.layers import AveragePooling2D, Flatten, Conv2DTranspose, Concatenate, MaxPool2D, SeparableConv2D

In [0]:
def _parse_image_function(example_proto):
    """Mapping function for parsing images and annotations from the tfrecord files. Transforms and augments images and annotations before training.
        Args:
            example_proto: a single element from the dataset
        Returns:
            image: transformed image extracted from the datatset
            annotation: transformed annotation extracted from the datatset
    """
    
    #This descibes the structure of each element in the dataset
    image_feature_description={
        'height': tf.FixedLenFeature([], tf.int64),
        'width': tf.FixedLenFeature([], tf.int64),
        'image_raw': tf.FixedLenFeature([], tf.string),
        'mask_raw': tf.FixedLenFeature([], tf.string)
        }
    
    #The features are extracted to a dictionary
    feature=tf.parse_single_example(example_proto, image_feature_description)
    
    #Images and Annotations are resized to (192,192)
    image = tf.image.decode_jpeg(feature['image_raw'])
    image = tf.cast(image, tf.float32) / 255.0
    annotation = tf.image.decode_png(feature['mask_raw'], channels=1)
    annotation = tf.cast(annotation, tf.float32) / 1.0
    
    image = tf.reshape(image, (360,640,3))
    annotation = tf.reshape(annotation, (360,640,1))
    
    image = tf.image.pad_to_bounding_box(image, 140, 0, 640, 640)
    annotation = tf.image.pad_to_bounding_box(annotation, 140, 0, 640, 640)
    
    image = tf.image.resize(image, size=(192,192))
    annotation = tf.image.resize(annotation, size=(192,192))
    
    #Randomly flips images and annotations
    if(random.random() > 0.5):
        image = tf.image.flip_left_right(image)
        annotation = tf.image.flip_left_right(annotation)
    
    return image, annotation

In [0]:
#Defines the batch size of the dataset
BATCH_SIZE = 8

In [0]:
#Decodes and maps the tfrecord files into a Dataset object for training, validation and evaluation

train_dataset = tf.data.TFRecordDataset(['train.tfrecords'])
train_dataset = train_dataset.map(_parse_image_function)
ds_train = train_dataset.shuffle(buffer_size=584)
ds_train = ds_train.repeat()
ds_train = ds_train.batch(BATCH_SIZE)
ds_train = ds_train.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

val_dataset = tf.data.TFRecordDataset('val.tfrecords')
val_dataset = val_dataset.map(_parse_image_function)
ds_val = val_dataset.shuffle(buffer_size=32)
ds_val = ds_val.repeat()
ds_val = ds_val.batch(BATCH_SIZE)
ds_val = ds_val.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

test_dataset = tf.data.TFRecordDataset('test.tfrecords')
test_dataset = test_dataset.map(_parse_image_function)
ds_test = test_dataset.shuffle(buffer_size=34)
ds_test = ds_test.repeat()
ds_test = ds_test.batch(BATCH_SIZE)
ds_test = ds_test.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

In [0]:
def segmentationModel():
    """Function to define the entire encoder-decoder structure for the model
        Args:
            None
        Returns:
            model: the keras model for semantic segmentation
    """
    
    #input layer
    input = Input(shape=(192,192,3), name='input')
    zero_pad = ZeroPadding2D((4,4))(input)
    
    ############   ENCODER   #######################
    
    #normal convolution layer
    conv_0 = Conv2D(64, (3,3), strides=(1,1), padding='same', name='conv_0')(zero_pad)
    bn_0 = BatchNormalization(axis=3, name='bn_0')(conv_0)
    actv_0 = Activation('relu')(bn_0)
    pool_0 = SeparableConv2D(64, (2,2), strides=(2,2), padding='valid')(actv_0)
    
    #first convolution block
    conv_1_a = Conv2D(64, (3,3), strides=(1,1), padding='valid', name='conv_1_a')(pool_0)
    bn_1_a = BatchNormalization(axis=3, name='bn_1_a')(conv_1_a)
    actv_1_a = Activation('elu')(bn_1_a)
    conv_1_b = Conv2D(64, (3,3), strides=(1,1), padding='same', dilation_rate=(3,3), name='conv_1_b')(actv_1_a)
    bn_1_b = BatchNormalization(axis=3, name='bn_1_b')(conv_1_b)
    actv_1_b = Activation('elu')(bn_1_b)
    conv_1_c = Conv2D(256, (1,1), strides=(1,1), padding='valid', name='conv_1_c')(actv_1_b)
    bn_1_c = BatchNormalization(axis=3, name='bn_1_c')(conv_1_c)
    conv_1_s = Conv2D(256, (3,3), strides=(1,1), padding='valid', name='conv_1_s')(pool_0)
    bn_1_s = BatchNormalization(axis=3, name='bn_1_s')(conv_1_s)
    actv_1_c = Add()([bn_1_c, bn_1_s])
    actv_1_c = Activation('relu')(actv_1_c)
    #Two Identity blocks
    conv_2_a = Conv2D(64, (1,1), strides=(1,1), padding='valid', name='conv_2_a')(actv_1_c)
    bn_2_a = BatchNormalization(axis=3, name='bn_2_a')(conv_2_a)
    actv_2_a = Activation('elu')(bn_2_a)
    conv_2_b = Conv2D(64, (3,3), strides=(1,1), padding='same', name='conv_2_b')(actv_2_a)
    bn_2_b = BatchNormalization(axis=3, name='bn_2_b')(conv_2_b)
    actv_2_b = Activation('elu')(bn_2_b)
    conv_2_c = Conv2D(256, (1,1), strides=(1,1), padding='valid', name='conv_2_c')(actv_2_b)
    bn_2_c = BatchNormalization(axis=3, name='bn_2_c')(conv_2_c)
    actv_2_c = Add()([bn_2_c, actv_1_c])
    actv_2_c = Activation('relu')(actv_2_c)
    #-----------------------------------------------------------------------------------
    conv_3_a = Conv2D(64, (1,1), strides=(1,1), padding='valid', name='conv_3_a')(actv_2_c)
    bn_3_a = BatchNormalization(axis=3, name='bn_3_a')(conv_3_a)
    actv_3_a = Activation('elu')(bn_3_a)
    conv_3_b = Conv2D(64, (3,3), strides=(1,1), padding='same', name='conv_3_b')(actv_3_a)
    bn_3_b = BatchNormalization(axis=3, name='bn_3_b')(conv_3_b)
    actv_3_b = Activation('elu')(bn_3_b)
    conv_3_c = Conv2D(256, (1,1), strides=(1,1), padding='valid', name='conv_3_c')(actv_3_b)
    bn_3_c = BatchNormalization(axis=3, name='bn_3_c')(conv_3_c)
    actv_3_c = Add()([bn_3_c, actv_2_c])
    actv_3_c = Activation('relu')(actv_3_c)
    pool_1 = SeparableConv2D(256, (2,2), strides=(2,2), padding='valid')(actv_3_c)
    #-----------------------------------------------------------------------------------
    
    #Second convolution block
    conv_4_a = Conv2D(128, (3,3), strides=(1,1), padding='valid', name='conv_4_a')(pool_1)
    bn_4_a = BatchNormalization(axis=3, name='bn_4_a')(conv_4_a)
    actv_4_a = Activation('elu')(bn_4_a)
    conv_4_b = Conv2D(128, (3,3), strides=(1,1), padding='same', dilation_rate=(3,3), name='conv_4_b')(actv_4_a)
    bn_4_b = BatchNormalization(axis=3, name='bn_4_b')(conv_4_b)
    actv_4_b = Activation('elu')(bn_4_b)
    conv_4_c = Conv2D(512, (1,1), strides=(1,1), padding='valid', name='conv_4_c')(actv_4_b)
    bn_4_c = BatchNormalization(axis=3, name='bn_4_c')(conv_4_c)
    conv_4_s = Conv2D(512, (3,3), strides=(1,1), padding='valid', name='conv_4_s')(pool_1)
    bn_4_s = BatchNormalization(axis=3, name='bn_4_s')(conv_4_s)
    actv_4_c = Add()([bn_4_c, bn_4_s])
    actv_4_c = Activation('relu')(actv_4_c)
    #Three Identity blocks
    conv_5_a = Conv2D(128, (1,1), strides=(1,1), padding='valid', name='conv_5_a')(actv_4_c)
    bn_5_a = BatchNormalization(axis=3, name='bn_5_a')(conv_5_a)
    actv_5_a = Activation('elu')(bn_5_a)
    conv_5_b = Conv2D(128, (3,3), strides=(1,1), padding='same', name='conv_5_b')(actv_5_a)
    bn_5_b = BatchNormalization(axis=3, name='bn_5_b')(conv_5_b)
    actv_5_b = Activation('elu')(bn_5_b)
    conv_5_c = Conv2D(512, (1,1), strides=(1,1), padding='valid', name='conv_5_c')(actv_5_b)
    bn_5_c = BatchNormalization(axis=3, name='bn_5_c')(conv_5_c)
    actv_5_c = Add()([bn_5_c, actv_4_c])
    actv_5_c = Activation('relu')(actv_5_c)
    #-----------------------------------------------------------------------------------
    conv_6_a = Conv2D(128, (1,1), strides=(1,1), padding='valid', name='conv_6_a')(actv_5_c)
    bn_6_a = BatchNormalization(axis=3, name='bn_6_a')(conv_6_a)
    actv_6_a = Activation('elu')(bn_6_a)
    conv_6_b = Conv2D(128, (3,3), strides=(1,1), padding='same', name='conv_6_b')(actv_6_a)
    bn_6_b = BatchNormalization(axis=3, name='bn_6_b')(conv_6_b)
    actv_6_b = Activation('elu')(bn_6_b)
    conv_6_c = Conv2D(512, (1,1), strides=(1,1), padding='valid', name='conv_6_c')(actv_6_b)
    bn_6_c = BatchNormalization(axis=3, name='bn_6_c')(conv_6_c)
    actv_6_c = Add()([bn_6_c, actv_5_c])
    actv_6_c = Activation('relu')(actv_6_c)
    #-----------------------------------------------------------------------------------
    conv_7_a = Conv2D(128, (1,1), strides=(1,1), padding='valid', name='conv_7_a')(actv_6_c)
    bn_7_a = BatchNormalization(axis=3, name='bn_7_a')(conv_7_a)
    actv_7_a = Activation('elu')(bn_7_a)
    conv_7_b = Conv2D(128, (3,3), strides=(1,1), padding='same', name='conv_7_b')(actv_7_a)
    bn_7_b = BatchNormalization(axis=3, name='bn_7_b')(conv_7_b)
    actv_7_b = Activation('elu')(bn_7_b)
    conv_7_c = Conv2D(512, (1,1), strides=(1,1), padding='valid', name='conv_7_c')(actv_7_b)
    bn_7_c = BatchNormalization(axis=3, name='bn_7_c')(conv_7_c)
    actv_7_c = Add()([bn_7_c, actv_6_c])
    actv_7_c = Activation('relu')(actv_7_c)
    pool_2 = SeparableConv2D(512, (2,2), strides=(2,2), padding='valid')(actv_7_c)
    #-----------------------------------------------------------------------------------
    
    #third convolution block
    conv_8_a = Conv2D(256, (3,3), strides=(1,1), padding='same', name='conv_8_a')(pool_2)
    bn_8_a = BatchNormalization(axis=3, name='bn_8_a')(conv_8_a)
    actv_8_a = Activation('elu')(bn_8_a)
    conv_8_b = Conv2D(256, (3,3), strides=(1,1), padding='same', dilation_rate=(3,3), name='conv_8_b')(actv_8_a)
    bn_8_b = BatchNormalization(axis=3, name='bn_8_b')(conv_8_b)
    actv_8_b = Activation('elu')(bn_8_b)
    conv_8_c = Conv2D(1024, (1,1), strides=(1,1), padding='valid', name='conv_8_c')(actv_8_b)
    bn_8_c = BatchNormalization(axis=3, name='bn_8_c')(conv_8_c)
    conv_8_s = Conv2D(1024, (3,3), strides=(1,1), padding='same', name='conv_8_s')(pool_2)
    bn_8_s = BatchNormalization(axis=3, name='bn_8_s')(conv_8_s)
    actv_8_c = Add()([bn_8_c, bn_8_s])
    actv_8_c = Activation('relu')(actv_8_c)
    #Two Identity blocks
    conv_9_a = Conv2D(256, (1,1), strides=(1,1), padding='valid', name='conv_9_a')(actv_8_c)
    bn_9_a = BatchNormalization(axis=3, name='bn_9_a')(conv_9_a)
    actv_9_a = Activation('elu')(bn_9_a)
    conv_9_b = Conv2D(256, (3,3), strides=(1,1), padding='same', name='conv_9_b')(actv_9_a)
    bn_9_b = BatchNormalization(axis=3, name='bn_9_b')(conv_9_b)
    actv_9_b = Activation('elu')(bn_9_b)
    conv_9_c = Conv2D(1024, (1,1), strides=(1,1), padding='valid', name='conv_9_c')(actv_9_b)
    bn_9_c = BatchNormalization(axis=3, name='bn_9_c')(conv_9_c)
    actv_9_c = Add()([bn_9_c, actv_8_c])
    actv_9_c = Activation('relu')(actv_9_c)
    #-----------------------------------------------------------------------------------
    conv_10_a = Conv2D(256, (1,1), strides=(1,1), padding='valid', name='conv_10_a')(actv_9_c)
    bn_10_a = BatchNormalization(axis=3, name='bn_10_a')(conv_10_a)
    actv_10_a = Activation('elu')(bn_10_a)
    conv_10_b = Conv2D(256, (3,3), strides=(1,1), padding='same', name='conv_10_b')(actv_10_a)
    bn_10_b = BatchNormalization(axis=3, name='bn_10_b')(conv_10_b)
    actv_10_b = Activation('elu')(bn_10_b)
    conv_10_c = Conv2D(1024, (1,1), strides=(1,1), padding='valid', name='conv_10_c')(actv_10_b)
    bn_10_c = BatchNormalization(axis=3, name='bn_10_c')(conv_10_c)
    actv_10_c = Add()([bn_10_c, actv_9_c])
    actv_10_c = Activation('relu')(actv_10_c)
    actv_10_c = ZeroPadding2D(((0, 1), (0, 1)))(actv_10_c)
    #-----------------------------------------------------------------------------------
    
    #Atrous Convolutions
    conv_final_a = Conv2D(256, (1,1), strides=(1,1), padding='same', name='conv_final_a')(actv_10_c)
    conv_final_a = AveragePooling2D((3,3), strides=(1,1), padding='same')(conv_final_a)
    bn_final_a = BatchNormalization(axis=3, name='bn_final_a')(conv_final_a)
    conv_final_b = Conv2D(256, (3,3), strides=(1,1), padding='same', dilation_rate=(3,3), name='conv_final_b')(actv_10_c)
    bn_final_b = BatchNormalization(axis=3, name='bn_final_b')(conv_final_b)
    conv_final_c = Conv2D(256, (3,3), strides=(1,1), padding='same', dilation_rate=(5,5), name='conv_final_c')(actv_10_c)
    bn_final_c = BatchNormalization(axis=3, name='bn_final_c')(conv_final_c)
    conv_final_d = Conv2D(256, (3,3), strides=(1,1), padding='same', dilation_rate=(7,7), name='conv_final_d')(actv_10_c)
    bn_final_d = BatchNormalization(axis=3, name='bn_final_d')(conv_final_d)
    
    actv_final = Concatenate()([bn_final_a, bn_final_b, bn_final_c, bn_final_d])
    actv_final = Conv2D(1024, (3,3), strides=(1,1), padding='same', activation='elu')(actv_final)
    actv_final = Conv2D(512, (1,1), strides=(1,1), padding='valid')(actv_final)
    actv_final = Activation('relu')(actv_final)
    
    ###############   DECODER   ####################
    
    conv_up_a = Conv2DTranspose(512, kernel_size=(3,3), strides=(2, 2), padding='same', activation='relu')(actv_final)
    actv_7_c = ZeroPadding2D(((0, 1), (0, 1)))(actv_7_c)
    conv_join_a = Concatenate()([actv_7_c, conv_up_a])
    conv_join_a = Conv2D(512, (3,3), strides=(1,1), padding='same', activation='elu')(conv_join_a)
    conv_join_a = Conv2D(256, (1,1), strides=(1,1), padding='same', activation='elu')(conv_join_a)
    conv_join_a = BatchNormalization(axis=3)(conv_join_a)
    
    conv_up_b = Conv2DTranspose(256, kernel_size=(3,3), strides=(2, 2), padding='same', activation='relu')(conv_join_a)
    conv_join_b = Conv2D(256, (3,3), strides=(1,1), padding='same', activation='elu')(conv_up_b)
    conv_join_b = Conv2D(128, (1,1), strides=(1,1), padding='same', activation='elu')(conv_join_b)
    conv_join_b = BatchNormalization(axis=3)(conv_join_b)
    
    output = Conv2DTranspose(128, kernel_size=(3,3), strides=(2, 2), padding='same', activation='relu')(conv_join_b)    
    output = Conv2D(128, (3,3), strides=(1,1), padding='same', activation='relu')(output)
    output = Conv2D(1, (1,1), strides=(1,1), padding='same', activation='sigmoid')(output)
    
    model = Model(inputs=input, outputs=output, name='laneModel')    
    return model
    

In [0]:
model = segmentationModel()

In [0]:
def step_decay(epoch):
    """Fuction to define the decaying of learning rate in a step-wise manner
        Args:
            epoch: the current epoch that the model is training on
        Returns:
            lrate: the corresponding learning rate for the epoch
    """
    
	initial_lrate = 0.001
	drop = 0.8
	epochs_drop = 2
	lrate = initial_lrate * math.pow(drop, math.floor((1+epoch)/epochs_drop))
	return lrate

#learning rate with decay dfined by above funtion
lrate = LearningRateScheduler(step_decay)

#Callback to stop training after accuracy reaches 99.5%
class myCallBack(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        if(logs.get('acc') >= 0.995):
            self.model.stop_training = True

callback = myCallBack()

#Final list of callbacks    
callbacks_list = [lrate, callback]

In [0]:
#Compile the model with Adam optimizer and crossentropy loss
model.compile(optimizer=tf.keras.optimizers.Adam(0.0), loss='binary_crossentropy', metrics=['accuracy'])

#Run the model on the train dataset and check it's performance on the validation datatset
history = model.fit(ds_train.make_one_shot_iterator(),steps_per_epoch=600//BATCH_SIZE,epochs=50, validation_data=ds_val.make_one_shot_iterator(), validation_steps=1, callbacks=callbacks_list, verbose = 1)

In [0]:
#Visualization for model accuracy and loss

plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

In [0]:
#Testing the model on the test datatset
model.evaluate(ds_test.make_one_shot_iterator(),batch_size=16, steps=20,verbose = 1)

In [0]:
#Save the h5 file of the model
model.save('lanes.h5')