<a href="https://colab.research.google.com/github/Jiaweihu08/Real-Time-NST/blob/master/NST_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install --upgrade tensorflow==2.0.0.-rc1

In [0]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.keras.preprocessing import image as kp_image
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import time
from IPython import display

print(tf.__version__)

2.0.0-rc1


In [0]:

def load_image(img_path):

    img = Image.open(img_path)
    
    img = img.resize((288, 288), Image.ANTIALIAS)
    
    img = kp_image.img_to_array(img)
    
    img = np.expand_dims(img, axis=0)
    
    return img

def tensor_to_image(tensor):
  #tensor = tensor*255
  tensor = np.array(tensor)

  if np.ndim(tensor)>3:
    assert tensor.shape[0] == 1
    tensor = tensor[0]

  return Image.fromarray(tensor)

def imshow(img, title=None):
    """Function used to display the image.

    We use matplitlib.pyplot.imshow to visualize the image,
    and when it takes an image in array form, the size of it should be
    (M, N, 3) for RGB images with values(0-1 float, 0-255 int),
    so in this case we
    need to convert the values of the array from float to int.
    
    """
    # Removing the batch dimension
    out = np.squeeze(img, axis=0)
    # Convert float to int
    out = out.astype('uint8')
    plt.imshow(out)
    if title is not None:
        plt.title(title)
    plt.imshow(out)

In [0]:
def get_model():
    vgg = keras.applications.vgg19.VGG19(weights='imagenet', include_top=False)
    vgg.trainable = False

    style_output = [vgg.get_layer(name).output for name in style_layers]
    content_output = [vgg.get_layer(name).output for name in content_layers]

    model_output = style_output + content_output

    return keras.Model(vgg.input, model_output)

In [0]:
def gram_matrix(style_outputs):
    channels = int(style_outputs.shape[-1])
    style_outputs = tf.reshape(style_outputs, [-1, channels])
    n = tf.shape(style_outputs)[0]
    
    gram = tf.matmul(style_outputs, style_outputs, transpose_a=True)

    return gram / tf.cast(n, tf.float32)

In [0]:
class StyleContentModel(keras.models.Model):
    def __init__(self, style_layers, content_layers):
        super().__init__()
        self.vgg = get_model()
        self.style_layers = style_layers
        self.content_layers = content_layers
        self.num_style_layers = len(style_layers)
        self.vgg.trainable = False
    
    def call(self, inputs):
        #inputs = inputs * 255.0
        preprocessed_inputs = keras.applications.vgg19.preprocess_input(inputs)
        outputs = self.vgg(preprocessed_inputs)
        style_outputs, content_outputs = (outputs[:num_style_layers],
                                          outputs[num_style_layers:])
        style_outputs = [gram_matrix(style_output)
                        for style_output in style_outputs]
        
        style_dict = {style_name: value
                      for style_name, value in 
                      zip(self.style_layers, style_outputs)}
       
        content_dict = {content_name: value
                        for content_name, value in
                        zip (content_layers, content_outputs)}
        
        return {'style': style_dict, 'content': content_dict}

In [0]:
def style_content_loss(outputs):
    style_outputs = outputs['style']
    content_outputs = outputs['content']

    style_loss = tf.add_n([tf.reduce_mean(tf.square(style_outputs[name], style_target[name]))
                                          for name in style_outputs.keys()])
    style_loss *= style_weight / num_style_layers

    content_loss = tf.add_n([tf.reduce_mean(tf.square(content_outputs[name], content_target[name]))
                                          for name in content_outputs.keys()])
    style_loss *= content_weight / num_content_layers

    loss = style_loss + content_loss

    return loss

In [0]:
class MyInstanceNorm(keras.layers.Layer):
    def build(self, batch_input_shape):
        self.scale = self.add_weight(name='scale', shape=[batch_input_shape[-1]],
                                     initializer='ones', dtype=tf.float32)
        self.shift = self.add_weight(name='shift', shape=[batch_input_shape[-1]],
                                     initializer='zeros', dtype=tf.float32)
        super().build(batch_input_shape)
    
    def call(self, X, training=True):
        if training:
            mean, variance = tf.nn.moments(X, axes=[1,2], keepdims=True)
            std = tf.sqrt(variance)
            epsilon = 1e-3
            X_ = (X - mean) / (std + epsilon)
            return self.scale * X + self.shift
        else:
            return X

