<a href="https://colab.research.google.com/github/Quan-Sun/Learning_Tensorflow/blob/master/style_transfer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

A tensorflow implementation of style transfer described in the papers **[Image Style Transfer Using Convolutional Neural Networks](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf)**

Most code in this file was borrowed from https://github.com/hwalsuklee/tensorflow-style-transfer

In [0]:
args_content = 'picture.jpeg'  # choose the image that you want transfer
args_style = 'portrait.jpg' # choose the image that you love its style

from google.colab import files
files.upload()     # choose the file- args_content and args_style

In [5]:
import tensorflow as tf 
import numpy as np 
import scipy.io 
from six.moves import urllib
import os

source_url = 'http://www.vlfeat.org/matconvnet/models/imagenet-vgg-verydeep-19.mat'
data_dir = './pre_trained_model'
filename = 'imagenet-vgg-verydeep-19.mat'
def maybe_download(filename):
    if not tf.gfile.Exists(data_dir):
        tf.gfile.MakeDirs(data_dir)
    file_path = os.path.join(data_dir, filename)
    
    if not tf.gfile.Exists(file_path):
        file_path, _ = urllib.request.urlretrieve(source_url, file_path)
        
        with tf.gfile.GFile(file_path) as f:
            size = f.size()
        print('Successfully download', filename, size, 'bytes.')
    return file_path

model_filename = maybe_download(filename)

def _conv_layer(input, weights, bias,padding='SAME'):
    conv = tf.nn.conv2d(input,tf.constant(weights),strides=[1,1,1,1],padding= padding)
    h_conv = conv + bias

    return h_conv

def _pool_layer(input, padding='SAME'):
    h_pool = tf.nn.max_pool(input, ksize=[1,2,2,1], strides=[1,2,2,1],padding= padding)

    return h_pool

def preprocess(image, mean_pixel):
    return image - mean_pixel

def unpreprocess(image, mean_pixel):
    return image + mean_pixel

class VGG19:
    layers = (
        'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
        'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
        'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
        'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
        'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4'
        )

    def __init__(self, model_filename):
        model = scipy.io.loadmat(model_filename)

        self.mean_pixel = np.array([123.68, 116.779, 103.939]) #np.mean(model['normalization'][0][0][0], axis=(0,1)) 

        self.weights = model['layers'][0]

    def preprocess(self, image):
        return np.float32(image - self.mean_pixel)

    def unpreprocess(self, image):
        return np.float32(image + self.mean_pixel)

    def feed_forward(self, input_image, scope=None):
        current_network = {}
        current_layer = input_image

        with tf.variable_scope(scope):
            for num, name in enumerate(self.layers):
                type_layer = name[:4]
                if type_layer == 'conv':
                    kernels = self.weights[num][0][0][2][0][0]
                    bias = self.weights[num][0][0][2][0][1]

                    # vgg19: shape of weights is [width, height, in_channels, out_channels]
                    # tensorflow: shape of weights is [height, width, in_channels, out_channels]

                    kernels = np.transpose(kernels, [1,0,2,3])
                    bias = bias.reshape(-1)
                    current_layer = _conv_layer(current_layer, kernels, bias)

                elif type_layer == 'relu':
                    current_layer = tf.nn.relu(current_layer)

                elif type_layer == 'pool':
                    current_layer = _pool_layer(current_layer)

                current_network[name] = current_layer
        assert len(current_network) == len(self.layers)
        return current_network

Successfully download imagenet-vgg-verydeep-19.mat 534904783 bytes.


In [0]:
import matplotlib.pyplot as plt 
import PIL

def load_image(filename, assign_shape=None, max_size=None):
    image = PIL.Image.open(filename)

    if max_size:
        proportion = float(max_size)/np.max(image.size)
        size = np.array(image.size) * proportion

        # PIL manipulation needs the size to be integers
        size = size.astype(int)

        # Resize the image
        image = image.resize(size, PIL.Image.LANCZOS) # PIL.Image.LANCZOS is a resampling filter

    if assign_shape:
        image = image.resize(assign_shape, PIL.Image.LANCZOS)

    image = np.float32(image)
    return image

# Save images as files of *.jpeg
def save_image(image,filename):
    image = np.clip(image, 0.0, 255.0)

    image = image.astype(np.uint8) # convert float to bytes

    with open(filename, 'wb') as f:
        PIL.Image.fromarray(image).save(f, 'jpeg')


# DRAW the content-, mixed-, style-images
def draw_images(content_image, style_image, mixed_image):
    fig,axes = plt.subplots(1,3,figsize=(10,10))

    fig.subplots_adjust(hspace=0.1,wspace=0.1)

    ax = axes.flat[0]
    ax.imshow(content_image/255.0, interpolation='sinc')
    ax.set_xlabel('Content')
    ax.set_xticks([])
    ax.set_yticks([])

    ax = axes.flat[1]
    ax.imshow(mixed_image/255.0, interpolation='sinc')
    ax.set_xlabel('Output')
    ax.set_xticks([])
    ax.set_yticks([])

    ax = axes.flat[2]
    ax.imshow(style_image/255.0, interpolation='sinc')
    ax.set_xlabel('Style')
    ax.set_xticks([])
    ax.set_yticks([])

    plt.show()

