In [None]:
def define_model(images,edgemaps,data_dict):

    """
    Load VGG params from disk without FC layers A
    Add branch layers (with deconv) after each CONV block
    """

    start_time = time.time()

    conv1_1 = conv_layer_vgg(images, "conv1_1")
    conv1_2 = conv_layer_vgg(conv1_1, "conv1_2")
    side_1 = side_layer(conv1_2, "side_1", 1)
    pool1 = max_pool(conv1_2, 'pool1')

    print('Added CONV-BLOCK-1+SIDE-1')

    conv2_1 = conv_layer_vgg(pool1, "conv2_1")
    conv2_2 = conv_layer_vgg(conv2_1, "conv2_2")
    side_2 = side_layer(conv2_2, "side_2", 2)
    pool2 = max_pool(conv2_2, 'pool2')

    print('Added CONV-BLOCK-2+SIDE-2')

    conv3_1 = conv_layer_vgg(pool2, "conv3_1")
    conv3_2 = conv_layer_vgg(conv3_1, "conv3_2")
    conv3_3 = conv_layer_vgg(conv3_2, "conv3_3")
    side_3 = side_layer(conv3_3, "side_3", 4)
    pool3 = max_pool(conv3_3, 'pool3')

    print('Added CONV-BLOCK-3+SIDE-3')

    conv4_1 = conv_layer_vgg(pool3, "conv4_1")
    conv4_2 = conv_layer_vgg(conv4_1, "conv4_2")
    conv4_3 = conv_layer_vgg(conv4_2, "conv4_3")
    side_4 = side_layer(conv4_3, "side_4", 8)
    pool4 = max_pool(conv4_3, 'pool4')

    print('Added CONV-BLOCK-4+SIDE-4')

    conv5_1 = conv_layer_vgg(pool4, "conv5_1")
    conv5_2 = conv_layer_vgg(conv5_1, "conv5_2")
    conv5_3 = conv_layer_vgg(conv5_2, "conv5_3")
    side_5 = side_layer(conv5_3, "side_5", 16)

    print('Added CONV-BLOCK-5+SIDE-5')

    side_outputs = [side_1, side_2, side_3, side_4, side_5]

    w_shape = [1, 1, len(side_outputs), 1]
    fuse = conv_layer(tf.concat(side_outputs, axis=3),
                                w_shape, name='fuse_1', use_bias=False,
                                w_init=tf.constant_initializer(0.2))

    print('Added FUSE layer')

    # complete output maps from side layer and fuse layers
    outputs = side_outputs + [fuse]

    data_dict = None
    print("Build model finished: {:.4f}s".format(time.time() - start_time))

def max_pool(bottom, name):
    return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)

def conv_layer_vgg(bottom, name):
    """
        Adding a conv layer + weight parameters from a dict
    """
    with tf.variable_scope(name):
        filt = get_conv_filter(name)

        conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')

        conv_biases = get_bias(name)
        bias = tf.nn.bias_add(conv, conv_biases)

        relu = tf.nn.relu(bias)
        return relu

def conv_layer(x, W_shape, b_shape=None, name=None,
               padding='SAME', use_bias=True, w_init=None, b_init=None):

    W = weight_variable(W_shape, w_init)
    tf.summary.histogram('weights_{}'.format(name), W)

    if use_bias:
        b = bias_variable([b_shape], b_init)
        tf.summary.histogram('biases_{}'.format(name), b)

    conv = tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding=padding)

    return conv + b if use_bias else conv

def deconv_layer(x, upscale, name, padding='SAME', w_init=None):

    x_shape = tf.shape(x)
    in_shape = x.shape.as_list()

    w_shape = [upscale * 2, upscale * 2, in_shape[-1], 1]
    strides = [1, upscale, upscale, 1]

    W = weight_variable(w_shape, w_init)
    tf.summary.histogram('weights_{}'.format(name), W)

    out_shape = tf.stack([x_shape[0], x_shape[1], x_shape[2], w_shape[2]]) * tf.constant(strides, tf.int32)
    deconv = tf.nn.conv2d_transpose(x, W, out_shape, strides=strides, padding=padding)

    return deconv

def side_layer(self, inputs, name, upscale):
    """
        https://github.com/s9xie/hed/blob/9e74dd710773d8d8a469ad905c76f4a7fa08f945/examples/hed/train_val.prototxt#L122
        1x1 conv followed with Deconvoltion layer to upscale the size of input image sans color
    """
    with tf.variable_scope(name):

        in_shape = inputs.shape.as_list()
        w_shape = [1, 1, in_shape[-1], 1]

        classifier = conv_layer(inputs, w_shape, b_shape=1,
                                     w_init=tf.constant_initializer(),
                                     b_init=tf.constant_initializer(),
                                     name=name + '_reduction')

        classifier = deconv_layer(classifier, upscale=upscale,
                                       name='{}_deconv_{}'.format(name, upscale),
                                       w_init=tf.truncated_normal_initializer(stddev=0.1))

        return classifier

def get_conv_filter(name):
    return tf.constant(data_dict[name][0], name="filter")

def get_bias(name):
    return tf.constant(data_dict[name][1], name="biases")

def weight_variable(shape, initial):

    init = initial(shape)
    return tf.Variable(init)

def bias_variable(shape, initial):

    init = initial(shape)
    return tf.Variable(init)

def setup_testing(session):

    """
        Apply sigmoid non-linearity to side layer ouputs + fuse layer outputs for predictions
    """

    self.predictions = []

    for idx, b in enumerate(self.outputs):
        output = tf.nn.sigmoid(b, name='output_{}'.format(idx))
        self.predictions.append(output)

def setup_training(self, session):

    """
        Apply sigmoid non-linearity to side layer ouputs + fuse layer outputs
        Compute total loss := side_layer_loss + fuse_layer_loss
        Compute predicted edge maps from fuse layer as pseudo performance metric to track
    """

    self.predictions = []
    self.loss = 0

    self.io.print_warning('Deep supervision application set to {}'.format(self.cfgs['deep_supervision']))

    for idx, b in enumerate(self.side_outputs):
        output = tf.nn.sigmoid(b, name='output_{}'.format(idx))
        cost = sigmoid_cross_entropy_balanced(b, self.edgemaps, name='cross_entropy{}'.format(idx))

        self.predictions.append(output)
        if self.cfgs['deep_supervision']:
            self.loss += (self.cfgs['loss_weights'] * cost)

    fuse_output = tf.nn.sigmoid(self.fuse, name='fuse')
    fuse_cost = sigmoid_cross_entropy_balanced(self.fuse, self.edgemaps, name='cross_entropy_fuse')

    self.predictions.append(fuse_output)
    self.loss += (self.cfgs['loss_weights'] * fuse_cost)

    pred = tf.cast(tf.greater(fuse_output, 0.5), tf.int32, name='predictions')
    error = tf.cast(tf.not_equal(pred, tf.cast(self.edgemaps, tf.int32)), tf.float32)
    self.error = tf.reduce_mean(error, name='pixel_error')

    tf.summary.scalar('loss', self.loss)
    tf.summary.scalar('error', self.error)

    self.merged_summary = tf.summary.merge_all()

    self.train_writer = tf.summary.FileWriter(self.cfgs['save_dir'] + '/train', session.graph)
    self.val_writer = tf.summary.FileWriter(self.cfgs['save_dir'] + '/val')