In [46]:
import tensorflow as tf
import numpy as np
import PIL.Image as img
from imageio import mimsave

In [2]:
def prep_VGG19():
    
    vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
    
    vgg.trainable = False
    
    features_list = [tf.reshape(tf.transpose(layer.output, perm = [0,3,1,2]), [1,layer.output.shape[-1],-1])
                 for layer in vgg.layers]
    
    extractor = tf.keras.Model(inputs = vgg.input, outputs = features_list)
    
    return extractor

In [3]:
def content_loss(Fp, Fc):
    
    return 0.5 * tf.reduce_sum(tf.square(Fp - Fc))

In [4]:
def gram(F):
    
    return tf.matmul(F, F, transpose_b = True)

def layer_gram_loss(Fp, Fc):

    (_, nl, ml) = Fp.shape
    
    return tf.reduce_sum(tf.square(gram(Fp) - gram(Fc)))/(4 * nl**2 * ml**2)    

In [5]:
def gram_loss(W, features_product, features_style):
    
    return tf.reduce_sum([
        (1/len(W)) * layer_gram_loss(features_product[layernum], features_style[layernum])
        for layernum in W
    ])

In [7]:
def train_step(grad_image, 
              feature_model, 
              optimizer, 
              content_features, 
              style_features, 
              content_layer,
              style_layers, 
              content_weight,
              style_weight,
              tvl_weight):
    
    assert(type(content_layer) == int)
    assert(type(style_layers) == list)
    
    with tf.GradientTape() as tape:
        
        image_features = feature_model(preprocess_image(grad_image))
        
        loss = content_weight * content_loss(image_features[content_layer], content_features[content_layer]) + style_weight * gram_loss(style_layers, image_features, style_features)
        loss += tvl_weight * tf.image.total_variation(grad_image)
    
    grads = tape.gradient(loss, grad_image)
    optimizer.apply_gradients([(grads, grad_image)])
    grad_image.assign(tf.clip_by_value(grad_image, 0., 1.))
    
    return loss

In [8]:
def get_image(path):
    image = img.open(path)
    image = image.resize((224,224))
    x = np.array(image)
    x = np.expand_dims(x, 0)
    x = np.ndarray.astype(x, dtype = np.float32)
    return x / 255.

In [9]:
def preprocess_image(image):
    image = image * 255
    image = tf.keras.applications.vgg19.preprocess_input(image)
    return image

In [62]:
def NST(content_path, 
        style_path, 
        epochs = 200,
        steps_per_epoch = 5,
        start_from_content = False,
        content_layer = 18, 
        style_layers = [1,4,7,12,17], 
        content_weight = 1e4, 
        style_weight = 1e-2,
        learning_rate = 0.02,
        total_loss_variation_weight = 30,       
       ):

    
    content_image = get_image(content_path)
    style_image = get_image(style_path)
    
    extractor = prep_VGG19()

    if not start_from_content:
        grad_image = tf.Variable(np.ndarray.astype(np.random.rand(1,224,224,3), np.float32))
    else:
        grad_image = tf.Variable(content_image)

    content_features, style_features = extractor(preprocess_image(content_image)), extractor(preprocess_image(style_image))
    
    optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.99, epsilon=1e-1)
    
    generated_images = []
    
    gif = np.zeros((epochs,224,224,3))
    
    try:
    
        for epoch in range(epochs):
            for step in range(steps_per_epoch):

                loss = train_step(grad_image, 
                  extractor, 
                  optimizer, 
                  content_features, 
                  style_features, 
                  content_layer,
                  style_layers, 
                  content_weight,
                  style_weight,
                  total_loss_variation_weight)

            print('Epoch: {}, Loss: {}'.format(str(epoch), str(loss.numpy()[0])))
            
            gen_image = np.squeeze(grad_image.numpy())
            
            gif[epoch] = gen_image
    finally:
        return gif, gif[epoch - 1]

In [71]:
giffy, final_image = NST('./content_images/stonearch_bridge.jpg', 
                         './style_images/starry_night.jpg', 
                         epochs = 200, 
                         steps_per_epoch = 5,
                         start_from_content = False,
                         style_weight=1e-1
                        )

Epoch: 0, Loss: 3337373400000.0
Epoch: 1, Loss: 2251820400000.0
Epoch: 2, Loss: 1676206700000.0
Epoch: 3, Loss: 1348027700000.0
Epoch: 4, Loss: 1152686400000.0
Epoch: 5, Loss: 1016218700000.0
Epoch: 6, Loss: 913706000000.0
Epoch: 7, Loss: 840514000000.0
Epoch: 8, Loss: 779989600000.0
Epoch: 9, Loss: 731284800000.0
Epoch: 10, Loss: 687869400000.0
Epoch: 11, Loss: 652885160000.0
Epoch: 12, Loss: 621538250000.0
Epoch: 13, Loss: 592897200000.0
Epoch: 14, Loss: 566322900000.0
Epoch: 15, Loss: 539295000000.0
Epoch: 16, Loss: 515028680000.0
Epoch: 17, Loss: 494045430000.0
Epoch: 18, Loss: 474870900000.0
Epoch: 19, Loss: 455578200000.0
Epoch: 20, Loss: 438300050000.0
Epoch: 21, Loss: 421900880000.0
Epoch: 22, Loss: 406771100000.0
Epoch: 23, Loss: 392561070000.0
Epoch: 24, Loss: 379049150000.0
Epoch: 25, Loss: 366642500000.0
Epoch: 26, Loss: 354526660000.0
Epoch: 27, Loss: 343143200000.0
Epoch: 28, Loss: 332646220000.0
Epoch: 29, Loss: 322161300000.0
Epoch: 30, Loss: 312532730000.0
Epoch: 31, L

In [72]:
final_image = img.fromarray((final_image * 255).astype('uint8'), 'RGB')
final_image.save('./output_image.jpg')

In [73]:
images = [
    img.fromarray((im * 255).astype('uint8'), 'RGB')
    for im in giffy if np.sum(im) > 0.
]

In [74]:
images[0].save('./evolution.gif', save_all=True, append_images=images[1:] + [images[len(images) - 1] for i in range(len(images)//4)], optimize=False, duration=40, loop=0)