In [1]:
import random

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow_probability import distributions as tfd
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [2]:
tf.debugging.enable_check_numerics()

INFO:tensorflow:Enabled check-numerics callback in thread MainThread


In [3]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [4]:
#### HELPERS ####

def coalesce_none(*args):
    if len(args) == 0:
        raise ValueError('coalesce_none expects one or more positional arguments')
    
    for value in args:
        if value is not None:
            return value
    
    return args[-1]

In [5]:
#### VAE ####

def compute_loss(model, data, latent_vars):
    mean, logvar = model.encode(data)

    kl = 0.5 * tf.reduce_sum(
        1.0 + logvar - tf.math.square(mean) - tf.math.exp(logvar),
        axis=-1)
    
    data_logits = model.decode(latent_vars)
    cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=data_logits, labels=data)
    logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
    
    return -tf.reduce_mean(kl + logpx_z)


@tf.function
def train_step(model, optimizer, data):
    latent_vars = model.sample_latent_posterior(data)
    with tf.GradientTape() as tape:
        loss = compute_loss(model, data, latent_vars)
        tf.debugging.check_numerics(loss, 'loss')
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(
            zip(gradients, model.trainable_variables))
    return loss


def train_vae(model, dataset, num_epoch, optimizer):
    loss_metric = keras.metrics.Mean('loss', dtype=tf.float32)
    
    epoch_logs = []
    for epoch_idx in tqdm(range(num_epoch)):
        for batch_data in dataset:
            loss = train_step(model, optimizer, batch_data)
            
            loss_metric(loss)
        
        epoch_log_row = {'loss_mean': loss_metric.result().numpy()}
        epoch_logs.append(epoch_log_row)
    
    return epoch_logs


class VariationalAutoEncoder(tf.Module):
    
    def __init__(self, data_shape, num_latent_dims):
        self._num_latent_dims = num_latent_dims
        self._encoder = self._make_encoder(data_shape, num_latent_dims)
        self._encoder.summary()
        self._decoder = self._make_decoder(num_latent_dims)
        self._decoder.summary()
        
        self._prior = tfd.MultivariateNormalDiag(
            loc=tf.zeros(num_latent_dims), 
            scale_diag=tf.ones(num_latent_dims))
        
    def decode(self, latent_vars):
        return self._decoder(latent_vars)
        
    def encode(self, data):
        params = self._encoder(data)
        mean = params[..., :self._num_latent_dims]
        logvar = params[..., self._num_latent_dims:]
        return mean, logvar
    
    def generate_data(self, batch_size):
        latent_vars = self.sample_latent_prior(batch_size)
        return self.generate_data_from_latent(latent_vars)
    
    def generate_data_from_latent(self, latent_vars):
        logits = self.decode(latent_vars)
        probs = tf.math.sigmoid(logits)
        unif_vals = tf.random.uniform(logits.shape)
        fake_data = unif_vals <= probs
        return fake_data, probs
    
    def sample_latent_prior(self, batch_size):
        return tf.random.normal([batch_size, self._num_latent_dims])
        
    def sample_latent_posterior(self, data):
        mean, logvar = self.encode(data)
        noise = tf.random.normal(mean.shape)
        return mean + tf.math.sqrt(tf.math.exp(logvar)) * noise
    
    def _make_decoder(self, num_latent_dims):
        model = keras.Sequential([
              layers.Input(shape=[num_latent_dims]),
              layers.Dense(units=7*7*32, activation=tf.nn.relu),
              layers.Reshape(target_shape=(7, 7, 32)),
              layers.Conv2DTranspose(
                  filters=64,
                  kernel_size=3,
                  strides=(2, 2),
                  padding='SAME',
                  activation='relu'),
              layers.Conv2DTranspose(
                  filters=32,
                  kernel_size=3,
                  strides=(2, 2),
                  padding='SAME',
                  activation='relu'),
              # No activation
              layers.Conv2DTranspose(
                  filters=1, kernel_size=3, strides=(1, 1), padding='SAME'),
        ])
        return model
    
    def _make_encoder(self, data_shape, num_latent_dims):
        model = keras.Sequential([
            layers.Input(shape=data_shape, dtype='float32'),
            layers.Conv2D(
                  filters=32, kernel_size=3, strides=(2, 2), activation='relu'),
            layers.Conv2D(
                  filters=64, kernel_size=3, strides=(2, 2), activation='relu'),
            layers.Flatten(),
            # No activation
            layers.Dense(2 * num_latent_dims)
        ])
        return model
    
    def _sample_latent_prior(self, batch_size):
        return self._prior.sample(batch_size)
        

