In [None]:
from __future__ import division, print_function
import matplotlib
import matplotlib.pyplot as plt
from IPython.display import Image, display, clear_output
import numpy as np
import matplotlib.pyplot as plt
import sklearn.datasets
import tensorflow as tf
import math
from tensorflow.python.framework.ops import reset_default_graph
import datetime
import os
import glob
import data_utils
from random import randint

from sklearn.utils import shuffle

## Data loader and Batch generator

In [None]:
# ==== CAMVID ==== #
dir_path = os.path.join('data', 'CamVid')
repo_name = 'SegNet-Tutorial'
images_path = os.path.join(dir_path, repo_name + '-master', 'CamVid')

if not os.path.exists(images_path):
    url = "https://github.com/alexgkendall/" + repo_name + "/archive/master.zip"
    import requests, zipfile
    from io import BytesIO
    request = requests.get(url)
    file = zipfile.ZipFile(BytesIO(request.content))
    file.extractall(dir_path)

# loading data and setting up constants
NUM_CLASSES = 12 #11 classes and unlabelled
IMAGE_SHAPE = (480, 352, 3)
data = data_utils.load_data(images_path = images_path)
# to visualize the size of the dimensions of the data
# print
print("@@@Shape checking of data sets@@@")
# print
print("TRAIN")
print("%d\timages\t%s\t%f" % (len(data.train), data.train[0].shape, data.train[0].mean()))
# print()
print("VALID")
print("%d\timages\t%s\t%f" % (len(data.valid), data.valid[0].shape, data.valid[0].mean()))
# print()
print("TEST")
print("%d\timages\t%s\t%f" % (len(data.test), data.test[0].shape, data.test[0].mean()))


# Batch generation
dummy_batch_gen = data_utils.batch_generator(data, batch_size=1, num_classes=NUM_CLASSES, num_iterations=5e3, seed=42)
train_batch, y_train_batch = next(dummy_batch_gen.gen_train())
_, valid_batch, y_valid_batch = next(dummy_batch_gen.gen_valid())
_, test_batch, y_test_batch = next(dummy_batch_gen.gen_test())

print("TRAIN")
print("\timages,", train_batch.shape)
print()
print("VALID")
print("\timages,", valid_batch.shape)
print()
print("TEST")
print("\timages,", test_batch.shape)

## Model definition

In [None]:
from tensorflow import layers
from tensorflow.contrib.layers import fully_connected, convolution2d, convolution2d_transpose, batch_norm, max_pool2d, dropout
from tensorflow.python.ops.nn import relu, elu, relu6, sigmoid, tanh, softmax, softplus, depthwise_conv2d, conv2d

# reset graph
reset_default_graph()

# -- THE MODEL --#
num_channels = IMAGE_SHAPE[2] #RGB
num_classes = NUM_CLASSES
k = 12;
height = IMAGE_SHAPE[1]
width = IMAGE_SHAPE[0]
dropout_prob = 0.2
layers_architecture = [4, 4, 4, 4, 4] #Number of layers in denseblocks
layers_bottleneck = 4
ident_layers = 0;

# Layer definitions
def layer(x, units):
    with tf.name_scope('layer_' + str(units)):
        global ident_layers
        x = batch_norm(x)
        x = relu(x)
        
        depthwise_filter = tf.get_variable("depth_conv_w_" + str(ident_layers), [3, 3, x.shape[-1], 1])
        pointwise_filter = tf.get_variable("point_conv_w_" + str(ident_layers), [1,1,x.shape[-1],units])
        ident_layers = ident_layers + 1
        x = depthwise_conv2d(x, depthwise_filter, padding ='SAME', strides =[1, 1, 1, 1])
        x = conv2d(x, pointwise_filter, padding ='SAME', strides =[1, 1, 1, 1])
        
        return dropout(x, is_training=is_training_pl, keep_prob=1-dropout_prob)
    
