In [3]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Lambda
from tensorflow.keras.models import Model
from tensorflow.keras.losses import mse
import numpy as np
import matplotlib.pyplot as plt

# Hyperparameters
latent_dim = 2
input_dim = 28 * 28

# Load dataset
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = x_train.reshape(-1, input_dim)
x_test = x_test.reshape(-1, input_dim)

# Encoder
inputs = Input(shape=(input_dim,))
x = Dense(512, activation='relu')(inputs)
mu = Dense(latent_dim)(x)
log_var = Dense(latent_dim)(x)

def sampling(args):
    mu, log_var = args
    epsilon = tf.keras.backend.random_normal(shape=(tf.shape(mu)[0], latent_dim))
    return mu + tf.exp(0.5 * log_var) * epsilon

z = Lambda(sampling, output_shape=(latent_dim,))([mu, log_var])

# Decoder
decoder_input = Input(shape=(latent_dim,))
decoder_hidden = Dense(512, activation='relu')(decoder_input)
decoder_output = Dense(input_dim, activation='sigmoid')(decoder_hidden)
decoder = Model(decoder_input, decoder_output)
x_recon = decoder(z)

# VAE Model
vae = Model(inputs, x_recon)

# Loss Function
def vae_loss(x, x_recon):
    recon_loss = mse(x, x_recon) * input_dim
    kl_loss = -0.5 * tf.reduce_sum(1 + log_var - tf.square(mu) - tf.exp(log_var), axis=-1)
    return tf.reduce_mean(recon_loss + kl_loss)

vae.add_loss(vae_loss(inputs, x_recon))
vae.compile(optimizer=tf.keras.optimizers.Adam(1e-3))

# Train Model
vae.fit(x_train, x_train, epochs=10, batch_size=128, validation_data=(x_test, x_test))

# Generate New Samples
def generate_images(grid_size=10):
    grid_x = np.linspace(-3, 3, grid_size)
    grid_y = np.linspace(-3, 3, grid_size)
    fig, axs = plt.subplots(grid_size, grid_size, figsize=(10, 10))

    for i, xi in enumerate(grid_x):
        for j, yj in enumerate(grid_y):
            z_sample = np.array([[xi, yj]])
            x_decoded = decoder.predict(z_sample)
            digit = x_decoded.reshape(28, 28)
            axs[i, j].imshow(digit, cmap='gray')
            axs[i, j].axis('off')

    plt.show()

generate_images()

ValueError: A KerasTensor cannot be used as input to a TensorFlow function. A KerasTensor is a symbolic placeholder for a shape and dtype, used when constructing Keras Functional models or Keras Functions. You can only use it as input to a Keras layer or a Keras operation (from the namespaces `keras.layers` and `keras.operations`). You are likely doing something like:

```
x = Input(...)
...
tf_fn(x)  # Invalid.
```

What you should do instead is wrap `tf_fn` in a layer:

```
class MyLayer(Layer):
    def call(self, x):
        return tf_fn(x)

x = MyLayer()(x)
```
