In [1]:
from __future__ import print_function, division
from builtins import range, input

In [2]:
from keras.layers import Input, Lambda, Dense, Flatten, AveragePooling2D, MaxPooling2D
from keras.layers.convolutional import Conv2D
from keras.models import Model, Sequential
from keras.applications.vgg16 import VGG16, preprocess_input
from keras.preprocessing import image
import keras.backend as backend

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [3]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import fmin_l_bfgs_b
from datetime import datetime

In [4]:
def vgg16_avg_pool(shape):
    vgg = VGG16(input_shape = shape, weights = 'imagenet', include_top = False)
    new_model = Sequential()
    for layer in vgg.layers:
        if (layer.__class__ == MaxPooling2D):
            new_model.add(AveragePooling2D())
        else:
            new_model.add(layer)
    return new_model

In [5]:
def vgg16_avg_pool_cutoff(shape, num_convs):
    if(num_convs < 1 or num_convs > 13):
        print("Input Error!!")
        return None
    model = vgg16_avg_pool(shape)
    new_model = Sequential()
    n = 0
    for layer in model.layers:
        if(layer.__class__ == Conv2D):
            n += 1
        new_model.add(layer)
        if(n >= num_convs):
            break
    return new_model

In [6]:
def unpreprocess(img):
    img[:, :, 0] += 103.939
    img[:, :, 1] += 116.779
    img[:, :, 2] += 126.68
    img = img[:, :, ::-1]
    return img

In [7]:
def scale_image(x):
    return (x - x.min()) / (x.max())

In [8]:
def gram_matrix(img):
    x = backend.batch_flatten(backend.permute_dimensions(img, (2, 0, 1)))
    gram = backend.dot(x, backend.transpose(x)) / img.get_shape().num_elements()
    return gram

In [9]:
def style_loss(y, t):
    return backend.mean(backend.square(gram_matrix(y) - gram_matrix(t)))

In [10]:
def minimize(fn, epochs, batch_shape):
    t0 = datetime.now()
    losses = []
    x = np.random.randn(np.prod(batch_shape))
    for i in range(epochs):
        x, l, _ = fmin_l_bfgs_b(func = fn, x0 = x, maxfun = 20)
        x = np.clip(x, -127, 127)
        print('iter: ', i, 'loss: ', l)
        losses.append(l)
    print('duration: ', datetime.now() - t0)
    plt.plot(losses)
    plt.show()
    new_img = x.reshape(*batch_shape)
    final_img = unpreprocess(new_img)
    return final_img[0]

In [11]:
def load_img_and_preprocess(path, shape = None):
    img = image.load_img(path, target_size = shape)
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis = 0)
    x = preprocess_input(x)
    return x

In [13]:
content_img = load_img_and_preprocess('C:\\Users\\AAKASH\\Desktop\\Temporary\\CN ML attachments\\baby1.jpg', )
height, width = content_img.shape[1:3]
style_img = load_img_and_preprocess('C:\\Users\\AAKASH\\Desktop\\Temporary\\CN ML attachments\\artScene.jpg', (height, width))

In [14]:
batch_shape = content_img.shape
shape = content_img.shape[1:]

In [15]:
vgg = vgg16_avg_pool(shape)

In [23]:
content_model = Model(vgg.input, vgg.layers[12].get_output_at(1))
content_target = backend.variable(content_model.predict(content_img))

In [17]:
symbolic_conv_outputs = [layer.get_output_at(1) for layer in vgg.layers if layer.name.endswith('conv1')]

In [19]:
style_model = Model(vgg.input, symbolic_conv_outputs)

In [22]:
style_layer_outputs = [backend.variable(y) for y in style_model.predict(style_img)]

In [24]:
style_weights = [1, 2, 3, 4, 5]

In [25]:
loss = backend.mean(backend.square(content_model.output - content_target))

In [26]:
for w, symbolic, actual in zip(style_weights, symbolic_conv_outputs, style_layer_outputs):
    loss += w * style_loss(symbolic[0], actual[0])

In [27]:
grads = backend.gradients(loss, vgg.input)

In [28]:
get_loss_and_grads = backend.function(inputs = [vgg.input], outputs = [loss] + grads)

In [29]:
def get_loss_and_grads_wrapper(x_vec):
    l, g = get_loss_and_grads([x_vec.reshape(*batch_shape)])
    return l.astype(np.float64), g.flatten().astype(np.float64)

In [None]:
final_img = minimize(get_loss_and_grads_wrapper, 10, batch_shape)
plt.imshow(scale_img(final_img))
plt.show()