In [None]:
import cv2
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim

import vgg

### Input image dimensions

In [None]:
height = 224
width = 224
channels = 3

### Set up input node and feature extractor

In [None]:
# The input node to the graph
input_var_initial_value = np.random.rand(1, height, width, channels)
input_var = tf.Variable(input_var_initial_value, dtype=tf.float32, name='input_var')

# Load the vgg model
with slim.arg_scope(vgg.vgg_arg_scope()):
    logits, end_points = vgg.vgg_19(input_var, num_classes=1000, is_training=False)

### Set up restoring feature extractor weights

In [None]:
# Prepare to restore the vgg19 nodes
# Skip trying to restore the input variable since it's new
all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
saver = tf.train.Saver(var_list=all_variables[1:])

### Load real image and create to-be optimised image

In [None]:
# Construct the real image tensor
# And the graph operation which assigns it to input_var
real_image = cv2.imread('./coastal_scene.jpg')
real_image = cv2.resize(real_image, (height, width))
real_image = cv2.cvtColor(real_image, cv2.COLOR_BGR2RGB)
real_image_batch = np.expand_dims(real_image, axis=0)
real_image_batch = np.asarray(real_image_batch, dtype=np.float32)
real_image_tensor = tf.Variable(real_image_batch, dtype=tf.float32, name='real_image')

assign_real_image = tf.assign(input_var, real_image_tensor, name='assign_real_image')

# Construct the white noise tensor
# And the graph operation which assigns it to input_var
white_noise = np.random.rand(height, width, channels) * 255.
white_noise_batch = np.expand_dims(white_noise, axis=0)
white_noise_batch = np.asarray(white_noise_batch, dtype=np.float32)
white_noise_tensor = tf.Variable(white_noise_batch, dtype=tf.float32, name='white_noise')

assign_white_noise = tf.assign(input_var, white_noise_tensor, name='assign_white_noise')

### Set up remaining graph nodes

In [None]:
# Choose which representation will be used to
# reconstruct the original image
# TODO: This should be a parameter higher up in the code
predictions = end_points['vgg_19/conv1/conv1_1']
batch_size_res, height_res, width_res, channels_res = np.shape(predictions)

# This placeholder will hold the response from the layer we are interested in
# given the real image
desired_response = tf.placeholder(tf.float32, 
                                  shape=[batch_size_res, height_res, width_res, channels_res],
                                 name='desired_response')

# Loss and optimizer
loss = tf.losses.mean_squared_error(labels=desired_response, predictions=predictions)
optimizer = tf.train.AdamOptimizer(1e-1)
train_op = optimizer.minimize(loss, var_list=[input_var])

# Tensorboard summaries
loss_summary = tf.summary.scalar('loss', loss)
image_summary = tf.summary.image('image', input_var)
merged_summary = tf.summary.merge_all()

# Initializers
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer(),
                  name='initialize_all')

### Helper function for displaying images 

In [None]:
def convert_to_proper(initial_image):
    """
    Convert an image from tf node format to format that is 
    suitable for plt displaying.
    
    This involves ensuring the image is 3D (removing the batch dimension),
    clipping the image to be between 0 and 255, rounding floats to int, and 
    setting the array type to be integers. 
    
    Arguments:
        initial_image: The original image from the node
    Returns:
        converted_image: Image to be shown by plt
    """
    if np.ndim(initial_image) == 4:
        initial_image = np.squeeze(initial_image, axis=0)
    
    image_clipped = np.clip(initial_image, 0, 255)
    image_rounded = np.rint(image_clipped)
    converted_image = np.asarray(image_rounded, dtype=np.uint8)
    
    return converted_image

### Training

In [None]:
with tf.Session() as sess:
    
    # Initialize all variables and then
    # restore weights for feature extractor
    sess.run(init_op)
    saver.restore(sess, './vgg_19.ckpt')
    
    # Set up summary writer for tensorboard
    train_writer = tf.summary.FileWriter('./train', sess.graph)
        
    # Using the real image, get the response of the chosen layer
    assign_real_image.eval()
    desired_response_ = predictions.eval()
    
    # Assign the white noise to the image
    assign_white_noise.eval()
    
    # Begin training
    for i in range(5000):
        input_var_, summary, loss_, _ = sess.run([input_var, merged_summary, loss, train_op], 
                                     feed_dict={desired_response:desired_response_})
        train_writer.add_summary(summary, i)
        
        if i%100==0:
            print('loss: {}'.format(loss_))
            
            input_var_ = input_var.eval()
            input_var_ = np.squeeze(input_var_, axis=0)
            plt.figure()
            input_var_int = convert_to_proper(input_var_)
            plt.imshow(input_var_int)
            plt.title(str(i))
            img_name = './debug/img_plt_{}.png'.format(str(i))
            plt.savefig(img_name)
            
            input_var_colour = cv2.cvtColor(input_var_, cv2.COLOR_BGR2RGB)
            img_name = './debug/img_{}.png'.format(str(i))
            cv2.imwrite(img_name, input_var_colour)
    
    input_var_ = input_var.eval()
    input_var_ = np.squeeze(input_var_, axis=0)
    input_var_int = convert_to_proper(input_var_)
    plt.figure()
    plt.imshow(input_var_int)
    plt.title('final')
    plt.savefig('./debug/final_plt.png')
    
    input_var_colour = cv2.cvtColor(input_var_, cv2.COLOR_BGR2RGB)
    cv2.imwrite('./debug/final.png', input_var_colour)   