# MNIST autoencoder vs variational autoencoder

### Aims of the notebook

Demonstrate the creation of variational autoencoder us tensorflow probability on the mnist digits dataset


conda environment definition
```
name: vae
channels:
  - defaults
dependencies:
  - python=3.8
  - tensorflow=2.7
  - matplotlib=3.4
  - pandas=1.3

```

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

RuntimeError: jaxlib is version 0.3.14, but this version of jax requires version >= 0.3.15.

In [None]:


# Define an autoencoder
class Autoencoder(keras.Model):
    def __init__(self, latent_dim):
        super(Autoencoder, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = keras.Sequential([
            layers.Flatten(),
            layers.Dense(256, activation='relu'),
            layers.Dense(128, activation='relu'),
            layers.Dense(latent_dim, activation='relu'),
        ])
        self.decoder = keras.Sequential([
            layers.Dense(128, activation='relu'),
            layers.Dense(256, activation='relu'),
            layers.Dense(784, activation='sigmoid'),
            layers.Reshape((28, 28))
        ])

    def call(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

# Define a variational autoencoder
class VariationalAutoencoder(keras.Model):
    def __init__(self, latent_dim):
        super(VariationalAutoencoder, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = keras.Sequential([
            layers.Flatten(),
            layers.Dense(256, activation='relu'),
            layers.Dense(128, activation='relu'),
            layers.Dense(latent_dim + latent_dim) # Output mean and variance
        ])
        self.decoder = keras.Sequential([
            layers.Dense(128, activation='relu'),
            layers.Dense(256, activation='relu'),
            layers.Dense(784, activation='sigmoid'),
            layers.Reshape((28, 28))
        ])

    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar

    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * 0.5) + mean

    def call(self, x):
        mean, logvar = self.encode(x)
        encoded = self.reparameterize(mean, logvar)
        decoded = self.decoder(encoded)
        return decoded

    def sample(self, eps=None):
        if eps is None:
            eps = tf.random.normal(shape=(100, self.latent_dim))
        return self.decode(eps)

    def decode(self, z, apply_sigmoid=False):
        logits = self.decoder(z)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs
        return logits

# Load MNIST dataset
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = tf.expand_dims(x_train, -1)
x_test = tf.expand_dims(x_test, -1)

# Instantiate models and optimizers
latent_dim = 32
autoencoder = Autoencoder(latent_dim)
vae = VariationalAutoencoder(latent_dim)
autoencoder_optimizer = keras.optimizers.Adam()
vae_optimizer = keras.optimizers.Adam()

# Define loss functions
def autoencoder_loss(x, x_decoded):
    return keras.losses.mean_squared_error(x, x_decoded)

def vae_loss(x, x_decoded, mean, logvar):
    reconstruction_loss = keras.losses.binary_crossentropy(x, x_decoded)
    reconstruction_loss = tf.reduce_mean(reconstruction_loss)
    kl_divergence = -0.5 * tf.reduce_sum(1 + logvar - tf.square(mean) - tf.exp(log