def dense_block(x, num_layers):
    with tf.name_scope('dense_' + str(num_layers)):
        for i in range(num_layers):
            layer_output = layer(x, k)
            x = tf.concat([x, layer_output], axis=-1)
            if i == 0:
                res = layer_output
            else:
                res = tf.concat([res, layer_output], axis=-1)
        return res
    

def transition_up(x, units):
    return convolution2d_transpose(x, num_outputs=units, kernel_size=(3, 3), stride=2)
    
    
def transition_down(x, units):
    with tf.name_scope('transition_down_' + str(units)):
        global ident_layers
        x = batch_norm(x)
        x = relu(x)
        
        depthwise_filter = tf.get_variable("depth_conv_w_" + str(ident_layers), [3, 3, x.shape[-1], 1])
        pointwise_filter = tf.get_variable("point_conv_w_" + str(ident_layers), [1,1,x.shape[-1],units])
        ident_layers = ident_layers + 1
        x = depthwise_conv2d(x, depthwise_filter, padding ='SAME', strides =[1, 1, 1, 1])
        x = conv2d(x, pointwise_filter, padding ='SAME', strides =[1, 1, 1, 1])
        
        x = dropout(x, is_training=is_training_pl, keep_prob=1-dropout_prob)
        x = max_pool2d(x, kernel_size=(2, 2))
        return x

# - Tiramisu Architecture - #
# Input placeholder
x_pl = tf.placeholder(tf.float32, [None, height, width, num_channels], 'x_pl')
y_pl = tf.placeholder(tf.float32, [None, height, width, num_classes], 'y_pl')
is_training_pl = tf.placeholder(tf.bool, name="is-training_pl")
print('x_pl', x_pl.shape)
print('y_pl', y_pl.shape)

def upsample(x, skip, num_dense, skip_up=False):
    x = transition_up(x, x.shape[-1].value)
    x = tf.concat([x, skip], axis=-1)
    dense_out = dense_block(x, num_dense)
    if skip_up:
        x = tf.concat([x, dense_out], axis=-1)
    else:
        x = dense_out
    print('DB ({} layers) + TU'.format(num_dense), '\t', x.shape)
    return x

def downsample(x, num_dense):
    skip = dense_block(x, num_dense)
    skip = tf.concat([x, skip], axis=-1)
    x = transition_down(skip, num_dense*k + x.shape[-1].value)
    print('DB ({} layers) + TD'.format(num_dense), '\t', x.shape)
    return x, skip

with tf.name_scope('tiramisu'):
    # DOWN SAMPLING
    x = convolution2d(x_pl, num_outputs=x_pl.shape[-1]*k, kernel_size=(3, 3),
                             stride=1, scope="pre-convolution")
    print('pre_conv', '\t\t', x.shape)

    skip = []
    for num_layers in layers_architecture:
        x, skipTmp = downsample(x, num_layers)
        skip.append(skipTmp)
                    
    # BOTTLENECK
    x = dense_block(x, layers_bottleneck)
    bottleneck_ext = x
    print('Bottleneck ({} layers)'.format(layers_bottleneck), '\t', x.shape)

    # UPSAMPLING
    for index in range(len(layers_architecture)-1, -1, -1):
        x = upsample(x, skip[index], layers_architecture[index], skip_up = index==0)

    upsampl_ext = x
    # Output layers
    x = convolution2d(x, num_outputs=num_classes, kernel_size=(1, 1),
                             stride=1, scope="post-convolution")
    
    post_conv = x
    print('post-convolution', '\t', x.shape)
    y = softmax(x)
    print('SoftMax output', '\t\t', y.shape)

print("Model built")

## Number of parameters 

In [None]:
def num_params():
    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
#         print(shape)
#         print(len(shape))
        variable_parameters = 1
        for dim in shape:
#             print(dim)
            variable_parameters *= dim.value
#         print(variable_parameters)
        total_parameters += variable_parameters
#     print(total_parameters)
    return total_parameters

print("Number of parameters\t", num_params())

