# Practical Session 8 - SegNet :  A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation - Tensorflow


In this Practical Session, you will try to understand how to compute and run neural networks, and more precisely SegNet neural network, used to segment images with many different objects inside.
You saw during lecture the original paper, of [Badrinarayanan et al.,2016](https://arxiv.org/pdf/1511.00561.pdf).    
You will see here implementation and execution on a dataset named CamVid, as in the paper.


Here you  have a link to the github of the original implementation of SegNet, with demos, articles references,etc : https://github.com/alexgkendall/SegNet-Tutorial .

This notebook is mainly based on this [github](https://github.com/advaitsave/Multiclass-Semantic-Segmentation-CamVid/blob/master/Multiclass_Semantic_Segmentation_using_VGG_16_SegNet.ipynb).   
As it is an encoder-decoder, and that several encoders already exist (VGG, ResNet), it is possible to create an encoder with VGG architecture for example, and then to load the weights of this well-known  encoder trained on same dataset. With this process you only have to learn the decoder (second part of SegNet network).  
You can also do transfer learning which consists of : 
- using a full pre-trained network on a firts dataset for a first problem.
- training it  on a new dataset for a related problem, and see if it adapts well.

**Here we don't directly a pre-trained network (or encoder) to build SegNet. But we provide you saved weights of our training sessions of this network, with a cetain number of epochs. By yourself you can train the model with less or more epochs (it is very long), train after loading weights, try to modify the network and see effect on prediction, understand preprocessing steps, think about differences with the paper...**


In [None]:
### TO USE IN GOOGLE COLAB
from google.colab import drive
drive.mount('/content/drive',force_remount=True)

In [None]:
from keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from PIL import Image
from tensorflow.python.keras.callbacks import TensorBoard, EarlyStopping

# CamVid dataset importation

In [None]:
def _read_to_tensor(fname, output_height=224, output_width=224, normalize_data=False):
    '''Function to read images from given image file path, and provide resized images as tensors
        Inputs: 
            fname - image file path
            output_height - required output image height
            output_width - required output image width
            normalize_data - if True, normalize data to be centered around 0 (mean 0, range 0 to 1)
        Output: Processed image tensors
    '''
    
    # Read the image as a tensor
    img_strings = tf.io.read_file(fname)
    imgs_decoded = tf.image.decode_jpeg(img_strings)
    
    # Resize the image
    output = tf.image.resize(imgs_decoded, [output_height, output_width])
    
    # Normalize if required
    if normalize_data:
        output = (output - 128) / 128
    return output

In [None]:
img_dir = './data/CamSeq01/'
# img_dir = '/content/drive/My Drive/X/MAA309/data/CamSeq01/' # for google colab

# Required image dimensions
output_height = 224
output_width = 224


## Reading frames and masks


In [None]:
def read_images(img_dir):
    '''Function to get all image directories, read images and masks in separate tensors
        Inputs: 
            img_dir - file directory
        Outputs 
            frame_tensors, masks_tensors, frame files list, mask files list
    '''
    
    # Get the file names list from provided directory
    file_list = [f for f in os.listdir(img_dir) if os.path.isfile(os.path.join(img_dir, f))]
    
    # Separate frame and mask files lists, exclude unnecessary files
    frames_list = [file for file in file_list if ('_L' not in file) and ('txt' not in file) and ('.D' not in file)]
    masks_list = [file for file in file_list if ('_L' in file) and ('txt' not in file) and ('.D' not in file)]
    
    frames_list.sort()
    masks_list.sort()
    
    print('{} frame files found in the provided directory.'.format(len(frames_list)))
    print('{} mask files found in the provided directory.'.format(len(masks_list)))
    
    # Create file paths from file names
    frames_paths = [os.path.join(img_dir, fname) for fname in frames_list]
    masks_paths = [os.path.join(img_dir, fname) for fname in masks_list]
    
    # Create dataset of tensors
    frame_data = tf.data.Dataset.from_tensor_slices(frames_paths)
    masks_data = tf.data.Dataset.from_tensor_slices(masks_paths)
    
    # Read images into the tensor dataset
    frame_tensors = frame_data.map(_read_to_tensor)
    masks_tensors = masks_data.map(_read_to_tensor)
    
    print('Completed importing {} frame images from the provided directory.'.format(len(frames_list)))
    print('Completed importing {} mask images from the provided directory.'.format(len(masks_list)))
    
    return frame_tensors, masks_tensors, frames_list, masks_list

frame_tensors, masks_tensors, frames_list, masks_list = read_images(img_dir)

In [None]:
# Make an iterator to extract images from the tensor dataset

frame_batches = tf.compat.v1.data.make_one_shot_iterator(frame_tensors)  
mask_batches = tf.compat.v1.data.make_one_shot_iterator(masks_tensors)

In [None]:
n_images_to_show = 5

for i in range(n_images_to_show):
    
    # Get the next image from iterator
    frame = frame_batches.get_next().numpy().astype(np.uint8)
    mask = mask_batches.get_next().numpy().astype(np.uint8)
    
    #Plot the corresponding frames and masks
    fig = plt.figure(figsize=(20,7))
    fig.add_subplot(1,2,1)
    plt.imshow(frame)
    fig.add_subplot(1,2,2)
    plt.imshow(mask)
    plt.show()

In [None]:
DATA_PATH = './data/CamSeq01/'
# DATA_PATH = '/content/drive/My Drive/X/MAA309/data/CamSeq01/'

# Create folders to hold images and masks

folders = ['train_frames/train', 'train_masks/train', 'val_frames/val', 'val_masks/val','frames/','masks/']


for folder in folders:
    try:
        os.makedirs(DATA_PATH + folder)
    except Exception as e: print(e)


In [None]:
def generate_image_folder_structure(frames, masks, frames_list, masks_list):
    '''Function to save images in the appropriate folder directories 
        Inputs :
        -----------
        frames - frame tensor dataset
        masks - mask tensor dataset
        frames_list - frame file paths
        masks_list - mask file paths
    '''
    #Create iterators for frames and masks
    frame_batches = tf.compat.v1.data.make_one_shot_iterator(frames)
    mask_batches = tf.compat.v1.data.make_one_shot_iterator(masks)
    
    #Iterate over the train images while saving the frames and masks in appropriate folders
    dir_name='train'
  
    for file in zip(frames_list[:-round(0.2*len(frames_list))],masks_list[:-round(0.2*len(masks_list))]):
        
            
        #Convert tensors to numpy arrays
        frame = frame_batches.get_next().numpy().astype(np.uint8)
        mask = mask_batches.get_next().numpy().astype(np.uint8)
        
        #Convert numpy arrays to images
        frame = Image.fromarray(frame)
        mask = Image.fromarray(mask)
        
        #Save frames and masks to correct directories
        frame.save(DATA_PATH+'{}_frames/{}'.format(dir_name,dir_name)+'/'+file[0])
        mask.save(DATA_PATH+'{}_masks/{}'.format(dir_name,dir_name)+'/'+file[1])
        
        frame.save(DATA_PATH+'frames/'+'/'+file[0])
        mask.save(DATA_PATH+'masks/'+'/'+file[1])
    
    #Iterate over the val images while saving the frames and masks in appropriate folders
    dir_name='val'
    
    for file in zip(frames_list[-round(0.2*len(frames_list)):],masks_list[-round(0.2*len(masks_list)):]):
        
        
        #Convert tensors to numpy arrays
        frame = frame_batches.next().numpy().astype(np.uint8)
        mask = mask_batches.next().numpy().astype(np.uint8)
        
        #Convert numpy arrays to images
        frame = Image.fromarray(frame)
        mask = Image.fromarray(mask)
        
        #Save frames and masks to correct directories
        frame.save(DATA_PATH+'{}_frames/{}'.format(dir_name,dir_name)+'/'+file[0])
        mask.save(DATA_PATH+'{}_masks/{}'.format(dir_name,dir_name)+'/'+file[1])
        
        frame.save(DATA_PATH+'frames/'+'/'+file[0])
        mask.save(DATA_PATH+'masks/'+'/'+file[1])
    
    print("Saved {} frames to directory {}".format(len(frames_list),DATA_PATH))
    print("Saved {} masks to directory {}".format(len(masks_list),DATA_PATH))
    
generate_image_folder_structure(frame_tensors, masks_tensors, frames_list, masks_list)




# Create useful label and code conversion dictionaries

These will be used for:

- One hot encoding the mask labels for model training
- Decoding the predicted labels for interpretation and visualization



## Function to parse the file "label_colors.txt" which contains the class definitions
You obtain a list with colors, and a list with corresponding object.

In [None]:
def parse_code(l):
    '''Function to parse lines in a text file, returns separated elements (label codes and names in this case)
    '''
    if len(l.strip().split("\t")) == 2:
        a, b = l.strip().split("\t")
        return tuple(int(i) for i in a.split(' ')), b
    else:
        a, b, c = l.strip().split("\t")
        return tuple(int(i) for i in a.split(' ')), c



In [None]:
label_codes, label_names = zip(*[parse_code(l) for l in open(img_dir+"label_colors.txt")])
label_codes, label_names = list(label_codes), list(label_names)
label_codes[:5], label_names[:5]

##  Define functions for one hot encoding rgb labels, and decoding encoded predictions


In [None]:
code2id = {v:k for k,v in enumerate(label_codes)}
id2code = {k:v for k,v in enumerate(label_codes)}
name2id = {v:k for k,v in enumerate(label_names)}
id2name = {k:v for k,v in enumerate(label_names)}

In [None]:
def rgb_to_onehot(rgb_image, colormap = id2code):
    '''Function to one hot encode RGB mask labels
        Inputs: 
            rgb_image - image matrix (eg. 256 x 256 x 3 dimension numpy ndarray)
            colormap - dictionary of color to label id
        Output: One hot encoded image of dimensions (height x width x num_classes) where num_classes = len(colormap)
    '''
    num_classes = len(colormap)
    shape = rgb_image.shape[:2]+(num_classes,)
    encoded_image = np.zeros( shape, dtype=np.int8 )
    for i, cls in enumerate(colormap):
        encoded_image[:,:,i] = np.all(rgb_image.reshape( (-1,3) ) == colormap[i], axis=1).reshape(shape[:2])
    return encoded_image


def onehot_to_rgb(onehot, colormap = id2code):
    '''Function to decode encoded mask labels
        Inputs: 
            onehot - one hot encoded image matrix (height x width x num_classes)
            colormap - dictionary of color to label id
        ------------

        Output: Decoded RGB image (height x width x 3) 
    '''
    single_layer = np.argmax(onehot, axis=-1)
    output = np.zeros( onehot.shape[:2]+(3,) )
    for k in colormap.keys():
        output[single_layer==k] = colormap[k]
    return np.uint8(output)


#  Creating custom Image data generators


In [None]:
# Normalizing only frame images, since masks contain label info
data_gen_args = dict(rescale=1./255.)
mask_gen_args = dict()

train_frames_datagen = ImageDataGenerator(**data_gen_args)
train_masks_datagen = ImageDataGenerator(**mask_gen_args)
val_frames_datagen = ImageDataGenerator(**data_gen_args)
val_masks_datagen = ImageDataGenerator(**mask_gen_args)
full_frames_datagen = ImageDataGenerator(**data_gen_args)
full_masks_datagen = ImageDataGenerator(**mask_gen_args)

seed = 1


## Custom image data generators for creating batches of frames and masks


In [None]:
def TrainAugmentGenerator(seed = 1, batch_size = 5):
    '''Train Image data generator
        ------------
        Inputs: 
        seed - seed provided to the flow_from_directory function to ensure aligned data flow
        batch_size - number of images to import at a time
        ------------
        Output: Decoded RGB image (height x width x 3) 
    '''
    train_image_generator = train_frames_datagen.flow_from_directory(DATA_PATH + 'train_frames/', batch_size = batch_size, seed = seed, target_size = (224, 224))
    train_mask_generator = train_masks_datagen.flow_from_directory(DATA_PATH + 'train_masks/', batch_size = batch_size, seed = seed, target_size = (224, 224))

    
    while True:
        X1i = train_image_generator.next()
        X2i = train_mask_generator.next()

        #One hot encoding RGB images
        mask_encoded = [rgb_to_onehot(X2i[0][x,:,:,:], id2code) for x in range(X2i[0].shape[0])]
        yield X1i[0], np.asarray(mask_encoded)

def ValAugmentGenerator(seed = 1, batch_size = 5):
    '''Validation Image data generator
    ------------
    Inputs: 
        seed - seed provided to the flow_from_directory function to ensure aligned data flow
        batch_size - number of images to import at a time
    ------------
    Output: Decoded RGB image (height x width x 3) 
    
    '''
    val_image_generator = val_frames_datagen.flow_from_directory(DATA_PATH + 'val_frames/', batch_size = batch_size, seed = seed, target_size = (224, 224))
    val_mask_generator = val_masks_datagen.flow_from_directory(DATA_PATH + 'val_masks/',batch_size = batch_size, seed = seed, target_size = (224, 224))


    while True:
        X1i = val_image_generator.next()
        X2i = val_mask_generator.next()
        
        #One hot encoding RGB images
        mask_encoded = [rgb_to_onehot(X2i[0][x,:,:,:], id2code) for x in range(X2i[0].shape[0])]
        
        yield X1i[0], np.asarray(mask_encoded)



# Create network SegNet

In [None]:
def Segnet(n_classes,input_height=224, input_width=224 , kernel=3):

    img_input = Input(shape=(input_height,input_width,3))

    x = Conv2D(64, (kernel, kernel), padding='same', name='block1_conv1', data_format='channels_last' )(img_input)
    x = BatchNormalization()(x)
    x = Activation('relu') (x)
    x = Conv2D(64, (kernel, kernel), padding='same', name='block1_conv2', data_format='channels_last' )(x)
    x = BatchNormalization()(x)
    x = Activation('relu') (x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool', data_format='channels_last' )(x)
    f1 = x
    
    
    # Block 2
    x = Conv2D(128, (kernel, kernel), padding='same', name='block2_conv1', data_format='channels_last' )(x)
    x = BatchNormalization()(x)
    x = Activation('relu') (x)
    x = Conv2D(128, (kernel, kernel), padding='same', name='block2_conv2', data_format='channels_last' )(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool', data_format='channels_last' )(x)
    x = BatchNormalization()(x)
    x = Activation('relu') (x)
    f2 = x

    # Block 3
    x = Conv2D(256, (kernel, kernel), padding='same', name='block3_conv1', data_format='channels_last' )(x)
    x = ( BatchNormalization())(x)
    x = Activation('relu') (x)
    x = Conv2D(256, (kernel, kernel), padding='same', name='block3_conv2', data_format='channels_last' )(x)
    x = ( BatchNormalization())(x)
    x = Activation('relu') (x)
    x = Conv2D(256, (kernel, kernel), padding='same', name='block3_conv3', data_format='channels_last' )(x)
    x = ( BatchNormalization())(x)
    x = Activation('relu') (x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool', data_format='channels_last' )(x)
    f3 = x

    # Block 4
    x = Conv2D(512, (kernel, kernel), padding='same', name='block4_conv1', data_format='channels_last' )(x)
    x = BatchNormalization()(x)
    x = Activation('relu') (x)
    x = Conv2D(512, (kernel, kernel), padding='same', name='block4_conv2', data_format='channels_last' )(x)
    x = BatchNormalization()(x)
    x = Activation('relu') (x)
    x = Conv2D(512, (kernel, kernel), padding='same', name='block4_conv3', data_format='channels_last' )(x)
    x = BatchNormalization()(x)
    x = Activation('relu') (x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool', data_format='channels_last' )(x)
    f4 = x

    # Block 5
    x = Conv2D(512, (kernel, kernel), padding='same', name='block5_conv1', data_format='channels_last' )(x)
    x = BatchNormalization() (x)
    x = Activation('relu') (x)
    x = Conv2D(512, (kernel, kernel), padding='same', name='block5_conv2', data_format='channels_last' )(x)
    x = BatchNormalization() (x)
    x = Activation('relu') (x)
    x = Conv2D(512, (kernel, kernel), padding='same', name='block5_conv3', data_format='channels_last' )(x)
    x = BatchNormalization() (x)
    x = Activation('relu') (x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool', data_format='channels_last' )(x)
    f5 = x


    o = f4 # or f5, you can change it.

    o = ( UpSampling2D( (2,2), data_format='channels_last'))(o)
    o = ( Conv2D( 512, (kernel, kernel), padding='same', data_format='channels_last'))(o)
    o = ( BatchNormalization())(o)
   
    o = ( UpSampling2D( (2,2), data_format='channels_last'))(o)
    o = ( Conv2D( 256, (kernel, kernel), padding='same', data_format='channels_last')) (o)
    o = ( BatchNormalization())(o)
   
    o = ( UpSampling2D((2,2)  , data_format='channels_last' ) )(o)
    o = ( Conv2D( 128 , (kernel, kernel), padding='same' , data_format='channels_last' )) (o)
    o = ( BatchNormalization())(o)
 
    o = ( UpSampling2D((2,2)  , data_format='channels_last' ))(o)
    o = ( Conv2D( 64 , (kernel, kernel), padding='same'  , data_format='channels_last' )) (o)
    o = ( BatchNormalization())(o)

    o =  Conv2D( n_classes , (kernel, kernel) , padding='same', data_format='channels_last' )(o)
    o = (Activation('softmax')) (o)
    
    model = Model(img_input,o)

    return model

In [None]:
## There are 32 possible classes.
model = Segnet(32)

In [None]:
model.summary()

# Compile model

In [None]:
# Categorical crossentropy loss since labels have been one hot encoded
model.compile(optimizer='adam', loss="categorical_crossentropy", metrics=['accuracy'])

In [None]:
## Allows to stop training before the end if some condition is reached. Here we stop if after 10 epochs the loss of validation set is not decreasing anymore.
es = EarlyStopping(mode='min', monitor='val_loss', patience=30, verbose=1)
callbacks = [es]

# Train model or load weights

## Load weights of pre-trained model
**Here** you can load pre-trained weights of the model, after coding and compiling it. Then, you can train it, it will consider existing weights as baseline, initialisation, and you can obtain results faster than if you train from beginning. 
So skip "Train model" part if you just want to see results.  
Or change `num_epochs` if you want to see evolution of training.

In [None]:
# model.load_weights(img_dir+"../model_camvid_weight_ep85.hdf5")
# model.evaluate(x=TrainAugmentGenerator(),steps=1)
# model.evaluate(x=ValAugmentGenerator(),steps=1)

## Train model

In [None]:
batch_size = 5
steps_per_epoch = np.ceil(float(len(frames_list) - round(0.2*len(frames_list))) / float(batch_size))
validation_steps = np.ceil(0.2*len(frames_list) / float(batch_size))

In [None]:
# Train model
num_epochs = 200 ## depending on if you load weights or not
result = model.fit(x=TrainAugmentGenerator(batch_size=batch_size), batch_size=batch_size, steps_per_epoch=int(steps_per_epoch), validation_data = ValAugmentGenerator(batch_size=batch_size),validation_steps=int(validation_steps),epochs=num_epochs, callbacks=callbacks, verbose=1)

In [None]:
model.save_weights(DATA_PATH+'../model_camvid_weight_ep{0}.hdf5'.format(your_nb_epochs)) 

# if you load weights of pre-trained model with 100 epochs, and that you train again after,
# you will have (100 + num_epochs)

# Model Evaluation

## Accuracy and Loss plots
You CAN NOT see evolutions of loss and accuracy if you don't train your model. 
If you want to see directly prediction from pre-trained network with loaded weights, skip this part.


In [None]:
# Get actual number of epochs model was trained for
N = len(result.history['loss'])

#Plot the model evaluation history
plt.style.use("ggplot")
fig = plt.figure(figsize=(20,8))

fig.add_subplot(1,2,1)
plt.title("Losses")
plt.plot(np.arange(0, N), result.history["loss"], label="train_loss")
plt.plot(np.arange(0, N), result.history["val_loss"], label="val_loss")
plt.ylim(0, 1)

fig.add_subplot(1,2,2)
plt.title("Accuracies")
plt.plot(np.arange(0, N), result.history["accuracy"], label="train_accuracy")
plt.plot(np.arange(0, N), result.history["val_accuracy"], label="val_accuracy")
plt.ylim(0, 1)

plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.show()


## Extract and display model frame, prediction and mask batch¶


Here you can see predictions of SegNet network on validation set : 
- just with loaded weights for 100 epochs, without training again, see obtained segmentation, not so good but a start.
- after more training (so more than 100 epochs).

In [None]:
training_gen = TrainAugmentGenerator()
testing_gen = ValAugmentGenerator(batch_size=20)

In [None]:
batch_img,batch_mask = next(testing_gen)
pred_all= model.predict(batch_img)
np.shape(pred_all)

In [None]:
 for i in range(0,np.shape(pred_all)[0]):
    
    fig = plt.figure(figsize=(20,8))
    
    ax1 = fig.add_subplot(1,3,1)
    ax1.imshow(batch_img[i])
    ax1.title.set_text('Actual frame')
    ax1.grid(b=None)
    
    
    ax2 = fig.add_subplot(1,3,2)
    ax2.set_title('Ground truth labels')
    ax2.imshow(onehot_to_rgb(batch_mask[i],id2code))
    ax2.grid(b=None)
    
    ax3 = fig.add_subplot(1,3,3)
    ax3.set_title('Predicted labels')
    ax3.imshow(onehot_to_rgb(pred_all[i],id2code))
    ax3.grid(b=None)
    
    plt.show()