In [25]:
from __future__ import print_function

import os
import subprocess
import tempfile

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tarfile
from keras import layers, Model
from keras.optimizers import Adam
from keras import backend as K
from keras.layers import (
    Input,
    Dense,
    Conv2D,
    MaxPooling2D,
    UpSampling2D,
    Flatten,
    Reshape,
    Conv2DTranspose,
    LeakyReLU,
    BatchNormalization,
    Activation,
    Dropout,
    Rescaling,
)

from six.moves import cPickle
from six.moves import urllib
from six.moves import xrange

In [26]:
img_height, img_width = 256, 256
batch_size = 32

In [27]:
anime_train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "../data/anime_face/images",
    image_size=(img_height, img_width),
    batch_size=batch_size,
    label_mode=None,
)

Found 63565 files belonging to 1 classes.


In [28]:
anime_train_ds, anime_val_ds = anime_train_ds.take(50000), anime_train_ds.skip(50000)
anime_val_ds, anime_test_ds = anime_val_ds.take(10000), anime_val_ds.skip(10000)

In [29]:
def unpickle(filename):
    with open(filename, "rb") as fo:
        return cPickle.load(fo, encoding="latin1")

In [30]:
def reshape_flattened_image_batch(flat_image_batch):
    return flat_image_batch.reshape(-1, 3, 32, 32).transpose([0, 2, 3, 1])  # convert from NCHW to NHWC

In [31]:
def cast_and_normalise_images(data_dict):
    images = data_dict["images"]
    data_dict["images"] = (tf.cast(images, tf.float32) / 255.0) - 0.5
    return data_dict

In [32]:
anime_train_ds = anime_train_ds.map(lambda x: tf.image.resize(x, (32, 32)))

In [33]:
def residual_stack(h, num_hiddens, num_residual_layers, num_residual_hiddens):
  for i in range(num_residual_layers):
    h_i = tf.nn.relu(h)

    h_i = Conv2D(
        output_channels=num_residual_hiddens,
        kernel_shape=(3, 3),
        stride=(1, 1),
        name="res3x3_%d" % i)(h_i)
    h_i = tf.nn.relu(h_i)

    h_i = Conv2D(
        output_channels=num_hiddens,
        kernel_shape=(1, 1),
        stride=(1, 1),
        name="res1x1_%d" % i)(h_i)
    h += h_i
  return tf.nn.relu(h)

In [34]:
class Encoder(Model):
  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
               name='encoder'):
    super(Encoder, self).__init__(name=name)
    self._num_hiddens = num_hiddens
    self._num_residual_layers = num_residual_layers
    self._num_residual_hiddens = num_residual_hiddens
    
  def _build(self, x):
    h = Conv2D(
        output_channels=self._num_hiddens / 2,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="enc_1")(x)
    h = tf.nn.relu(h)

    h = Conv2D(
        output_channels=self._num_hiddens,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="enc_2")(h)
    h = tf.nn.relu(h)

    h = Conv2D(
        output_channels=self._num_hiddens,
        kernel_shape=(3, 3),
        stride=(1, 1),
        name="enc_3")(h)

    h = residual_stack(
        h,
        self._num_hiddens,
        self._num_residual_layers,
        self._num_residual_hiddens)
    return h

In [35]:
class Decoder(Model):
  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
               name='decoder'):
    super(Decoder, self).__init__(name=name)
    self._num_hiddens = num_hiddens
    self._num_residual_layers = num_residual_layers
    self._num_residual_hiddens = num_residual_hiddens
  
  def _build(self, x):
    h = Conv2D(
      output_channels=self._num_hiddens,
      kernel_shape=(3, 3),
      stride=(1, 1),
      name="dec_1")(x)

    h = residual_stack(
        h,
        self._num_hiddens,
        self._num_residual_layers,
        self._num_residual_hiddens)

    h = Conv2DTranspose(
        output_channels=int(self._num_hiddens / 2),
        output_shape=None,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="dec_2")(h)
    h = tf.nn.relu(h)

    x_recon = Conv2DTranspose(
        output_channels=3,
        output_shape=None,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="dec_3")(h)

    return x_recon

In [37]:
import sonnet as snt


In [39]:
batch_size = 32
image_size = 32

num_training_updates = 50000

num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2
embedding_dim = 64

num_embeddings = 512

commitment_cost = 0.25

vq_use_ema = False

decay = 0.99

learning_rate = 3e-4

In [None]:
train_dataset_iterator = (
    tf.data.Dataset.from_tensor_slices(anime_train_ds)
    .map(cast_and_normalise_images)
    .shuffle(10000)
    .repeat(-1)
    .batch(batch_size)).make_one_shot_iterator()

train_dataset_batch = train_dataset_iterator.get_next()


valid_dataset_iterator = (
    tf.data.Dataset.from_tensor_slices(anime_val_ds)
    .map(cast_and_normalise_images)
    .repeat(1)
    .batch(batch_size)).make_initializable_iterator()

train_dataset_batch = train_dataset_iterator.get_next()
valid_dataset_batch = valid_dataset_iterator.get_next()

valid_dataset_batch = valid_dataset_iterator.get_next()

In [None]:
encoder = Encoder(num_hiddens, num_residual_layers, num_residual_hiddens)
decoder = Decoder(num_hiddens, num_residual_layers, num_residual_hiddens)
pre_vq_conv1 = Conv2D(
    output_channels=embedding_dim, kernel_shape=(1, 1), stride=(1, 1), name="to_vq"
)

vq_vae = snt.nets.VectorQuantizer(
    embedding_dim=embedding_dim,
    num_embeddings=num_embeddings,
    commitment_cost=commitment_cost,
)

In [None]:
x = tf.placeholder(tf.float32, shape=(None, image_size, image_size, 3))
z = pre_vq_conv1(encoder(x))

In [None]:
vq_output_train = vq_vae(z, is_training=True)
x_recon = decoder(vq_output_train["quantize"])
recon_error = tf.reduce_mean((x_recon - x) ** 2)
loss = recon_error + vq_output_train["loss"]

vq_output_eval = vq_vae(z, is_training=False)
x_recon_eval = decoder(vq_output_eval["quantize"])

perplexity = vq_output_train["perplexity"]

optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(loss)
sess = tf.train.SingularMonitoredSession()

train_res_recon_error = []
train_res_perplexity = []

for i in xrange(num_training_updates):
    feed_dict = {x: sess.run(train_dataset_batch)}
    results = sess.run([train_op, recon_error, perplexity], feed_dict)
    train_res_recon_error.append(results[1])
    train_res_perplexity.append(results[2])

    if i % 100 == 0:
        print("%d iterations" % i)
        print("recon_error: %.3f" % np.mean(train_res_recon_error[-100:]))
        print("perplexity: %.3f" % np.mean(train_res_perplexity[-100:]))
        print()

In [None]:
f = plt.figure(figsize=(16, 8))
ax = f.add_subplot(1, 2, 1)
ax.plot(train_res_recon_error)
ax.set_yscale("log")
ax.set_title("NMSE.")

ax = f.add_subplot(1, 2, 2)
ax.plot(train_res_perplexity)
ax.set_title("Average codebook usage (perplexity).")