In [None]:
# encoding: utf-8

import numpy as np
import tensorflow as tf
import sklearn.metrics as skmt
import matplotlib
# matplotlib.use('agg') # so that plt works in command line
import matplotlib.pyplot as plt
import scipy.io as sio
import skimage.io
import h5py
import sys
import os
import gc
import os
import psutil

from optparse import OptionParser

sys.path.append('../Metric/')
sys.path.append('../../Visualization/')
sys.path.append('../../Data_Preprocessing/')
from Metric import *
from Visualization import *
from Data_Extractor import *
from Bilinear_Kernel import *

In [None]:
parser = OptionParser()
parser.add_option("--save", dest="save_path")
parser.add_option("--name", dest="model_name")

parser.add_option("--train", dest="path_train_set", default="../../Data/090085/Road_Data/motor_trunk_pri_sec_tert_uncl_track/posneg_seg_coord_split_128_18_train")
parser.add_option("--cv", dest="path_cv_set", default="../../Data/090085/Road_Data/motor_trunk_pri_sec_tert_uncl_track/posneg_seg_coord_split_128_18_cv")

parser.add_option("--norm", default="mean", dest="norm")
parser.add_option("--pos", type="int", default=0, dest="pos_num")
parser.add_option("--size", type="int", default=128, dest="size")
parser.add_option("-e", "--epoch", type="int", default=15, dest="epoch")
parser.add_option("--learning_rate", type="float", default=9e-6, dest="learning_rate")
parser.add_option("--batch", type="int", default=2, dest="batch_size")
parser.add_option("--rand", type="int", default=0, dest="rand_seed")

parser.add_option("--conv", dest="conv_struct")
parser.add_option("--concat_input", dest="concat_input")
parser.add_option("--not_weight", action="store_false", default=True, dest="use_weight")
parser.add_option("--use_batch_norm", action="store_true", default=False, dest="use_batch_norm")

parser.add_option("--gpu", dest="gpu")
parser.add_option("--gpu_max_mem", type="float", default=0.8, dest="gpu_max_mem")

(options, args) = parser.parse_args(["--save", "-", 
                                     "--gpu", "0",
                                     "--conv", "16-32-64-128", 
                                     "--concat_input", "0"])

path_train_set = options.path_train_set
path_cv_set = options.path_cv_set
save_path = options.save_path
model_name = options.model_name

norm = options.norm
pos_num = options.pos_num
size = options.size
epoch = options.epoch
batch_size = options.batch_size
learning_rate = options.learning_rate
rand_seed = options.rand_seed

conv_struct = options.conv_struct

use_weight = options.use_weight
use_batch_norm = options.use_batch_norm
concat_input = options.concat_input

gpu = options.gpu
gpu_max_mem = options.gpu_max_mem

# restrict to single gpu
assert gpu in set(['0', '1'])
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu

if norm.startswith('m'): norm = 'mean'
elif norm.startswith('G'): norm = 'Gaussian'
else: 
    print("norm = ", norm, " not in ('mean', 'Gaussian')")
    sys.exit()

if not model_name:
    model_name = "Unet_"
    model_name += conv_struct + "_"
    model_name += norm[0] + "_"
    if use_weight: model_name += "weight_"
    if use_batch_norm: model_name += "bn_"
    if concat_input: model_name += "concatI" + concat_input + "_"
    model_name += "p" + str(pos_num) + "_"
    model_name += "e" + str(epoch) + "_"
    model_name += "r" + str(rand_seed)
    
if not save_path:
    print("no save path provided")
    sys.exit()
save_path = save_path.strip('/') + '/' + model_name + '/'
if not os.path.exists(save_path):
    os.makedirs(save_path)
if not os.path.exists(save_path+'Analysis'):
    os.makedirs(save_path+'Analysis')

print("Train set:", path_train_set)
print("CV set:", path_cv_set)

print("will be saved as ", model_name)
print("will be saved into ", save_path)

if not conv_struct:
    print("must provide structure for conv")
    sys.exit()