In [6]:
NUM_EPOCH = 100
LATENT_DIM = 50
TRAIN_BUF = 20000 # 60000
BATCH_SIZE = 64
TEST_BUF = 10000
ONLY_INCLUDE_DIGITS = [1,3,8]

LEARNING_RATE = 1e-2

DATA_SHAPE = [28, 28, 1]

In [7]:
def preprocess_mnist_images_and_labels(images, labels, max_obs=None):
    if ONLY_INCLUDE_DIGITS is not None and len(ONLY_INCLUDE_DIGITS) > 0:
        images = images[np.isin(labels, ONLY_INCLUDE_DIGITS)]
    
    np.random.shuffle(images)
    np.random.shuffle(labels)
    
    num_obs = int(min(images.shape[0], coalesce_none(max_obs, float('inf'))))
    images = images[:num_obs]
    labels = labels[:num_obs]
    
    images = images \
        .reshape([images.shape[0]] + DATA_SHAPE) \
        .astype('float32')
    
    # Normalizing the images to the range of [0., 1.]
    images /= 255.0
    
    # Binarization
    images[images >= 0.5] = 1.0
    images[images < 0.5] = 0.0
    images[images >= 0.5] = 1.0
    images[images < 0.5] = 0.0
    
    image_dataset = tf.data.Dataset.from_tensor_slices(images) \
        .shuffle(num_obs) \
        .batch(BATCH_SIZE)
    
    return image_dataset, images, labels


# MNIST

In [8]:
mnist_data_raw = tf.keras.datasets.mnist.load_data()

mnist_train_dataset, _, _ = preprocess_mnist_images_and_labels(
        mnist_data_raw[0][0], mnist_data_raw[0][1], max_obs=TRAIN_BUF)
mnist_test_dataset, _, _ = preprocess_mnist_images_and_labels(
        mnist_data_raw[1][0], mnist_data_raw[1][1],max_obs=TEST_BUF)

mnist_train_dataset = mnist_train_dataset.prefetch(1)

del mnist_data_raw

In [9]:
mnist_vae = VariationalAutoEncoder(
    data_shape=DATA_SHAPE, 
    num_latent_dims=LATENT_DIM)
mnist_optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 13, 13, 32)        320       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 6, 6, 64)          18496     
_________________________________________________________________
flatten (Flatten)            (None, 2304)              0         
_________________________________________________________________
dense (Dense)                (None, 100)               230500    
Total params: 249,316
Trainable params: 249,316
Non-trainable params: 0
_________________________________________________________________
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 1568)              79968     
____________________________

In [10]:
mnist_logs = train_vae(
    model=mnist_vae, 
    dataset=mnist_train_dataset, 
    num_epoch=NUM_EPOCH,
    optimizer=mnist_optimizer)

  2%|▉                                            | 2/100 [01:07<54:01, 33.07s/it]

KeyboardInterrupt: 

In [None]:
plt.plot([row['loss_mean'] for row in mnist_logs])

In [None]:
latent_vars = mnist_vae.sample_latent_prior(10)
fake_data, probs = mnist_vae.generate_data_from_latent(latent_vars)

In [None]:
for i in range(3):
    plt.figure()
    plt.imshow(np.squeeze(probs[i]), cmap='gray')
    plt.show()

In [None]:
for i in range(3):
    plt.figure()
    plt.imshow(np.squeeze(fake_data[i]), cmap='gray')
    plt.show()