"""

This notebook runs through Section 2.1:Content Reconstructions. 
The images generated are similar to those in Figure 1. 

The Directory Paths block should be edited to suit local directory structure. 
The Chosen Parameters block can be changed to try out different experimental settings. 

"""

In [None]:
import cv2
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import os
import shutil
import tensorflow as tf
import tensorflow.contrib.slim as slim

import vgg

### Directory paths

In [None]:
# Path to image we are extracting content from
real_image_path = './coastal_scene.jpg'

# Path to vgg19 checkpoint, must be downloaded separately
checkpoint_path = './vgg_19.ckpt'

# Location of tensorboard summaries
tensorboard_dir = './train/'

# Path to directory used for storing images
debug_dir = './debug/'

# Determines whether information is saved between runs
# for tensorboard
reset_saves = True
if reset_saves is True:
    # Ensure tensorboard is not running when you try to delete
    # this directory
    if os.path.exists(tensorboard_dir):
        shutil.rmtree(tensorboard_dir)
        
# Create the debug directory if it doesn't exist
# Tensorboard directory is made automatically if it doesn't exist
if os.path.exists(debug_dir):
    shutil.rmtree(debug_dir)
os.makedirs(debug_dir)

### Chosen parameters

In [None]:
# Layer being used to produce features
feature_layer = 'vgg_19/conv2/conv2_2'

# Learning rate for optimizer
learning_rate = 1e-1

# Number of training and validation step
# In this instance, validation refers to when we would like to examine the 
# currently optimized image, save it, and loss
training_steps = 100000
validation_steps = 1000

# Online debugging refers to images that will be displayed within the notebook 
# using plt
# Offline debugging refers to images that will be saved to folder using plt
debug_online = True
debug_offline = True

### Set up input node and feature extractor

In [None]:
# Dimensions are required by vgg19
height = 224
width = 224
channels = 3

# Set the seeds to provide consistency between runs
np.random.seed(0)
tf.set_random_seed(0)

# The input node to the graph
# These values are what is required by vgg19 for height, width, channels
input_var_initial_value = np.random.rand(1, height, width, channels)
input_var = tf.Variable(input_var_initial_value, dtype=tf.float32, name='input_var')

# Load the vgg model
with slim.arg_scope(vgg.vgg_arg_scope()):
    logits, end_points = vgg.vgg_19(input_var, num_classes=1000, is_training=False)

### Set up restoring feature extractor weights

In [None]:
# Prepare to restore the vgg19 nodes
# Skip trying to restore the input variable since it's new
all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
saver = tf.train.Saver(var_list=all_variables[1:])

### Load real image and create to-be optimised image

In [None]:
# Construct the real image tensor
# And the graph operation which assigns it to input_var
real_image = plt.imread(real_image_path)
real_image = cv2.resize(real_image, (height, width))
real_image_batch = np.expand_dims(real_image, axis=0)
real_image_batch = np.asarray(real_image_batch, dtype=np.float32)
real_image_tensor = tf.Variable(real_image_batch, dtype=tf.float32, name='real_image')

assign_real_image = tf.assign(input_var, real_image_tensor, name='assign_real_image')

# Construct the white noise tensor
# And the graph operation which assigns it to input_var
white_noise = np.random.rand(height, width, channels) * 255.
white_noise_batch = np.expand_dims(white_noise, axis=0)
white_noise_batch = np.asarray(white_noise_batch, dtype=np.float32)
white_noise_tensor = tf.Variable(white_noise_batch, dtype=tf.float32, name='white_noise')

assign_white_noise = tf.assign(input_var, white_noise_tensor, name='assign_white_noise')

### Set up remaining graph nodes

In [None]:
# Choose which representation will be used to
# reconstruct the original image
predictions = end_points[feature_layer]
batch_size_res, height_res, width_res, channels_res = np.shape(predictions)

# This placeholder will hold the response from the layer we are interested in
# given the real image
desired_response = tf.placeholder(tf.float32, 
                                  shape=[batch_size_res, height_res, width_res, channels_res],
                                 name='desired_response')

# Loss and optimizer
loss = tf.losses.mean_squared_error(labels=desired_response, predictions=predictions)
optimizer = tf.train.AdamOptimizer(learning_rate)
train_op = optimizer.minimize(loss, var_list=[input_var])

# Tensorboard summaries
loss_summary = tf.summary.scalar('loss', loss)
image_summary = tf.summary.image('image', input_var)

# Initializers
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer(),
                  name='initialize_all')

### Helper function for formatting images 

In [None]:
def format_image(initial_image):
    """
    Convert an image from tf node format to format that is 
    suitable for plt displaying.
    
    This involves ensuring the image is 3D (removing the batch dimension),
    clipping the image to be between 0 and 255, rounding floats to int, and 
    setting the array type to be integers. 
    
    Arguments:
        initial_image: The original image from the node
    Returns:
        converted_image: Image to be shown by plt
    """
    if np.ndim(initial_image) == 4:
        initial_image = np.squeeze(initial_image, axis=0)
    
    image_clipped = np.clip(initial_image, 0, 255)
    image_rounded = np.rint(image_clipped)
    formatted_image = np.asarray(image_rounded, dtype=np.uint8)
    
    return formatted_image

### Helper function for displaying images

In [None]:
def display_image(initial_image, fig_title):
    """
    Display an image within the notebook.
    
    Arguments:
        initial_image: The original image from the node
        fig_title    : Title for this image 
    Returns:
        None
    """
    converted_image = format_image(initial_image)
    plt.figure()
    plt.imshow(converted_image)
    plt.axis('off')
    plt.title(fig_title)

### Helper function for saving images

In [None]:
def save_image(initial_image, fig_title):
    """
    Save an image to disk.
    
    Arguments:
        initial_image: The original image from the node
        fig_title    : Title for this image 
    Returns:
        None
    """
    converted_image = format_image(initial_image)
    img_name = './debug/img_{}.png'.format(fig_title)
    plt.imsave(img_name, converted_image)

### Training

In [None]:
with tf.Session() as sess: 
    # Initialize all variables and then
    # restore weights for feature extractor
    sess.run(init_op)
    saver.restore(sess, checkpoint_path)
    
    # Set up summary writer for tensorboard
    train_writer = tf.summary.FileWriter(tensorboard_dir, sess.graph)
        
    # Using the real image, get the response of the chosen layer
    assign_real_image.eval()
    desired_response_ = predictions.eval()
    
    # Assign the white noise to the image
    assign_white_noise.eval()
    
    # Begin training
    for i in range(training_steps):
        summary, _ = sess.run([loss_summary, train_op], 
                                     feed_dict={desired_response:desired_response_})
        train_writer.add_summary(summary, i)
        
        if i % validation_steps == 0:
            summary, current_image, loss_ = sess.run([image_summary, input_var, loss],
                                           feed_dict={desired_response:desired_response_})
            train_writer.add_summary(summary, i)
            
            print('Step: {}, Loss: {}'.format(i, loss_))
            
            if debug_online is True:
                display_image(current_image, i)
            if debug_offline is True:
                save_image(current_image, i)  
                
    # Display and save the final image
    current_image = input_var.eval()
    display_image(current_image, 'Final')
    save_image(current_image, 'Final')