In [None]:
# variational autoencoder
# with custom loss (KL)

# https://towardsdatascience.com/variational-autoencoders-as-generative-models-with-keras-e0c79415a7eb

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, MaxPooling2D, Conv2D, Flatten, \
    Dense, Lambda, Reshape, UpSampling2D, Conv2DTranspose
from tensorflow import shape, exp, reduce_mean, square
from tensorflow.keras.backend import random_normal
from tensorflow.keras import Model, losses


import warnings
warnings.filterwarnings('ignore')
%matplotlib inline


# download mnist
(trainX, trainy), (testX, testy) = mnist.load_data()

print('Training data shapes: X=%s, y=%s' % (trainX.shape, trainy.shape))
print('Testing data shapes: X=%s, y=%s' % (testX.shape, testy.shape))

"""
# print random pics
for j in range(5):
    i = np.random.randint(0, 10000)
    plt.subplot(550 + 1 + j)
    plt.imshow(trainX[i], cmap='gray')
    plt.title(trainy[i])
plt.show()
"""


# preprocessing
# normalize pixel values from (0,255) to (0,1)
train_data = trainX.astype('float32')/255
test_data = testX.astype('float32')/255
# add dimension for image channels for Conv2D
train_data = np.reshape(train_data, (60000, 28, 28, 1))
test_data = np.reshape(test_data, (10000, 28, 28, 1))
print (train_data.shape, test_data.shape)


# encoder
input_data = Input(shape=(28, 28, 1)) #tensorflow.keras.layers.Input
# 64 output filters, 5x5 kernel size
encoder = Conv2D(64, (5,5), activation='relu')(input_data)
encoder = MaxPooling2D((2,2))(encoder)
encoder = Conv2D(64, (3,3), activation='relu')(encoder)
encoder = MaxPooling2D((2,2))(encoder)
encoder = Conv2D(32, (3,3), activation='relu')(encoder)
encoder = MaxPooling2D((2,2))(encoder)
encoder = Flatten()(encoder)
encoder = Dense(16)(encoder)

# encoder latent features
def sample_latent_features(distribution):
    distribution_mean, distribution_variance = distribution
    batch_size = shape(distribution_variance)[0]
    random = random_normal(shape=(batch_size, shape(distribution_variance)[1]))
    return distribution_mean + exp(0.5 * distribution_variance) * random

distribution_mean = Dense(2, name='mean')(encoder)
distribution_variance = Dense(2, name='log_variance')(encoder)
latent_encoding = Lambda(sample_latent_features)([distribution_mean, distribution_variance])

# build encoder
encoder_model = Model(input_data, latent_encoding)
encoder_model.summary()


################
# decoder
decoder_input = Input(shape=(2))
decoder = Dense(64)(decoder_input)
decoder = Reshape((1, 1, 64))(decoder)
decoder = Conv2DTranspose(64, (3,3), activation='relu')(decoder)

decoder = Conv2DTranspose(64, (3,3), activation='relu')(decoder)
decoder = UpSampling2D((2,2))(decoder)

decoder = Conv2DTranspose(64, (3,3), activation='relu')(decoder)
decoder = UpSampling2D((2,2))(decoder)

decoder_output = Conv2DTranspose(1, (5,5), activation='relu')(decoder)

decoder_model = Model(decoder_input, decoder_output)
decoder_model.summary()



# build VAE
encoded = encoder_model(input_data)
decoded = decoder_model(encoded)
autoencoder = Model(input_data, decoded)
autoencoder.summary()


# custom loss (Reconstruction Loss + KL loss)
def get_loss(distribution_mean, distribution_variance):
    def get_reconstruction_loss(y_true, y_pred):
        reconstruction_loss = losses.mse(y_true, y_pred)
        reconstruction_loss_batch = reduce_mean(reconstruction_loss)
        return reconstruction_loss_batch*28*28
    def get_kl_loss(distribution_mean, distribution_variance):
        kl_loss = 1 + distribution_variance - square(distribution_mean) - exp(distribution_variance)
        kl_loss_batch = reduce_mean(kl_loss)
        return kl_loss_batch
    def total_loss(y_true, y_pred):
        reconstruction_loss_batch = get_reconstruction_loss(y_true,y_pred)
        kl_loss_batch = get_kl_loss(distribution_mean, distribution_variance)
        return reconstruction_loss_batch + kl_loss_batch
    return total_loss


# compile
autoencoder.compile(loss=get_loss(distribution_mean, distribution_variance), optimizer="adam")

# train
autoencoder.fit(train_data, train_data, epochs=20, batch_size=64, validation_data=(test_data, test_data))