In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Lambda, BatchNormalization, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.losses import mse
import numpy as np
import matplotlib.pyplot as plt

# Define Hyperparameters
batch_size = 128
epochs = 30
learning_rate = 1e-3
beta = 4  # For Beta-VAE
latent_dim = 20

# Load Dataset (MNIST as Example)
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = x_train.reshape(-1, 28*28)
x_test = x_test.reshape(-1, 28*28)

# Encoder
inputs = Input(shape=(28*28,))
x = Dense(512)(inputs)
x = BatchNormalization()(x)
x = tf.keras.activations.relu(x)
x = Dropout(0.3)(x)
x = Dense(256, activation='relu')(x)
mu = Dense(latent_dim)(x)
logvar = Dense(latent_dim)(x)

# Reparameterization Trick
def sampling(args):
    mu, logvar = args
    batch = tf.shape(mu)[0]
    dim = tf.shape(mu)[1]
    epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
    return mu + tf.exp(0.5 * logvar) * epsilon

z = Lambda(sampling, output_shape=(latent_dim,))([mu, logvar])

# Decoder
decoder_input = Input(shape=(latent_dim,))
decoder_hidden = Dense(256, activation='relu')(decoder_input)
decoder_hidden = Dense(512, activation='relu')(decoder_hidden)
decoder_output = Dense(28*28, activation='sigmoid')(decoder_hidden)
decoder = Model(decoder_input, decoder_output)
x_recon = decoder(z)

# VAE Model
vae = Model(inputs, x_recon)

# Loss Function
def vae_loss(x, x_recon):
    recon_loss = mse(x, x_recon) * 28 * 28
    kl_loss = -0.5 * tf.reduce_sum(1 + logvar - tf.square(mu) - tf.exp(logvar), axis=-1)
    return tf.reduce_mean(recon_loss + beta * kl_loss)

vae.add_loss(vae_loss(inputs, x_recon))
vae.compile(optimizer=tf.keras.optimizers.Adam(learning_rate))

# Train the Model
vae.fit(x_train, x_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test, x_test))

# Save Model
vae.save("vae_model.h5")


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


ValueError: A KerasTensor cannot be used as input to a TensorFlow function. A KerasTensor is a symbolic placeholder for a shape and dtype, used when constructing Keras Functional models or Keras Functions. You can only use it as input to a Keras layer or a Keras operation (from the namespaces `keras.layers` and `keras.operations`). You are likely doing something like:

```
x = Input(...)
...
tf_fn(x)  # Invalid.
```

What you should do instead is wrap `tf_fn` in a layer:

```
class MyLayer(Layer):
    def call(self, x):
        return tf_fn(x)

x = MyLayer()(x)
```
