In [None]:
import tensorflow as tf
from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer, Dense, Flatten, Conv2D, Input, LeakyReLU, Reshape, Conv2DTranspose
import numpy as np
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
from pathlib import Path
import os

In [None]:
mnist = tf.keras.datasets.mnist
(X, y), (X_test, y_test) = mnist.load_data()
X, X_test = X / 255 * 2 - 1, X_test / 255.0 * 2 - 1
X = X.reshape(-1, 28, 28, 1)
X_test = X_test.reshape(-1, 28, 28, 1)

In [3]:
dataset = tf.data.Dataset.from_tensor_slices(X.astype('float32')).batch(32)
test_dataset = tf.data.Dataset.from_tensor_slices(X_test.astype('float32')).batch(32)

In [4]:
latent_dim = 4

In [None]:
X = X.reshape(-1, 28, 28, 1)
X_test = X_test.reshape(-1, 28, 28, 1)

In [61]:
def get_encoder(input_shape, latent_dim):
    
    i = Input(shape=input_shape)
    
    x = Conv2D(8, 3, activation=LeakyReLU(alpha=0.2)) (i)
    x = Conv2D(32, 3, activation=LeakyReLU(alpha=0.2)) (x)
    x = Conv2D(64, 3, activation=LeakyReLU(alpha=0.2)) (x)
    
    x = Flatten()(x)
    
    x = Dense(2 * latent_dim, activation='relu')(x)
    
    return Model(i, x)

In [34]:
def get_decoder(input_shape, latent_dim):
    
    i = Input(shape=input_shape)
    
    x = Dense(484, activation='relu')(i)
    
    x = Reshape((22, 22, 1)) (x)
    
    x = Conv2DTranspose(128, 3, activation=LeakyReLU(alpha=0.2)) (x)
    x = Conv2DTranspose(64, 3, activation=LeakyReLU(alpha=0.2)) (x)
    x = Conv2DTranspose(1, 3, activation='tanh') (x)
    
    return Model(i, x)

In [107]:
from tensorflow.keras import backend as K

def loss_fn(inputs, outputs, z_mean, z_log_var):
    
    regularization_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    regularization_loss = -0.5 * K.sum(regularization_loss, axis=-1) 
    reconstruction_loss = K.sum(K.square(outputs - inputs), axis=[1,2,3])      

    return K.mean(reconstruction_loss + regularization_loss*K.constant(0.1))

In [108]:
class VAE(tf.keras.Model):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = get_encoder((28, 28, 1), latent_dim)
        self.decoder = get_decoder((latent_dim,), latent_dim)
    
    def get_input_for_z(self, mean, logvar):
        
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * .5) + mean

In [109]:
model = VAE()
optimizer = Adam(1e-4)

In [110]:
@tf.function
def train_step(batch):
        
    with tf.GradientTape() as tape:
        
        x = model.encoder(batch)
        mean, logvar = tf.split(x, num_or_size_splits=2, axis=1)
        z = model.get_input_for_z(mean, logvar)
        y = model.decoder(z)

        loss = loss_fn(batch, y, mean, logvar)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

In [None]:
for epoch in range(10):
    for batch in dataset:
        train_step(batch)

In [114]:
for test_batch in test_dataset.take(1):
    test_sample = test_batch[0:16, :, :, :]

In [None]:
x = model.encoder(test_sample)
mean, logvar = tf.split(x, num_or_size_splits=2, axis=1)
z = model.get_input_for_z(mean, logvar)
y = model.decoder(z)
fig = plt.figure(figsize=(4, 4))

for i in range(y.shape[0]):
    plt.subplot(4, 4, i + 1)
    plt.imshow(y[i, :, :, 0], cmap='gray')
    plt.axis('off')