In [None]:
#loading packages
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (10,10)
mpl.rcParams['axes.grid'] = False

import numpy as np
from PIL import Image
import time
import functools

In [None]:
%tensorflow_version 1.x
import tensorflow as tf

from tensorflow.python.keras import models 
from tensorflow.python.keras import losses
from tensorflow.python.keras import layers
from tensorflow.python.keras import backend as K

In [None]:
tf.enable_eager_execution()
print("Eager execution: {}".format(tf.executing_eagerly()))

#Please upload 4 imgs under img folder to the sample_data folder in google drive

In [None]:
# Set up global values
content_path = '/content/sample_data/Butler_Library_-_1000px_-_AC.jpg'
style_path = '/content/sample_data/1280px-A_Sunday_on_La_Grande_Jatte,_Georges_Seurat,_1884.jpg'
content_path1 = '/content/sample_data/Almamater.jpg'
style_path1 = '/content/sample_data/1024px-Monet_-_Impression,_Sunrise.jpg'

##Image preprocess and visualization functions

In [None]:
import matplotlib.image as mpimg
def load_img(path):
  img = mpimg.imread(path)
  img = tf.image.resize(img,(340,512))
  #img = tf.image.resize(img,(width,height))
  img = tf.image.convert_image_dtype(img, tf.float32)
  #img = img/255.0
  img = img[tf.newaxis,:]
  img = np.array(img)
  return(img)

In [None]:
def imshow(img_array, title=None):
  temp = np.squeeze(img_array, axis=0)
  temp = temp.astype('uint8')
  plt.imshow(temp)
  if title is not None:
    plt.title(title)

In [None]:
def generate_noise_image(content_image):
    np.random.seed(50)
    noise_image = np.random.uniform(-20, 20,content_image.shape).astype('float32')
    # White noise image from the content representation. Take a weighted average
    # of the values
    input_image = noise_image * 0.5 + content_image * 0.5
    return input_image

In [None]:
plt.figure(figsize=(10,10))

content = load_img(content_path).astype('uint8')
style = load_img(style_path).astype('uint8')
np.random.seed(10)
noise_image = np.random.uniform(-20, 20,content.shape).astype('float32')
input_image = generate_noise_image(content)


plt.subplot(2, 2, 1)
imshow(content, 'Content Image')

plt.subplot(2, 2, 2)
imshow(style, 'Style Image')

plt.subplot(2, 2, 3)
imshow(noise_image, 'noise Image')

plt.subplot(2, 2, 4)
imshow(input_image, 'Input Image')

plt.show()

In [None]:
def load_and_process_img(path_to_img):
  img = load_img(path_to_img)
  img = tf.keras.applications.vgg19.preprocess_input(img)
  return img

In [None]:
def deprocess_img(processed_img):
  x = processed_img.copy()
  if len(x.shape) == 4:
    x = np.squeeze(x, 0)
  assert len(x.shape) == 3, ("Input to deprocess image must be an image of "
                             "dimension [1, height, width, channel] or [height, width, channel]")
  if len(x.shape) != 3:
    raise ValueError("Invalid input to deprocessing image")
  
  x[:, :, 0] += 103.939
  x[:, :, 1] += 116.779
  x[:, :, 2] += 123.68
  x = x[:, :, ::-1]

  x = np.clip(x, 0, 255).astype('uint8')
  return x

##Loading model

In [None]:
content_layers = ['block5_conv2']
style_layers = ['block1_conv1','block2_conv1','block3_conv1','block4_conv1','block5_conv1']

num_content_layers = len(content_layers)
num_style_layers = len(style_layers)

In [None]:
def get_model():
  vgg = tf.keras.applications.vgg19.VGG19(include_top=False, weights='imagenet')
  vgg.trainable = False
  style_outputs = [vgg.get_layer(name).output for name in style_layers]
  content_outputs = [vgg.get_layer(name).output for name in content_layers]
  model_outputs = style_outputs + content_outputs 
  return models.Model(vgg.input, model_outputs)

##Loss functions

In [None]:
def get_content_loss(base_content, target):
  return tf.reduce_mean(tf.square(base_content - target))

In [None]:
def gram_matrix(input_tensor):
  channels = int(input_tensor.shape[-1])
  a = tf.reshape(input_tensor, [-1, channels])
  n = tf.shape(a)[0]
  gram = tf.matmul(a, a, transpose_a=True)
  return gram / tf.cast(n, tf.float32)

def get_style_loss(base_style, gram_target):
  height, width, channels = base_style.get_shape().as_list()
  gram_style = gram_matrix(base_style)
  
  return tf.reduce_mean(tf.square(gram_style - gram_target))

In [None]:
def get_feature_representations(model, content_path, style_path):
  content_image = load_and_process_img(content_path)
  style_image = load_and_process_img(style_path)
  input_image = generate_noise_image(content_image)
  
  style_outputs = model(style_image)
  content_outputs = model(content_image)
  
  style_features = [style_layer[0] for style_layer in style_outputs[:num_style_layers]]
  content_features = [content_layer[0] for content_layer in content_outputs[num_style_layers:]]
  return style_features, content_features