def conv_layer(net, filters, kernel_size, strides,
               padding='SAME', relu=True,
               transpose=False, input_shape=None):
    if not transpose:
        if input_shape:
            x = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size,
                                    strides=strides, padding=padding,
                                    input_shape=input_shape,
                                    kernel_initializer=keras.initializers.TruncatedNormal(0, 1, 1))(net)
        else:
            x = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size,
                                    strides=strides, padding=padding,
                                    kernel_initializer=keras.initializers.TruncatedNormal(0, 1, 1))(net)
    else:
        x = keras.layers.Conv2DTranspose(filters=filters, kernel_size=kernel_size,
                                         strides=strides, padding=padding,
                                         kernel_initializer=keras.initializers.TruncatedNormal(0, 1, 1))(net)
    
    x = MyInstanceNorm()(x)
    
    if relu:
        x = keras.activations.relu(x)
    
    return x

def residual_block(net):
    tmp1 = conv_layer(net, 128, 3, 1)
    
    return net + conv_layer(tmp1, 128, 3, 1, relu=False)

def NST_model(init_image):
    conv1 = conv_layer(init_image, 32, 9, 1)
    conv2 = conv_layer(conv1, 64, 3, 2)
    conv3 = conv_layer(conv2, 128, 3, 2)
    resid1 = residual_block(conv3)
    resid2 = residual_block(resid1)
    resid3 = residual_block(resid2)
    resid4 = residual_block(resid3)
    resid5 = residual_block(resid4)
    conv_t1 = conv_layer(resid2, 64, 3, 2, transpose=True)
    conv_t2 = conv_layer(conv_t1, 32, 3, 2, transpose=True)
    conv_t3 = conv_layer(conv_t2, 3, 9, 1, relu=False)
    out = keras.activations.sigmoid(conv_t3) * 255
    
    return out


In [0]:
def train_step(image, itn_model):
    with tf.GradientTape() as tape:
        output_image = itn_model(image, training=True)
        outputs = extractor(output_image)
        loss = style_content_loss(outputs)
    grads = tape.gradient(loss, itn_model.trainable_variables)
    opt.apply_gradients(zip(grads, itn_model.trainable_variables))
    #print(grads)

In [0]:
init_image = keras.layers.Input(shape=[288, 288, 3])
output_image = NST_model(init_image)

ITN_model = keras.Model(init_image, output_image)

In [0]:
import os

img_dir = '/images/'
if not os.path.exists(img_dir):
    os.mkdir(img_dir)

!wget --quiet -P /images/ https://upload.wikimedia.org/wikipedia/commons/d/d7/Green_Sea_Turtle_grazing_seagrass.jpg
!wget --quiet -P /images/ https://upload.wikimedia.org/wikipedia/commons/0/0a/The_Great_Wave_off_Kanagawa.jpg

content_path = '/images/Green_Sea_Turtle_grazing_seagrass.jpg'
style_path = '/images/The_Great_Wave_off_Kanagawa.jpg'

In [0]:
content_layers = ['block5_conv2'] 

style_layers = ['block1_conv1',
                'block2_conv1',
                'block3_conv1', 
                'block4_conv1', 
                'block5_conv1'
               ]

num_content_layers = len(content_layers)
num_style_layers = len(style_layers)

In [0]:
style_weight = 1e-2
content_weight = 1e4
total_variation_weight = 30
iterations = 1000
    
style_image = load_image(style_path)
content_image = load_image(content_path)
init_image = content_image

extractor = StyleContentModel(style_layers, content_layers)
style_target = extractor(style_image)['style']
content_target = extractor(content_image)['content']

opt = tf.optimizers.Adam(learning_rate=0.02, beta_1=0.99, epsilon=1e-1)

display_interval = 10

start_time = time.time()

for iteration in range(1, iterations+1):
    train_step(init_image, ITN_model)

    iter_time = time.time() - start_time
    print("\nIteration: {}/{} - Time: {:.4f}s".format(iteration,
                                                      iterations,
                                                      iter_time))
    imshow(output_image)
