## Setup

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

print(tf.__version__)

device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

## Create a sampling layer

In [None]:

class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

## Build the encoder

In [None]:
latent_dim = 25

encoder_inputs = keras.Input(shape=(32, 32, 1))
x = layers.Conv2D(32, 4, strides=2, padding="same")(encoder_inputs)
x = layers.BatchNormalization(axis=-1)(x)
x = layers.Activation("relu")(x)

x = layers.Conv2D(64, 4, strides=2, padding="same")(x)
x = layers.BatchNormalization(axis=-1)(x)
x = layers.Activation("relu")(x)

x = layers.Conv2D(128, 4, strides=2, padding="same")(x)
x = layers.BatchNormalization(axis=-1)(x)
x = layers.Activation("relu")(x)

x = layers.Flatten()(x)
x = layers.Dense(256, activation="relu")(x)

z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

## Build the decoder

In [None]:
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(4 * 4 * 256, activation="relu")(latent_inputs)
x = layers.Reshape((4, 4, 256))(x)

x = layers.Conv2DTranspose(128, 4, strides=2, padding="same")(x)
x = layers.BatchNormalization(axis=-1)(x)
x = layers.Activation("relu")(x)

x = layers.Conv2DTranspose(64, 4, strides=2, padding="same")(x)
x = layers.BatchNormalization(axis=-1)(x)
x = layers.Activation("relu")(x)

x = layers.Conv2DTranspose(32, 4, strides=2, padding="same")(x)
x = layers.BatchNormalization(axis=-1)(x)
x = layers.Activation("relu")(x)

decoder_outputs = layers.Conv2DTranspose(1, 1, strides=1, activation="sigmoid", padding="valid")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

## Define the VAE

In [None]:

class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

## Train the VAE

In [None]:
(x_train, _), (_, _) = keras.datasets.fashion_mnist.load_data()
x_train = np.pad(x_train, ((0,0),(2,2),(2,2)), 'constant')
x_train = np.expand_dims(x_train, -1).astype("float32") / 255

vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())
vae.fit(x_train, epochs=100, batch_size=32)

## Save model

In [None]:
#path = "/content/drive/MyDrive/"
#tf.keras.models.save_model(vae.encoder, path+"VAE_encoder_fashion_mnist")
#tf.keras.models.save_model(vae.decoder, path+"VAE_decoder_fashion_mnist")