In [None]:
def compute_loss(model, loss_weights, init_image, gram_style_features, content_features):
  
  style_weight, content_weight = loss_weights

  model_outputs = model(init_image)
  
  style_output_features = model_outputs[:num_style_layers]
  content_output_features = model_outputs[num_style_layers:]
  
  style_score = 0
  content_score = 0

  weight_per_style_layer = 1.0 / float(num_style_layers)
  for target_style, comb_style in zip(gram_style_features, style_output_features):
    style_score += weight_per_style_layer * get_style_loss(comb_style[0], target_style)
    
  weight_per_content_layer = 1.0 / float(num_content_layers)
  for target_content, comb_content in zip(content_features, content_output_features):
    content_score += weight_per_content_layer* get_content_loss(comb_content[0], target_content)
  
  style_score *= style_weight
  content_score *= content_weight

  loss = style_score + content_score 
  return loss, style_score, content_score

##Optimization 

In [None]:
def compute_grads(cfg):
  with tf.GradientTape() as tape: 
    all_loss = compute_loss(**cfg)
  total_loss = all_loss[0]
  return tape.gradient(total_loss, cfg['init_image']), all_loss

In [None]:
import IPython.display

def run_style_transfer(content_path,style_path,num_iterations=1000,content_weight=1e3,style_weight=1e-3): 
  model = get_model()
  for layer in model.layers:
    layer.trainable = False
  
  style_features, content_features = get_feature_representations(model, content_path, style_path)
  gram_style_features = [gram_matrix(style_feature) for style_feature in style_features]
  
  content = load_and_process_img(content_path)
  init_image = generate_noise_image(content)
  init_image = tf.Variable(init_image, dtype=tf.float32)

  opt = tf.train.AdamOptimizer(learning_rate=5, beta1=0.99, epsilon=1e-1)

  iter_count = 1
  
  best_loss, best_img = float('inf'), None
  
  loss_weights = (style_weight, content_weight)
  cfg = {
      'model': model,
      'loss_weights': loss_weights,
      'init_image': init_image,
      'gram_style_features': gram_style_features,
      'content_features': content_features
  }
    
  num_rows = 2
  num_cols = 5
  display_interval = num_iterations/(num_rows*num_cols)
  start_time = time.time()
  global_start = time.time()
  
  norm_means = np.array([103.939, 116.779, 123.68])
  min_vals = -norm_means
  max_vals = 255 - norm_means   
  
  imgs = []
  loss_list = []
  style_score_list = []
  content_score_list = []
  for i in range(num_iterations):
    grads, all_loss = compute_grads(cfg)
    loss, style_score, content_score = all_loss
    loss_list.append(loss.numpy())
    style_score_list.append(style_score.numpy())
    content_score_list.append(content_score.numpy())

    opt.apply_gradients([(grads, init_image)])
    clipped = tf.clip_by_value(init_image, min_vals, max_vals)
    init_image.assign(clipped)
    end_time = time.time() 
    
    if loss < best_loss:
      best_loss = loss
      best_img = deprocess_img(init_image.numpy())

    if i % display_interval== 0:
      start_time = time.time()
      
      plot_img = init_image.numpy()
      plot_img = deprocess_img(plot_img)
      imgs.append(plot_img)
      IPython.display.clear_output(wait=True)
      IPython.display.display_png(Image.fromarray(plot_img))
      print('Iteration: {}'.format(i))        
      print('Total loss: {:.4e}, ' 
            'style loss: {:.4e}, '
            'content loss: {:.4e}, '
            'time: {:.4f}s'.format(loss, style_score, content_score, time.time() - start_time))
  print('Total time: {:.4f}s'.format(time.time() - global_start))
  IPython.display.clear_output(wait=True)
  plt.figure(figsize=(14,4))
  for i,img in enumerate(imgs):
      plt.subplot(num_rows,num_cols,i+1)
      plt.imshow(img)
      plt.xticks([])
      plt.yticks([])
      
  return best_img, best_loss,loss_list, style_score_list, content_score_list

In [None]:
best, best_loss,loss_list, style_score_list,content_score_list = run_style_transfer(content_path,style_path, num_iterations=1000)

In [None]:
Image.fromarray(best)

In [None]:
def show_results(best_img, content_path, style_path, show_large_final=True):
  plt.figure(figsize=(10, 5))
  content = load_img(content_path) 
  style = load_img(style_path)

  plt.subplot(1, 2, 1)
  imshow(content, 'Content Image')

  plt.subplot(1, 2, 2)
  imshow(style, 'Style Image')

  if show_large_final: 
    plt.figure(figsize=(10, 10))

    plt.imshow(best_img)
    plt.title('Output Image')
    plt.show()

In [None]:
show_results(best, content_path, style_path)

##Trainning on different conbination

In [None]:
num_iterations= 1000
iteration = range(num_iterations)
from matplotlib.pyplot import figure

fig = plt.figure(figsize=(9, 6))

plt.subplots_adjust(wspace= 0.25, hspace= 0.5)

sub1 = fig.add_subplot(2,2,1) # two rows, two columns, fist cell
plt.plot(iteration, content_score_list, color='green')
plt.xlabel('Number of Iterations')
plt.ylabel('content_score_list')
plt.title('Weighted Content Loss Plot')

# Create second axes, the top-left plot with orange plot
sub2 = fig.add_subplot(2,2,2) # two rows, two columns, second cell
plt.plot(iteration, style_score_list)
plt.xlabel('Number of Iterations')
plt.ylabel('style_score_list')
plt.title('Weighted Style Loss Plot')

# Create third axes, a combination of third and fourth cell
sub3 = fig.add_subplot(2,2,(3,4)) # two rows, two colums, combined third and fourth cell
plt.plot(iteration, loss_list, color='red')
plt.xlabel('Number of Iterations')
plt.ylabel('Total Loss')
plt.title('Total Loss Plot');