In [0]:
!pip install -q tensorflow-gpu==2.0.0-rc1

[K     |████████████████████████████████| 380.5MB 46kB/s 
[K     |████████████████████████████████| 501kB 43.2MB/s 
[K     |████████████████████████████████| 4.3MB 55.0MB/s 
[?25h

In [0]:
# Connect Google Drive
from google.colab import auth
auth.authenticate_user()

from google.colab import drive
drive.mount('/content/gdrive')

%cd ./gdrive/"My Drive"/"Colab Notebooks"/"Fully Convolutional Network for Semantic segmentation"

In [0]:
import cv2
import numpy as np
import random
import csv
import time
import os
import glob

import tensorflow as tf
from tensorflow.keras.layers import Conv2D, MaxPool2D, BatchNormalization, ReLU, Conv2DTranspose

IMG_SIZE = 384
num_classes = 32

In [0]:
# get data from directory
def prepare_data(dataset_dir):
    train_input_names=[]
    train_output_names=[]
    test_input_names=[]
    test_output_names=[]

    train_input_names = glob.glob(dataset_dir + "/" + "train/*")
    train_output_names = glob.glob(dataset_dir + "/" + "train_labels/*")
    test_input_names = glob.glob(dataset_dir + "/" + "test/*")
    test_output_names = glob.glob(dataset_dir + "/" + "test_labels/*")
    
    train_input_names.sort(), train_output_names.sort()
    test_input_names.sort(), test_output_names.sort()
    
    return train_input_names, train_output_names, test_input_names, test_output_names

def load_image(path):
    image = cv2.cvtColor(cv2.imread(path,-1), cv2.COLOR_BGR2RGB)
    return image

def get_label_info(csv_path):
    """
    Retrieve the class names and label values for the selected dataset.
    Must be in CSV format!
    # Arguments
        csv_path: The file path of the class dictionairy
        
    # Returns
        Two lists: one for the class names and the other for the label values
    """
    class_names = []
    label_values = []
    with open(csv_path, 'r') as csvfile:
        file_reader = csv.reader(csvfile, delimiter=',')
        header = next(file_reader)
        for row in file_reader:
            class_names.append(row[0])
            label_values.append([int(row[1]), int(row[2]), int(row[3])])
    return class_names, label_values

def one_hot_it(label, label_values):
    """
    Convert a segmentation image label array to one-hot format
    by replacing each pixel value with a vector of length num_classes
    # Arguments
        label: The 2D array segmentation image label
        label_values
        
    # Returns
        A 2D array with the same width and hieght as the input, but
        with a depth size of num_classes
    """
    semantic_map = []
    for colour in label_values:
        # colour_map = np.full((label.shape[0], label.shape[1], label.shape[2]), colour, dtype=int)
        equality = np.equal(label, colour)
        class_map = np.all(equality, axis = -1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1)

    return semantic_map

def reverse_one_hot(image):
    """
    Transform a 2D array in one-hot format (depth is num_classes),
    to a 2D array with only 1 channel, where each pixel value is
    the classified class key.
    # Arguments
        image: The one-hot format image 
        
    # Returns
        A 2D array with the same width and hieght as the input, but
        with a depth size of 1, where each pixel value is the classified 
        class key.
    """
    x = np.argmax(image, axis = -1)
    return x

def colour_code_segmentation(image, label_values):
    """
    Given a 1-channel array of class keys, colour code the segmentation results.
    # Arguments
        image: single channel array where each value represents the class key.
        label_values
        
    # Returns
        Colour coded image for segmentation visualization
    """
    
    colour_codes = np.array(label_values)
    x = colour_codes[image.astype(int)]

    return x

# Randomly crop the image to a specific size. For data augmentation
def random_crop(image, label, crop_height, crop_width):
    if (image.shape[0] != label.shape[0]) or (image.shape[1] != label.shape[1]):
        raise Exception('Image and label must have the same dimensions!')
        
    if (crop_width <= image.shape[1]) and (crop_height <= image.shape[0]):
        x = random.randint(0, image.shape[1]-crop_width)
        y = random.randint(0, image.shape[0]-crop_height)
        
        if len(label.shape) == 3:
            return image[y:y+crop_height, x:x+crop_width, :], label[y:y+crop_height, x:x+crop_width, :]
        else:
            return image[y:y+crop_height, x:x+crop_width, :], label[y:y+crop_height, x:x+crop_width]
    else:
        raise Exception('Crop shape (%d, %d) exceeds image dimensions (%d, %d)!' % (crop_height, crop_width, image.shape[0], image.shape[1]))

# data augmentation
def data_augmentation(input_image, output_image):
    h_flip = True # to randomly flip the image horizontally for data augmentation.
    v_flip = True # to randomly flip the image vertically for data augmentation.
    brightness = 0.1 # to randomly change the image brightness for data augmentation. Specifies the max bightness change as a factor between 0.0 and 1.0. For example, 0.1 represents a max brightness change of 10%% (+-).
    rotation = 90 # to randomly rotate the image for data augmentation. Specifies the max rotation angle in degrees.
    crop_height = IMG_SIZE # Height of cropped input image to network
    crop_width = IMG_SIZE # Width  of cropped input image to network
    
    # Data augmentation
    input_image, output_image = random_crop(input_image, output_image, crop_height, crop_width)

    if h_flip and random.randint(0,1):
        input_image = cv2.flip(input_image, 1)
        output_image = cv2.flip(output_image, 1)
    if v_flip and random.randint(0,1):
        input_image = cv2.flip(input_image, 0)
        output_image = cv2.flip(output_image, 0)
    if brightness:
        factor = 1.0 + random.uniform(-1.0*brightness, brightness)
        table = np.array([((i / 255.0) * factor) * 255 for i in np.arange(0, 256)]).astype(np.uint8)
        input_image = cv2.LUT(input_image, table)
    if rotation:
        angle = random.uniform(-1*rotation, rotation)
    if rotation:
        M = cv2.getRotationMatrix2D((input_image.shape[1]//2, input_image.shape[0]//2), angle, 1.0)
        input_image = cv2.warpAffine(input_image, M, (input_image.shape[1], input_image.shape[0]), flags=cv2.INTER_NEAREST)
        output_image = cv2.warpAffine(output_image, M, (output_image.shape[1], output_image.shape[0]), flags=cv2.INTER_NEAREST)

    return input_image, output_image

In [0]:
def compute_global_accuracy(pred, label):
    total = len(label)
    count = 0.0
    for i in range(total):
        if pred[i] == label[i]:
            count = count + 1.0
    return float(count) / float(total)

# Compute the class-specific segmentation accuracy
def compute_class_accuracies(pred, label, num_classes):
    total = []
    for val in range(num_classes):
        total.append((label == val).sum())

    count = [0.0] * num_classes
    for i in range(len(label)):
        if pred[i] == label[i]:
            count[int(pred[i])] = count[int(pred[i])] + 1.0

    # If there are no pixels from a certain class in the GT, 
    # it returns NAN because of divide by zero
    # Replace the nans with a 1.0.
    accuracies = []
    for i in range(len(total)):
        if total[i] == 0:
            accuracies.append(1.0)
        else:
            accuracies.append(count[i] / total[i])

    return accuracies


def compute_mean_iou(pred, label):

    unique_labels = np.unique(label)
    num_unique_labels = len(unique_labels);

    I = np.zeros(num_unique_labels)
    U = np.zeros(num_unique_labels)

    for index, val in enumerate(unique_labels):
        pred_i = pred == val
        label_i = label == val

        I[index] = float(np.sum(np.logical_and(label_i, pred_i)))
        U[index] = float(np.sum(np.logical_or(label_i, pred_i)))


    mean_iou = np.mean(I / U)
    return mean_iou


def evaluate_segmentation(pred, label, num_classes, score_averaging="weighted"):
    flat_pred = pred.flatten()
    flat_label = label.flatten()

    global_accuracy = compute_global_accuracy(flat_pred, flat_label)
    class_accuracies = compute_class_accuracies(flat_pred, flat_label, num_classes)

    iou = compute_mean_iou(flat_pred, flat_label)

    return global_accuracy, class_accuracies, iou

# Takes an absolute file path and returns the name of the file without th extension
def filepath_to_name(full_name):
    file_name = os.path.basename(full_name)
    file_name = os.path.splitext(file_name)[0]
    return file_name

In [0]:
# add fully convolutional layer
class Conv1BatchNorm(tf.keras.layers.Layer):

    def __init__(self, kernel_size):
        super(Conv1BatchNorm, self).__init__()
        self.conv = Conv2D(4096, kernel_size)
        self.batchnorm = BatchNormalization()
        self.relu = ReLU()

    def call(self, x):
        x = self.conv(x)
        x = self.batchnorm(x)
        return self.relu(x)

In [0]:
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)

base_model = tf.keras.applications.VGG16(input_shape=IMG_SHAPE, 
                                         include_top=False,
                                         weights='imagenet')

base_model.trainable = False
base_model.summary()

In [0]:
# add fully convolutional layer
class FCN_VGG16(tf.keras.Model):

    def __init__(self, num_classes=21, base_model=None):
        super(FCN_VGG16, self).__init__()
        if base_model != None:
            self.base = base_model
        self.conv6 = Conv1BatchNorm(1)
        self.conv7 = Conv1BatchNorm(1)
      
        self.score_fc = Conv2D(num_classes, 1, activation='relu')
        
        self.score_pool4 = Conv2D(num_classes, 1, padding='same', activation='relu')
        self.score_pool3 = Conv2D(num_classes, 1, activation='relu')

        self.upsampling1 = Conv2DTranspose(num_classes, 4, 2, padding='same')
        self.upsampling2 = Conv2DTranspose(num_classes, 4, 2, padding='same')
        
        self.upsample8s = Conv2DTranspose(num_classes, 16, 8, padding='same')

    def shadow_model(self, input_content, start_index, end_index):
        x = input_content
        for i in range(start_index, end_index + 1):
            x = self.base.layers[i](x)
        return x

    def call(self, x):
        ##########################pre trained###########################
        input_image = x
        x = self.base(x)
        pool3 = self.shadow_model(input_image, 1, 10)
        pool4 = self.shadow_model(pool3, 11, 14)
        ################################################################
        x = self.conv6(x)
        x = self.conv7(x)
        x = self.score_fc(x)

        fusion1 = tf.add(self.upsampling1(x), self.score_pool4(pool4))
        fusion2 = tf.add(self.upsampling2(fusion1), self.score_pool3(pool3))

        return self.upsample8s(fusion2)

In [0]:
dnn_model = FCN_VGG16(num_classes, base_model)

loss_object = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.RMSprop()

checkpoint_directory = "./checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpts1")

ckpt = tf.train.Checkpoint(optimizer=optimizer, net=dnn_model)
manager = tf.train.CheckpointManager(ckpt, './checkpoints/ckpts1', max_to_keep=50)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

def train_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = dnn_model(images)
        loss = loss_object(labels, predictions)

    gradients = tape.gradient(loss, dnn_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, dnn_model.trainable_variables))
    
    train_loss(loss)
    train_accuracy(labels, predictions)

In [0]:
class_names_list, label_values = get_label_info("CamVid/class_dict.csv")
class_names_string = ""
with tf.device('/cpu:0'):
    for class_name in class_names_list:
        if not class_name == class_names_list[-1]:
            class_names_string = class_names_string + class_name + ", "
        else:
            class_names_string = class_names_string + class_name
num_classes = len(label_values)

print("Loading the data ...", end=' ')
train_input_names, train_output_names, test_input_names, test_output_names = prepare_data("CamVid")
print("Complete!!")

avg_loss_per_epoch = []
avg_scores_per_epoch = []
avg_iou_per_epoch = []

In [0]:
test_indices = []
num_tests = min(20, len(test_input_names))


random.seed(16)
test_indices=random.sample(range(0,len(test_input_names)),num_tests)

In [0]:
batch_size = 16
stat_period = 1

for epoch in range(50):
    print("epoch: ", epoch + 1)
    start_time = time.time()

    id_list = np.random.permutation(len(train_input_names))

    num_iters = int(np.floor(len(id_list) / batch_size))
    for i in range(num_iters):
        input_image_batch = []
        output_image_batch = []

        # Collect a batch of images
        for j in range(batch_size):
            index = i * batch_size + j
            id = id_list[index]
            input_image = load_image(train_input_names[id])
            output_image = load_image(train_output_names[id])

            with tf.device('/cpu:0'):
                input_image, output_image = data_augmentation(input_image, output_image)

                # Prep the data. Make sure the labels are in one-hot format
                input_image = np.float32(input_image) / 255.0
                output_image = np.float32(one_hot_it(label=output_image, label_values=label_values))

                input_image_batch.append(np.expand_dims(input_image, axis=0))
                output_image_batch.append(np.expand_dims(output_image, axis=0))

        if batch_size == 1:
            input_image_batch = input_image_batch[0]
            output_image_batch = output_image_batch[0]
        else:
            input_image_batch = np.squeeze(np.stack(input_image_batch, axis=1))
            output_image_batch = np.squeeze(np.stack(output_image_batch, axis=1))
        train_step(input_image_batch, output_image_batch)
            
    end_time = time.time()
     
    if epoch % stat_period is 0:
        template = "\t{} (sec/epoch), Epoch {}, Loss: {}, Accuracy: {} %"
        print(template.format((end_time - start_time) / stat_period,
                              epoch + 1,
                              train_loss.result(),
                              train_accuracy.result() * 100))
        avg_loss_per_epoch.append(train_loss.result())
        
        save_path = manager.save()
        print("\tSaved checkpoint for epoch {}: {}".format(epoch + 1, save_path))

        train_loss.reset_states()
        train_accuracy.reset_states()
      
        print("\tPerforming test")

        target=open("%s/%04dtest_scores.csv"%("checkpoints_test",epoch),'w')
        target.write("test_name, avg_accuracy, mean iou, %s\n" % (class_names_string))

        scores_list = []
        class_scores_list = []
        iou_list = []

        for ind in test_indices:
            input_image = np.expand_dims(np.float32(load_image(test_input_names[ind])[:IMG_SIZE, :IMG_SIZE]),axis=0)/255.0
            gt = load_image(test_output_names[ind])[:IMG_SIZE, :IMG_SIZE]
            gt = reverse_one_hot(one_hot_it(gt, label_values))

            output_image = dnn_model(input_image)

            output_image = np.array(output_image[0,:,:,:])
            output_image = reverse_one_hot(output_image)
            out_vis_image = colour_code_segmentation(output_image, label_values)

            accuracy, class_accuracies, iou = evaluate_segmentation(pred=output_image, label=gt, num_classes=num_classes)

            file_name = filepath_to_name(test_input_names[ind])
            target.write("%s, %f, %f"%(file_name, accuracy, iou))
            for item in class_accuracies:
                target.write(", %f"%(item))
            target.write("\n")

            scores_list.append(accuracy)
            class_scores_list.append(class_accuracies)
            iou_list.append(iou)

            gt = colour_code_segmentation(gt, label_values)

            file_name = os.path.basename(test_input_names[ind])
            file_name = os.path.splitext(file_name)[0]
            cv2.imwrite("%s/%04d%s_pred.png"%("checkpoints_test", epoch, file_name),cv2.cvtColor(np.uint8(out_vis_image), cv2.COLOR_RGB2BGR))
            cv2.imwrite("%s/%04d%s_gt.png"%("checkpoints_test", epoch, file_name),cv2.cvtColor(np.uint8(gt), cv2.COLOR_RGB2BGR))
        target.close()

        avg_score = np.mean(scores_list)
        class_avg_scores = np.mean(class_scores_list, axis=0)
        avg_scores_per_epoch.append(avg_score)
        avg_iou = np.mean(iou_list)

        print("\nAverage test accuracy for epoch # %04d = %f"% (epoch, avg_score))
        print("Average per class test accuracies for epoch # %04d:"% (epoch))
        for index, item in enumerate(class_avg_scores):
            print("%s = %f" % (class_names_list[index], item))
        print("Test IoU score = ", avg_iou)