# Training the VAE

## Setup

In [None]:
import os

os.environ["KERAS_BACKEND"] = "tensorflow"
# import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import keras
from keras import layers, ops
from data import vae_data  # hurray for modularity!

## Create a custom sampling layer and VAE model

In [None]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def __init__(self, kl_loss_factor=1, seed=None, **kwargs):
        super().__init__(**kwargs)
        self.seed_generator = keras.random.SeedGenerator(seed=seed)
        self.kl_loss_factor = kl_loss_factor

    def call(self, inputs):
        z_mean, z_log_var = inputs
        # add loss
        kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
        self.add_loss(kl_loss * self.kl_loss_factor)
        # sample
        batch = ops.shape(z_mean)[0]
        dim = ops.shape(z_mean)[1]
        epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)
        return z_mean + ops.exp(0.5 * z_log_var) * epsilon


class VAE(keras.Model):
    """Wraps an encoder and decoder into a single variational autoencoder model."""

    def __init__(self, encoder, decoder, **kwargs):
        # TODO hacky?
        vae_inputs = keras.Input(shape=encoder.input.shape[1:], name="vae_inputs")
        vae_outputs = decoder(encoder(vae_inputs))
        super().__init__(inputs=vae_inputs, outputs=vae_outputs, **kwargs)
        self.encoder = encoder
        self.decoder = decoder

## Build the models

In [None]:
input_shape = (6,)
latent_dim = 2

In [None]:
encoder_inputs = keras.Input(shape=input_shape, name="encoder_inputs")

x = layers.Dense(5, activation="relu")(encoder_inputs)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)

# technically maybe this should be part of the decoder,
# but right now we have bigger problems (nan loss)
encoder_outputs = Sampling(kl_loss_factor=1, seed=489, name="encoder_outputs")([
    z_mean,
    z_log_var,
])

encoder = keras.Model(inputs=encoder_inputs, outputs=encoder_outputs, name="encoder")
encoder.summary()

In [None]:
decoder_inputs = keras.Input(shape=(latent_dim,), name="decoder_inputs")

x = layers.Dense(5, activation="relu")(decoder_inputs)
decoder_outputs = layers.Dense(6, name="decoder_outputs")(x)

decoder = keras.Model(inputs=decoder_inputs, outputs=decoder_outputs, name="decoder")
decoder.summary()

In [None]:
vae = VAE(encoder=encoder, decoder=decoder, name="vae")
vae.summary()

## Train the VAE

In [None]:
train, valid, test = vae_data

In [None]:
vae.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-5, clipnorm=1.0),
    loss="mean_squared_error",
    # metrics=["mean_squared_error"],
)
# note: still shows nan :(
history = vae.fit(train, epochs=30, validation_data=valid)

In [None]:
test_stats = vae.evaluate(test, return_dict=True)