In [None]:
import numpy as np
import matplotlib.pylab as plt
%matplotlib inline
import tensorflow as tf
import datetime
from functools import partial
import scipy.io

In [None]:
py_conv_scale = 3
py_num_blocks = 4
py_num_conv_in_block = 2

py_input_image_channels = 1
py_hidden_channels = 128

py_num_classes = 9
py_loss_alpha = 1e-4

py_learning_rate = 0.01
py_decay_rate = 0.95
py_decay_step = 30

py_momentum = 0.1


In [None]:
def nan_helper(y):
    """Helper to handle indices and logical indices of NaNs.

    Input:
        - y, 1d numpy array with possible NaNs
    Output:
        - nans, logical indices of NaNs
        - index, a function, with signature indices= index(logical_indices),
          to convert logical indices of NaNs to 'equivalent' indices
    Example:
        >>> # linear interpolation of NaNs
        >>> nans, x= nan_helper(y)
        >>> y[nans]= np.interp(x(nans), x(~nans), y[~nans])
    """

    return np.isnan(y), lambda z: z.nonzero()[0]

In [None]:
def get_random_training():
    training_case = scipy.io.loadmat("data/2015_BOE_Chiu/Subject_%02d.mat" % (np.random.randint(1,11)))
    annotated = [y for y in range(training_case['manualLayers1'].shape[2]) if not np.all(np.isnan(training_case['manualLayers1'][:,:,y]))]
    
    y = np.random.choice(annotated)

    layers = training_case['manualLayers1'][:,:,y].T
    layers[layers <= 10] = np.nan
    layers[np.isinf(layers)] = np.nan
    layers[layers >= 1000] = np.nan
    
    has_layer_seg = np.where(~np.isnan(np.sum(layers, axis=1)))[0]
    x_min, x_max = has_layer_seg[0], has_layer_seg[-1]

    layers = training_case['manualLayers1'][:,x_min:x_max,y].T
    layers[layers <= 0] = np.nan    
    layers[layers >= 1000] = np.nan
        
    for li in range(layers.shape[1]):
        nans, x= nan_helper(layers[:,li])
        layers[nans,li] = np.interp(x(nans), x(~nans), layers[~nans,li])
    layers = layers.astype(np.int) 
    
    z_min, z_max = np.min(layers), np.max(layers)
    z_min -= 20
    z_max += 20
    layers -= z_min
    
    
    img = training_case['images'][z_min:z_max,x_min:x_max,y]
    labels = np.zeros(img.shape + (py_num_classes,))

    for x in range(img.shape[1]):
        labels[:layers[x,0],x,0] = 1    
        labels[layers[x,-1]:,x,py_num_classes-1] = 1    

    for i in range(0,py_num_classes-2):
        for x in range(img.shape[1]):
            labels[layers[x,i]:layers[x,i+1],x,i+1] = 1

    return img, labels, layers

for e in range(10):
    img, labels, layers = get_random_training()
    plt.figure(figsize=(20,2))
    for i in range(py_num_classes + 1):
        plt.subplot(1,py_num_classes + 1,i+1)
        if i == 0:
            plt.imshow(img[:,:], cmap=plt.cm.gray)
            plt.plot(layers[:,:]);
        else:            
            plt.imshow(labels[:,:,i-1])    
        plt.axis('off')

