# Road Segmentation with FCN-8 powered by Google TPU
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/anjithaap/Road-detection-and-segmentation/blob/master/notebooks/GPU/FCN-8.ipynb)

## Download Dataset

In [None]:
!rm -rf Dataset Trained_Model                   # Remove existing directory
!pip install -U gdown --pre >/dev/null          # Install gdown to download file from GDrive
!gdown 1u4WJLjYrbZHwdvFOHQXJqDTtco6F5hJ-        # Download dataset from GDrive by file ID
!unzip -q Dataset.zip; rm Dataset.zip           # Extract the dataset zip file

## Prepare Dataset

In [None]:
import os
import cv2
import numpy as np
from skimage.io import imread
from sklearn.model_selection import train_test_split

IMAGES_PATH = 'Dataset/Images/'
MASKS_PATH  = 'Dataset/Masks/'
TEST_PATH   = 'Dataset/Test_Images/'

# Number of images to use (Larger the number, more RAM required)
N_IMAGES = 1500

# Imread each image and save to an array

sat_imgs = os.listdir(IMAGES_PATH)
msk_imgs = os.listdir(MASKS_PATH)
sat_imgs.sort(), msk_imgs.sort()

images = []
for image in (sat_imgs[:N_IMAGES]):
    data = imread(IMAGES_PATH + image)
    images.append(data)

masks = []
for mask in (msk_imgs[:N_IMAGES]):
    data = imread(MASKS_PATH + mask)
    data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY)
    data = np.expand_dims(data, axis=-1)
    masks.append(data)

images = np.stack(images)
masks = np.stack(masks)


train_images, test_images, train_masks, test_masks = train_test_split(images, masks, test_size=0.2, random_state=10)
del images, masks
print("Training Set")
print(train_images.shape)
print(train_masks.shape)
print("\n")
print("Testing set")
print(test_images.shape)
print(test_masks.shape)

!rm -rf epochs Trained_Model; mkdir epochs

## Define Loss Functions

In [None]:
from keras import backend as K
def iou_coef(y_true, y_pred, smooth=1):
    intersection = K.sum(K.abs(y_true * y_pred), axis=[1,2,3])
    union = K.sum(y_true,[1,2,3])+K.sum(y_pred,[1,2,3])-intersection
    iou = K.mean((intersection + smooth) / (union + smooth), axis=0)
    return iou

def dice_coef(y_true, y_pred, smooth = 1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def soft_dice_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)

## Define FCN-8 Neural Network

In [None]:
# Import Libraries
from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, Dropout, Activation, BatchNormalization, add
from keras.models import Model

