In [19]:
"""
Manish Aradwad / MTech AI / 19494
Pratyush Gauri / MTech AI / 20227
"""

import IPython.display as display
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12, 12)
mpl.rcParams['axes.grid'] = False
import PIL.Image
import time
import functools
import os
import tensorflow as tf
import tensorflow_hub as hub
# Load compressed models from tensorflow_hub
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'

In [20]:
# Loading images
content_path = tf.keras.utils.get_file('content.jpg', 'https://www.worldatlas.com/r/w1200/upload/1f/e7/fd/1280px-van-gogh-starry-night-google-art-project.jpg')
style1_path = tf.keras.utils.get_file('style1.jpg', 'https://www.worldatlas.com/r/w1200/upload/1f/e7/fd/1280px-van-gogh-starry-night-google-art-project.jpg')
style2_path = tf.keras.utils.get_file('style2.jpg', 'https://cdn.kastatic.org/ka-perseus-images/61a0415e6b2126c71c446100069d01440a13bf2f.jpg')
style3_path = tf.keras.utils.get_file('style3.jpg', 'https://cdn.theculturetrip.com/wp-content/uploads/2012/01/hokusai.jpg')

In [21]:
def tensor_to_image(tensor):
  tensor = tensor*255
  tensor = np.array(tensor, dtype=np.uint8)
  if np.ndim(tensor)>3:
    assert tensor.shape[0] == 1
    tensor = tensor[0]
  return PIL.Image.fromarray(tensor)

def load_img(path_to_img):
  max_dim = 512
  img = tf.io.read_file(path_to_img)
  img = tf.image.decode_image(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.float32)

  shape = tf.cast(tf.shape(img)[:-1], tf.float32)
  long_dim = max(shape)
  scale = max_dim / long_dim

  new_shape = tf.cast(shape * scale, tf.int32)

  img = tf.image.resize(img, (512, 512))
  img = img[tf.newaxis, :]
  return img

def imshow(image, title=None):
  if len(image.shape) > 3:
    image = tf.squeeze(image, axis=0)

  plt.imshow(image)
  if title:
    plt.title(title)

In [22]:
content_image = load_img(content_path)
style1_image = load_img(style1_path)
style2_image = load_img(style2_path)
style3_image = load_img(style3_path)
style_image = 0.35*style1_image + 0.1*style2_image + 0.45*style3_image

plt.figure(figsize=(18, 18))
plt.subplot(1, 5, 1)
imshow(content_image, 'Content Image')
plt.axis("off")

plt.subplot(1, 5, 2)
imshow(style1_image, 'Style 1 Image')
plt.axis("off")

plt.subplot(1, 5, 3)
imshow(style2_image, 'Style 2 Image')
plt.axis("off")

plt.subplot(1, 5, 4)
imshow(style3_image, 'Style 3 Image')
plt.axis("off")

In [23]:
hub_model = hub.load('https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2')
stylized_image = hub_model(tf.constant(content_image), tf.constant(style_image))[0]
imshow(stylized_image, 'Generated Image')
plt.axis("off")