In [None]:
with tf.variable_scope('loss'):
    # computing cross entropy per sample
    #cross_entropy = -tf.reduce_sum(y_pl * tf.log(y+1e-8), reduction_indices=[1])
    weights = 1 - y_pl[:,:,:,-1]
    cross_entropy = tf.losses.softmax_cross_entropy(y_pl, y, weights=weights)
    # averaging over samples
    #cross_entropy = tf.reduce_mean(cross_entropy)

    
with tf.variable_scope('training'):
    # defining our optimizer
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)

    # applying the gradients
    train_op = optimizer.minimize(cross_entropy)

    
with tf.variable_scope('performance'):
    # making a one-hot encoded vector of correct (1) and incorrect (0) predictions
    correct_prediction = tf.equal(tf.argmax(y, axis=-1), tf.argmax(y_pl, axis=-1))
    
    # averaging the one-hot encoded vector
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
tf.summary.scalar('Evaluation/loss', cross_entropy)
tf.summary.scalar('Evaluation/accuracy', accuracy)

# Memory limitation
config = tf.ConfigProto()
config.gpu_options.allow_growth = True

## Testing the forward path

In [None]:
#Test the forward pass
x_batch, y_batch = next(dummy_batch_gen.gen_train())

with tf.Session(config=config) as sess:
    sess.run(tf.global_variables_initializer())
    y_pred = sess.run(fetches=y, feed_dict={x_pl: x_batch, is_training_pl: True})

assert y_pred.shape == y_batch.shape, "ERROR the output shape is not as expected!" \
        + " Output shape should be " + str(y_batch.shape) + ' but was ' + str(y_pred.shape)

print('Forward pass successful!')

## Launch tensorboard

In [None]:
# setup and write summaries
timestr = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
logdir = os.path.join('logs', timestr)
summaries = tf.summary.merge_all()

import subprocess
subprocess.Popen(["tensorboard","--logdir=" + os.path.split(logdir)[0]])

## Training the model

In [None]:
#Training Loop
batch_size = 3
max_epochs = 100
seed = 42
LOG_FREQ = 10
VALID_FREQ = 100
VALIDATION_SIZE = 0.1 # 0.1 is ~ 100 samples for valition

batch_gen = data_utils.batch_generator(data, batch_size=batch_size, num_classes=num_classes,
                            num_iterations=max_epochs, seed=seed, val_size=VALIDATION_SIZE)

valid_loss, valid_accuracy = [], []
train_loss, train_accuracy = [], []

#To save the trained network
saver = tf.train.Saver()

