## Import Libraries

In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import cv2
import os

import skimage.transform
import datetime as dt

from sklearn.model_selection import train_test_split

#calculate time
start = dt.datetime.now()

## Initialize

In [None]:
img_width, img_height = 320, 320
batch_size = 12
epochs = 2
img_channels=3
prevCheckpoint = None
root = r'../input'


In [None]:
df= pd.read_csv(os.path.join(root, 'culanelist/list/train_gt.txt'), delim_whitespace=True, header=None)

#This example is only training on driver_161_90 frames of the CULane 

df = df[df[0].str.contains('driver_161_90')].reset_index(drop=True)
df[0]=df[0].replace({'/driver_161_90frame':os.path.join(root, 'culane/driver_161_90frame')}, regex=True)
df[1]=df[1].replace({'/laneseg_label_w16/driver_161_90frame':os.path.join(root, 'culane/driver_161_90frame_labels')}, regex=True)

#split for train and validation set

train_df, valid_df = train_test_split(df, test_size=0.1, random_state=18)
train_df = train_df.reset_index(drop=True)
valid_df = valid_df.reset_index(drop=True)

In [None]:
train_dir = train_df[0].copy()
train_dir_gt = train_df[1].copy()

valid_dir = valid_df[0].copy()
valid_dir_gt = valid_df[1].copy()

n_train_samples = len(train_dir)
n_valid_samples = len(valid_dir)

print('Training samples: {}'.format(n_train_samples))
print('Valid samples: {}'.format(n_valid_samples))

## Generator

In [None]:
# Train data generator

def train_generator():
    while True:
        for start in range(0, n_train_samples, batch_size):
            x_batch = []
            y_batch = []
            end = min(start + batch_size, n_train_samples)
            for img_path in range(start, end):
                img = cv2.imread(train_dir[img_path])
                
                img = skimage.transform.resize(img, (img_height, img_width), preserve_range=True, 
                                               anti_aliasing=False, order=0)
                
                x_batch.append(img)

                img = cv2.imread(train_dir_gt[img_path])
                img = skimage.transform.resize(img, (img_height, img_width,1), preserve_range=True, 
                                               anti_aliasing=False,  order=0)
          
                y_batch.append(img)
                
            y_batch = tf.keras.utils.to_categorical(y_batch, num_classes=5)

            yield (np.array(x_batch), np.array(y_batch) )

In [None]:
def valid_generator():
    while True:
        for start in range(0, n_valid_samples, batch_size):
            
            x_batch = []
            y_batch = []
            
            end = min(start + batch_size, n_valid_samples)
            for img_path in range(start, end):
                img = cv2.imread(valid_dir[img_path])
                
                img = skimage.transform.resize(img, (img_height, img_width), preserve_range=True, 
                                               anti_aliasing=False, order=0)
                
                x_batch.append(img)

                img = cv2.imread(valid_dir_gt[img_path])
                img = skimage.transform.resize(img, (img_height, img_width,1), preserve_range=True, 
                                               anti_aliasing=False, order=0)
             
                y_batch.append(img)

                
            y_batch = tf.keras.utils.to_categorical(y_batch, num_classes=5)

            
            yield (np.array(x_batch), np.array(y_batch))


## Metrics and Loss Function

In [None]:
def precision(y_true, y_pred, numLabels=5):
    
    y_true = K.permute_dimensions(y_true, (3,1,2,0))
    y_pred = K.permute_dimensions(y_pred, (3,1,2,0))

    y_true_current = K.batch_flatten(y_true)
    y_pred_current = K.batch_flatten(y_pred)

    true_pos = K.sum(y_true_current*y_pred_current,1)
    false_neg = K.sum(y_true_current * (1-y_pred_current), 1)
    false_pos = K.sum((1-y_true_current)*y_pred_current, 1)
    precision = true_pos/(true_pos+false_pos)
    
    return precision

def recall(y_true, y_pred, numLabels=5):
    
    y_true = K.permute_dimensions(y_true, (3,1,2,0))
    y_pred = K.permute_dimensions(y_pred, (3,1,2,0))

    y_true_current = K.batch_flatten(y_true)
    y_pred_current = K.batch_flatten(y_pred)
    true_pos = K.sum(y_true_current*y_pred_current,1)
    false_neg = K.sum(y_true_current * (1-y_pred_current), 1)
    false_pos = K.sum((1-y_true_current)*y_pred_current, 1)

    recall = true_pos/(true_pos+false_neg)
    
    return recall


def f1_score(y_true, y_pred, numLabels=5):

    precisions = K.sum(precision(y_true, y_pred))/numLabels
    recalls = K.sum(recall(y_true, y_pred))/numLabels

    f1_score = 2*((precisions*recalls)/(precisions+recalls+K.epsilon()))
    return f1_score

