In [1]:
import tensorflow as tf
import numpy as np
import os, sys
import time

In [2]:
def weight_variable(shape, name=None):
    # initialize weighted variables.
    initial = tf.truncated_normal(shape, stddev=0.001)
    return tf.Variable(initial, name=name)

def conv2d(x, W, strides=[1, 1, 1, 1], p='SAME', name=None):
    # set convolution layers.
    assert isinstance(x, tf.Tensor)
    return tf.nn.conv2d(x, W, strides=strides, padding=p, name=name)

def batch_norm(x):
    assert isinstance(x, tf.Tensor)
    # reduce dimension 1, 2, 3, which would produce batch mean and batch variance.
    mean, var = tf.nn.moments(x, axes=[1, 2, 3])
    return tf.nn.batch_normalization(x, mean, var, 0, 1, 1e-5)

def relu(x):
    assert isinstance(x, tf.Tensor)
    return tf.nn.relu(x)

def deconv2d(x, W, strides=[1, 1, 1, 1], p='SAME', name=None):
    assert isinstance(x, tf.Tensor)
    _, _, c, _ = W.get_shape().as_list()
    b, h, w, _ = x.get_shape().as_list()
    return tf.nn.conv2d_transpose(x, W, [b, strides[1] * h, strides[1] * w, c], strides=strides, padding=p, name=name)

def max_pool_2x2(x):
    assert isinstance(x, tf.Tensor)
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')


class ResidualBlock():
    def __init__(self, idx, ksize=3, train=False, data_dict=None):
        self.W1 = weight_variable([ksize, ksize, 128, 128], name='R'+ str(idx) + '_conv1_w')
        self.W2 = weight_variable([ksize, ksize, 128, 128], name='R'+ str(idx) + '_conv2_w')
    def __call__(self, x, idx, strides=[1, 1, 1, 1]):
        h = relu(batch_norm(conv2d(x, self.W1, strides, name='R' + str(idx) + '_conv1')))
        h = batch_norm(conv2d(h, self.W2, name='R' + str(idx) + '_conv2'))
        return x + h


class FastStyleNet():
    def __init__(self, train=True, data_dict=None):
        self.c1 = weight_variable([9, 9, 3, 32], name='t_conv1_w')
        self.c2 = weight_variable([4, 4, 32, 64], name='t_conv2_w')
        self.c3 = weight_variable([4, 4, 64, 128], name='t_conv3_w')
        self.r1 = ResidualBlock(1, train=train)
        self.r2 = ResidualBlock(2, train=train)
        self.r3 = ResidualBlock(3, train=train)
        self.r4 = ResidualBlock(4, train=train)
        self.r5 = ResidualBlock(5, train=train)
        self.d1 = weight_variable([4, 4, 64, 128], name='t_dconv1_w')
        self.d2 = weight_variable([4, 4, 32, 64], name='t_dconv2_w')
        self.d3 = weight_variable([9, 9, 3, 32], name='t_dconv3_w')            
    def __call__(self, h):
        h = batch_norm(relu(conv2d(h, self.c1, name='t_conv1')))
        h = batch_norm(relu(conv2d(h, self.c2, strides=[1, 2, 2, 1], name='t_conv2')))
        h = batch_norm(relu(conv2d(h, self.c3, strides=[1, 2, 2, 1], name='t_conv3')))

        h = self.r1(h, 1)
        h = self.r2(h, 2)
        h = self.r3(h, 3)
        h = self.r4(h, 4)
        h = self.r5(h, 5)

        h = batch_norm(relu(deconv2d(h, self.d1, strides=[1, 2, 2, 1], name='t_deconv1')))
        h = batch_norm(relu(deconv2d(h, self.d2, strides=[1, 2, 2, 1], name='t_deconv2')))
        y = deconv2d(h, self.d3, name='t_deconv3')
        return tf.multiply((tf.tanh(y) + 1), tf.constant(127.5, tf.float32, shape=y.get_shape()), name='output')

