In [0]:
# Getting Dataset
from urllib.request import urlretrieve
data_set_url = 'http://images.cocodataset.org/zips/train2014.zip'
urlretrieve(data_set_url,'/content/data_set.zip')

In [0]:
# Extract dataset folder
import zipfile
zip_ref = zipfile.ZipFile('/content/data_set.zip','r')
zip_ref.extractall('/content/training_data')

In [0]:
from google.colab import drive
drive.mount('/content/my_drive')

In [0]:
# Import libraries
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from keras import layers
from keras import backend as K
from keras.models import Model
from keras.applications.vgg16 import VGG16, preprocess_input
from keras.engine.topology import Layer
from keras.regularizers import Regularizer
from keras.layers.normalization import BatchNormalization
from keras.preprocessing.image import array_to_img, load_img, img_to_array
from matplotlib import pyplot as plt
from IPython.display import clear_output

Using TensorFlow backend.


In [0]:
img_width,img_height = (512,512)

In [0]:
def preprocess_image(image_path,rows,cols):
    img = load_img(image_path, target_size=(rows, cols))
    img = img_to_array(img)
    img = np.expand_dims(img, axis=0)
    return img

In [0]:
# Layer used to rescale the output values from [-1,1] to [0,255]
class Denormalize(Layer):
    def __init__(self, **kwargs):
        super(Denormalize, self).__init__(**kwargs)

    def build(self, input_shape):
        pass

    def call(self, x, mask=None):
        return (x + 1) * 127.

#---------------------------------------------------------------------------------------------
# Layer used to normalize input image pixels by its values by 255.0
class InputNormalize(Layer):
    def __init__(self, **kwargs):
        super(InputNormalize, self).__init__(**kwargs)

    def build(self, input_shape):
        pass

#     def compute_output_shape(self,input_shape):
#         return input_shape

    def call(self, x, mask=None):
        return x/255.

#---------------------------------------------------------------------------------------------      
# Layer used to normailze the VGG16 model input
class VGGNormalize(Layer):
    def __init__(self, **kwargs):
        super(VGGNormalize, self).__init__(**kwargs)

    def build(self, input_shape):
        pass

    def call(self, x, mask=None):
        x = preprocess_input(x)          
        return x

In [0]:
# Transformation Network
def transform_net(img_width,img_height):
  input_tensor = layers.Input(shape=(img_width,img_height,3))
  input_tensor2 = InputNormalize()(input_tensor)
  
  x = layers.Conv2D(32, kernel_size = (9,9), strides = (1,1), padding = 'same')(input_tensor2)
  x = layers.BatchNormalization()(x)
  x = layers.Activation('relu')(x)
  
  x = layers.Conv2D(64, kernel_size = (3,3), strides = (2,2), padding = 'same')(x)
  x = layers.BatchNormalization()(x)
  x = layers.Activation('relu')(x)
  
  x = layers.Conv2D(128, kernel_size = (3,3), strides = (2,2), padding = 'same')(x)
  x = layers.BatchNormalization()(x)
  x = layers.Activation('relu')(x)
  
  x = residual_block(x)
  x = residual_block(x)
  x = residual_block(x)
  x = residual_block(x)
  x = residual_block(x)

  x = layers.Conv2DTranspose(64, kernel_size = (3,3), strides = (2,2), padding = 'same')(x)
  x = layers.BatchNormalization()(x)
  x = layers.Activation('relu')(x)
  
  
  x = layers.Conv2DTranspose(32, kernel_size = (3,3), strides = (2,2), padding = 'same')(x)
  x = layers.BatchNormalization()(x)
  x = layers.Activation('relu')(x)
  
  x = layers.Conv2DTranspose(3, kernel_size = (9,9), strides = (1,1), padding = 'same')(x)
  x = layers.BatchNormalization()(x)
  output_tensor = layers.Activation('tanh')(x)  
  
  output_tensor2 = Denormalize()(output_tensor)
  model = Model(inputs = input_tensor,outputs = output_tensor2)  
  return model  

def residual_block(x):
  y = x
  x = layers.Conv2D(128,kernel_size = (3,3),strides = (1,1),padding = 'same')(x)
  x = layers.BatchNormalization()(x)
  x = layers.Activation('relu')(x)
  
  x = layers.Conv2D(128,kernel_size = (3,3),strides = (1,1),padding = 'same')(x)
  x = layers.BatchNormalization()(x)
  res = layers.merge.add([x, y])
  return res

In [0]:
def dummy_loss(y_true, y_pred ):
    return K.variable(0.0)

def gram_matrix(x):
    features = K.batch_flatten(K.permute_dimensions(x, (2, 0, 1)))

    shape = K.shape(x)
    
    C, W, H = (shape[0],shape[1], shape[2])
    
#   Reshape the features to (C X H*W) 
    cf = K.reshape(features ,(C,-1))
    
    gram = K.dot(cf, K.transpose(cf)) /  K.cast(C*W*H,dtype='float32')
    return gram

#---------------------------------------------------------------------------------------------              
class StyleReconstructionRegularizer(Regularizer):
    def __init__(self, style_feature_target, weight=1.0):
        self.style_feature_target = style_feature_target
        self.weight = weight
        self.uses_learning_phase = False
        super(StyleReconstructionRegularizer, self).__init__()

        self.style_gram = gram_matrix(style_feature_target)

    def __call__(self, x):
        output = x.output[0] # Generated by the transformation network
        
        loss = self.weight *  K.sum(K.mean(K.square(self.style_gram-gram_matrix(output))))
        return loss

