<a href="https://colab.research.google.com/github/atsukoba/neural-sketches/blob/master/style_transfer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Style Transfer Sample

In [1]:
import tensorflow as tf
import IPython.display as display

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (14, 14)
mpl.rcParams['axes.grid'] = False

import numpy as np
import PIL.Image
import time
import functools


def tensor_to_image(tensor):
    tensor = tensor*255
    tensor = np.array(tensor, dtype=np.uint8)
    if np.ndim(tensor)>3:
        assert tensor.shape[0] == 1
        tensor = tensor[0]
    return PIL.Image.fromarray(tensor)


def load_img(path_to_img):
    max_dim = 512
    img = tf.io.read_file(path_to_img)
    img = tf.image.decode_image(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    shape = tf.cast(tf.shape(img)[:-1], tf.float32)
    long_dim = max(shape)
    scale = max_dim / long_dim
    new_shape = tf.cast(shape * scale, tf.int32)
    img = tf.image.resize(img, new_shape)
    img = img[tf.newaxis, :]
    return img


def imshow(image, title=None):
    if len(image.shape) > 3:
        image = tf.squeeze(image, axis=0)

    plt.imshow(image)
    if title:
        plt.title(title)


## Use Pretrained Model via TensorHub

In [3]:
import os
from google.colab import files
import tensorflow_hub as hub


hub_module = hub.load('https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/1')


def transfer_style(path_to_image: str, path_to_style: str, save_file=None):
    assert os.path.exists(path_to_image), "Image File Not Found"
    assert os.path.exists(path_to_style), "Style File Not Found"
    print(f"Load image: {path_to_image}")
    content_image = load_img(path_to_image)
    print(f"Load image: {path_to_style}")
    style_image = load_img(path_to_style)
    print("Transfering style...")
    stylized_image = hub_module(tf.constant(content_image), tf.constant(style_image))[0]
    if save_file is not None:
        img = tensor_to_image(stylized_image)
        img.save(save_file)
        return img
    return tensor_to_image(stylized_image)


## Image for Content/Style

In [25]:
import os
import requests
import cv2
import numpy as np
from IPython.display import display, Javascript
from google.colab.output import eval_js
from base64 import b64decode
from IPython.display import Image


def upload_image(select_texture=True):
    print("Upload Your Photo !")
    uploaded = files.upload()
    fname = list(uploaded.keys())[0]
    return fname


VIDEO_JS = Javascript('''
async function takePhoto(quality) {
  // Create a video and play it.
  const video = document.createElement('video')
  document.body.appendChild(video)
  video.srcObject = await navigator.mediaDevices.getUserMedia({video: true})
  await video.play()
  // Resize the output to fit the video element.
  google.colab.output.setIframeHeight(document.documentElement.scrollHeight, true)
  // Wait for video to be clicked.
  await new Promise((resolve) => video.onclick = resolve)
  const canvas = document.createElement('canvas')
  canvas.width = video.videoWidth
  canvas.height = video.videoHeight
  canvas.getContext('2d').drawImage(video, 0, 0)
  video.srcObject.getVideoTracks()[0].stop()
  video.remove()
  return canvas.toDataURL('image/jpeg', quality)
}
''')

def take_photo(filename='photo.jpg', quality=0.8): 
    display(VIDEO_JS)
    data = eval_js('takePhoto({})'.format(quality))
    binary = b64decode(data.split(',')[1])
    with open(filename, 'wb') as f:
        f.write(binary)
    return filename


def url_to_image(url):
    # download the image, convert it to a NumPy array, and then read
    # it into OpenCV format
    resp = requests.get(url)
    image = np.asarray(bytearray(resp.content), dtype="uint8")
    image = cv2.imdecode(image, cv2.IMREAD_COLOR)
    # return the image
    return image


## Try it

In [27]:
transfer_style(
    take_photo(),  # Take Your Photo
    upload_image()  # Upload Image of Style
)