In [None]:
def res_and_dist(tf_data, tf_skip_data_in, block_name):
    with tf.variable_scope(block_name):        
        for py_conv_in_block in range(py_num_conv_in_block):
            with tf.variable_scope("conv%d" % py_conv_in_block):
                tf_filter = tf.get_variable("filter",
                                            shape=[py_conv_scale, py_conv_scale, py_hidden_channels, py_hidden_channels],
                                            initializer=tf.contrib.layers.xavier_initializer_conv2d())

                tf_data = tf.nn.conv2d(tf_data,
                                       tf_filter,
                                       strides=[1,1,1,1],
                                       padding="SAME")

                if py_conv_in_block == 0:
                    tf.summary.histogram("%s_direct" % block_name, tf_filter[:,:,0*py_hidden_channels/4:1*py_hidden_channels/4,:])
                    tf.summary.histogram("%s_skip" % block_name,   tf_filter[:,:,1*py_hidden_channels/4:2*py_hidden_channels/4,:])
                    tf.summary.histogram("%s_up" % block_name,     tf_filter[:,:,2*py_hidden_channels/4:3*py_hidden_channels/4,:])
                    tf.summary.histogram("%s_down" % block_name,   tf_filter[:,:,3*py_hidden_channels/4:4*py_hidden_channels/4,:])                                        
                    
                tf_data = tf.contrib.layers.batch_norm(tf_data,
                                   is_training=True,
                                   decay=0.999,
                                   center=True,
                                   scale=True,
                                   activation_fn=tf.nn.relu,
                                   updates_collections=None)

        with tf.variable_scope("bottleneck"):
            tf_filter = tf.get_variable("filter",
                                        shape=[py_conv_scale, py_conv_scale, py_hidden_channels, py_hidden_channels/4],
                                        initializer=tf.contrib.layers.xavier_initializer_conv2d())

            tf_data = tf.nn.conv2d(tf_data,
                                   tf_filter,
                                   strides=[1,1,1,1],
                                   padding="SAME")

            tf_data = tf.contrib.layers.batch_norm(tf_data,
                               is_training=True,
                               decay=0.999,
                               center=True,
                               scale=True,
                               activation_fn=tf.nn.relu,
                               updates_collections=None)
            tf_skip_data_out = tf_data        

        with tf.variable_scope("dist"):
            def distance_func(prev_dist, current, alpha):
                return tf.maximum(current, 
                                  tf.subtract(prev_dist, 
                                              tf.multiply(
                                                  tf.subtract(1.0, current), 
                                                  alpha) ))        

            tf_data_sq = tf.squeeze(tf_data, axis=0)

            with tf.variable_scope("up"):            
                tf_alpha = tf.Variable(1.0, trainable=True, name="alpha")

                tf_init = tf.slice(tf_data_sq, [0,0,0], [1,-1,-1])

                df = partial(distance_func, alpha=tf_alpha)
                tf_data_up = tf.scan(df, tf_data_sq, initializer=tf_init, back_prop=True)
                tf_data_up = tf.squeeze(tf_data_up, axis=1)
                tf_data_up = tf.expand_dims(tf_data_up, axis=0)

            with tf.variable_scope("down"):
                tf_alpha = tf.Variable(1.0, trainable=True, name="alpha")

                tf_data_reversed = tf.reverse(tf_data_sq, [0])

                tf_init = tf.slice(tf_data_reversed, [0,0,0], [1,-1,-1])

                df = partial(distance_func, alpha=tf_alpha)
                tf_data_down = tf.scan(df, tf_data_reversed, initializer=tf_init, back_prop=True)
                tf_data_down = tf.reverse(tf_data_down, [0])
                tf_data_down = tf.squeeze(tf_data_down, axis=1)
                tf_data_down = tf.expand_dims(tf_data_down, axis=0)

        with tf.variable_scope("combined"):
            tf_data = tf.concat([tf_skip_data_out, tf_skip_data_in, tf_data_up, tf_data_down], axis=3)        

        return tf_data, tf_skip_data_out


In [None]:
tf_graph = tf.Graph()