#---------------------------------------------------------------------------------------------
class FeatureReconstructionRegularizer(Regularizer):
    def __init__(self, weight=1.0):
        self.weight = weight
        self.uses_learning_phase = False
        super(FeatureReconstructionRegularizer, self).__init__()

    def __call__(self, x):
        generated = x.output[0] # Generated by the transformation network
        content = x.output[1]   # The original input image

        loss = self.weight *  K.sum(K.mean(K.square(content-generated)))

        return loss
      
#---------------------------------------------------------------------------------------------
class TVRegularizer(Regularizer):
    def __init__(self, weight=1.0):
        self.weight = weight
        self.uses_learning_phase = False
        super(TVRegularizer, self).__init__()

    def __call__(self, x):
        x_out = x.output
        
        shape = K.shape(x_out)
        img_width, img_height,channel = (shape[1],shape[2], shape[3])

        a = K.square(x_out[:, :img_width - 1, :img_height - 1, :] - x_out[:, 1:, :img_height - 1, :])
        b = K.square(x_out[:, :img_width - 1, :img_height - 1, :] - x_out[:, :img_width - 1, 1:, :])

        loss = self.weight * K.sum(K.pow(a + b, 1.25)) 
        return loss

In [0]:
def add_style_loss(vgg,style_image_path,vgg_layers,vgg_output_dict,img_width, img_height,weight):
    style_img = preprocess_image(style_image_path, img_width, img_height)
    print('Getting style features from VGG network.')
    style_layers = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1','block5_conv1']

    style_layer_outputs = []

    for layer in style_layers:
        style_layer_outputs.append(vgg_output_dict[layer])

#   Getting the output of each layer of the style layers 
    vgg_style_func = K.function([vgg.layers[-19].input], style_layer_outputs)
    style_features = vgg_style_func([style_img])
    
#   Adding the style losses to multiple layers
    for i, layer_name in enumerate(style_layers):
        layer = vgg_layers[layer_name]

        feature_var = K.variable(value=style_features[i][0])
        style_loss = StyleReconstructionRegularizer(
                            style_feature_target=feature_var,
                            weight=weight)(layer)
        
        layer.add_loss(style_loss)
        
#---------------------------------------------------------------------------------------------        
def add_content_loss(vgg_layers,vgg_output_dict,weight):
    content_layer = 'block4_conv2'

    layer = vgg_layers[content_layer]
    content_regularizer = FeatureReconstructionRegularizer(weight)(layer)
    layer.add_loss(content_regularizer)
    
#---------------------------------------------------------------------------------------------        
def add_total_variation_loss(transform_output_layer,weight):
    # Total Variation Regularization
    layer = transform_output_layer  # Transformation output layer
    tv_regularizer = TVRegularizer(weight)(layer)
    layer.add_loss(tv_regularizer)

In [0]:
# Declare and initialize the model for training phase
trans = transform_net(img_width,img_height)
tensor1 = layers.merge.concatenate([trans.output,trans.input],axis = 0)
tensor2 = VGGNormalize(name="vgg_normalize")(tensor1)
vgg1 = VGG16(include_top = False,input_tensor = tensor2,weights = None)

In [0]:
# Loading the VGG16 weights
vgg1.load_weights(filepath ='/content/my_drive/My Drive/Project/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5',by_name = True)

In [0]:
trans.load_weights('/content/my_drive/My Drive/Project/style_weights/picasso/epoch1/picasso_weights_after_40000iteration_1st_epoch.h5')

In [0]:
# Select the path of the style image
style_path = '/content/my_drive/My Drive/Project/style_weights/picasso/picasso.jpg'
vgg_output_dict = dict([(layer.name, layer.output) for layer in vgg1.layers[-18:]])
vgg_layers = dict([(layer.name, layer) for layer in vgg1.layers[-18:]])
add_style_loss(vgg1, style_path, vgg_layers, vgg_output_dict, img_width, img_height,3.0)
add_content_loss(vgg_layers,vgg_output_dict,2.0)
add_total_variation_loss(trans.layers[-1],5e-5)

# Freeze all VGG16 layers
for layer in vgg1.layers[-19:]:
  layer.trainable = False


Getting style features from VGG network.


In [0]:
# Select the optimizer for the model
from keras.optimizers import Adam
optimizer = Adam()
vgg1.compile(optimizer, loss = dummy_loss)

In [0]:
files = [f for f in os.listdir('/content/training_data/train2014')]

In [0]:
# Trainig the model
y = np.zeros((1,img_width,img_height,3),dtype='float32')
j = 0
history = []
for i in range(j,82783):
  v = preprocess_image('/content/training_data/train2014/{}'.format(files[i]),img_width,img_height)
  h = vgg1.train_on_batch(v,y)
  history.append(h)
  if i % 100 == 0:
    clear_output()
    print('loss: {} at iteration: {} '.format(h,i))
    res = trans.predict(v)
    resImg = array_to_img(res[0])
    plt.imshow(resImg)
    plt.figure()
    inputImg = array_to_img(v[0])
    plt.imshow(inputImg)
    plt.show()
    
  if (i % 10000 == 0 and i > 0) or i == 82782:
    trans.save_weights('/content/my_drive/My Drive/Project/style_weights/picasso/epoch1/picasso_weights_after_{}iteration_1st_epoch.h5'.format(i))
    inputImg.save('/content/my_drive/My Drive/Project/style_weights/picasso/epoch1/inputImage_after_{}iteration_1st_epoch.jpg'.format(i))
    resImg.save('/content/my_drive/My Drive/Project/style_weights/picasso/epoch1/outputImage_after_{}iteration_1st_epoch.jpg'.format(i))
    pd.DataFrame(history).to_csv('/content/my_drive/My Drive/Project/style_weights/picasso/epoch1/picasso_history_of_losses_after_{}iteration_1st_epoch.csv'.format(i))    
      