##### Copyright 2019 The TensorFlow Hub Authors.

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
# Copyright 2019 The TensorFlow Hub Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Fast Style Transfer for Arbitrary Styles


Based on the model code in [magenta](https://github.com/tensorflow/magenta/tree/master/magenta/models/arbitrary_image_stylization) and the publication:

[Exploring the structure of a real-time, arbitrary neural artistic stylization
network](https://arxiv.org/abs/1705.06830).
*Golnaz Ghiasi, Honglak Lee,
Manjunath Kudlur, Vincent Dumoulin, Jonathon Shlens*,
Proceedings of the British Machine Vision Conference (BMVC), 2017.


## Setup

Import TF2 and all relevant dependencies:

In [None]:
import functools
import os

from matplotlib import gridspec
import matplotlib.pylab as plt
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub

print("TF Version: ", tf.__version__)
print("TF Hub version: ", hub.__version__)
print("Eager mode enabled: ", tf.executing_eagerly())
print("GPU available: ", tf.config.list_physical_devices('GPU'))

In [None]:
# @title Define image loading and visualization functions  { display-mode: "form" }

def crop_center(image):
  """Returns a cropped square image."""
  shape = image.shape
  new_shape = min(shape[1], shape[2])
  offset_y = max(shape[1] - shape[2], 0) // 2
  offset_x = max(shape[2] - shape[1], 0) // 2
  image = tf.image.crop_to_bounding_box(
      image, offset_y, offset_x, new_shape, new_shape)
  return image

@functools.lru_cache(maxsize=None)
def load_image(image_url, image_size=(256, 256), preserve_aspect_ratio=True):
  """Loads and preprocesses images."""
  # Cache image file locally.
  image_path = tf.keras.utils.get_file(os.path.basename(image_url)[-128:], image_url)
  # Load and convert to float32 numpy array, add batch dimension, and normalize to range [0, 1].
  img = tf.io.decode_image(
      tf.io.read_file(image_path),
      channels=3, dtype=tf.float32)[tf.newaxis, ...]
  img = crop_center(img)
  img = tf.image.resize(img, image_size, preserve_aspect_ratio=True)
  return img

def show_n(images, titles=('',)):
  n = len(images)
  image_sizes = [image.shape[1] for image in images]
  w = (image_sizes[0] * 6) // 320
  plt.figure(figsize=(w * n, w))
  gs = gridspec.GridSpec(1, n, width_ratios=image_sizes)
  for i in range(n):
    plt.subplot(gs[i])
    plt.imshow(images[i][0], aspect='equal')
    plt.axis('off')
    plt.title(titles[i] if len(titles) > i else '')
  plt.show()


## Import TF Hub module

In [None]:
# Load TF Hub module.

hub_handle = 'https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2'
hub_module = hub.load(hub_handle)

## Run style transfer on custom datasets

In [None]:
!mkdir datasets
!gdown --id 1K1wrFBFtUQp7plhnPkg6NV_15vbpZKvV --output ./datasets/reality2minecraft.zip
!mkdir ./datasets/reality2minecraft/
!unzip ./datasets/reality2minecraft.zip -d ./datasets/
!rm ./datasets/reality2minecraft.zip

In [None]:
!mkdir datasets
!gdown --id 1D2sVZo0bjQwzjajk44jtO96fzVyCsZho --output ./datasets/drawing2horse.zip
!mkdir ./datasets/drawing2horse/
!unzip ./datasets/drawing2horse.zip -d ./datasets/
!rm ./datasets/drawing2horse.zip

In [None]:
!mkdir datasets
!gdown --id 1PQtRu8lIVcGToYLmcQg1SuCL_rrKQpmr --output ./datasets/mask2nomask.zip
!mkdir ./datasets/mask2nomask/
!unzip ./datasets/mask2nomask.zip -d ./datasets/
!rm ./datasets/mask2nomask.zip

In [None]:
from PIL import Image
imsize = 256

def image_loader(image_name):
    image = Image.open(image_name).resize((imsize, imsize)).convert('RGBA')
    data = image.getdata()

    newData = []
    for item in data:
        if item[3] == 0:
            newData.append((255, 255, 255))
        else:
            newData.append(item[:3])

    image = image.convert('RGB')
    image.putdata(newData)
    image.save(image_name, "PNG")
    # fake batch dimension required to fit network's input dimensions
    img = tf.io.decode_image(
      tf.io.read_file(image_name),
      channels=3, dtype=tf.float32)[tf.newaxis, ...]
    img = crop_center(img)
    img = tf.image.resize(img, (imsize, imsize), preserve_aspect_ratio=True)
    return img

In [None]:
dataset = "./datasets/mask2nomask/"

content_images = {k: image_loader(dataset + "testA/" + k) for k in os.listdir(dataset + "testA")[:100]}
style_images = {k: image_loader(dataset + "testB/" + k) for k in os.listdir(dataset + "testB")[:100]}

In [None]:
#@title Specify the main content image and the style you want to use.  { display-mode: "form" }

for i in range(0, 10):
  content_name = list(content_images.keys())[i]
  style_name = list(style_images.keys())[i]

  stylized_image = hub_module(tf.constant(content_images[content_name]),
                              tf.constant(style_images[style_name]))[0]

  show_n([content_images[content_name], style_images[style_name], stylized_image],
        titles=['Original content image', 'Style image', 'Stylized image'])