快速图像风格迁移，实现原理在图像迁移的算法分析和提高部分做了简要介绍。

简单来讲就是将训练网络对raw picture所做的变化过程（或者说函数，或者是非线性变换方法）由一个transfer net来代替。

相当于用这个transfer net来记录了进行某一个特定style transfer所做的所有关键变换方法。

而在 infer 时只需要使用这个 transfer net ，而不用再次训练 整个网络。

模型分为两个部分：  
- transfer net
- loss train net

另一个不同点在于：训练 transfer net 需要大量的 content picture，才能拟合出合适的参数。

### Code

reference：
1. https://github.com/lengstrom/fast-style-transfer
2. https://github.com/hzy46/fast-neural-style-tensorflow

dataset：http://cocodataset.org/#download   2014 Train images [83K/13GB]

In [1]:
import tensorflow as tf
import numpy as np
import cv2
import scipy.io
import os
import glob
import matplotlib.pyplot as plt
%matplotlib inline

from imageio import imread, imsave
from tqdm import tqdm_notebook

#### preprocess

In [3]:
def resize_and_crop(image, image_size):
    """picture处理"""
    h = image.shape[0]
    w = image.shape[1]
    
    if h > w:
        image = image[h // 2 - w // 2: h // 2 + w // 2, :, :]
    else:
        image = image[:, w // 2 - h // 2: w // 2 + h // 2, :]    
    
    image = cv2.resize(image, (image_size, image_size))
    return image

In [None]:
# 处理 content picture
X_data = []
image_size = 256
paths = glob.glob('train2014/*.jpg')

for i in tqdm_notebook(range(len(paths))):
    path = paths[i]
    image = imread(path)
    if len(image.shape) < 3:  # 黑白图片，删除
        continue
    X_data.append(resize_and_crop(image, image_size))

X_data = np.array(X_data)
print(X_data.shape)

#### feature map extraction VGG19

Details of the VGG19 model:
	- 0 is conv1_1 (3, 3, 3, 64)
	- 1 is relu
	- 2 is conv1_2 (3, 3, 64, 64)
	- 3 is relu    
	- 4 is maxpool
	- 5 is conv2_1 (3, 3, 64, 128)
	- 6 is relu
	- 7 is conv2_2 (3, 3, 128, 128)
	- 8 is relu
	- 9 is maxpool
	- 10 is conv3_1 (3, 3, 128, 256)
	- 11 is relu
	- 12 is conv3_2 (3, 3, 256, 256)
	- 13 is relu
	- 14 is conv3_3 (3, 3, 256, 256)
	- 15 is relu
	- 16 is conv3_4 (3, 3, 256, 256)
	- 17 is relu
	- 18 is maxpool
	- 19 is conv4_1 (3, 3, 256, 512)
	- 20 is relu
	- 21 is conv4_2 (3, 3, 512, 512)
	- 22 is relu
	- 23 is conv4_3 (3, 3, 512, 512)
	- 24 is relu
	- 25 is conv4_4 (3, 3, 512, 512)
	- 26 is relu
	- 27 is maxpool
	- 28 is conv5_1 (3, 3, 512, 512)
	- 29 is relu
	- 30 is conv5_2 (3, 3, 512, 512)
	- 31 is relu
	- 32 is conv5_3 (3, 3, 512, 512)
	- 33 is relu
	- 34 is conv5_4 (3, 3, 512, 512)
	- 35 is relu
	- 36 is maxpool
	- 37 is fullyconnected (7, 7, 512, 4096)
	- 38 is relu
	- 39 is fullyconnected (1, 1, 4096, 4096)
	- 40 is relu
	- 41 is fullyconnected (1, 1, 4096, 1000)
	- 42 is softmax

In [4]:
# 预训练model参数
# http://www.vlfeat.org/matconvnet/pretrained/
vgg = scipy.io.loadmat('imagenet-vgg-verydeep-19.mat')
vgg_layers = vgg['layers']

def vgg_endpoints(inputs, reuse=None):
    """定义出vgg19网络，加载预训练参数，计算图像特征

    return: graph--网络及其参数，dict
    """
    with tf.variable_scope('endpoints', reuse=reuse):

        def _weights(layer, expected_layer_name):
            W = vgg_layers[0][layer][0][0][2][0][0]
            b = vgg_layers[0][layer][0][0][2][0][1]
            layer_name = vgg_layers[0][layer][0][0][0][0]
            assert layer_name == expected_layer_name
            return W, b

        def _conv2d_relu(prev_layer, layer, layer_name):
            W, b = _weights(layer, layer_name)
            W = tf.constant(W)
            b = tf.constant(np.reshape(b, (b.size)))
            return tf.nn.relu(tf.nn.conv2d(
                    prev_layer, filter=W, strides=[1, 1, 1, 1], padding='SAME') + b)

        def _avgpool(prev_layer):
            return tf.nn.avg_pool(prev_layer,
                                  ksize=[1, 2, 2, 1],
                                  strides=[1, 2, 2, 1],
                                  padding='SAME')

        graph = {}
        graph['conv1_1'] = _conv2d_relu(inputs, 0, 'conv1_1')
        graph['conv1_2'] = _conv2d_relu(graph['conv1_1'], 2, 'conv1_2')
        graph['avgpool1'] = _avgpool(graph['conv1_2'])
        graph['conv2_1'] = _conv2d_relu(graph['avgpool1'], 5, 'conv2_1')
        graph['conv2_2'] = _conv2d_relu(graph['conv2_1'], 7, 'conv2_2')
        graph['avgpool2'] = _avgpool(graph['conv2_2'])
        graph['conv3_1'] = _conv2d_relu(graph['avgpool2'], 10, 'conv3_1')
        graph['conv3_2'] = _conv2d_relu(graph['conv3_1'], 12, 'conv3_2')
        graph['conv3_3'] = _conv2d_relu(graph['conv3_2'], 14, 'conv3_3')
        graph['conv3_4'] = _conv2d_relu(graph['conv3_3'], 16, 'conv3_4')
        graph['avgpool3'] = _avgpool(graph['conv3_4'])
        graph['conv4_1'] = _conv2d_relu(graph['avgpool3'], 19, 'conv4_1')
        graph['conv4_2'] = _conv2d_relu(graph['conv4_1'], 21, 'conv4_2')
        graph['conv4_3'] = _conv2d_relu(graph['conv4_2'], 23, 'conv4_3')
        graph['conv4_4'] = _conv2d_relu(graph['conv4_3'], 25, 'conv4_4')
        graph['avgpool4'] = _avgpool(graph['conv4_4'])
        graph['conv5_1'] = _conv2d_relu(graph['avgpool4'], 28, 'conv5_1')
        graph['conv5_2'] = _conv2d_relu(graph['conv5_1'], 30, 'conv5_2')
        graph['conv5_3'] = _conv2d_relu(graph['conv5_2'], 32, 'conv5_3')
        graph['conv5_4'] = _conv2d_relu(graph['conv5_3'], 34, 'conv5_4')
        graph['avgpool5'] = _avgpool(graph['conv5_4'])

        return graph

##### style gram matrix

In [6]:
style_images = glob.glob('styles/*.jpg')
# params
style_index = 0  # 图片index
image_size = 256
STYLE_LAYERS = ['conv1_2', 'conv2_2', 'conv3_3', 'conv4_3']  # 提取的layer tensor

def gram_matix(style_index, style_images, image_size, STYLE_LAYERS):
    """计算STYLE_LAYERS中对应层的gram_matix，并返回style_features字典"""
    X_style_data = resize_and_crop(imread(style_images[style_index]), 
                                                       image_size)
    X_style_data = np.expand_dims(X_style_data, 0)
    # print(X_style_data.shape)
    
    # MEAN_VALUES来自google net的大量图片统计
    MEAN_VALUES = np.array([123.68, 116.779, 103.939]).reshape((1, 1, 1, 3))

    X_style = tf.placeholder(dtype=tf.float32, 
                                           shape=X_style_data.shape, 
                                           name='X_style')
    style_endpoints = vgg_endpoints(X_style - MEAN_VALUES)
    
    style_features = {}
    sess = tf.Session()
    for layer_name in STYLE_LAYERS:
        features = sess.run(style_endpoints[layer_name], 
                                        feed_dict={X_style: X_style_data})
        
        # 计算gram matix
        # features.shape[3]：channels of feature map 
        features = np.reshape(features, (-1, features.shape[3]))
        # 不同channel之间求点积，获得相关性
        # features.size：总元素个数
        gram = np.matmul(features.T, features) / features.size
        style_features[layer_name] = gram
        
    return style_features

style_features = gram_matix(style_index, style_images, image_size, STYLE_LAYERS)

#### transfer net

In [10]:
# params
batch_size = 4
MEAN_VALUES = np.array([123.68, 116.779, 103.939]).reshape((1, 1, 1, 3))

X = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 3], name='X')
k_initializer = tf.truncated_normal_initializer(0, 0.1)