In [3]:
net = FastStyleNet()
image = tf.placeholder(tf.float32, shape=[1, 224, 224, 3], name='input')
outputs = net(image)

In [5]:
init_op = tf.global_variables_initializer()
image_init = np.random.rand(1, 224, 224, 3)

with tf.Session() as sess:
    sess.run(init_op)

    feed_dict = {image: image_init}
    
    start_time = time.time()
    sess.run(outputs, feed_dict=feed_dict)
    print("--- %s seconds ---" % (time.time() - start_time))

--- 0.34183812141418457 seconds ---


In [2]:
def instance_norm(x):
    epsilon = 1e-9
    mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)
    return tf.div(tf.subtract(x, mean), tf.sqrt(tf.add(var, epsilon)))

In [3]:
def conv_block(net, N):
    net = tf.layers.conv2d(net, N, 3, padding='same')
    net = instance_norm(net)
    net = tf.nn.relu(net)

    net = tf.layers.conv2d(net, N, 3, padding='same')
    net = instance_norm(net)
    net = tf.nn.relu(net)

    net = tf.layers.conv2d(net, N, 1, padding='same')
    net = instance_norm(net)
    net = tf.nn.relu(net)
    return net

In [4]:
def join_block(lower, higher):
    lower = tf.image.resize_nearest_neighbor(lower, higher.get_shape().as_list()[1:3])
    lower = instance_norm(lower)
    higher = instance_norm(higher)
    higher = tf.concat([lower, higher], axis=3)
    return higher

In [5]:
def transfer_network(content_image_8, 
                     content_image_16,
                     content_image_32,
                     content_image_64,
                     content_image_128) :
    
    conv_block_32_8 = conv_block(content_image_8, 32)
    conv_block_32_16 = conv_block(content_image_16, 32)
    conv_block_32_32 = conv_block(content_image_32, 32)
    conv_block_32_64 = conv_block(content_image_64, 32)
    conv_block_32_128 = conv_block(content_image_128, 32)

    net = join_block(conv_block_32_8, conv_block_32_16)
    net = conv_block(net, 64)
    net = join_block(net, conv_block_32_32)
    net = conv_block(net, 96)
    net = join_block(net, conv_block_32_64)
    net = conv_block(net, 128)
    net = join_block(net, conv_block_32_128)
    net = conv_block(net, 160)

    net = tf.layers.conv2d(net, 3, 1, padding='same')
    return net

In [6]:
content_image_8 = tf.placeholder(tf.float32, shape=[1, 8, 8, 3])
content_image_16 = tf.placeholder(tf.float32, shape=[1, 16, 16, 3])
content_image_32 = tf.placeholder(tf.float32, shape=[1, 32, 32, 3])
content_image_64 = tf.placeholder(tf.float32, shape=[1, 64, 64, 3])
content_image_128 = tf.placeholder(tf.float32, shape=[1, 128, 128, 3])

ts_net_output = transfer_network(content_image_8, content_image_16, content_image_32, content_image_64, content_image_128)

In [7]:
init_op = tf.global_variables_initializer()
content_image_8_init = np.random.rand(1, 8, 8, 3)
content_image_16_init = np.random.rand(1, 16, 16, 3)
content_image_32_init = np.random.rand(1, 32, 32, 3)
content_image_64_init = np.random.rand(1, 64, 64, 3)
content_image_128_init = np.random.rand(1, 128, 128, 3)

with tf.Session() as sess:
    sess.run(init_op)

    feed_dict = {content_image_8: content_image_8_init, content_image_16: content_image_16_init, content_image_32: content_image_32_init, content_image_64: content_image_64_init, content_image_128: content_image_128_init}
    
    start_time = time.time()
    sess.run(ts_net_output, feed_dict=feed_dict)
    print("--- %s seconds ---" % (time.time() - start_time))

--- 0.45665407180786133 seconds ---
