In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np 

In [None]:
images = np.load('anime.npy') 

In [None]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon
    
    
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder


    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = encoder(data)
            reconstruction = decoder(z)
            reconstruction_loss = tf.reduce_mean(
                keras.losses.mse(data, reconstruction)
            )
            reconstruction_loss *= 64 * 64 * 3
            kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
            kl_loss = tf.reduce_mean(kl_loss)
            kl_loss *= -0.5
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }

In [None]:
latent_dim = (32 * 32 * 3)

encoder_inputs = keras.Input(shape=(64, 64, 3))
x = layers.Conv2D(32, 8, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 6, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2D(64, 4, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2D(128, 4, activation="relu", strides=2, padding="same")(x)

x = layers.Flatten()(x)
x = layers.Dense(latent_dim, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

In [None]:
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(latent_dim, activation="relu")(latent_inputs)
x = layers.Reshape((32, 32, 3))(x)
x = layers.Conv2DTranspose(128, 4, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(64, 8, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2D(32, 6, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(3, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

In [None]:
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())
vae.fit(images, epochs=30, batch_size=128)

In [None]:
encoded = encoder.predict(images[3].reshape(1,64,64,3))
decoded = decoder.predict(encoded[2])
plt.imshow(decoded.reshape(64,64,3))

In [None]:
plt.imshow(images[3])

In [None]:
encoder.save('vaeenc.h5')
decoder.save('vaedec.h5')