else:
    conv_struct = [int(x) for x in conv_struct.split('-')]
    assert len(conv_struct) == 4

# parse concat_input options (if not None): e.g. 3-16;5-8;1-32 
# => concat[ 3x3 out_channel=16, 5x5 out_channel=8, 1x1 out_channel=32] followed by 1x1 conv out_channel = classoutput
# concat_input = 0 => concat the raw input before the calculation of logits
if concat_input:
    concat_input = [[int(x) for x in config.split('-')] for config in concat_input.split(';')]

# monitor mem usage
process = psutil.Process(os.getpid())
print('mem usage before data loaded:', process.memory_info().rss / 1024/1024, 'MB')
print()



In [None]:
''' Data preparation '''



# set random seed
np.random.seed(rand_seed)

# Load training set
train_set = h5py.File(path_train_set, 'r')
train_pos_topleft_coord = np.array(train_set['positive_example'])
train_neg_topleft_coord = np.array(train_set['negative_example'])
train_raw_image = np.array(train_set['raw_image'])
train_road_mask = np.array(train_set['road_mask'])
train_set.close()

# Load cross-validation set
CV_set = h5py.File(path_cv_set, 'r')
CV_pos_topleft_coord = np.array(CV_set['positive_example'])
CV_neg_topleft_coord = np.array(CV_set['negative_example'])
CV_raw_image = np.array(CV_set['raw_image'])
CV_road_mask = np.array(CV_set['road_mask'])
CV_set.close()

Train_Data = FCN_Data_Extractor (train_raw_image, train_road_mask, size,
                                 pos_topleft_coord = train_pos_topleft_coord,
                                 neg_topleft_coord = train_neg_topleft_coord,
                                 normalization = norm)

CV_Data = FCN_Data_Extractor (CV_raw_image, CV_road_mask, size,
                              pos_topleft_coord = CV_pos_topleft_coord,
                              neg_topleft_coord = CV_neg_topleft_coord,
                              normalization = norm)
# run garbage collector
gc.collect()

# monitor mem usage
process = psutil.Process(os.getpid())
print('mem usage after data loaded:', process.memory_info().rss / 1024/1024, 'MB')
print()

In [None]:
''' Create model '''



# general model parameter
band = 7

class_output = 2 # number of possible classifications for the problem
if use_weight:
    class_weight = [Train_Data.pos_size/Train_Data.size, Train_Data.neg_size/Train_Data.size]
    print(class_weight, '[neg, pos]')

iteration = int(Train_Data.size/batch_size) + 1

tf.reset_default_graph()
with tf.variable_scope('input'):
    x = tf.placeholder(tf.float32, shape=[None, size, size, band], name='x')
    y = tf.placeholder(tf.float32, shape=[None, size, size, class_output], name='y')
    weight      = tf.placeholder(tf.float32, shape=[None, size, size], name='weight')
    is_training = tf.placeholder(tf.bool, name='is_training') # batch norm

with tf.variable_scope('input_bridge'):
    if concat_input:
        if concat_input == [[0]]:
            input_map = x
        else:
            input_map = tf.concat([tf.contrib.layers.conv2d(inputs=x, num_outputs=cfg[1], kernel_size=cfg[0], 
                                                            stride=1, padding='SAME', scope=str(cfg[0])+'-'+str(cfg[1])) 
                                   for cfg in concat_input], 
                                  axis=-1)

        

