In [1]:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
from sklearn.manifold import TSNE

# Load MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize the images to [0, 1] range
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Flatten the images (28x28 -> 784)
x_train = x_train.reshape(-1, 28 * 28)
x_test = x_test.reshape(-1, 28 * 28)

def build_encoder(latent_dim=10):
    encoder_input = layers.Input(shape=(28 * 28,))
    x = layers.Dense(512, activation='relu')(encoder_input)
    x = layers.Dense(256, activation='relu')(x)
    z_mean = layers.Dense(latent_dim, name='z_mean')(x)
    z_log_var = layers.Dense(latent_dim, name='z_log_var')(x)
    encoder = models.Model(encoder_input, [z_mean, z_log_var], name="encoder")
    return encoder

def build_decoder(latent_dim=10):
    latent_input = layers.Input(shape=(latent_dim,))
    x = layers.Dense(256, activation='relu')(latent_input)
    x = layers.Dense(512, activation='relu')(x)
    decoder_output = layers.Dense(28 * 28, activation='sigmoid')(x)
    decoder = models.Model(latent_input, decoder_output, name="decoder")
    return decoder

class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

def build_beta_vae(latent_dim=10, beta=4):
    encoder = build_encoder(latent_dim)
    decoder = build_decoder(latent_dim)
    
    inputs = layers.Input(shape=(28 * 28,))
    z_mean, z_log_var = encoder(inputs)
    z = Sampling()([z_mean, z_log_var])
    reconstructed = decoder(z)
    
    vae = models.Model(inputs, reconstructed, name="beta_vae")
    
    # Define loss
    reconstruction_loss = tf.reduce_mean(
        tf.reduce_sum(tf.keras.losses.binary_crossentropy(inputs, reconstructed), axis=-1)
    )
    kl_loss = - 0.5 * tf.reduce_mean(
        tf.reduce_sum(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=-1)
    )
    vae_loss = reconstruction_loss + beta * kl_loss
    vae.add_loss(vae_loss)
    vae.compile(optimizer='adam')
    
    return vae

def total_correlation_loss(z_mean, z_log_var):
    q_z = tf.distributions.Normal(loc=z_mean, scale=tf.exp(0.5 * z_log_var))
    log_q_z = q_z.log_prob(z_mean)
    return tf.reduce_mean(log_q_z)

def build_factorvae(latent_dim=10, gamma=10):
    encoder = build_encoder(latent_dim)
    decoder = build_decoder(latent_dim)
    
    inputs = layers.Input(shape=(28 * 28,))
    z_mean, z_log_var = encoder(inputs)
    z = Sampling()([z_mean, z_log_var])
    reconstructed = decoder(z)
    
    factorvae = models.Model(inputs, reconstructed, name="factorvae")
    
    # Loss function
    reconstruction_loss = tf.reduce_mean(
        tf.reduce_sum(tf.keras.losses.binary_crossentropy(inputs, reconstructed), axis=-1)
    )
    kl_loss = - 0.5 * tf.reduce_mean(
        tf.reduce_sum(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=-1)
    )
    tc_loss = total_correlation_loss(z_mean, z_log_var)
    factorvae_loss = reconstruction_loss + kl_loss + gamma * tc_loss
    factorvae.add_loss(factorvae_loss)
    factorvae.compile(optimizer='adam')
    
    return factorvae

# Build and train the \(eta\)-VAE model
beta_vae = build_beta_vae(latent_dim=10, beta=4)
beta_vae.fit(x_train, epochs=20, batch_size=128, validation_data=(x_test, x_test))

# Build and train the FactorVAE model
factorvae = build_factorvae(latent_dim=10, gamma=10)
factorvae.fit(x_train, epochs=20, batch_size=128, validation_data=(x_test, x_test))





ValueError: A KerasTensor cannot be used as input to a TensorFlow function. A KerasTensor is a symbolic placeholder for a shape and dtype, used when constructing Keras Functional models or Keras Functions. You can only use it as input to a Keras layer or a Keras operation (from the namespaces `keras.layers` and `keras.operations`). You are likely doing something like:

```
x = Input(...)
...
tf_fn(x)  # Invalid.
```

What you should do instead is wrap `tf_fn` in a layer:

```
class MyLayer(Layer):
    def call(self, x):
        return tf_fn(x)

x = MyLayer()(x)
```
