In [1]:
import numpy as np
import keras
from keras import ops
from keras import layers

In [10]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.seed_generator = keras.random.SeedGenerator(1337)

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = ops.shape(z_mean)[0]
        dim = ops.shape(z_mean)[1]
        epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)
        return z_mean + ops.exp(0.5 * z_log_var) * epsilon

In [11]:
class Encoder(layers.Layer):
    """Maps MNIST digits to a triplet (z_mean, z_log_var, z)."""

    def __init__(self, latent_dim=32, intermediate_dim=64, name="encoder", **kwargs):
        super().__init__(name=name, **kwargs)
        self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
        self.dense_mean = layers.Dense(latent_dim)
        self.dense_log_var = layers.Dense(latent_dim)
        self.sampling = Sampling()

    def call(self, inputs):
        x = self.dense_proj(inputs)
        z_mean = self.dense_mean(x)
        z_log_var = self.dense_log_var(x)
        z = self.sampling((z_mean, z_log_var))
        return z_mean, z_log_var, z

In [12]:
class Decoder(layers.Layer):
    """Converts z, the encoded digit vector, back into a readable digit."""

    def __init__(self, original_dim, intermediate_dim=64, name="decoder", **kwargs):
        super().__init__(name=name, **kwargs)
        self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
        self.dense_output = layers.Dense(original_dim, activation="sigmoid")

    def call(self, inputs):
        x = self.dense_proj(inputs)
        return self.dense_output(x)

In [13]:
class VariationalAutoEncoder(keras.Model):
    """Combines the encoder and decoder into an end-to-end model for training."""

    def __init__(
        self,
        original_dim,
        intermediate_dim=64,
        latent_dim=32,
        name="autoencoder",
        **kwargs
    ):
        super().__init__(name=name, **kwargs)
        self.original_dim = original_dim
        self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)
        self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)

    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        # Add KL divergence regularization loss.
        kl_loss = -0.5 * ops.mean(
            z_log_var - ops.square(z_mean) - ops.exp(z_log_var) + 1
        )
        self.add_loss(kl_loss)
        return reconstructed

In [9]:
(x_train, _), _ = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255

original_dim = 784
vae = VariationalAutoEncoder(784, 64, 32)

optimizer = keras.optimizers.Adam(learning_rate=1e-3)
vae.compile(optimizer, loss=keras.losses.MeanSquaredError())

vae.fit(x_train, x_train, epochs=20, batch_size=64, verbose=2)

Epoch 1/20
938/938 - 4s - 4ms/step - loss: 0.0744
Epoch 2/20
938/938 - 3s - 3ms/step - loss: 0.0676
Epoch 3/20
938/938 - 3s - 3ms/step - loss: 0.0676
Epoch 4/20
938/938 - 2s - 2ms/step - loss: 0.0675
Epoch 5/20
938/938 - 2s - 3ms/step - loss: 0.0675
Epoch 6/20
938/938 - 3s - 3ms/step - loss: 0.0675
Epoch 7/20
938/938 - 3s - 3ms/step - loss: 0.0674
Epoch 8/20
938/938 - 4s - 4ms/step - loss: 0.0674
Epoch 9/20
938/938 - 5s - 5ms/step - loss: 0.0674
Epoch 10/20
938/938 - 8s - 8ms/step - loss: 0.0674
Epoch 11/20
938/938 - 4s - 5ms/step - loss: 0.0674
Epoch 12/20
938/938 - 5s - 5ms/step - loss: 0.0674
Epoch 13/20
938/938 - 5s - 5ms/step - loss: 0.0674
Epoch 14/20
938/938 - 4s - 4ms/step - loss: 0.0673
Epoch 15/20
938/938 - 3s - 3ms/step - loss: 0.0673
Epoch 16/20
938/938 - 3s - 3ms/step - loss: 0.0673
Epoch 17/20
938/938 - 2s - 3ms/step - loss: 0.0673
Epoch 18/20
938/938 - 2s - 2ms/step - loss: 0.0673
Epoch 19/20
938/938 - 2s - 3ms/step - loss: 0.0673
Epoch 20/20
938/938 - 2s - 3ms/step - lo

<keras.src.callbacks.history.History at 0x19110f175c0>