In [None]:
import tensorflow as tf
import numpy as np
import PIL
import matplotlib.pyplot as plt

In [None]:
def tensor_to_image(tensor):
  tensor = tensor*255
  tensor = np.array(tensor, dtype=np.uint8)
  if np.ndim(tensor)>3:
    assert tensor.shape[0] == 1
    tensor = tensor[0]
  return PIL.Image.fromarray(tensor)

In [None]:
def load_images(path_to_content,path_to_style):
  contentimg = tf.io.read_file(path_to_content)
  contentimg = tf.image.decode_image(contentimg, channels=3)
  contentimg = tf.image.convert_image_dtype(contentimg, tf.float32)

  styleimg = tf.io.read_file(path_to_style)
  styleimg = tf.image.decode_image(styleimg, channels=3)
  styleimg = tf.image.convert_image_dtype(styleimg, tf.float32)

  contentshape = tf.cast(tf.shape(contentimg)[:-1], tf.float32)
  styleshape = tf.cast(tf.shape(contentimg)[:-1], tf.float32)
  
  max_dim = 256
  style_long_dim = max(styleshape)
  style_scale = max_dim / style_long_dim

  new_shape = tf.cast(styleshape * style_scale, tf.int32)

  styleimg = tf.image.resize(styleimg, new_shape)
  styleimg = styleimg[tf.newaxis, :]


  max_dim = 720
  content_long_dim = max(contentshape)
  content_scale = max_dim / content_long_dim

  new_shape = tf.cast(contentshape * content_scale, tf.int32)

  contentimg = tf.image.resize(contentimg, new_shape)
  contentimg = contentimg[tf.newaxis, :]
  return contentimg,styleimg

Create a simple function to display an image:

In [None]:
def imshow(image, title=None):
  if len(image.shape) > 3:
    image = tf.squeeze(image, axis=0)

  plt.imshow(image)
  if title:
    plt.title(title)

In [None]:
content_image,style_image = load_images('rads.jpeg',"starrynight.jpg")

plt.subplot(1, 2, 1)
imshow(content_image, 'Content Image')

plt.subplot(1, 2, 2)
imshow(style_image, 'Style Image')

In [None]:
content_layers = ['block5_conv2'] 

style_layers = ['block1_conv1',
                'block2_conv1',
                'block3_conv1', 
                'block4_conv1', 
                'block5_conv1']

num_content_layers = len(content_layers)
num_style_layers = len(style_layers)

In [None]:
def vgg_layers(layer_names,shape):
  """ Creates a VGG model that returns a list of intermediate output values."""
  sh=(shape[1],shape[2],shape[3])
  print(sh)
  vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
  vgg.trainable = False
  
  outputs = [vgg.get_layer(name).output for name in layer_names]

  model = tf.keras.Model([vgg.input], outputs)
  return model
image = tf.Variable(content_image)
vgg_layers(style_layers,image.shape)


In [None]:
def gram_matrix(input_tensor):
  result = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor)
  input_shape = tf.shape(input_tensor)
  num_locations = tf.cast(input_shape[1]*input_shape[2], tf.float32)
  return result/(num_locations)

In [None]:
class StyleContentModel(tf.keras.models.Model):
  def __init__(self, style_layers, content_layers,shape):
    super(StyleContentModel, self).__init__()
    self.vgg = vgg_layers(style_layers + content_layers,shape)
    self.style_layers = style_layers
    self.content_layers = content_layers
    self.num_style_layers = len(style_layers)
    self.vgg.trainable = False

  def call(self, inputs):
    "Expects float input in [0,1]"
    inputs = inputs*255.0
    preprocessed_input = tf.keras.applications.vgg19.preprocess_input(inputs)
    outputs = self.vgg(preprocessed_input)
    style_outputs, content_outputs = (outputs[:self.num_style_layers],
                                      outputs[self.num_style_layers:])

    style_outputs = [gram_matrix(style_output)
                     for style_output in style_outputs]

    content_dict = {content_name: value
                    for content_name, value
                    in zip(self.content_layers, content_outputs)}

    style_dict = {style_name: value
                  for style_name, value
                  in zip(self.style_layers, style_outputs)}

    return {'content': content_dict, 'style': style_dict}

In [None]:
image = tf.Variable(content_image)
# image = tf.Variable(np.random.randn(1,content_image.shape[1],content_image.shape[2],content_image.shape[3],),dtype=tf.float32)
extractor = StyleContentModel(style_layers, content_layers,image.shape)
style_targets = extractor(style_image)['style']
content_targets = extractor(content_image)['content']
def clip_0_1(image):
  return tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0)
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    0.1,
    decay_steps=100,
    decay_rate=0.8,
    staircase=True)
opt = tf.keras.optimizers.Adam(learning_rate=lr_schedule, beta_1=0.99, epsilon=1e-1)
style_weight=1e-1
content_weight=1e4
total_variation_weight=20

In [None]:
def style_content_loss(outputs):
    style_outputs = outputs['style']
    content_outputs = outputs['content']
    style_loss = tf.add_n([tf.reduce_mean((style_outputs[name]-style_targets[name])**2) 
                           for name in style_outputs.keys()])
    style_loss *= style_weight / num_style_layers

    content_loss = tf.add_n([tf.reduce_mean((content_outputs[name]-content_targets[name])**2) 
                             for name in content_outputs.keys()])
    content_loss *= content_weight / num_content_layers
    loss =  content_loss+style_loss
    return loss

In [None]:
@tf.function()
def train_step(image):
  with tf.GradientTape() as tape:
    outputs = extractor(image)
    loss = style_content_loss(outputs)
    loss += total_variation_weight*tf.image.total_variation(image)

  grad = tape.gradient(loss, image)
  opt.apply_gradients([(grad, image)])
  image.assign(clip_0_1(image))
  return loss

In [None]:
import time
import IPython.display as display
start = time.time()

epochs = 50
steps_per_epoch = 100
lossarr=[]
tdiffarr=[]
step = 0
for n in range(epochs):
  for m in range(steps_per_epoch):
    step += 1
    l=train_step(image)
    lossarr.append(l)
    print(".", end='', flush=True)
    tdif=tf.reduce_sum(lossarr[-20:-1]-lossarr[-1])/20
    tdiffff=tf.reduce_sum(lossarr[-50:-1]-lossarr[-1])
    tdiffarr.append(tdif)
  display.clear_output(wait=True)
  display.display(tensor_to_image(image))
  
  if(tdif<0):
    print("EARLY STOPPING")
    break
  # elif(tdiffff/tdif<1)
  print(tdif/tdiffff)
  print("Train step: {}".format(step))

end = time.time()
print("Total time: {:.1f}".format(end-start))

In [None]:
plt.plot(lossarr)
