In [None]:
import tensorflow as tf 
import os, sys
slim = tf.contrib.slim

In [None]:
sys.path.append("/home/shou/network/tf-models/research/slim")
sys.path.append("/home/shou/network/fcn/tf-image-segmentation")
from tf_image_segmentation.utils.augmentation import (distort_randomly_image_color,
                                                      flip_randomly_left_right_image_with_annotation,
                                                      scale_randomly_image_with_annotation_with_fixed_size_output)
from tf_image_segmentation.utils.pascal_voc import pascal_segmentation_lut

from tf_image_segmentation.utils.training import get_valid_logits_and_labels

In [None]:
def u_net(image):
    with tf.variable_scope("u_net", reuse=None):
        with slim.arg_scope([slim.conv2d,  slim.conv2d_transpose, slim.max_pool2d, slim.avg_pool2d], stride=1, padding='SAME'):
            # slim.conv2d default relu activation
            # subsampling
            conv0 = slim.repeat(image, 2, slim.conv2d, 32, [3, 3], scope='conv0')
            pool0 = slim.max_pool2d(conv0, [2, 2], scope='pool0')  # 1/2
            bn0 = slim.batch_norm(pool0, decay=0.9, epsilon=1e-5, scope="bn0")
            
            conv1 = slim.repeat(bn0, 2, slim.conv2d, 64, [3, 3], scope='conv1')
            pool1 = slim.max_pool2d(conv1, [2, 2], scope='pool1')  # 1/4
            bn1 = slim.batch_norm(pool1, decay=0.9, epsilon=1e-5, scope="bn1")
            
            conv2 = slim.repeat(bn1, 2, slim.conv2d, 128, [3, 3], scope='conv2')
            pool2 = slim.max_pool2d(conv2, [2, 2], scope='pool2')  # 1/8
            bn2 = slim.batch_norm(pool2, decay=0.9, epsilon=1e-5, scope="bn2")
            
            conv3 = slim.repeat(bn2, 2, slim.conv2d, 256, [3, 3], scope='conv3')
            pool3 = slim.max_pool2d(conv3, [2, 2], scope='pool3')  # 1/16
            bn3 = slim.batch_norm(pool3, decay=0.9, epsilon=1e-5, scope="bn3")
            
            conv4 = slim.repeat(bn3, 2, slim.conv2d, 512, [3, 3], scope='conv4')
            pool4 = slim.max_pool2d(conv4, [2, 2], scope='pool4')  # 1/32
            bn4 = slim.batch_norm(pool4, decay=0.9, epsilon=1e-5, scope="bn4")
            
            # upsampling
            conv_t1 = slim.conv2d_transpose(bn4, 256, [2,2], scope='conv_t1') # up to 1/16 + conv3
            merge1 = tf.concat([conv_t1, conv3], 3)
            conv5 = slim.stack(merge1, slim.conv2d, [(512, [3, 3]),(256, [3,3])], scope='conv5')
            bn5 = slim.batch_norm(conv5, decay=0.9, epsilon=1e-5, scope='bn5')
            
            conv_t2 = slim.conv2d_transpose(bn5, 128, [2,2], scope='conv_t2') # up to 1/8 + conv2
            merge2 = tf.concat([conv_t2, conv2], 3)
            conv6 = slim.stack(merge2, slim.conv2d, [(256, [3,3]), (128, [3,3])], scope='conv6')
            bn6 = slim.batch_norm(conv6, decay=0.9, epsilon=1e-5, scope='bn6')
            
            conv_t3 = slim.conv2d_transpose(bn6, 64, [2,2], scope='conv_t3') # up to 1/4 + conv1
            merge3 = tf.concat([conv_t3, conv1], 3)
            conv7 = slim.stack(merge3, slim.conv2d, [(128, [3,3]), (64, [3,3])], scope='conv7')
            bn7 = slim.batch_norm(conv7, decay=0.9, epsilon=1e-5, scope='bn7')
            
            conv_t4 = slim.conv2d_transpose(bn7, 32, [2,2], scope='convt4')  # up to 1/2 + conv0
            merge4 = tf.concat([conv_t4, conv0], 3)
            conv8 = slim.stack(merge4, slim.conv2d, [(64, [3,3]), (32, [3,3])], scope='conv8')
            bn8 = slim.batch_norm(conv7, decay=0.9, epsilon=1e-5, scope='bn8')
            
            # output layer scoreMap
            conv9 = slim.conv2d(bn7, 1, [1,1], scope='scoreMap') # 2 CLASSES_NUM
            annotation_pred = tf.argmax(conv9, dimension=3, name='prediction')
            return annotation_pred, conv9
            