with tf_graph.as_default():
    tf_global_step = tf.Variable(0, trainable=False, name="global_step")

    tf_image  = tf.placeholder(tf.float32, shape=[1, None, None, py_input_image_channels], name="input_image")
    tf_labels  = tf.placeholder(tf.int64, shape=[1, None, None, py_num_classes], name="input_labels")
    tf_gt_layers  = tf.placeholder(tf.int64, shape=[1, None, py_num_classes - 1], name="input_layers")
    
    with tf.variable_scope("input_conv"):
        tf_filter = tf.get_variable("filter",
                                    shape=[py_conv_scale, py_conv_scale, py_input_image_channels, py_hidden_channels * 5 / 4],
                                    initializer=tf.contrib.layers.xavier_initializer_conv2d())

        tf_data_in = tf.nn.conv2d(tf_image,
                               tf_filter,
                               strides=[1,1,1,1],
                               padding="SAME")

        tf_data_in = tf.contrib.layers.batch_norm(tf_data_in,
                           is_training=True,
                           decay=0.999,
                           center=True,
                           scale=True,
                           activation_fn=tf.nn.relu,
                           updates_collections=None)
            
    tf_skip_data = tf.slice(tf_data_in, [0,0,0,0], [1,-1,-1,py_hidden_channels/4])
    tf_data = tf.slice(tf_data_in, [0,0,0,py_hidden_channels/4], [1,-1,-1,py_hidden_channels])
    
    for j in range(py_num_blocks):
        tf_data, tf_skip_data = res_and_dist(tf_data, tf_skip_data, "block%02d" % j)    
        
    with tf.variable_scope("output_conv"):
        tf_filter = tf.get_variable("filter",
                                    shape=[py_conv_scale, py_conv_scale, py_hidden_channels, py_num_classes],
                                    initializer=tf.contrib.layers.xavier_initializer_conv2d())

        tf_data = tf.nn.conv2d(tf_data,
                               tf_filter,
                               strides=[1,1,1,1],
                               padding="SAME")

        tf_pred = tf.nn.softmax(tf_data)
    
    with tf.variable_scope("layers"):
        py_layers = []
        for li in range(py_num_classes -1):
            with tf.variable_scope("layer%02d" % li):
                
                tf_fg_slice = tf.slice(tf_pred, [0,0,0,0], [1,-1,-1,li])
                tf_fg_slice = tf.reduce_sum(tf_fg_slice, axis=3)
                tf_fg_slice = tf.cumsum(tf_fg_slice, axis=1)
                #tf_fg_slice = tf.divide(tf_fg_slice,
                #                        tf_fg_slice[0,-1,:],
                #                       name="foreground")
                
                
                tf_bg_slice = tf.slice(tf_pred, [0,0,0,li], [1,-1,-1,-1])
                tf_bg_slice = tf.reduce_sum(tf_bg_slice, axis=3)
                tf_bg_slice = tf.cumsum(tf_bg_slice, axis=1, reverse=True)
                #tf_bg_slice = tf.divide(tf_bg_slice,
                #                        tf_bg_slice[0,0,:],
                #                       name="background")
                
                tf_layer_pos = tf.argmax(tf_fg_slice + tf_bg_slice, axis=1)                
                py_layers.append(tf_layer_pos)
                
        tf_layers = tf.stack(py_layers, axis=2)

    with tf.variable_scope("loss"):
        tf_layer_loss = tf.reduce_mean(tf.square(
                                            tf.cast(tf_layers, tf.float32) - tf.cast(tf_gt_layers, tf.float32), 
                                            name="layers")) * py_loss_alpha
        
        
        #tf_layer_loss = tf.reduce_sum(tf.pow(tf.cast(tf_layers-tf_gt_layers, tf.float32), 2)) / (2*tf.cast(tf.shape(tf_gt_layers), tf.float32)[1])
        #tf_layer_loss = tf.reduce_mean(tf.nn.log_poisson_loss(tf.cast(tf_gt_layers, tf.float32), tf.cast(tf_layers, tf.float32)))

        tf_pixel_loss = tf.losses.softmax_cross_entropy(tf_labels, tf_data)
                
        
        
        
        #tf_pixel_loss = tf.Print(tf_pixel_loss,  ["tf_pixel_loss", tf.shape(tf_pixel_loss), tf_pixel_loss])
        #tf_layer_loss = tf.Print(tf_layer_loss,  ["tf_layer_loss", tf.shape(tf_layer_loss), tf_layer_loss])
        
        tf_loss = tf_pixel_loss + tf_layer_loss
        #tf_loss = tf.Print(tf_loss,  ["tf_loss", tf.shape(tf_loss), tf_loss])
        
        tf.summary.scalar("loss_layer", tf_layer_loss)
        tf.summary.scalar("loss_pixel", tf_pixel_loss)
        tf.summary.scalar("loss_total", tf_loss)
        
        
    with tf.variable_scope("training"):
        tf_learning_rate = tf.train.exponential_decay(learning_rate=py_learning_rate, 
                                            global_step=tf_global_step, 
                                            decay_steps=py_decay_step,  
                                            decay_rate=py_decay_rate, 
                                            staircase=True)
        tf.summary.scalar("training_learningrate", tf_learning_rate)
        tf_optimizer = tf.train.AdagradOptimizer(learning_rate=tf_learning_rate)
        #tf_optimizer = tf.train.GradientDescentOptimizer(learning_rate=tf_learning_rate)
        #tf_optimizer = tf.train.MomentumOptimizer(learning_rate=tf_learning_rate, momentum=py_momentum)        
        tf_train = tf_optimizer.minimize(tf_loss, global_step=tf_global_step)
        
    with tf.variable_scope("util"):
        tf_init = tf.global_variables_initializer()
        tf_summary = tf.summary.merge_all()

In [None]:

sv = tf.train.Supervisor(logdir="first_run_wider", 
                         graph=tf_graph,
                         init_op=tf_init,
                         summary_op=None)



with sv.managed_session() as sess:        
#with tf.Session(graph=tf_graph) as sess:        
    sess.run([tf_init])
    for e in range(5000):
        img, labels, layers = get_random_training()
        if sv.should_stop():
            break        
        for i in range(10):
            if sv.should_stop():
                break
            py_feed_dict = {tf_image: img[np.newaxis,:,:,np.newaxis],
                            tf_labels: labels[np.newaxis,:,:,:],
                            tf_gt_layers: layers[np.newaxis,:,:]}
            (py_summary, py_loss,_, py_pred, py_layers) = sess. run([
                tf_summary, tf_loss, tf_train, tf_pred, tf_layers], 
                                             feed_dict=py_feed_dict)    
            sv.summary_computed(sess, py_summary)
            print  py_loss



In [None]:
plt.figure(figsize=(20,20))        
for i in range(py_num_classes):
    
    plt.subplot(py_num_classes,2,i*2+1)
    plt.imshow(py_pred[0,:,:,i])    
    plt.plot(py_layers[0,:,:]);
    plt.axis('off')
    
    plt.subplot(py_num_classes,2,i*2+2)
    plt.imshow(labels[:,:,i])  
    plt.plot(layers[:,:]);
    plt.axis('off')
