<a href="https://colab.research.google.com/github/MaxHuerlimann/AdaIN-Style-tf2/blob/master/AdaIN_Tensorflow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Tensorflow implementation of *AdaIN-Style* network

In [0]:
# Install torchfile to extract torch model
!git clone https://github.com/bshillingford/python-torchfile.git
%cd python-torchfile/
!python setup.py install
%cd ..

In [0]:
import os
import math
import time
from pathlib import Path
from PIL import Image
import datetime

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torchfile
import cv2

In [0]:
%tensorflow_version 2.x
import tensorflow as tf
import tensorflow_addons as tfa

from tensorflow.python.keras.preprocessing import image as kp_image
from tensorflow.python.keras import models 
from tensorflow.python.keras import losses
from tensorflow.python.keras import layers
from tensorflow.python.keras import backend as K

%load_ext tensorboard

# Tensorboard

In [0]:
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = 'logs/gradient_tape/' + current_time + '/train'
summary_writer = tf.summary.create_file_writer(train_log_dir)

Setup Datasets

Get kaggle configurations

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')
!pip3 install -q kaggle
!mkdir -p ~/.kaggle
!cp /content/gdrive/'My Drive'/Kaggle/kaggle.json ~/.kaggle/
!ls ~/.kaggle
!chmod 600 /root/.kaggle/kaggle.json  # set permission
drive.flush_and_unmount()

Download the data from kaggle

In [0]:
img_dir = Path('/content/data/train_images')
data_package = 'train_1'

In [0]:
if not os.path.exists(img_dir):
  os.makedirs(img_dir)
!kaggle competitions download painter-by-numbers -f {data_package}.zip -p {img_dir.as_posix()}

In [0]:
!unzip -q {(img_dir / data_package).as_posix()} -d {img_dir.as_posix()}