In [None]:
def read_tfrecord_and_decode_into_image_annotation_pair_tensors(tf_filenames_queue):
    """Return image/annotation tensors that are created by reading tfrecord file.
    The function accepts tfrecord filenames queue as an input which is usually
    can be created using tf.train.string_input_producer() where filename
    is specified with desired number of epochs. This function takes queue
    produced by aforemention tf.train.string_input_producer() and defines
    tensors converted from raw binary representations into
    reshaped image/annotation tensors.
    Parameters
    ----------
    tfrecord_filenames_queue : tfrecord filename queue
        String queue object from tf.train.string_input_producer()
    
    Returns
    -------
    image, annotation : tuple of tf.int32 (image, annotation)
        Tuple of image/annotation tensors
    """
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(tf_filenames_queue)
    
    features = tf.parse_single_example(serialized_example, 
                                      features={
                                          'height':tf.FixedLenFeature([], tf.int64),
                                          'width':tf.FixedLenFeature([], tf.int64),
                                          'image_raw':tf.FixedLenFeature([], tf.string),
                                          'mask_raw':tf.FixedLenFeature([], tf.string)
                                      })
    
    image = tf.decode_raw(features['image_raw'], tf.uint8)
    annotation = tf.decode_raw(features['mask_raw'], tf.uint8)
    
    height = tf.cast(features['height'], tf.int32)
    width = tf.cast(features['width'], tf.int32)
    
    image_shape = tf.stack([height, width, 3])
    annotation_shape = tf.stack([height, width, 1])
    
    image = tf.reshape(image, image_shape)
    annotation = tf.reshape(annotation, annotation_shape)
    
    return image, annotation

In [None]:
image_train_size = [384, 384]
number_of_classes = 2
pascal_voc_lut = pascal_segmentation_lut(number_of_classes)
class_labels = pascal_voc_lut.keys()
tfrecord_filename = '/home/shou/network/dataset/pascal_augmented_train_island.tfrecords'
filename_queue = tf.train.string_input_producer([tfrecord_filename], num_epochs=5)

image, annotation = read_tfrecord_and_decode_into_image_annotation_pair_tensors(filename_queue)
image, annotation = flip_randomly_left_right_image_with_annotation(image, annotation)
resized_image, resized_annotation = scale_randomly_image_with_annotation_with_fixed_size_output(image, annotation, image_train_size)
resized_annotation = tf.squeeze(resized_annotation)

image_batch, annotation_batch = tf.train.shuffle_batch( [resized_image, resized_annotation],
                                             batch_size=1,
                                             capacity=3000,
                                             num_threads=2,
                                             min_after_dequeue=1000)


In [None]:
image_for_train = tf.placeholder(tf.float32, shape=[486, 384, 384, 3], name="input_image")
annotation_for_train = tf.placeholder(tf.float32, shape=[486, 384, 384, 1], name="annotation")

In [None]:
pred_annotation, logits = u_net(image_for_train)

In [None]:
cross_entropies =tf.losses.sparse_softmax_cross_entropy(logits=logits,
                                                           labels=tf.squeeze(annotation_batch)
                                                           )
cross_entropy_sum = tf.reduce_mean(cross_entropies)

In [None]:
with tf.variable_scope("adam_vars"):
    train_step = tf.train.AdamOptimizer(learning_rate=1e-5).minimize(cross_entropies)

global_vars_init_op = tf.global_variables_initializer()
local_vars_init_op = tf.local_variables_initializer()
combined_op = tf.group(global_vars_init_op, local_vars_init_op)

In [None]:
with tf.Session() as sess:
    image_batch_float = sess.run(image_batch)
    image_annotation_float = sess.run(annotation_batch)
    feed_dict1 = {image_for_train:image_batch_float, annotation_for_train:image_annotation_float}
    sess.run(combined_op, feed_dict=feed_dict1)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    
    for i in xrange(5*486):
        cross_entropy, _ = sess.run([cross_entropy_sum, train_step], feed_dict=feed_dict1)
        print(str(i) + " Current loss: " + str(cross_entropy))
    
    coord.request_stop()
    coord.join(threads)