### Import statements

In [1]:
import tensorflow as tf
import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()

  from ._conv import register_converters as _register_converters


Instructions for updating:
Use the retry module or similar alternatives.


### Utility functions

In [2]:
def get_conv2d(inputs, filters, name, kernel=3, strides=1, reuse=None, activation=tf.nn.leaky_relu):
    # Return a Tensorflow conv2d layer customized with the specified parameters
    return tf.layers.conv2d(inputs=inputs,
                            filters=filters,
                            kernel_size=kernel,
                            strides=strides,
                            padding='same',
                            reuse=reuse,
                            activation=activation,
                            name=name)

In [3]:
def get_dense(inputs, units, name):
    # Return a Tensorflow dense layer customized with the specified parameters
    return tf.layers.dense(inputs=inputs,
                           units=units,
                           activation=tf.nn.leaky_relu,
                           name=name)

In [4]:
def get_upsample(layer):
    # Return a Tensorflow resize_images layer
    [batch_size, height, width, channels] = layer.get_shape().as_list()
    return tf.image.resize_images(layer,[height * 2, width * 2])

### Generator model

In [5]:
tf.reset_default_graph()

h, w = 256, 256
# img = tf.placeholder(dtype=tf.float32, shape=[None, h, w, 1])
batch_size = 2
img = tf.contrib.eager.Variable(tf.random_uniform(shape=[batch_size, h, w, 1]))
img_scaled = tf.image.resize_images(img, [224, 224])

# -----------------------------LOW-LEVEL FEATURES: FIXED-----------------------------
fconv2d_11 = get_conv2d(inputs=img_scaled, filters=64, strides=2, reuse=None, name='l_conv2d_11')
fconv2d_12 = get_conv2d(inputs=fconv2d_11, filters=128, name='l_conv2d_12')

fconv2d_21 = get_conv2d(inputs=fconv2d_12, filters=128, strides=2, reuse=None, name='l_conv2d_21')
fconv2d_22 = get_conv2d(inputs=fconv2d_21, filters=256, name='l_conv2d_22')

fconv2d_31 = get_conv2d(inputs=fconv2d_22, filters=256, strides=2, reuse=None, name='l_conv2d_31')
fconv2d_32 = get_conv2d(inputs=fconv2d_31, filters=512, name='l_conv2d_32')

# -----------------------------GLOBAL FEATURE NETWORK-----------------------------
gconv2d_1 = get_conv2d(inputs=fconv2d_32, filters=512, strides=2, reuse=None, name='g_conv2d_1')
gconv2d_2 = get_conv2d(inputs=gconv2d_1, filters=512, name='g_conv2d_2')
gconv2d_3 = get_conv2d(inputs=gconv2d_2, filters=512, strides=2, reuse=None, name='g_conv2d_3')
gconv2d_4 = get_conv2d(inputs=gconv2d_3, filters=512, name='g_conv2d_4')
g_flatten = tf.contrib.layers.flatten(gconv2d_4)
gfc_1 = get_dense(inputs=g_flatten, units=1024, name='g_fc_1')
gfc_2 = get_dense(inputs=gfc_1, units=512, name='g_fc_2')
gfc_3 = get_dense(inputs=gfc_2, units=256, name='g_fc_3')
gfc_reshape = tf.reshape(gfc_3, [batch_size, 1, 1, 256])

# -----------------------------LOW-LEVEL FEATURES: VARIABLE-----------------------------
vconv2d_11 = get_conv2d(inputs=img, filters=64, strides=2, reuse=True, name='l_conv2d_11')
vconv2d_12 = get_conv2d(inputs=vconv2d_11, filters=128, reuse=True, name='l_conv2d_12')

vconv2d_21 = get_conv2d(inputs=vconv2d_12, filters=128, strides=2, reuse=True, name='l_conv2d_21')
vconv2d_22 = get_conv2d(inputs=vconv2d_21, filters=256, reuse=True, name='l_conv2d_22')

vconv2d_31 = get_conv2d(inputs=vconv2d_22, filters=256, strides=2, reuse=True, name='l_conv2d_31')
vconv2d_32 = get_conv2d(inputs=vconv2d_31, filters=512, reuse=True, name='l_conv2d_32')

# -----------------------------MID-LEVEL FEATURE NETWORK-----------------------------
mconv2d_1 = get_conv2d(inputs=vconv2d_32, filters=512, name='m_conv2d_1')
mconv2d_2 = get_conv2d(inputs=mconv2d_1, filters=256, name='m_conv2d_2')

# -----------------------------FUSION NETWORK-----------------------------
fu_tiled = tf.tile(gfc_reshape, multiples=[1, int(h/8), int(w/8), 1])
fu_concat = tf.concat([mconv2d_2, fu_tiled], axis=3, name='f_concat')
fu_fusion = tf.layers.conv2d(fu_concat, filters=256, kernel_size=1, strides=1, activation=tf.sigmoid, name='fu_fusion')

# -----------------------------COLORIZATION NETWORK-----------------------------

cconv2d_1 = get_conv2d(inputs=fu_fusion, filters=128, name='c_conv2d_1')
c_upsample_1 = get_upsample(layer=cconv2d_1)
cconv2d_2 = get_conv2d(inputs=c_upsample_1, filters=64, name='c_conv2d_2')
cconv2d_3 = get_conv2d(inputs=cconv2d_2, filters=64, name='c_conv2d_3')
c_upsample_2 = get_upsample(layer=cconv2d_3)
cconv2d_4 = get_conv2d(inputs=c_upsample_2, filters=32, name='c_conv2d_4')
c_output = get_conv2d(inputs=cconv2d_4, filters=2, activation=tf.sigmoid, name='c_output')
c_upsample_o = get_upsample(layer=c_output)

In [6]:
# with tf.Session() as sess:
#     tf.global_variables_initializer().run()
#     file_writer = tf.summary.FileWriter('logs', sess.graph)