In [0]:
import collections

class StyleTransfer:

    def __init__(self, content_layer, style_layer, init_image, content_image, style_image,
                session, model_selection, num_iter, loss_ratio, content_loss_norm_type):

        self.model_selection = model_selection
        self.sess = session

        self.CONTENT_LAYERS = collections.OrderedDict(sorted(content_layer.items()))
        self.STYLE_LAYERS = collections.OrderedDict(sorted(style_layer.items()))

        # Preprocess
        self.content_image_preprocess = self.model_selection.preprocess(content_image)
        self.style_image_preprocess = self.model_selection.preprocess(style_image)
        self.init_image_preprocess = self.model_selection.preprocess(init_image)

        # Parameters for optimization
        self.content_loss_norm_type = content_loss_norm_type
        self.num_iter = num_iter
        self.loss_ratio = loss_ratio
        self._build_graph()


    def _gram_matrix(self, tensor):
        shape = tensor.get_shape()
        num_channels = int(shape[3])
        matrix = tf.reshape(tensor, shape=[-1,num_channels])
        gram = tf.matmul(tf.transpose(matrix), matrix)
        return gram


    def _build_graph(self):
        self.init_image_variable = tf.Variable(self.init_image_preprocess, trainable=True, dtype=tf.float32)

        self.input_content_image = tf.placeholder(tf.float32, shape=self.content_image_preprocess.shape, name='content')
        self.output_style_image = tf.placeholder(tf.float32, shape=self.style_image_preprocess.shape, name='style')

        content_layers = self.model_selection.feed_forward(self.input_content_image, scope='content')
        self.content_features = {}
        for layer in self.CONTENT_LAYERS:
            self.content_features[layer] = content_layers[layer]

        style_layers = self.model_selection.feed_forward(self.output_style_image, scope='style')
        self.style_features = {}
        for layer in self.STYLE_LAYERS:
            self.style_features[layer] = self._gram_matrix(style_layers[layer])

        self.init_featues = self.model_selection.feed_forward(self.init_image_variable, scope='mixed')

        Loss_content = 0
        Loss_style = 0
        for layer in self.init_featues:
            if layer in self.CONTENT_LAYERS:
                init_featues_value = self.init_featues[layer] 
                content_features_value = self.content_features[layer] 

                _, heighgt, width, num_filters = init_featues_value.get_shape()
                N = heighgt.value * width.value
                M = num_filters.value # number of filters

                W = self.CONTENT_LAYERS[layer]

                if self.content_loss_norm_type==1:
                    Loss_content += w * tf.reduce_sum(tf.pow((init_featues_value - content_features_value),2))/2
                elif self.content_loss_norm_type==2:
                    Loss_content += w * tf.reduce_sum(tf.pow((init_featues_value - content_features_value),2))/(N*M)

                elif self.content_loss_norm_type==3:
                    Loss_content += w * (1. / (2. * np.sqrt(M) * np.sqrt(N))) * tf.reduce_sum(tf.pow((init_featues_value - content_features_value),2))

            elif layer in self.STYLE_LAYERS:
                init_featues_value = self.init_featues[layer]

                _,h,w,d = init_featues_value.get_shape()
                N = h.value * w.value
                M = d.value

                w = self.STYLE_LAYERS[layer]
                G = self._gram_matrix(init_featues_value)
                A = self.style_features[layer]

                Loss_style += w * (1. / (4. * N ** 2 * M ** 2)) * tf.reduce_sum(tf.pow((G-A),2))

        alpha = self.loss_ratio
        beta = 1

        self.Loss_content = Loss_content
        self.Loss_style = Loss_style
        self.Loss_total = alpha*Loss_content + beta*Loss_style


    def optimize(self):
        # define optimizer L-BFGS
        global iteration
        iteration = 0
        def callback(total_loss, content_loss, style_loss):
            global iteration
            print('iteration: %4d, '%iteration, 'Loss_total: %g, Loss_content: %g, Loss_style: %g' % (total_loss, content_loss, style_loss))
            iteration += 1

        optimizer = tf.contrib.opt.ScipyOptimizerInterface(self.Loss_total, method='L-BFGS-B', options={'maxiter':self.num_iter})

        self.sess.run(tf.global_variables_initializer())

        optimizer.minimize(self.sess, feed_dict={self.output_style_image:self.style_image_preprocess, self.input_content_image:self.content_image_preprocess},
            fetches=[self.Loss_total, self.Loss_content, self.Loss_style], loss_callback=callback)

        final_image = self.sess.run(self.init_image_variable)
        final_image = np.clip(self.model_selection.unpreprocess(final_image), 0.0, 255.0)

        return final_image