with tf.variable_scope('down_sampling'):
    # Convolutional Layer 1
    net = tf.contrib.layers.conv2d(inputs=x, num_outputs=conv_struct[0], kernel_size=3, 
                                   stride=1, padding='SAME', scope='conv1')

    conv1 = net
    net = tf.contrib.layers.max_pool2d(inputs=net, kernel_size=2, stride=2, padding='VALID', scope='pool1')
    
    # Convolutional Layer 2
    net = tf.contrib.layers.conv2d(inputs=net, num_outputs=conv_struct[1], kernel_size=3, 
                                   stride=1, padding='SAME', scope='conv2')
    conv2 = net
    net = tf.contrib.layers.max_pool2d(inputs=net, kernel_size=2, stride=2, padding='VALID', scope='pool2')
    
    # Convolutional Layer 3
    net = tf.contrib.layers.conv2d(inputs=net, num_outputs=conv_struct[2], kernel_size=3, 
                                   stride=1, padding='SAME', scope='conv3')
    conv3 = net
    net = tf.contrib.layers.max_pool2d(inputs=net, kernel_size=2, stride=2, padding='VALID', scope='pool3')


net = tf.contrib.layers.conv2d(inputs=net, num_outputs=conv_struct[3], kernel_size=3, 
                               stride=1, padding='SAME', scope='bridge')


with tf.variable_scope('up_sampling'):
    kernel_size = get_kernel_size(2)
    net = tf.contrib.layers.conv2d_transpose(inputs=net, num_outputs=conv_struct[2], kernel_size=kernel_size, stride=2, 
                                             weights_initializer=tf.constant_initializer(get_bilinear_upsample_weights(2, conv_struct[3], conv_struct[2])), 
                                             scope='conv3_T')
    with tf.variable_scope('concat3'):
        net = tf.concat([net, conv3], axis=-1)

    net = tf.contrib.layers.conv2d_transpose(inputs=net, num_outputs=conv_struct[1], kernel_size=kernel_size, stride=2, 
                                             weights_initializer=tf.constant_initializer(get_bilinear_upsample_weights(2, conv_struct[2], conv_struct[1])), 
                                             scope='conv2_T')
    with tf.variable_scope('concat2'):
        net = tf.concat([net, conv2], axis=-1)
    
    net = tf.contrib.layers.conv2d_transpose(inputs=net, num_outputs=conv_struct[0], kernel_size=kernel_size, stride=2, 
                                             weights_initializer=tf.constant_initializer(get_bilinear_upsample_weights(2, conv_struct[1], conv_struct[0])), 
                                             scope='conv1_T')

    with tf.variable_scope('concat1'):
        net = tf.concat([net, conv1], axis=-1)

        if concat_input:
            with tf.variable_scope('concat_input'):
                net = tf.concat([net, input_map], axis=-1)

logits = tf.contrib.layers.conv2d(inputs=net, num_outputs=class_output, kernel_size=3, stride=1, padding='SAME', scope='logits')

with tf.variable_scope('prob_out'):
    prob_out = tf.nn.softmax(logits, name='prob_out')

with tf.variable_scope('cross_entropy'):
    flat_logits = tf.reshape(logits, (-1, class_output), name='flat_logits')
    flat_labels = tf.to_float(tf.reshape(y, (-1, class_output)), name='flat_labels')
    flat_weight = tf.reshape(weight, [-1], name='flat_weight')

    cross_entropy = tf.losses.softmax_cross_entropy(flat_labels, flat_logits, weights=flat_weight)

# Ensures that we execute the update_ops before performing the train_step
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)

# monitor mem usage
process = psutil.Process(os.getpid())
print('mem usage after model created:', process.memory_info().rss / 1024/1024, 'MB')
print()
sys.stdout.flush()

In [None]:
show_graph(tf.get_default_graph().as_graph_def())

In [None]:
''' Train & monitor '''



saver = tf.train.Saver()

config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = gpu_max_mem
sess = tf.InteractiveSession(config=config)
sess.run(tf.global_variables_initializer())

