# Supervised Adversarial Networks for Depth Maps Generation
Following the 2017 paper by Pan et al: *Supervised Adversarial Networks for Image Saliency Detection*

In [None]:
import tensorflow as tf
import numpy as np
from scipy.misc import imresize
from matplotlib import pyplot as plt

In [None]:
# Make3D data
from Make3D import train_pairs, test_pairs

# NYU data
# from NYU import nyu_data
# train_pairs, test_pairs = nyu_data()

## Preprocessing

In [None]:
shape = (256, 256)

train_data, train_targets = zip(*train_pairs)
test_data, test_targets = zip(*test_pairs)


def rgb2gray(rgb):
    return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])


def resize(img):
    return imresize(img, shape)


train_data = [resize(rgb2gray(img)) for img in train_data]
train_targets = [resize(img) for img in train_targets]
test_data = [resize(rgb2gray(img)) for img in test_data]
test_targets = [resize(img) for img in test_targets]

train_x, train_t = np.asarray(train_data), np.asarray(train_targets)
test_x, test_t = np.asarray(test_data), np.asarray(test_targets)

print('train input/target shapes', train_data[0].shape, train_targets[0].shape)
print('train input min/max/ptp', np.min(train_data),
      np.max(train_data), np.ptp(train_data))
print('train target min/max/ptp', np.min(train_targets),
      np.max(train_targets), np.ptp(train_targets))

tuples = zip(train_x[:10], train_t[:10])
fig, axis = plt.subplots(5, 2, figsize=(10, 20))
plt.tight_layout(), plt.gray()
for (rgb, d), (ax1, ax2) in zip(tuples, axis):
    ax1.axis('off'), ax2.axis('off')
    ax1.imshow(rgb)
    ax2.imshow(imresize(d, rgb.shape))
# plt.show()

In [None]:
from tensorflow.contrib.framework import get_variables


def lrelu(x, alpha=0.2):
    """Leaky rectifier."""
    return tf.maximum(alpha * x, x)


def conv2d(inputs, num_outputs, kernel_size=(4, 4), strides=2,
           padding='SAME', activation=lrelu, norm=False, training=False):
    """Wrapper for tf.layers.conv2d with default parameters."""
    net = tf.layers.conv2d(inputs, num_outputs, kernel_size, strides,
                           padding=padding)
    if norm:
        net = tf.layers.batch_normalization(net, training=training)
    return activation(net)


def conv2d_transpose(inputs, num_outputs, kernel_size=(4, 4), strides=2,
                     padding='SAME', activation=lrelu, norm=False,
                     training=False):
    """Wrapper for tf.layers.conv2d_transpose with default parameters."""
    net = tf.layers.conv2d_transpose(inputs, num_outputs, kernel_size,
                                     strides, padding=padding)
    if norm:
        net = tf.layers.batch_normalization(net, training=training)
    return activation(net)

In [None]:
def make_discriminator(images, is_training):
    """Discriminator.

    Args:
        images (tf.Tensor[tf.float32]): Input images, either from generator
            or real inputs. Shape (BATCH, 256, 256, (input_c + output_c))
        is_training (tf.Tensor[tf.bool]): True when in training, False 
            when in testing step.

    Returns:
        (tf.Tensor): Sigmoid network output, single scalar in [0, 1].
        (tf.Tensor): Linear network output, single scalar.
        (List[tf.Operation]): Batch normalization update operations.
    """
    with tf.variable_scope('discriminator') as scope:
        net = conv2d(images, 64)
        net = conv2d(net, 128, norm=True, training=is_training)
        net = conv2d(net, 256, norm=True, training=is_training)
        net = conv2d(net, 512, norm=True, training=is_training)
        net = conv2d(net, 1, (1, 1))
        logits = tf.layers.dense(net, 1, tf.identity)
        out = tf.nn.sigmoid(logits)
        ops = get_variables(scope, collection=tf.GraphKeys.UPDATE_OPS)
        return out, logits, ops