with tf.Session(config=config) as sess:
    summary_writer = tf.summary.FileWriter(os.path.split(logdir)[0], graph=sess.graph)
    
    summarywriter_train = tf.summary.FileWriter(os.path.join(logdir, 'train'), sess.graph)
    summarywriter_valid = tf.summary.FileWriter(os.path.join(logdir, 'valid'), sess.graph)
    
    sess.run(tf.global_variables_initializer())
    print('Begin training loop')

    try:
        for num, batch_train in enumerate(batch_gen.gen_train()):
            _train_loss, _train_accuracy = [], []
            
            ## Run train op
            x_batch = batch_train[0]
            y_batch = batch_train[1]
            fetches_train = [train_op, cross_entropy, accuracy, summaries, y, bottleneck_ext, upsampl_ext, post_conv]
            feed_dict_train = {x_pl: x_batch, y_pl: y_batch, is_training_pl: True}
            _, _loss, _acc, sum_train, output_t, bot_out, up_out, pconv_out = sess.run(fetches_train, feed_dict_train)
            
            _train_loss.append(_loss)
            _train_accuracy.append(_acc)
            
            if num % LOG_FREQ == 0:
                summarywriter_train.add_summary(sum_train, num) # save the train summary
                print("seen", num*batch_size)

            ## Compute validation loss and accuracy
            if num % VALID_FREQ == 0 \
                    and num >= batch_size:
                ## Save the netork at each validation step for backup
                saver.save(sess, os.path.join(os.getcwd(), 'trained_nets', 'trained_network_l'+','.join(str(e) for e in layers_architecture)+'_b'+str(layers_bottleneck)))
                
                train_loss.append(np.mean(_train_loss))
                train_accuracy.append(np.mean(_train_accuracy))
                cur_acc = 0
                cur_loss = 0
                tot_num = 0
                iou_v = np.zeros((num_classes))
                # batch validation
                num_batch = len(batch_gen._idcs_valid)//batch_size
                example_index = randint(0, num_batch)
                for i, (numval, x_valid, y_valid) in enumerate(batch_gen.gen_valid()):
                    fetches_valid = [cross_entropy, accuracy, summaries, y]
                    feed_dict_valid = {x_pl: x_valid, y_pl: y_valid, is_training_pl: False}
                    _loss, _acc, sum_valid, output_v = sess.run(fetches_valid, feed_dict_valid)
                    if i == 0:
                        summarywriter_valid.add_summary(sum_valid, num) # save the valid summary
                    
                    iou_v = iou_v + data_utils.compute_iou(output_v, y_valid, num_classes, width, height);
                    cur_acc += _acc*numval
                    cur_loss += _loss*numval
                    tot_num += numval
                    
                    if example_index == i:
                        example_image = output_v
                        example_labels = y_valid
                
                data_utils.print_image(example_image, width, height, num_classes, example_labels)
                
                iou_v = iou_v/num_batch
                valid_loss.append(cur_loss / float(tot_num))
                valid_accuracy.append(cur_acc / float(tot_num))
                print('IoU\'s :')
                print(iou_v)
                print("Training examples {} : Train Loss {:6.3f}, Train acc {:6.3f},  Valid loss {:6.3f},  Valid acc {:6.3f}, Valid mean IoU {:6.3f}".format(
                    num*batch_size, train_loss[-1], train_accuracy[-1], valid_loss[-1], valid_accuracy[-1], np.mean(iou_v)))
                
                #Add mean IoU to tensorboard
                summary = tf.Summary(value=[tf.Summary.Value(tag='IoU', simple_value=np.mean(iou_v))])
                summarywriter_valid.add_summary(summary, num)
                                                
        saver.save(sess, os.path.join(os.getcwd(), 'trained_nets', 'trained_network_l'+','.join(str(e) for e in layers_architecture)+'_b'+str(layers_bottleneck)))
        print('Network Saved !')
    except KeyboardInterrupt:
        pass

## Testing the model

In [None]:
with tf.Session(config=config) as sess:    
    try:
        #Restoring network
        network_to_restore = os.path.join(os.getcwd(), 'trained_nets', 'trained_network_l'+','.join(str(e) for e in layers_architecture)+'_b'+str(layers_bottleneck));
        saver.restore(sess, network_to_restore)
        print('Network Restored !');
        cur_acc = 0
        cur_loss = 0
        tot_num = 0
        iou_t = np.zeros((num_classes))
        num_batch = len(batch_gen._idcs_test)//batch_size
        # batch test
        for numtest, x_test, y_test in batch_gen.gen_test():
            fetches_test = [cross_entropy, accuracy, y]
            feed_dict_test = {x_pl: x_test, y_pl: y_test, is_training_pl: False}
            t_loss, t_acc, output_v = sess.run(fetches_test, feed_dict_test)

            iou_t = iou_t + data_utils.compute_iou(output_v, y_test, num_classes, width, height);
            cur_acc += t_acc*numtest
            cur_loss += t_loss*numtest
            tot_num += numtest

        data_utils.print_image(output_v, width, height, num_classes, y_test)
        
        iou_t = iou_t/num_batch
        test_loss = (cur_loss / float(tot_num))
        test_accuracy = (cur_acc / float(tot_num))
        print("Testing : Test Loss {:6.3f}, Test acc {:6.3f}, Mean IoU {:6.3f}".format(test_loss, test_accuracy, np.mean(iou_t)))
        
    except KeyboardInterrupt:
        print('KeyboardInterrupt')