balanced_acc_curve = []
AUC_curve = []
avg_precision_curve = []
cross_entropy_curve = []
for epoch_num in range(epoch):
    for iter_num in range(iteration):

        batch_x, batch_y, batch_w = Train_Data.get_patches(batch_size=batch_size, positive_num=pos_num, norm=True, weighted=use_weight)
        batch_x = batch_x.transpose((0, 2, 3, 1))

        train_step.run(feed_dict={x: batch_x, y: batch_y, weight: batch_w, is_training: True})
        
    # snap shot on CV set
    cv_metric = Metric_Record()
    cv_cross_entropy_list = []
    for batch_x, batch_y, batch_w in CV_Data.iterate_data(norm=True, weighted=use_weight):
        batch_x = batch_x.transpose((0, 2, 3, 1))

        [pred_prob, cross_entropy_cost] = sess.run([logits, cross_entropy], feed_dict={x: batch_x, y: batch_y, weight: batch_w, is_training: False})

        cv_metric.accumulate(Y         = np.array(batch_y.reshape(-1,class_output)[:,1]>0.5, dtype=int), 
                             pred      = np.array(pred_prob.reshape(-1,class_output)[:,1]>0.5, dtype=int), 
                             pred_prob = pred_prob.reshape(-1,class_output)[:,1])
        cv_cross_entropy_list.append(cross_entropy_cost)


    # calculate value
    balanced_acc = cv_metric.get_balanced_acc()
    AUC_score = skmt.roc_auc_score(np.array(cv_metric.y_true).flatten(), np.array(cv_metric.pred_prob).flatten())
    avg_precision_score = skmt.average_precision_score(np.array(cv_metric.y_true).flatten(), np.array(cv_metric.pred_prob).flatten())
    mean_cross_entropy = sum(cv_cross_entropy_list)/len(cv_cross_entropy_list)

    balanced_acc_curve.append(balanced_acc)
    AUC_curve.append(AUC_score)
    avg_precision_curve.append(avg_precision_score)
    cross_entropy_curve.append(mean_cross_entropy)
    
    print("mean_cross_entropy = ", mean_cross_entropy, "balanced_acc = ", balanced_acc, "AUC = ", AUC_score, "avg_precision = ", avg_precision_score)
    sys.stdout.flush()
print("finish")

# monitor mem usage
process = psutil.Process(os.getpid())
print('mem usage after model trained:', process.memory_info().rss / 1024/1024, 'MB')
print()

# plot training curve
plt.figsize=(9,5)
plt.plot(balanced_acc_curve, label='balanced_acc')
plt.plot(AUC_curve, label='AUC')
plt.plot(avg_precision_curve, label='avg_precision')
plt.legend()
plt.title('learning_curve_on_cross_validation')
plt.show()
# plt.savefig(save_path+'Analysis/'+'cv_learning_curve.png', bbox_inches='tight')
plt.close()

plt.figsize=(9,5)
plt.plot(cross_entropy_curve)
plt.show()
plt.title('cv_cross_entropy_curve')
# plt.savefig(save_path+'Analysis/'+'cv_cross_entropy_curve.png', bbox_inches='tight')
plt.close()

# save model
saver.save(sess, save_path + model_name)

# run garbage collection
saved_sk_obj = 0
gc.collect()

In [None]:
''' Evaluate model '''



# train set eva
print("On training set: ")
train_metric = Metric_Record()
train_cross_entropy_list = []
for batch_x, batch_y, batch_w in CV_Data.iterate_data(norm=True, weighted=use_weight):
    batch_x = batch_x.transpose((0, 2, 3, 1))

    [pred_prob, cross_entropy_cost] = sess.run([prob_out, cross_entropy], feed_dict={x: batch_x, y: batch_y, weight: batch_w, is_training: False})

    train_metric.accumulate(Y         = np.array(batch_y.reshape(-1,class_output)[:,1]>0.5, dtype=int),
                            pred      = np.array(pred_prob.reshape(-1,class_output)[:,1]>0.5, dtype=int), 
                            pred_prob = pred_prob.reshape(-1,class_output)[:,1])    
    train_cross_entropy_list.append(cross_entropy_cost)