In [8]:
def relu(x):
    return tf.nn.relu(x)

def conv2d(inputs, filters, kernel_size, strides):
    p = int(kernel_size / 2)
    # 先填充：在width和height维度，将边缘p个值，对称reflect到填充位置
    h0 = tf.pad(inputs, 
                       [[0, 0], [p, p], [p, p], [0, 0]], 
                       mode='reflect')
    return tf.layers.conv2d(inputs=h0, 
                                          filters=filters, 
                                          kernel_size=kernel_size, 
                                          strides=strides, 
                                          padding='valid', 
                                          kernel_initializer=k_initializer)

def deconv2d(inputs, filters, kernel_size, strides):
    """不适用空洞填充的方法进行逆卷积，而是以下方式手动实现。因为tf的逆卷积函数，
    在本模型中，会使得图像出现明显的网格"""
    shape = tf.shape(inputs)
    height, width = shape[1], shape[2]
    # 先插值增大到2倍大
    h0 = tf.image.resize_images(inputs, 
                                                  [height * strides * 2, width * strides * 2], 
                                                  tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    # 再卷积
    return conv2d(h0, filters, kernel_size, strides)
    
def instance_norm(inputs):
    """每一个图片单独归一化"""
    return tf.contrib.layers.instance_norm(inputs)

def residual(inputs, filters, kernel_size):
    """residual block"""
    h0 = relu(conv2d(inputs, filters, kernel_size, 1))
    h0 = conv2d(h0, filters, kernel_size, 1)
    return tf.add(inputs, h0)

In [11]:
# variable_scope: 可以方便在infer时，提取该部分网络
with tf.variable_scope('transformer', reuse=None):
    # 先pad，后面裁剪，防止边缘效果的效果差
    h0 = tf.pad(X - MEAN_VALUES, 
                       [[0, 0], [10, 10], [10, 10], [0, 0]], 
                       mode='reflect')
    h0 = relu(instance_norm(conv2d(h0, 32, 9, 1)))
    h0 = relu(instance_norm(conv2d(h0, 64, 3, 2)))
    h0 = relu(instance_norm(conv2d(h0, 128, 3, 2)))  # 1/4 size

    for i in range(5):
        h0 = residual(h0, 128, 3)

    h0 = relu(instance_norm(deconv2d(h0, 64, 3, 2)))  
    h0 = relu(instance_norm(deconv2d(h0, 32, 3, 2)))  # 4 size
    h0 = tf.nn.tanh(instance_norm(conv2d(h0, 3, 9, 1)))
    
    # [0, 255]
    h0 = (h0 + 1) / 2 * 255.
    
    # [0, 10, 10, 0]: slice开始位置，[-1, shape[1] - 20, shape[2] - 20, -1]：slice结束位置
    shape = tf.shape(h0)
    g = tf.slice(h0, 
                      [0, 10, 10, 0], 
                      [-1, shape[1] - 20, shape[2] - 20, -1], 
                      name='g')

#### content loss

In [12]:
CONTENT_LAYER = 'conv3_3'

# vgg特征提取
# 原图
content_endpoints = vgg_endpoints(X - MEAN_VALUES, True)
# 生成图
g_endpoints = vgg_endpoints(g - MEAN_VALUES, True)

In [13]:
# 损失计算
def get_content_loss(endpoints_x, endpoints_y, layer_name):
    x = endpoints_x[layer_name]
    y = endpoints_y[layer_name]
    return 2 * tf.nn.l2_loss(x - y) / tf.to_float(tf.size(x))

content_loss = get_content_loss(content_endpoints, 
                                                     g_endpoints, 
                                                     CONTENT_LAYER)

#### style loss

In [14]:
style_loss = []
STYLE_LAYERS = ['conv1_2', 'conv2_2', 'conv3_3', 'conv4_3']  # 提取的layer tensor

for layer_name in STYLE_LAYERS:
    # 生成图片garm matrix计算
    layer = g_endpoints[layer_name]
    shape = tf.shape(layer)
    bs, height, width, channel = shape[0], shape[1], shape[2], shape[3]
    
    features = tf.reshape(layer, (bs, height * width, channel))
    gram = tf.matmul(tf.transpose(features, (0, 2, 1)), features) 
    gram /= tf.to_float(height * width * channel)
    
    # 原风格图片garm matrix
    style_gram = style_features[layer_name]
    
    # loss
    style_loss.append(2 * tf.nn.l2_loss(gram - style_gram) / tf.to_float(tf.size(layer)))

style_loss = tf.reduce_sum(style_loss)

#### 全变差正则

In [15]:
# params
content_weight = 1
style_weight = 250
total_variation_weight = 0.01

In [16]:
def get_total_variation_loss(inputs):
    """相邻位置之差的L2 loss。若相邻位置变化太大，那么损失会很高。使图像平滑"""
    h = inputs[:, :-1, :, :] - inputs[:, 1:, :, :]  # height方向上相邻位置之差
    w = inputs[:, :, :-1, :] - inputs[:, :, 1:, :]  # width方向上相邻位置之差
    return tf.nn.l2_loss(h) / tf.to_float(tf.size(h)) \
                + tf.nn.l2_loss(w) / tf.to_float(tf.size(w)) 

In [17]:
total_variation_loss = get_total_variation_loss(g)

loss = content_weight * content_loss \
           + style_weight * style_loss \
           + total_variation_weight * total_variation_loss

#### train

In [None]:
# params
epochs = 2
X_sample = imread('sample.jpg')

In [18]:
# 学习 transfer net的参数
vars_t = [var for var in tf.trainable_variables() if var.name.startswith('transformer')]

# optimizer
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss, 
                                                                                                            var_list=vars_t)

