<a href="https://colab.research.google.com/github/IverMartinsen/colab_notebooks/blob/main/vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from IPython import display

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import PIL
import tensorflow as tf
import tensorflow_probability as tfp
import time

# load images as numpy arrays with shape (n, x, y)
(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()

def preprocess_images(images):
  # reshape, normalize and threshold images
  n, x, y = images.shape
  images = images.reshape(n, x, y, 1) / 255.
  return np.where(images > .5, 1.0, .0).astype('float32')

train_images = preprocess_images(train_images)
test_images = preprocess_images(test_images)

train_size = train_images.shape[0]
test_size = test.images.shape[0]
batch_size = 32

# create datasets
train_dataset = (tf.data.Dataset.from_tensor_slices(train_images).
                 shuffle(train_size).batch(batch_size))
test_dataset = (tf.data.Dataset.from_tensor_slices(test_images)
                .shuffle(test_size).batch(batch_size))

class CVAE(tf.keras.Model):
  """Convolutional variational autoencoder."""

  def __init__(self, latent_dim):
    super(CVAE, self).__init__()
    self.latent_dim = latent_dim
    self.encoder = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape = (28, 28, 1)),
            tf.keras.layers.Conv2D(
                filters = 32, kernel_size = 3, strides = (2, 2), activation = 'relu'),
            tf.keras.layers.Conv2D(
                filters = 64, kernel_size = 3, strides = (2, 2), activation = 'relu'),
            tf.keras.layers.Flatten(),
            # No activation
            tf.keras.layers.Dense(latent_dim + latent_dim),
        ]
    )

    self.decoder = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape = (latent_dim, )),
            tf.keras.layers.Dense(units = 7*7*32, activation = tf.nn.relu),
            tf.keras.layers.Reshape(target_shape = (7, 7, 32)),
            tf.keras.layers.Conv2DTranspose(
                filters = 64, kernel_size = 3, strides = 2, padding = 'same',
                activation = 'relu'),
            tf.keras.layers.Conv2DTranspose(
                filters = 32, kernel_size = 3, strides = 2, padding = 'same',
                activation = 'relu'),
            # No activation
            tf.keras.layers.Conv2DTranspose(
                filters = 1, kernel_size = 3, strides = 1, padding = 'same'),
        ]
    )

  @tf.function
  def sample(self, eps = None):
    if eps is None:
      eps = tf.random.normal(shape = (100, self.latent_dim))
    return self.decode(eps, apply_sigmoid = True)

  def encode(self, x):
    mean, logvar = tf.split(self.encoder(x), num_or_size_splits = 2, axis = 1)
    return mean, logvar

  def reparameterize(self, mean, logvar):
    eps = tf.random.normal(shape = mean.shape)
    return eps * tf.exp(logvar * .5) + mean

  def decode(self, z, apply_sigmoid = False):
    logits = self.decoder(z)
    if apply_sigmoid:
      probs = tf.sigmoid(logits)
      return probs
    return logits

RuntimeError: ignored

In [None]:
test_images.shape

(10000, 28, 28)