train_metric.print_info()
AUC_score = skmt.roc_auc_score(np.array(train_metric.y_true).flatten(), np.array(train_metric.pred_prob).flatten())
avg_precision_score = skmt.average_precision_score(np.array(train_metric.y_true).flatten(), np.array(train_metric.pred_prob).flatten())
mean_cross_entropy = sum(train_cross_entropy_list)/len(train_cross_entropy_list)
print("mean_cross_entropy = ", mean_cross_entropy, "balanced_acc = ", balanced_acc, "AUC = ", AUC_score, "avg_precision = ", avg_precision_score)

# plot ROC curve
fpr, tpr, thr = skmt.roc_curve(np.array(train_metric.y_true)[:,1].flatten(), np.array(train_metric.pred_prob)[:,1].flatten())
plt.plot(fpr, tpr)
# plt.savefig(save_path+'Analysis/'+'train_ROC_curve.png', bbox_inches='tight')
plt.show()
plt.close()

# cross validation eva
print("On CV set:")
cv_metric.print_info()

# plot ROC curve
fpr, tpr, thr = skmt.roc_curve(np.array(cv_metric.y_true).flatten(), np.array(cv_metric.pred_prob).flatten())
plt.plot(fpr, tpr)
# plt.savefig(save_path+'Analysis/'+'cv_ROC_curve.png', bbox_inches='tight')
plt.show()
plt.close()
sys.stdout.flush()

# run garbage collection
train_metric = 0
cv_metric = 0
gc.collect()

In [None]:
# Predict road mask
# Predict road prob masks on train
train_pred_road = np.zeros([x for x in train_road_mask.shape] + [2])
for coord, patch in Train_Data.iterate_raw_image_patches_with_coord(norm=True):
    patch = patch.transpose((0, 2, 3, 1))
    train_pred_road[coord[0]:coord[0]+size, coord[1]:coord[1]+size, :] += logits.eval(feed_dict={x: patch, is_training: False})[0]

# Predict road prob on CV
CV_pred_road = np.zeros([x for x in CV_road_mask.shape] + [2])
for coord, patch in CV_Data.iterate_raw_image_patches_with_coord(norm=True):
    patch = patch.transpose((0, 2, 3, 1))
    CV_pred_road[coord[0]:coord[0]+size, coord[1]:coord[1]+size, :] += logits.eval(feed_dict={x: patch, is_training: False})[0]

# save prediction
prediction_name = model_name + '_pred.h5'
h5f_file = h5py.File(save_path + prediction_name, 'w')
h5f_file.create_dataset (name='train_pred', data=train_pred_road)
h5f_file.create_dataset (name='CV_pred', data=CV_pred_road)
h5f_file.close()

# monitor mem usage
process = psutil.Process(os.getpid())
print('mem usage after prediction maps calculated:', process.memory_info().rss / 1024/1024, 'MB')
print()

In [None]:
from IPython.display import clear_output, Image, display, HTML

def strip_consts(graph_def, max_const_size=32):
    """Strip large constant values from graph_def."""
    strip_def = tf.GraphDef()
    for n0 in graph_def.node:
        n = strip_def.node.add() 
        n.MergeFrom(n0)
        if n.op == 'Const':
            tensor = n.attr['value'].tensor
            size = len(tensor.tensor_content)
            if size > max_const_size:
                tensor.tensor_content = ("<stripped %d bytes>"%size).encode('utf-8')
    return strip_def

def show_graph(graph_def, max_const_size=32):
    """Visualize TensorFlow graph."""
    if hasattr(graph_def, 'as_graph_def'):
        graph_def = graph_def.as_graph_def()
    strip_def = strip_consts(graph_def, max_const_size=max_const_size)
    code = """
        <script>
          function load() {{
            document.getElementById("{id}").pbtxt = {data};
          }}
        </script>
        <link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()>
        <div style="height:600px">
          <tf-graph-basic id="{id}"></tf-graph-basic>
        </div>
    """.format(data=repr(str(strip_def)), id='graph'+str(np.random.rand()))

    iframe = """
        <iframe seamless style="width:1200px;height:620px;border:0" srcdoc="{}"></iframe>
    """.format(code.replace('"', '&quot;'))
    display(HTML(iframe))