##### tf.summary

In [20]:
style_name = style_images[style_index]
style_name = style_name[style_name.find('\\') + 1:].rstrip('.jpg')
OUTPUT_DIR = 'samples_%s' % style_name
if not os.path.exists(OUTPUT_DIR):
    os.mkdir(OUTPUT_DIR)

tf.summary.scalar('losses/content_loss', content_loss)
tf.summary.scalar('losses/style_loss', style_loss)
tf.summary.scalar('losses/total_variation_loss', total_variation_loss)
tf.summary.scalar('losses/loss', loss)
tf.summary.scalar('weighted_losses/weighted_content_loss', 
                              content_weight * content_loss)
tf.summary.scalar('weighted_losses/weighted_style_loss', 
                              style_weight * style_loss)
tf.summary.scalar('weighted_losses/weighted_total_variation_loss', 
                              total_variation_weight * total_variation_loss)

tf.summary.image('transformed', g)
tf.summary.image('origin', X)

summary = tf.summary.merge_all()
writer = tf.summary.FileWriter(OUTPUT_DIR)

In [None]:
sess.run(tf.global_variables_initializer())
losses = []

h_sample = X_sample.shape[0]
w_sample = X_sample.shape[1]

for e in range(epochs):
    data_index = np.arange(X_data.shape[0])
    np.random.shuffle(data_index)
    X_data = X_data[data_index]
    
    for i in tqdm_notebook(range(X_data.shape[0] // batch_size)):
        X_batch = X_data[i * batch_size: i * batch_size + batch_size]
        
        ls_, _ = sess.run([loss, optimizer], feed_dict={X: X_batch})
        losses.append(ls_)
        
        if i > 0 and i % 20 == 0:
            writer.add_summary(sess.run(summary,
                                                            feed_dict={X: X_batch}), 
                                               e * X_data.shape[0] // batch_size + i)
            writer.flush()
        
    print('Epoch %d Loss %f' % (e, np.mean(losses)))
    losses = []  # reset losses temp

    gen_img = sess.run(g, feed_dict={X: [X_sample]})[0]
    gen_img = np.clip(gen_img, 0, 255)
    
    result = np.zeros((h_sample, w_sample * 2, 3))
    # 原图
    result[:, :w_sample, :] = X_sample / 255.
    # 生成
    result[:, w_sample:, :] = gen_img[:h_sample, :w_sample, :] / 255.
    
    plt.axis('off')
    plt.imshow(result)
    plt.show()
    
    imsave(os.path.join(OUTPUT_DIR, 'sample_%d.jpg' % e), result)

### Test

In [1]:
import tensorflow as tf
import numpy as np
from imageio import imread, imsave
import os
import time

In [2]:
def the_current_time():
    print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))))

In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# restore model
saver = tf.train.import_meta_graph('fast_style_transfer.meta')

# get tensor
graph = tf.get_default_graph()
X = graph.get_tensor_by_name('X:0')
g = graph.get_tensor_by_name('transformer/g:0')



sess.run(tf.global_variables_initializer())

# ['wave', 'rain', 'starry', 'scream', 'mosaic', 'muse']
style = 'rain'

# load params
model = 'samples_%s' % style
saver.restore(sess, tf.train.latest_checkpoint(model))

content_image = '144924.png'
result_image = '144924_%s.png' % style
X_image = imread(content_image)

the_current_time()

# generate
gen_img = sess.run(g, feed_dict={X: [X_image]})[0]
gen_img = np.clip(gen_img, 0, 255) / 255.

imsave(result_image, gen_img)

the_current_time()