In [0]:
!rm {img_dir.as_posix()}/*.zip

Download MSCOCO dataset

In [0]:
!wget -p {(img_dir / 'ms_coco.zip').as_posix()} http://images.cocodataset.org/zips/train2017.zip

In [0]:
!mv /content/images.cocodataset.org/zips/train2017.zip /content/data/train_images/ms_coco.zip

In [0]:
!unzip -q {(img_dir / 'ms_coco').as_posix()} -d {img_dir.as_posix()}

In [0]:
!rm {img_dir.as_posix()}/*.zip

Download vgg weights

In [0]:
!wget -c https://s3.amazonaws.com/xunhuang-public/adain/vgg_normalised.t7

Visualize Images

In [0]:
def load_img(path_to_img):
  max_dim = 512
  img = Image.open(path_to_img)
  long = min(img.size)
  scale = max_dim/long
  img = img.resize((round(img.size[0]*scale), round(img.size[1]*scale)), Image.ANTIALIAS)
  
  img = kp_image.img_to_array(img)
  
  # We need to broadcast the image array such that it has a batch dimension 
  img = np.expand_dims(img, axis=0)
  return img

In [0]:
  def imshow(img, title=None):
    # Remove the batch dimension
    out = np.squeeze(img, axis=0)
    # Normalize for display 
    out = out.astype('uint8')
    plt.imshow(out)
    if title is not None:
      plt.title(title)
    plt.imshow(out)

In [0]:
content_path = img_dir / 'train2017' / '000000000009.jpg'
style_path = img_dir / data_package / '1007.jpg'

In [0]:
plt.figure(figsize=(10,10))

content = load_img(content_path.as_posix()).astype('uint8')
style = load_img(style_path.as_posix()).astype('uint8')

plt.subplot(1, 2, 1)
imshow(content, 'Content Image')

plt.subplot(1, 2, 2)
imshow(style, 'Style Image')
plt.show()  

Data Preprocessing

In [0]:
# Clean datasets
corrupted_files = []
def check_jpegs(path_to_img_dir):
  """Check if the jpegs can be loaded, otherwise list names."""
  for f in path_to_img_dir.glob('*.jpg'):
    print(f)
    data = cv2.imread(f.as_posix())
    print(data.size)


In [0]:
# check_jpegs((img_dir / 'train2017'))
# check_jpegs((img_dir / data_package))

In [0]:
# Workaround for keras application bug only being able to be initialized once
# Add before any TF calls
# Initialize the keras global outside of any tf.functions
temp = tf.random.uniform([4, 32, 32, 3])  # Or tf.zeros
tf.keras.applications.vgg16.preprocess_input(temp)
print("Noice")

In [0]:
def load_and_process_img(path_to_img):
  """Load the image and preprocess according to trained VGG19 model standards.
  """
  new_min_dim = 512
  img = tf.io.read_file(path_to_img)
  # This creates RGB image
  try:
    img = tf.image.decode_jpeg(img, channels=3)
  except:
    return None

  # Scale minimum dimension to 512px
  height = tf.cast(tf.shape(img)[0], tf.float32)
  width = tf.cast(tf.shape(img)[1], tf.float32)
  min_dim = tf.minimum(height, width)
  scale = new_min_dim / min_dim
  img = tf.image.resize(img, (scale*height, scale*width))

  # This scales pixel values and reorders channels to BGR
  #img = tf.keras.applications.vgg19.preprocess_input(img)
  # img = tf.image.convert_image_dtype(img, tf.float32)
  # img = tf.cast(img, tf.float32)
  img /= 255.

  return img


def preprocess_img(img):
  """Preprocess image."""
  crop_size = 256
  img = tf.image.random_crop(img, (crop_size, crop_size, 3))

  return img


def process_path(path_to_img):
  img = load_and_process_img(path_to_img)
  img = preprocess_img(img)
  return img

Create tensorflow Dataset

In [0]:
BATCH_SIZE = 8
AUTOTUNE = tf.data.experimental.AUTOTUNE

def prepare_dataset(data):
  data = data.map(process_path)
  # data = data.repeat()
  data = data.batch(BATCH_SIZE)
  data = data.prefetch(AUTOTUNE)

  return data

In [0]:
# Get number of train data
NUM_STYLE_IMAGES = len(list((img_dir / data_package).glob('*.jpg')))
NUM_CONTENT_IMAGES = len(list((img_dir / 'train2017').glob('*.jpg')))
NUM_STYLE_BATCHES = math.ceil(NUM_STYLE_IMAGES / BATCH_SIZE)
NUM_CONTENT_BATCHES = math.ceil(NUM_CONTENT_IMAGES / BATCH_SIZE)

In [0]:
# Style images
style_dataset = tf.data.Dataset.list_files(str(img_dir / data_package / '*.jpg'))
style_dataset = prepare_dataset(style_dataset)
# Content images
content_dataset = tf.data.Dataset.list_files(str(img_dir / 'train2017' / '*.jpg'))
content_dataset = prepare_dataset(content_dataset)

Model definition

In [0]:
# Content layer where will pull our feature maps
content_layers = ['conv4_1'] 

# Style layer we are interested in
style_layers = ['conv1_1',
                'conv2_1',
                'conv3_1', 
                'conv4_1' 
               ]

num_content_layers = len(content_layers)
num_style_layers = len(style_layers)

In [0]:
# Import torch model into tensorflow
def get_encoder_from_torch(target_layer='relu4_1'):
  """Load a model from t7 and translate it to tensorflow."""
  t7 = torchfile.load('/content/vgg_normalised.t7', force_8bytes_long=True)

  inputs = tf.keras.Input((None, None, 3), name="vgg_input")

  x = inputs
    
  style_outputs = []
  content_outputs = []
  for idx,module in enumerate(t7.modules):
    name = module.name.decode() if module.name is not None else None
    
    if idx == 0:
      name = 'preprocess'  # VGG 1st layer preprocesses with a 1x1 conv to multiply by 255 and subtract BGR mean as bias

    if module._typename == b'nn.SpatialReflectionPadding':
      x = tf.keras.layers.Lambda(
          lambda t: tf.pad(t, [[0, 0], [1, 1], [1, 1], [0, 0]],
          mode='REFLECT'))(x)            
    elif module._typename == b'nn.SpatialConvolution':
      filters = module.nOutputPlane
      kernel_size = module.kH
      weight = module.weight.transpose([2,3,1,0])
      bias = module.bias
      x = layers.Conv2D(filters, kernel_size, padding='valid', activation='relu', name=name,
                    kernel_initializer=tf.constant_initializer(weight),
                    bias_initializer=tf.constant_initializer(bias),
                    trainable=False)(x)
      if name in style_layers:
        style_outputs.append(x)
      if name in content_layers:
        content_outputs.append(x)
    elif module._typename == b'nn.ReLU':
      pass # x = layers.Activation('relu', name=name)(x)
    elif module._typename == b'nn.SpatialMaxPooling':
      x = layers.MaxPooling2D(padding='same', name=name)(x)
    else:
      raise NotImplementedError(module._typename)

    if name == target_layer:
      # print("Reached target layer", target_layer)
      break
  
  # Get output layers corresponding to style and content layers 
  #style_outputs = [vgg.get_layer(name).output for name in style_layers]
  #content_outputs = [vgg.get_layer(name).output for name in content_layers]
  model_outputs = style_outputs + content_outputs

  return models.Model(inputs=inputs, outputs=model_outputs)

In [0]:
def get_encoder():
  """ Creates encoder from VGG19 model.
  
  This function will load the VGG19 model and access the intermediate layers. 
  These layers will then be used to create a new model that will take input image
  and return the outputs from these intermediate layers from the VGG model. 
  
  Returns:
    returns a keras model that takes image inputs and outputs the style and 
      content intermediate layers. 
  """
  # Load our model. We load pretrained VGG, trained on imagenet data
  vgg = tf.keras.applications.vgg19.VGG19(include_top=False, weights='imagenet')
  vgg.trainable = False
  # Get output layers corresponding to style and content layers 
  style_outputs = [vgg.get_layer(name).output for name in style_layers]
  content_outputs = [vgg.get_layer(name).output for name in content_layers]
  model_outputs = style_outputs + content_outputs
  # Build model 
  return models.Model(vgg.input, model_outputs)

In [0]:
def get_decoder(encoder):
  """Creates a trainable decoder, that mirrors the encoder.

  Pooling layers are replaced with nearest up-sampling layers and reflection
  padding is used to avoid border artifacts.
  """
  decoder = tf.keras.Sequential()
  
  inputs = tf.keras.Input((None, None, encoder.layers[-1].filters))
  # Mirror the encoder
  x = inputs
  for i in reversed(range(4, len(encoder.layers))):
    layer = encoder.layers[i]
    if isinstance(layer, layers.MaxPooling2D):
      x = layers.UpSampling2D()(x)
    elif isinstance(layer, layers.Conv2D):
      x = layers.Conv2D(
          layer.get_weights()[0].shape[2], 
          layer.kernel_size, 
          activation=tf.keras.activations.relu)(
              tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]],
              mode='REFLECT'))

  # Finally reduce number of channels to three
  x = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]],
             mode='REFLECT')
  x = tf.keras.layers.Conv2D(3, 
                             3)(x) 
                             # activation=tf.keras.activations.relu)(x)
  outputs = x
    
  return models.Model(inputs, outputs)

In [0]:
def adaptive_instance_normalization(x, y):
  """Aligning the mean and variance of y onto x."""
  eps = 1e-4
  x_mean, x_var = tf.nn.moments(x, [1,2], keepdims=True)
  x_std = tf.math.sqrt(x_var)
  y_mean, y_var = tf.nn.moments(y, [1,2], keepdims=True)
  y_std = tf.math.sqrt(y_var)
  # result = y_std * (x - x_mean) / (x_std + eps) + y_mean 
  result = tf.nn.batch_normalization(x, x_mean, x_var, y_mean, y_std, eps)
  return result

In [0]:
encoder = get_encoder_from_torch()
decoder = get_decoder(encoder)

In [0]:
# print(encoder.summary())

In [0]:
# print(decoder.summary())

# Define costs

In [0]:
def get_content_loss(adain_output, target_encoded):
  return tf.reduce_mean(tf.square(adain_output - target_encoded))

In [0]:
def get_style_loss(base_style_encoded, target_encoded):
  eps = 1e-5
  
  base_style_mean, base_style_var = tf.nn.moments(base_style_encoded, 
                                                  axes=[1,2])
  # Add epsilon for numerical stability for gradients close to zero
  base_style_std = tf.math.sqrt(base_style_var + eps)

  target_mean, target_var = tf.nn.moments(target_encoded,
                                          axes=[1,2])
  # Add epsilon for numerical stability for gradients close to zero
  target_std = tf.math.sqrt(target_var + eps)

  mean_diff = tf.reduce_sum(tf.square(base_style_mean - target_mean)) / BATCH_SIZE
  std_diff = tf.reduce_sum(tf.square(base_style_std - target_std)) / BATCH_SIZE
  return mean_diff + std_diff

In [0]:
STYLE_LOSS_WEIGHT = 1

def get_loss(adain_output, base_style_encoded, target_encoded):
  # Content loss
  content_loss = get_content_loss(adain_output, target_encoded[-1])
  
  # Style loss
  style_loss = 0
  for i in range(num_style_layers):
    style_loss += get_style_loss(base_style_encoded[i], target_encoded[i])

  return content_loss + STYLE_LOSS_WEIGHT * style_loss

# Train

In [0]:
def decode_img(img, reverse_channels=False):
  """Decodes preprocessed images."""

  # perform the inverse of the preprocessiing step
  img *= 255.
  if reverse_channels:
    img = img[..., ::-1]

  img = tf.cast(img, dtype=tf.uint8)
  return img

In [0]:
optimizer = tf.keras.optimizers.Adam()

train_loss = tf.keras.metrics.Mean(name='train_loss')

In [0]:
@tf.function
def train_step(content_img, style_img, step):
  with tf.GradientTape() as tape:
    encoded_content_img = encoder(content_img)
    encoded_style_img = encoder(style_img)
    tape.watch(encoded_content_img + encoded_style_img)

    adain_output = adaptive_instance_normalization(encoded_content_img[-1],
                                        encoded_style_img[-1])

    target_img = decoder(adain_output)

    loss = get_loss(adain_output, encoded_style_img, encoder(target_img))

    if step % 50 == 0:
      with summary_writer.as_default():
        tf.summary.image("Target image", decode_img(
            target_img, 
            reverse_channels=True), 
            step=step)
        tf.summary.scalar("Loss", loss, step=step)

  gradients = tape.gradient(loss, decoder.trainable_variables)
  optimizer.apply_gradients(zip(gradients, decoder.trainable_variables))

  train_loss(loss)

In [0]:
@tf.function
def test_step(content_image, style_image):
  encoded_content_img = encoder(content_img)
  encoded_style_img = encoder(style_img)
  # Only feed the last layer to AdaIN
  t = adaptive_instance_normalization(encoded_content_img[-1],
                                      encoded_style_img[-1])
  target_img = decoder(t)
  loss = get_loss(content_img, style_img, target_img)

In [0]:
EPOCHS = 5
PROGBAR = tf.keras.utils.Progbar(NUM_STYLE_BATCHES*NUM_CONTENT_BATCHES)

for epoch in range(EPOCHS):
  # Reset the metrics at the start of the next epoch
  train_loss.reset_states()

  step = 0
  start_time = time.perf_counter()
  for i, content_images in enumerate(content_dataset):
    for j, style_images in enumerate(style_dataset):
      # Handle loading errors
      if content_images is None or style_images is None:
        if content_images is None:
          tf.print("Content image couldn't be loaded.")
        if style_images is None:
          tf.print("Style image couldn't be loaded.")
        step += 1
        PROGBAR.update(step)
        break
      # print(f"Image loading: {time.perf_counter() - start_time}")
      # start_time = time.perf_counter()
      # Using the file writer, log the reshaped image.
      if step % 50 == 0:
        with summary_writer.as_default():
          tf.summary.image("Style data", decode_img(style_images), step=step)
          tf.summary.image("Content data", decode_img(content_images), step=step)
      # print(f"Summary: {time.perf_counter() - start_time}")

      start_time = time.perf_counter()
      train_step(content_images, style_images, tf.constant(step, dtype=tf.int64))
      print(f"Train step: {time.perf_counter() - start_time}")
      # start_time = time.perf_counter()
      step += 1
      PROGBAR.update(step)

  template = 'Epoch {}, Loss: {}'
  print(template.format(epoch+1,
                        train_loss.result()))

Evalutation

In [0]:
%tensorboard --logdir logs