In [None]:
# Reference: https://github.com/nabsabraham/focal-tversky-unet/issues/3
def class_tversky(y_true, y_pred):
    smooth = 1

    y_true = K.permute_dimensions(y_true, (3,1,2,0))
    y_pred = K.permute_dimensions(y_pred, (3,1,2,0))

    y_true_pos = K.batch_flatten(y_true)
    y_pred_pos = K.batch_flatten(y_pred)
    true_pos = K.sum(y_true_pos * y_pred_pos, 1)
    false_neg = K.sum(y_true_pos * (1-y_pred_pos), 1)
    false_pos = K.sum((1-y_true_pos)*y_pred_pos, 1)
    alpha = 0.25
    return (true_pos + smooth)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + smooth)

# channels sensitive loss function
def focal_tversky_loss_c(y_true,y_pred):
    pt_1 = class_tversky(y_true, y_pred)
    gamma = 0.75
    return K.sum(K.pow((1-pt_1), gamma))

## UNet Model

In [None]:
def Conv2D_layer(prev_layer, f, k=(3, 3), activation='relu', kernel_initializer='he_normal', padding='same'):
    return layers.Conv2D(f, k, activation=activation, kernel_initializer=kernel_initializer, padding=padding)(prev_layer)

def down_block(in_layer, f, dropout=None, pool=True):
    
    c = Conv2D_layer(prev_layer=in_layer, f=f)
    
    if dropout is not None:
        c = layers.SpatialDropout2D(dropout)(c)
        
    c = Conv2D_layer(prev_layer=c, f=f)
    c = layers.BatchNormalization()(c)
    
    if pool == True:
        p = layers.MaxPooling2D((2,2))(c)
    
        return c, p
    else:
        return c

def up_block(in_layer, concat_layer, f, dropout=None):

    u = layers.Conv2DTranspose(f, (2, 2), strides=(2, 2), padding='same')(in_layer)
    u = layers.concatenate([u, concat_layer])
    c = Conv2D_layer(prev_layer=u, f=f)
    
    if dropout is not None:
        c = layers.SpatialDropout2D(0.4)(c)
        
    c = Conv2D_layer(prev_layer=c, f=f)
    c = layers.BatchNormalization()(c)
    
    return c

In [None]:

inputs = layers.Input((img_height,img_width, img_channels))

s = layers.Lambda(lambda x: x / 255)(inputs)

c1, p1 = down_block(in_layer=s, f=16, dropout=0.1)

c2, p2 = down_block(in_layer=p1, f=32, dropout=0.1)

c3, p3 = down_block(in_layer=p2, f=64, dropout=0.2)

c4, p4 = down_block(in_layer=p3, f=128, dropout=0.3)

c5, p5 = down_block(in_layer=p4, f=256, dropout=0.4)
 
c6 = down_block(in_layer=p5, f=512, dropout=0.5, pool=False)

##############################################################################################

c7 = up_block(in_layer=c6, concat_layer=c5, f=256, dropout=0.4)

c8 = up_block(in_layer=c7, concat_layer=c4, f=128, dropout=0.3)
 
c9 = up_block(in_layer=c8, concat_layer=c3, f=64, dropout=0.2)

c10 = up_block(in_layer=c9, concat_layer=c2, f=32, dropout=0.1)

c11 = up_block(in_layer=c10, concat_layer=c1, f=16, dropout=0.1)

# output = layers.Conv2D(5, (1, 1), activation='sigmoid', name='output')(c11)

output = Conv2D_layer(prev_layer=c11, f=5, k=(1, 1), activation='sigmoid')
 
model = tf.keras.Model(inputs=[inputs], outputs=[output])

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5, clipnorm=1.0),loss=focal_tversky_loss_c, metrics=[tf.keras.metrics.CategoricalAccuracy(), precision, recall, f1_score])
model.summary()

## Training

In [None]:
#Model CheckPoint

checkpointer = tf.keras.callbacks.ModelCheckpoint('model_for_CuLane_UNET.h5', verbose=1, save_best_only=True)

callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=5, monitor='val_loss'),
    tf.keras.callbacks.CSVLogger("history.csv", append=True),
    checkpointer
] 

In [None]:
if prevCheckpoint is not None:
    model.load_weights(prevCheckpoint)

In [None]:
model.fit(
    train_generator(),
    steps_per_epoch= n_train_samples // batch_size,
    epochs= epochs,
    validation_data= valid_generator(),
    validation_steps = n_valid_samples // batch_size,
    callbacks=callbacks
)

In [None]:
model.save('Unet_1_epoch.h5')

In [None]:
end = dt.datetime.now()
print('Total time: ' + str(end-start))