In [None]:
def make_generator(images, is_training):
    """Discriminator.

    Args:
        images (tf.Tensor[tf.float32]): Input images, either from generator
            or real inputs. Shape (BATCH, 256, 256, (input_c + output_c))
        is_training (tf.Tensor[tf.bool]): True when in training, False 
            when in testing step.

    Returns:
        (tf.Tensor): Tanh network output, single channel shaped as input.
        (List[tf.Operation]): Batch normalization update operations.

    TODO: Implement skipping between encoder and decoder.
    """

    with tf.variable_scope('generator') as scope:
        with tf.variable_scope('encoder'):
            enc = conv2d(images, 64)
            enc = conv2d(enc, 128, norm=True, training=is_training)
            enc = conv2d(enc, 256, norm=True, training=is_training)
            enc = conv2d(enc, 512, norm=True, training=is_training)
            enc = conv2d(enc, 512, norm=True, training=is_training)
            enc = conv2d(enc, 512, norm=True, training=is_training)
            enc = conv2d(enc, 512, norm=True, training=is_training)
            enc = conv2d(enc, 512, norm=True, training=is_training)
        with tf.variable_scope('decoder'):
            dec = enc
            dec = conv2d_transpose(dec, 512, norm=True, training=is_training)
            dec = tf.layers.dropout(dec, .5)
            dec = conv2d_transpose(dec, 512, norm=True, training=is_training)
            dec = tf.layers.dropout(dec, .5)
            dec = conv2d_transpose(dec, 512, norm=True, training=is_training)
            dec = tf.layers.dropout(dec, .5)
            dec = conv2d_transpose(dec, 512, norm=True, training=is_training)
            dec = conv2d_transpose(dec, 512, norm=True, training=is_training)
            dec = conv2d_transpose(dec, 256, norm=True, training=is_training)
            dec = conv2d_transpose(dec, 128, norm=True, training=is_training)
            dec = conv2d_transpose(dec, 64, norm=True, training=is_training)
            out = conv2d(dec, 1, (1, 1), 1, activation=tf.nn.tanh)
        ops = get_variables(scope, collection=tf.GraphKeys.UPDATE_OPS)
        return out, ops

In [None]:
tf.reset_default_graph()

is_training = tf.placeholder_with_default(False, None)
images = tf.placeholder(tf.uint8, (None, 256, 256), name='inputs')
depthmaps = tf.placeholder(tf.float32, (None, 256, 256), name='targets')


def scale(img):
    return ((tf.cast(img, tf.float32) / 255) - .5) * 2


images_ = tf.reshape(scale(images), (-1, 256, 256, 1))
depthmaps_ = tf.reshape(scale(depthmaps), (-1, 256, 256, 1))

# Create generator (and sampler?)
# remake_generator = tf.make_template('generator', make_generator)
generator, g_ops = make_generator(images_, is_training)
# sampler = remake_generator(images_, is_training)


def soft_labels_like(like, real=False):
    if real:
        return tf.random_uniform(tf.shape(like), 0.7, 1.2, tf.float32)
    else:
        return tf.random_uniform(tf.shape(like), 0., 0.3, tf.float32)


# Create the two discriminator graphs.
real = tf.concat([images_, depthmaps_], axis=-1)
fake = tf.concat([images_, generator], axis=-1)
remake_discriminator = tf.make_template('discriminator', make_discriminator)
d_real, d_logits_real, d_ops = remake_discriminator(real, is_training)
d_fake, d_logits_fake, _ = remake_discriminator(fake, is_training)

# Create discriminator losses.
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_real, labels=soft_labels_like(d_fake, real=True)))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_fake, labels=soft_labels_like(d_fake, real=False)))
d_loss = d_loss_real + d_loss_fake

LAMBDA = 0

# Create generator loss.
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_fake, labels=soft_labels_like(d_fake, real=True))) \
    + LAMBDA * tf.reduce_mean(tf.abs(depthmaps - generator))

with tf.control_dependencies(g_ops):
    g_op = tf.train.AdamOptimizer().minimize(g_loss)

with tf.control_dependencies(d_ops):
    d_op = tf.train.AdamOptimizer().minimize(d_loss)

## Training

In [None]:
def batches(x, y, batchsize=32):
    permute = np.random.permutation(len(train_t))
    for i in range(0, len(train_t) - 1, batchsize):
        indices = permute[i * batchsize:i * batchsize + batchsize]
        yield x[indices], y[indices]

In [None]:
EPOCHS = 10

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

batch = -1
for epoch in range(1, EPOCHS + 1):
    for batch_x, batch_t in batches(train_x, train_t, 32):
        batch += 1
        if batch % 2 == 0:
            sess.run(d_op, 
                     {images: batch_x, depthmaps: batch_t, is_training: True})
        else:
            sess.run(g_op, 
                     {images: batch_x, is_training: True})
        print(batch)

    dloss, gloss = sess.run([d_loss, g_loss],
                            {images: test_x, depthmaps: test_t})

    print('{}: {:.2f}, {:.2f}'.format(i, dloss, gloss))