In [1]:
import tensorflow as tf
# import tensorflow.contrib.eager as tfe
# tfe.enable_eager_execution()

In [12]:
def create_conv2d(inputs, filters, name, strides=1, reuse=None, activation=tf.nn.leaky_relu):
    return tf.layers.conv2d(inputs=inputs,
                            filters=filters,
                            kernel_size=3,
                            strides=strides,
                            padding='same',
                            reuse=reuse,
                            activation=activation,
                            name=name)

In [13]:
def create_dense(inputs, units, name):
    return tf.layers.dense(inputs=inputs,
                           units=units,
                           activation=tf.nn.leaky_relu,
                           name=name)

In [14]:
def create_upsample(layer):
    [batch_size, height, width, channels] = layer.get_shape().as_list()
    return tf.image.resize_images(layer,[height * 2, width * 2])

In [15]:
tf.reset_default_graph()

h, w = 128, 128
img = tf.placeholder(dtype=tf.float32, shape=[None, h, w, 1])
# img_scaled = tf.placeholder(dtype=tf.float32, shape=[None, 112, 112, 1])
# batch_size = 2
# img = tf.contrib.eager.Variable(tf.random_uniform(shape=[batch_size, h, w, 1]))
img_scaled = tf.image.resize_images(img, [112, 112])

# -----------------------------LOW-LEVEL FEATURES: FIXED-----------------------------
fconv2d_11 = create_conv2d(inputs=img_scaled, filters=64, strides=2, reuse=None, name='l_conv2d_11')
fconv2d_12 = create_conv2d(inputs=fconv2d_11, filters=128, name='l_conv2d_12')

fconv2d_21 = create_conv2d(inputs=fconv2d_12, filters=128, strides=2, reuse=None, name='l_conv2d_21')
fconv2d_22 = create_conv2d(inputs=fconv2d_21, filters=256, name='l_conv2d_22')

fconv2d_31 = create_conv2d(inputs=fconv2d_22, filters=256, strides=2, reuse=None, name='l_conv2d_31')
fconv2d_32 = create_conv2d(inputs=fconv2d_31, filters=512, name='l_conv2d_32')

# -----------------------------GLOBAL FEATURE NETWORK-----------------------------
gconv2d_1 = create_conv2d(inputs=fconv2d_32, filters=512, strides=2, reuse=None, name='g_conv2d_1')
gconv2d_2 = create_conv2d(inputs=gconv2d_1, filters=512, name='g_conv2d_2')
gconv2d_3 = create_conv2d(inputs=gconv2d_2, filters=512, strides=2, reuse=None, name='g_conv2d_3')
gconv2d_4 = create_conv2d(inputs=gconv2d_3, filters=512, name='g_conv2d_4')

gfc_1 = create_dense(inputs=gconv2d_4, units=1024, name='g_fc_1')
gfc_2 = create_dense(inputs=gfc_1, units=512, name='g_fc_2')
gfc_3 = create_dense(inputs=gfc_2, units=256, name='g_fc_3')

# -----------------------------LOW-LEVEL FEATURES: VARIABLE-----------------------------
vconv2d_11 = create_conv2d(inputs=img, filters=64, strides=2, reuse=True, name='l_conv2d_11')
vconv2d_12 = create_conv2d(inputs=vconv2d_11, filters=128, reuse=True, name='l_conv2d_12')

vconv2d_21 = create_conv2d(inputs=vconv2d_12, filters=128, strides=2, reuse=True, name='l_conv2d_21')
vconv2d_22 = create_conv2d(inputs=vconv2d_21, filters=256, reuse=True, name='l_conv2d_22')

vconv2d_31 = create_conv2d(inputs=vconv2d_22, filters=256, strides=2, reuse=True, name='l_conv2d_31')
vconv2d_32 = create_conv2d(inputs=vconv2d_31, filters=512, reuse=True, name='l_conv2d_32')

# -----------------------------MID-LEVEL FEATURE NETWORK-----------------------------
mconv2d_1 = create_conv2d(inputs=vconv2d_32, filters=512, name='m_conv2d_1')
mconv2d_2 = create_conv2d(inputs=mconv2d_1, filters=256, name='m_conv2d_2')

# -----------------------------COLORIZATION NETWORK-----------------------------
c_fusion = tf.concat([mconv2d_2, tf.tile(gfc_3, multiples=[1, 4, 4, 1])], axis=3, name='c_fusion')
cconv2d_1 = create_conv2d(inputs=c_fusion, filters=128, name='c_conv2d_1')
c_upsample_1 = create_upsample(layer=cconv2d_1)
cconv2d_2 = create_conv2d(inputs=c_upsample_1, filters=64, name='c_conv2d_2')
cconv2d_3 = create_conv2d(inputs=cconv2d_2, filters=64, name='c_conv2d_3')
c_upsample_2 = create_upsample(layer=cconv2d_3)
cconv2d_4 = create_conv2d(inputs=c_upsample_2, filters=32, name='c_conv2d_4')
c_output = create_conv2d(inputs=cconv2d_4, filters=2, activation=tf.sigmoid, name='c_output')
c_upsample_o = create_upsample(layer=c_output)

In [16]:
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    file_writer = tf.summary.FileWriter('logs', sess.graph)