In [46]:
import os
import time
import cv2
import tensorflow as tf
from keras import layers, Model
import matplotlib.pyplot as plt
import numpy as np
from keras import backend as K
from keras.layers import (
    Input,
    Dense,
    Conv2D,
    MaxPooling2D,
    UpSampling2D,
    Flatten,
    Reshape,
    Conv2DTranspose,
    LeakyReLU,
    BatchNormalization,
    Activation,
    Dropout,
    Rescaling,
    Concatenate,
)
import matplotlib.pyplot as plt
from keras.optimizers import Adam
from keras.utils import to_categorical
from keras.losses import MeanSquaredError

In [32]:
print(tf.config.list_physical_devices("GPU"))
K.clear_session()

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [33]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [34]:
img_height, img_width = 256, 256
batch_size = 32

In [35]:
anime_train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "../data/anime_face/images",
    image_size=(img_height, img_width),
    batch_size=batch_size,
    label_mode=None,
)

Found 63565 files belonging to 1 classes.


In [36]:
cartoon_train_ds = None

for i in range(5):
    cartoon_ds = tf.keras.preprocessing.image_dataset_from_directory(
        f"../data/cartoonset100k_jpg/{i}",
        image_size=(img_height, img_width),
        batch_size=batch_size,
        label_mode=None,
    )
    if cartoon_train_ds is None:
        cartoon_train_ds = cartoon_ds
    else:
        cartoon_train_ds = cartoon_train_ds.concatenate(cartoon_ds)

Found 10000 files belonging to 1 classes.
Found 10000 files belonging to 1 classes.
Found 10000 files belonging to 1 classes.
Found 10000 files belonging to 1 classes.
Found 10000 files belonging to 1 classes.


In [37]:
anime_label = to_categorical([0] * batch_size, num_classes=2)
cartoon_label = to_categorical([1] * batch_size, num_classes=2)

In [38]:
anime_train_ds = anime_train_ds.map(lambda x: (x, anime_label))
cartoon_train_ds = cartoon_train_ds.map(lambda x: (x, cartoon_label))

In [39]:
combined_ds = tf.data.Dataset.sample_from_datasets([anime_train_ds, cartoon_train_ds], [0.5, 0.5])

In [40]:
label_dim = 2
latent_dim = 128
filters = 32

In [41]:
def encoder(input_shape, label_dim, filters, latent_dim):
    x = Input(shape=input_shape)
    y = Input(shape=(label_dim,))
    
    y_reshaped = Dense(input_shape[0] * input_shape[1] * input_shape[2])(y)
    y_reshaped = Reshape((input_shape[0], input_shape[1], input_shape[2]))(y_reshaped)
    
    inputs = Concatenate(axis=-1)([x, y_reshaped])

    conv1 = Conv2D(filters, kernel_size=3, strides=2, activation="relu", padding="same")(inputs)
    conv2 = Conv2D(filters * 2, kernel_size=3, strides=2, activation="relu", padding="same")(conv1)
    conv3 = Conv2D(filters * 4, kernel_size=3, strides=2, activation="relu", padding="same")(conv2)
    conv4 = Conv2D(filters * 8, kernel_size=3, strides=2, activation="relu", padding="same")(conv3)

    flattened = Flatten()(conv4)
    z_mean = Dense(latent_dim)(flattened)
    z_log_var = Dense(latent_dim)(flattened)

    return Model([x, y], [z_mean, z_log_var], name="encoder")

In [42]:
def decoder(input_shape, label_dim, filters, latent_dim):
    z = Input(shape=(latent_dim,))
    y = Input(shape=(label_dim,))
    
    inputs = Concatenate()([z, y])

    hidden = Dense(16 * 16 * filters * 8, activation="relu")(inputs)
    reshaped = Reshape((16, 16, filters * 8))(hidden)

    deconv1 = Conv2DTranspose(filters * 8, kernel_size=3, strides=2, activation="relu", padding="same")(reshaped)
    deconv2 = Conv2DTranspose(filters * 4, kernel_size=3, strides=2, activation="relu", padding="same")(deconv1)
    deconv3 = Conv2DTranspose(filters * 2, kernel_size=3, strides=2, activation="relu", padding="same")(deconv2)
    deconv4 = Conv2DTranspose(filters, kernel_size=3, strides=2, activation="relu", padding="same")(deconv3)
    x_decoded = Conv2DTranspose(3, kernel_size=3, activation="sigmoid", padding="same")(deconv4)

    return Model([z, y], x_decoded, name="decoder")

In [43]:
def sampling(args):
    z_mean, z_log_var = args
    epsilon = tf.random.normal(shape=(tf.shape(z_mean)[0], latent_dim))
    return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [44]:
enc = encoder((img_height, img_width, 3), label_dim, filters, latent_dim)
dec = decoder((img_height, img_width, 3), label_dim, filters, latent_dim)

x = Input(shape=(img_height, img_width, 3))
y = Input(shape=(label_dim,))
z_mean, z_log_var = enc([x, y])
z = layers.Lambda(sampling)([z_mean, z_log_var])
x_decoded = dec([z, y])

cvae = Model([x, y], x_decoded, name="vae")

In [47]:
reconstruction_loss = MeanSquaredError()(x, x_decoded)
kl_loss = -0.5 * tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
cvae_loss = tf.reduce_mean(reconstruction_loss + kl_loss)

cvae.add_loss(cvae_loss)
cvae.compile(optimizer=Adam(learning_rate=0.001))

In [None]:
cvae.fit(combined_ds, epochs=10)

In [None]:
anime_sample = np.random.rand(1, latent_dim)
anime_label = to_categorical([0], num_classes=2)
anime_sample = dec.predict([anime_sample, anime_label])

plt.imshow(anime_sample[0])
plt.show()

In [None]:
cartoon_sample = np.random.rand(1, latent_dim)
cartoon_label = to_categorical([1], num_classes=2)
cartoon_sample = dec.predict([cartoon_sample, cartoon_label])

plt.imshow(cartoon_sample[0])
plt.show()