In [1]:
# Import TensorFlow for model building and training
import tensorflow as tf
from tensorflow.keras import layers


In [2]:
class HierarchicalVAE(tf.keras.Model):
    # Initialize model with input and latent dimensions
    def __init__(self, input_dim, latent_dim_1, latent_dim_2):
        super(HierarchicalVAE, self).__init__()

        # Encoder for the first level, converting input to first latent representation
        self.encoder_level1 = tf.keras.Sequential([
            layers.InputLayer(input_shape=input_dim),
            layers.Dense(256, activation='relu'),  # First hidden layer
            layers.Dense(128, activation='relu'),  # Second hidden layer
            layers.Dense(latent_dim_1 * 2)         # Output for mean and log variance
        ])

        # Encoder for the second level, converting first latent to deeper latent representation
        self.encoder_level2 = tf.keras.Sequential([
            layers.InputLayer(input_shape=(latent_dim_1,)),
            layers.Dense(64, activation='relu'),    # Hidden layer for latent processing
            layers.Dense(latent_dim_2 * 2)          # Output for mean and log variance
        ])

        # Decoder for the second level, reconstructing latent_dim_1 from latent_dim_2
        self.decoder_level2 = tf.keras.Sequential([
            layers.InputLayer(input_shape=(latent_dim_2,)),
            layers.Dense(64, activation='relu'),    # Hidden layer for decoding
            layers.Dense(latent_dim_1)              # Reconstructed latent_dim_1 output
        ])

        # Decoder for the first level, reconstructing the input from latent_dim_1
        self.decoder_level1 = tf.keras.Sequential([
            layers.InputLayer(input_shape=(latent_dim_1,)),
            layers.Dense(128, activation='relu'),   # First hidden layer
            layers.Dense(256, activation='relu'),   # Second hidden layer
            layers.Dense(input_dim, activation='sigmoid')  # Final output layer
        ])


In [3]:
    def sample(self, mean, log_var):
        """
        Reparameterization trick to sample from N(mean, var) using N(0,1).
        """
        epsilon = tf.random.normal(shape=tf.shape(mean))  # Sample from standard normal
        return mean + tf.exp(0.5 * log_var) * epsilon     # Scale and shift to match mean and log_var


In [4]:
    def encode(self, x):
        """
        Encodes input x into hierarchical latent representations at two levels.
        """
        # Level 1 encoding
        h1 = self.encoder_level1(x)
        mean1, log_var1 = tf.split(h1, num_or_size_splits=2, axis=1)  # Split into mean and log variance
        z1 = self.sample(mean1, log_var1)  # Sample first latent z1

        # Level 2 encoding
        h2 = self.encoder_level2(z1)
        mean2, log_var2 = tf.split(h2, num_or_size_splits=2, axis=1)  # Split for mean and log variance
        z2 = self.sample(mean2, log_var2)  # Sample second latent z2

        return mean1, log_var1, z1, mean2, log_var2, z2


In [5]:
    def decode(self, z1, z2):
        """
        Decodes hierarchical latents z2 and z1 back to the input space.
        """
        z1_reconstructed = self.decoder_level2(z2)  # Reconstruct z1 from z2
        x_reconstructed = self.decoder_level1(z1_reconstructed)  # Reconstruct input from z1
        return x_reconstructed


In [6]:
    def call(self, x):
        mean1, log_var1, z1, mean2, log_var2, z2 = self.encode(x)
        x_reconstructed = self.decode(z1, z2)
        return x_reconstructed, mean1, log_var1, mean2, log_var2


In [7]:
def compute_loss(model, x):
    # Forward pass to get reconstructed x and latent parameters
    x_reconstructed, mean1, log_var1, mean2, log_var2 = model(x)

    # Binary cross-entropy for reconstruction loss
    reconstruction_loss = tf.reduce_mean(
        tf.reduce_sum(tf.keras.losses.binary_crossentropy(x, x_reconstructed), axis=1)
    )

    # KL divergence for Level 1 latent space
    kl_div1 = -0.5 * tf.reduce_sum(1 + log_var1 - tf.square(mean1) - tf.exp(log_var1), axis=1)

    # KL divergence for Level 2 latent space
    kl_div2 = -0.5 * tf.reduce_sum(1 + log_var2 - tf.square(mean2) - tf.exp(log_var2), axis=1)

    kl_loss = tf.reduce_mean(kl_div1 + kl_div2)  # Total KL loss for both levels
    return reconstruction_loss + kl_loss         # Total VAE loss


In [8]:
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

@tf.function
def train_step(model, x):
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss


In [9]:
def train_model(model, dataset, epochs=10):
    for epoch in range(epochs):
        for step, x_batch in enumerate(dataset):
            loss = train_step(model, x_batch)
        print(f'Epoch {epoch + 1}, Loss: {loss.numpy()}')
