In [25]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from keras import layers, Model
from keras.optimizers import Adam
from keras import backend as K
from keras.layers import (
    Input,
    Dense,
    Conv2D,
    MaxPooling2D,
    UpSampling2D,
    Flatten,
    Reshape,
    Conv2DTranspose,
    LeakyReLU,
    BatchNormalization,
    Activation,
    Dropout,
    Rescaling,
)
from tqdm import tqdm
import sonnet as snt

In [26]:
img_height, img_width = 32, 32
batch_size = 32

In [27]:
anime_train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "../data/anime_face/images",
    image_size=(img_height, img_width),
    batch_size=batch_size,
    label_mode=None,
)

Found 63565 files belonging to 1 classes.


In [28]:
anime_train_ds, anime_val_ds = anime_train_ds.take(50000), anime_train_ds.skip(50000)
anime_val_ds, anime_test_ds = anime_val_ds.take(10000), anime_val_ds.skip(10000)

In [31]:
def cast_and_normalise_images(data):
    images = data
    data = (tf.cast(images, tf.float32) / 255.0)
    return data

In [33]:
def residual_stack(h, num_hiddens, num_residual_layers, num_residual_hiddens):

    for i in range(num_residual_layers):
        h_i = tf.nn.relu(h)

        h_i = Conv2D(
            output_channels=num_residual_hiddens,
            kernel_shape=(3, 3),
            name="res3x3_%d" % i,
        )(h_i)
        h_i = tf.nn.relu(h_i)

        h_i = Conv2D(
            output_channels=num_hiddens,
            kernel_shape=(1, 1),
            name="res1x1_%d" % i,
        )(h_i)

        h += h_i

    return tf.nn.relu(h)

In [None]:
def relu(x):
    if x > 0:
        return x
    return 0

In [None]:
class Encoder(Model):
    def __init__(
        self, num_hiddens, num_residual_layers, num_residual_hiddens, name="encoder"
    ):
        super(Encoder, self).__init__(name=name)
        self._num_hiddens = num_hiddens
        self._num_residual_layers = num_residual_layers
        self._num_residual_hiddens = num_residual_hiddens

    def _build(self, x):
        
        h = Conv2D(
            output_channels=self._num_hiddens / 2,
            kernel_shape=(4, 4),
            stride=(2, 2),
        )(x)
        h = relu(h)

        h = Conv2D(
            output_channels=self._num_hiddens,
            kernel_shape=(4, 4),
            stride=(2, 2),
        )(h)
        h = relu(h)

        h = Conv2D(
            output_channels=self._num_hiddens,
            kernel_shape=(3, 3),
            stride=(1, 1),
        )(h)

        h = residual_stack(
            h, self._num_hiddens, self._num_residual_layers, self._num_residual_hiddens
        )
        
        return h

In [35]:
class Decoder(Model):
    def __init__(
        self, num_hiddens, num_residual_layers, num_residual_hiddens, name="decoder"
    ):
        super(Decoder, self).__init__(name=name)
        self._num_hiddens = num_hiddens
        self._num_residual_layers = num_residual_layers
        self._num_residual_hiddens = num_residual_hiddens

    def _build(self, x):
        
        h = Conv2D(
            output_channels=self._num_hiddens,
            kernel_shape=(3, 3),
        )(x)

        h = residual_stack(
            h, self._num_hiddens, self._num_residual_layers, self._num_residual_hiddens
        )

        h = Conv2DTranspose(
            output_channels=int(self._num_hiddens / 2),
            output_shape=None,
            kernel_shape=(4, 4),
            stride=(2, 2),
        )(h)
        h = relu(h)

        x_recon = Conv2DTranspose(
            output_channels=3,
            output_shape=None,
            kernel_shape=(3, 3),
        )(h)

        return x_recon

In [39]:
num_training_updates = 15000
num_hiddens = 32
num_residual_hiddens = 32
num_residual_layers = 2
embedding_dim = 32
num_embeddings = 128
commitment_cost = 0.25

In [None]:
train_dataset_iterator = (
    tf.data.Dataset.from_tensor_slices(anime_train_ds)
    .map(cast_and_normalise_images)
    .shuffle(10000)
    .repeat(-1)
    .batch(batch_size)).make_one_shot_iterator()

train_dataset_batch = train_dataset_iterator.get_next()


valid_dataset_iterator = (
    tf.data.Dataset.from_tensor_slices(anime_val_ds)
    .map(cast_and_normalise_images)
    .repeat(1)
    .batch(batch_size)).make_initializable_iterator()

valid_dataset_batch = valid_dataset_iterator.get_next()

In [None]:
encoder = Encoder(num_hiddens, num_residual_layers, num_residual_hiddens)
decoder = Decoder(num_hiddens, num_residual_layers, num_residual_hiddens)
pre_vq_conv1 = Conv2D(output_channels=embedding_dim, kernel_shape=(1, 1), stride=(1, 1), name="to_vq")
vq_vae = snt.nets.VectorQuantizer(
    embedding_dim=embedding_dim,
    num_embeddings=num_embeddings,
    commitment_cost=commitment_cost,
)

In [None]:
x = tf.placeholder(tf.float32, shape=(None, 32, 32, 3))
z = pre_vq_conv1(encoder(x))

In [None]:
vq_output_train = vq_vae(z, is_training=True)
x_recon = decoder(vq_output_train["quantize"])
recon_error = tf.reduce_mean((x_recon - x) ** 2)
loss = recon_error + vq_output_train["loss"]

vq_output_eval = vq_vae(z, is_training=False)
x_recon_eval = decoder(vq_output_eval["quantize"])

optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
train_op = optimizer.minimize(loss)
sess = tf.train.SingularMonitoredSession()

In [None]:
train_res_recon_error = []

for i in tqdm(range(num_training_updates)):
    feed_dict = {x: sess.run(train_dataset_batch)}
    results = sess.run([train_op, recon_error], feed_dict)
    train_res_recon_error.append(results[1])

    if i % 100 == 0:
        print("%d iterations" % i)
        print("recon_error: %.3f" % np.mean(train_res_recon_error[-100:]))
        print()

In [None]:
plt.figure()
plt.plot(train_res_recon_error)
plt.show()