def FCN8():

    img_input = Input(shape=(512, 512, 3))

    x = Conv2D(64, 3, activation='relu', name='Block-1_Conv-1', padding='same') (img_input)
    x = BatchNormalization() (x); x = Dropout(0.2) (x)
    x = Conv2D(64, 3, activation='relu', name='Block-1_Conv-2', padding='same') (x)
    x = BatchNormalization() (x)
    x = MaxPooling2D(2, strides=2, name='Pooling-1') (x)
    skip1 = x
    skip1 = Conv2D(1, 1, kernel_initializer='he_normal', name='S-1') (skip1)

    # Block 2
    x = Conv2D(128, 3, activation='relu', name='Block-2_Conv-1', padding='same') (x)
    x = BatchNormalization() (x); x = Dropout(0.2) (x)
    x = Conv2D(128, 3, activation='relu', name='Block-2_Conv-2', padding='same') (x)
    x = BatchNormalization() (x)
    x = MaxPooling2D(2, strides=2, name='Pooling-2') (x)
    skip2 = x
    skip2 = Conv2D(1, 1, kernel_initializer='he_normal', name='S-2') (skip2)

    # Block 3
    x = Conv2D(256, 3, activation='relu', name='Block-3_Conv-1', padding='same') (x)
    x = BatchNormalization() (x); x = Dropout(0.2) (x)
    x = Conv2D(256, 3, activation='relu', name='Block-3_Conv-2', padding='same') (x)
    x = BatchNormalization() (x); x = Dropout(0.2) (x)
    x = Conv2D(256, 3, activation='relu', name='Block-3_Conv-3', padding='same') (x)
    x = BatchNormalization() (x)
    x = MaxPooling2D(2, strides=2, name='Pooling-3') (x)
    skip3 = x
    skip3 = Conv2D(1, 1, kernel_initializer='he_normal', name='S-3') (skip3)
    

    # Block 4
    x = Conv2D(512, 3, activation='relu', name='Block-4_Conv-1', padding='same') (x)
    x = BatchNormalization() (x); x = Dropout(0.2) (x)
    x = Conv2D(512, 3, activation='relu', name='Block-4_Conv-2', padding='same') (x)
    x = BatchNormalization() (x); x = Dropout(0.2) (x)
    x = Conv2D(512, 3, activation='relu', name='Block-4_Conv-3', padding='same') (x)
    x = BatchNormalization() (x)
    x = MaxPooling2D(2, strides=2, name='Pooling-4') (x)
    skip4 = x
    skip4 = Conv2D(1, 1, kernel_initializer='he_normal', name='S-4') (skip4)
    

    # Block 5
    x = Conv2D(512, 3, activation='relu', name='Block-5_Conv-1', padding='same') (x)
    x = BatchNormalization() (x); x = Dropout(0.2) (x)
    x = Conv2D(512, 3, activation='relu', name='Block-5_Conv-2', padding='same') (x)
    x = BatchNormalization() (x); x = Dropout(0.2) (x)
    x = Conv2D(512, 3, activation='relu', name='Block-5_Conv-3', padding='same') (x)
    x = BatchNormalization() (x)
    x = MaxPooling2D(2, strides=2, name='Pooling-5') (x)


    x = Conv2D(4096 , (7, 7) , activation='relu' , name='Fully-Connected-1', padding='same') (x)
    x = Conv2D(4096 , (1, 1) , activation='relu' , name='Fully-Connected-2', padding='same') (x)

    
    # Skip connections
    x = Conv2DTranspose(512, kernel_size=2, name='Upsample_2x', strides=2) (x)
    skip4 = MaxPooling2D(2, strides=2) (skip3)
    add4 = add([skip4, x])

    x = Conv2DTranspose(256, kernel_size=2, name='Upsample_4x', strides=2) (add4)
    skip3 = MaxPooling2D(2, strides=2) (skip2)
    add3 = add([skip3, x])

    x = Conv2DTranspose(128, kernel_size=2, kernel_initializer='he_normal', name='Upsample_8x', strides=2) (add3)
    x = Conv2DTranspose( 64, kernel_size=2, kernel_initializer='he_normal', name='Upsample_16x', strides=2) (x)
    x = Conv2DTranspose( 32, kernel_size=2, kernel_initializer='he_normal', name='Upsample_32x', strides=2) (x)

    x = Conv2D(1, 1, kernel_initializer='he_normal') (x)
    x = Dropout(0.5) (x)

    x = (Activation('sigmoid'))(x)
    model = Model(img_input, x)
    return model

## Create and Compile Model

In [None]:
from keras.callbacks import ModelCheckpoint
from keras import backend as K
from IPython.display import clear_output
from skimage.io import imread
import tensorflow as tf

model = FCN8()

model_path = "./Trained_Model/Road_Model.h5"

## Function to save best model weights

checkpointer = ModelCheckpoint(model_path, monitor="val_loss", mode="min", save_best_only = True, verbose=1)



## Code block to show predictions at each epoch

def show_predictions(epoch):
    test_path1 = 'Dataset/Test_Images/11740_sat.jpg'
    test_path2 = 'Dataset/Test_Images/112348_sat.jpg'
    test_path3 = 'Dataset/Test_Images/115172_sat.jpg'

    test_img1  = np.asarray([imread(test_path1)])
    test_img2  = np.asarray([imread(test_path2)])
    test_img3  = np.asarray([imread(test_path3)])

    f = plt.figure(figsize = (8, 10))
    f.suptitle(f'Epoch: {epoch}', x=0.5, y=0.02)

    
    f.add_subplot(3,2,1)
    plt.imshow(imread(test_path1), cmap='gray')
    plt.title("Input Image")
    plt.axis('off')
    f.add_subplot(3,2,2)
    plt.imshow(model.predict(test_img1, verbose=1)[0][:,:,0], cmap='gray')
    plt.title("Predicted Image")
    plt.axis('off')

    f.add_subplot(3,2,3)
    plt.imshow(imread(test_path2), cmap='gray')
    plt.axis('off')
    f.add_subplot(3,2,4)
    plt.imshow(model.predict(test_img2, verbose=1)[0][:,:,0], cmap='gray')
    plt.axis('off')

    f.add_subplot(3,2,5)
    plt.imshow(imread(test_path3), cmap='gray')
    plt.axis('off')
    f.add_subplot(3,2,6)
    plt.imshow(model.predict(test_img3, verbose=1)[0][:,:,0], cmap='gray')
    plt.axis('off')
    
    
    plt.savefig(f'epochs/{epoch}.png')
    plt.show()


