In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import layers, Model
from tensorflow.keras.datasets import mnist

2024-05-31 23:36:16.229352: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [22]:
def VAE(input_shape=(784,),
        n_components_encoder=2048,
        n_components_decoder=2048,
        n_hidden=2,
        debug=False):
    
    # Input placeholder
    if debug:
        input_shape = (50, 784)
        x = tf.Variable(np.zeros(input_shape, dtype=np.float32))
    else:
        x = tf.keras.Input(shape=input_shape)
    
    activation = tf.nn.softplus

    n_features = input_shape[0]
    
    # Encoder
    h_enc1 = layers.Dense(n_components_encoder, activation=activation)(x)
    h_enc2 = layers.Dense(n_components_encoder, activation=activation)(h_enc1)
    h_enc3 = layers.Dense(n_components_encoder, activation=activation)(h_enc2)
    
    z_mu = layers.Dense(n_hidden)(h_enc3)
    z_log_sigma = 0.5 * layers.Dense(n_hidden)(h_enc3)
    
    # Sample from noise distribution p(eps) ~ N(0, 1)
    epsilon = tf.random.normal(tf.shape(x)[:-1] + (n_hidden,))
    
    # Sample from posterior
    z = z_mu + tf.exp(z_log_sigma) * epsilon
    
    # Decoder
    h_dec1 = layers.Dense(n_components_decoder, activation=activation)(z)
    h_dec2 = layers.Dense(n_components_decoder, activation=activation)(h_dec1)
    h_dec3 = layers.Dense(n_components_decoder, activation=activation)(h_dec2)
    
    y = layers.Dense(n_features, activation='tanh')(h_dec3)
    
    # Reconstruction loss
    log_px_given_z = -tf.reduce_sum(
        x * tf.math.log(y + 1e-10) +
        (1 - x) * tf.math.log(1 - y + 1e-10), axis=1)
    
    # KL divergence
    kl_div = -0.5 * tf.reduce_sum(
        1.0 + 2.0 * z_log_sigma - tf.square(z_mu) - tf.exp(2.0 * z_log_sigma), axis=1)
    
    # Total loss
    loss = tf.reduce_mean(log_px_given_z + kl_div)
    
    return {'cost': loss, 'x': x, 'z': z, 'y': y}

In [23]:
def preprocess_data(x, y):
    x = x.reshape((-1, 784)).astype('float32') / 255.0
    y = tf.keras.utils.to_categorical(y, 10)
    return x, y

In [24]:
def test_mnist():
    # Load MNIST data
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, y_train = preprocess_data(x_train, y_train)
    x_test, y_test = preprocess_data(x_test, y_test)

    ae = VAE()
    optimizer = tf.keras.optimizers.Adam()

    # Training loop
    batch_size = 100
    n_epochs = 50
    for epoch in range(n_epochs):
        print('Epoch', epoch)
        for batch_i in range(len(x_train) // batch_size):
            batch_xs = x_train[batch_i * batch_size:(batch_i + 1) * batch_size]

            with tf.GradientTape() as tape:
                loss = ae['cost']
            gradients = tape.gradient(loss, ae.trainable_variables)
            optimizer.apply_gradients(zip(gradients, ae.trainable_variables))

        print('Train Loss:', loss.numpy())

    # Validation
    valid_loss = tf.reduce_mean([ae['cost'](x) for x in x_test])
    print('Validation Loss:', valid_loss.numpy())

In [None]:
if __name__ == '__main__':
    test_mnist()