In [8]:

args_output = 'mixed_image.jpg'
args_loss_ratio = 1e-3
args_content_layers = ['conv4_2']
args_style_layers = ['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1']
args_content_layer_weights = [1.0]
args_style_layer_weights = [.2,.2,.2,.2,.2]
args_initial_type = 'content' # choices=['random','content','style']
args_max_size = 1024
args_content_loss_norm_type = 3 #choices=[1,2,3]
args_num_iter = 1000



try:
    assert len(args_content_layers) == len(args_content_layer_weights)
except:
    print('Content layers info and weights info must be matched')

try:
    assert len(args_style_layers) == len(args_style_layer_weights)
except:
    print('Style layers info and weight info must be matched')


try:
    assert args_max_size>100

except:
    print('Too small size')


model_file_path = model_filename

try:
    assert os.path.exists(model_file_path)
except:
    print('There is no %s' % model_file_path)


try :
    size_in_KB = os.path.getsize(model_file_path)
    assert abs(size_in_KB - 534904783) < 10
except:
    print("Check file size of 'imagenet-vgg-verydeep-19.mat' ")
    print('There are some files with the same name')
    print('pre_trained_model used here can be download from below')
    print('http://www.vlfeat.org/matconvnet/models/imagenet-vgg-verydeep-19.mat')


try:
    assert os.path.exists(args_content)
except:
    print('There is no %s'%args_content)


try:
    assert os.path.exists(args_style)
except:
    print('There is no %s'%args_style)


# VGG19 requires input dimension to be [batch, height, width, channel]


def add_one_dim(image):
    shape = (1,) + image.shape
    return np.reshape(image,shape)

model_file_path = model_filename
vgg_net = VGG19(model_file_path)

content_image = load_image(args_content, max_size=args_max_size)
style_image = load_image(args_style, assign_shape=[content_image.shape[1],content_image.shape[0]])

if args_initial_type == 'content':
    initial_image = content_image
elif args_initial_type == 'style':
    initial_image = style_image
elif args_initial_type == 'random':
    initial_image = np.ranodm.normal(size=content_image.shape, scale=np.std(content_image))

CONTENT_LAYERS = {}
for layer, weight in zip(args_content_layers, args_content_layer_weights):
    CONTENT_LAYERS[layer] = weight

STYLE_LAYERS = {}
for layer, weight in zip(args_style_layers, args_style_layer_weights):
    STYLE_LAYERS[layer] = weight

sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))

style_teansfer = StyleTransfer(session=sess,
                content_layer=CONTENT_LAYERS,
                style_layer=STYLE_LAYERS,
                init_image = add_one_dim(initial_image),
                content_image = add_one_dim(content_image),
                style_image = add_one_dim(style_image),
                model_selection=vgg_net,
                num_iter = args_num_iter,
                loss_ratio = args_loss_ratio,
                content_loss_norm_type = args_content_loss_norm_type,
                )

result_image = style_teansfer.optimize()
sess.close()

shape = result_image.shape
result_image = np.reshape(result_image, shape[1:])

save_image(result_image, args_output)

iteration:    0,  Loss_total: 1.05573e+09, Loss_content: 0, Loss_style: 1.05573e+09
iteration:    1,  Loss_total: 1.0554e+09, Loss_content: 22.6252, Loss_style: 1.0554e+09
iteration:    2,  Loss_total: 1.05406e+09, Loss_content: 567.662, Loss_style: 1.05406e+09
iteration:    3,  Loss_total: 1.04874e+09, Loss_content: 10122.1, Loss_style: 1.04874e+09
iteration:    4,  Loss_total: 1.02757e+09, Loss_content: 169139, Loss_style: 1.02757e+09
iteration:    5,  Loss_total: 9.50359e+08, Loss_content: 2.44104e+06, Loss_style: 9.50357e+08
iteration:    6,  Loss_total: 5.76628e+08, Loss_content: 5.4265e+07, Loss_style: 5.76573e+08
iteration:    7,  Loss_total: 3.12687e+08, Loss_content: 9.85602e+07, Loss_style: 3.12588e+08
iteration:    8,  Loss_total: 1.82608e+08, Loss_content: 1.4817e+08, Loss_style: 1.8246e+08
iteration:    9,  Loss_total: 1.3821e+08, Loss_content: 1.40796e+08, Loss_style: 1.38069e+08
iteration:   10,  Loss_total: 1.10629e+08, Loss_content: 1.42601e+08, Loss_style: 1.10487e+08

In [9]:
!ls

mixed_portait.jpg  picture.jpeg  portrait.jpg  pre_trained_model  sample_data


In [0]:
files.download(args_output)  # download the mixed_image