class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        show_predictions(epoch+1)

In [None]:
EPOCHS = 100
LEARNING_RATE = 0.0001
BATCH_SIZE = 56
adam = tf.keras.optimizers.Adam(LEARNING_RATE)

model.compile(optimizer=adam, loss=soft_dice_loss, metrics=['accuracy'])

### Start model training

In [None]:
history = model.fit(train_images,
                    train_masks/255,
                    validation_split = 0.1,
                    epochs = EPOCHS,
                    batch_size = BATCH_SIZE,
                    callbacks = [checkpointer, DisplayCallback()])

## Plot training history

In [None]:
from matplotlib.pyplot import figure

history_fig = plt.figure(figsize=(20,5))

accuracy = history_fig.add_subplot(1,2,1)
imgplot = plt.plot(history.history['accuracy'])
imgplot = plt.plot(history.history['val_accuracy'])
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Training Accuracy', 'Validation accuracy'], loc='upper right')
accuracy.set_title("Epoch Accuracy")

loss = history_fig.add_subplot(1,2,2)
imgplot = plt.plot(history.history['loss'])
imgplot = plt.plot(history.history['val_loss'])
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Training loss', 'Validation loss'], loc='upper right')
loss.set_title("Epoch Loss")

## Get predictions from Trained Model
> Load model from file `Road_Model.h5` and generate predictions

In [None]:
import random
from skimage.io import imshow
from keras.models import load_model
from matplotlib import pyplot as plt


model = load_model("./Trained_Model/Road_Model.h5", custom_objects={'soft_dice_loss': soft_dice_loss, 'iou_coef': iou_coef})
predictions = model.predict(test_images, verbose=1)
thresh_val = 0.1
predicton_threshold = (predictions > thresh_val).astype(np.uint8)


ix = random.randint(0, len(predictions))
num_samples = 3

f = plt.figure(figsize = (12, 10))
for i in range(1, num_samples*4, 4):
    ix = random.randint(0, len(predictions))

    f.add_subplot(num_samples, 4, i)
    imshow(test_images[ix])
    plt.title("Image")
    plt.axis('off')

    f.add_subplot(num_samples, 4, i+1)
    imshow(np.squeeze(test_masks[ix][:,:,0]))
    plt.title("Ground Truth")
    plt.axis('off')

    f.add_subplot(num_samples, 4, i+2)
    imshow(np.squeeze(predictions[ix][:,:,0]))
    plt.title("Prediction")
    plt.axis('off')

plt.show()

test_path = 'Dataset/Images/100892_sat.jpg'
mask_path = 'Dataset/Masks/100892_mask.png'
test_img  = np.asarray([imread(test_path)])

f = plt.figure(figsize = (12, 10))
f.add_subplot(1,3,1)
imshow(imread(test_path))
plt.title("Input Image")
plt.axis('off')

f.add_subplot(1,3,2)
imshow(imread(mask_path))
plt.title("Ground Truth")
plt.axis('off')

f.add_subplot(1,3,3)
imshow(model.predict(test_img, verbose=1)[0][:,:,0])
plt.title("Prtedicted Image")
plt.axis('off')

plt.show()

## Generate epoch prediction video

In [None]:
!rm -rf *.mp4 *.avi
import cv2
import os

video_name = 'video.avi'
epochs = os.listdir('epochs')
frame = cv2.imread('epochs/1.png')
height, width, layers = frame.shape
video = cv2.VideoWriter(video_name, 0, 15, (width,height))
for i in range(1, len(epochs)):
    video.write(cv2.imread(f'epochs/{i}.png'))
cv2.destroyAllWindows()
video.release()

!ffmpeg -i video.avi -c:v copy -c:a copy output.mp4
!ffmpeg -i output.mp4 -vcodec libx265 -crf 28 Epochs.mp4
!rm output.mp4
!printf "Video file ready : Epochs.mp4"