In [None]:
import keras
import numpy
import matplotlib.pyplot as plt
import os
import random
import tensorflow

from keras import backend as K
from keras.datasets import mnist
from keras.engine.topology import Layer
from keras.layers import Activation, Dense, Input, Lambda
from keras.layers import Conv2D, Flatten
from keras.layers import Reshape, Conv2DTranspose
from keras.losses import binary_crossentropy, mean_squared_error
from keras.models import Model
from keras.utils import plot_model
from PIL import Image
from numpy.linalg import inv, det
from sklearn.mixture import GaussianMixture
from scipy.stats import mode

numpy.random.seed(42)

# Network parameters
batch_size = 128
num_epochs = 60
kernel_size = 4
latent_dims = [32]
strides = 2
layer_filters = [32, 64]
here = ''

# mnist dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
y_train = y_train.flatten()
y_test = y_test.flatten()

def preprocess(data):
    if data.ndim == 3:
        data = numpy.asarray([data]).transpose((1, 2, 3, 0))
    maxima = data.max(axis=tuple(range(1, data.ndim))).reshape((len(data),) + (1,) * (data.ndim - 1))
    return data.astype('float32') / maxima, maxima

image_size = x_train.shape[1]
x_train, train_decode = preprocess(x_train)
input_shape = x_train.shape[1:]
num_channels = x_train.shape[-1]
x_test, test_decode = preprocess(x_test)

x_train = numpy.clip(x_train, 0., 1.)
x_test = numpy.clip(x_test, 0., 1.)

def run_tests(model_name):

    encoder_layers = []
    for filters in layer_filters:
        encoder_layers.append(Conv2D(filters=filters,
                   kernel_size=kernel_size,
                   strides=strides,
                   activation='relu',
                   padding='same'))

    decoder_layers = []
    for filters in [num_channels] + layer_filters[:-1]:
        decoder_layers.append(Conv2DTranspose(filters=filters,
                   kernel_size=kernel_size,
                   strides=strides,
                   activation='relu',
                   padding='same'))
    decoder_layers[0].activation = Activation('sigmoid')

    # Encoder
    common_input = Input(shape=input_shape, name='encoder_input')
    x = common_input
    for layer in encoder_layers:
        x = layer(x)

    conv_shape = K.int_shape(x)[1:]
    x = Flatten()(x)
    flat_shape = K.int_shape(x)[1:]

    # Latent Layer
    for latent_dim in latent_dims[:-1]:
        layer = Dense(latent_dim, activation='relu')
        x = layer(x)
    layer.activation = Activation(None)

    def sampling(args):
        '''Reparameterization trick by sampling fr an isotropic unit Gaussian.

        # Arguments:
            args (tensor): mean and log of variance of Q(z|X)

        # Returns:
            z (tensor): sampled latent vector
        '''

        z_mean, z_log_var = args
        batch = K.shape(z_mean)[0]
        dim = K.int_shape(z_mean)[1]
        # by default, random_normal has mean=0 and std=1.0
        epsilon = K.random_normal(shape=(batch, dim))
        return z_mean + K.exp(0.5 * z_log_var) * epsilon

    latent_dim = latent_dims[-1]
    z_mean = Dense(latent_dim, name='z_mean')(x)
    z_log_var = Dense(latent_dim, name='z_log_var')(x)

    # use reparameterization trick to push the sampling out as input
    # note that 'output_shape' isn't necessary with the TensorFlow backend
    z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])
    # z = Dense(latent_dim, name='latent_layer')(x)
    # layer.activation=Activation(None)

    encoder = Model(common_input, [z, z_mean, z_log_var], name='encoder')
    encoder_mean = Model(common_input, z_mean, name='encoder_mean')

    # Decoder
    decoder_input = Input(shape=(latent_dim,), name='decoder_input')
    x = decoder_input
    for i, latent_dim in enumerate(latent_dims[-2::-1] + [numpy.prod(encoder_layers[-1].output_shape[1:])]):
        layer = Dense(latent_dim, activation='relu')
        if i == 0:
            layer.activation = Activation(None)
        x = layer(x)
    x = Reshape(conv_shape)(x)

    for layer in decoder_layers[::-1]:
        x = layer(x)
    decoder = Model(decoder_input, x, name='decoder')

    def elbo_loss(yTrue, yPred):
        sample_mean = K.mean(z_mean, 0)
        # large batch size ~> unbiased estimator
        sample_log_var = K.log(K.mean(K.exp(z_log_var), 0))
        kl_loss = K.sum((-z_log_var + K.square(z_mean) + K.exp(z_log_var)) / 2, axis=-1)
        reconstruction_loss = mean_squared_error(K.flatten(yTrue), K.flatten(yPred)) * numpy.prod(x_train.shape[1:])
        return K.mean(reconstruction_loss + kl_loss)

    def elbo_loss_sample(yTrue, yPred):
        sample_mean = K.mean(z_mean, 0)
        # large batch size ~> unbiased estimator
        sample_log_var = K.log(K.mean(K.exp(z_log_var), 0))
        kl_loss = K.sum((-sample_log_var + K.square(sample_mean) + K.exp(sample_log_var)) / 2, axis=-1)
        reconstruction_loss = mean_squared_error(K.flatten(yTrue), K.flatten(yPred)) * numpy.prod(x_train.shape[1:])
        return K.mean(reconstruction_loss) + kl_loss

    # Autoencoder
    if model_name == 'variational':
        autoencoder_output = decoder(encoder(common_input)[0])
        autoencoder = Model(common_input, autoencoder_output, name='autoencoder')
        autoencoder.compile(loss=elbo_loss, optimizer='adam')
        autoencoder.fit(x_train,
                x_train,
                validation_data=(x_test, x_test),
                epochs=num_epochs,
                batch_size=batch_size)
    elif model_name == 'variational_sample':
        autoencoder_output = decoder(encoder(common_input)[0])
        autoencoder = Model(common_input, autoencoder_output, name='autoencoder')
        autoencoder.compile(loss=elbo_loss_sample, optimizer='adam')
        autoencoder.fit(x_train,
                x_train,
                validation_data=(x_test, x_test),
                epochs=num_epochs,
                batch_size=batch_size)
    elif model_name == 'vanilla':
        autoencoder_output = decoder(encoder_mean(common_input))
        autoencoder = Model(common_input, autoencoder_output, name='autoencoder')
        autoencoder.compile(loss='mean_squared_error', optimizer='adam')
        autoencoder.fit(x_train,
                x_train,
                validation_data=(x_test, x_test),
                epochs=num_epochs,
                batch_size=batch_size)
    elif model_name == 'double':
        autoencoder_output = decoder(encoder_mean(common_input))
        autoencoder = Model(common_input, autoencoder_output, name='autoencoder')
        autoencoder.compile(loss='mean_squared_error', optimizer='adam')

        autodecoder_output = encoder_mean(decoder(decoder_input))
        autodecoder = Model(decoder_input, autodecoder_output, name='autodecoder')
        autodecoder.compile(loss='mean_squared_error', optimizer='adam')
        n = 10
        for _ in range(n):
            train = numpy.random.normal(size=(len(x_train), latent_dims[-1]))
            test = numpy.random.normal(size=(len(x_test), latent_dims[-1]))
            autodecoder.fit(train,
                    train,
                    validation_data=(test, test),
                    epochs=num_epochs // (2 * n),
                    batch_size=batch_size)

            autoencoder.fit(x_train,
                    x_train,
                    validation_data=(x_test, x_test),
                    epochs=num_epochs // (2 * n),
                    batch_size=batch_size)

    # model_name reconstruction
    autoencoder_output = decoder(encoder_mean(common_input))
    autoencoder_final = Model(common_input, autoencoder_output, name='autoencoder_final')
    autoencoder_final.compile(loss='mean_squared_error', optimizer='adam')
    x_decoded = autoencoder_final.predict(x_test)
    
    n = 8
    topn = numpy.argsort(((x_decoded - x_test) ** 2).mean((1, 2, 3)))[-1:-n:-1]
    i1 = x_test[topn]
    i2 = x_decoded[topn]

    if x_test.shape[-1] < 3:
        i1 = numpy.tile(i1, (1,) * (i1.ndim - 1) + (3,))
        i2 = numpy.tile(i2, (1,) * (i2.ndim - 1) + (3,))

    i1 = numpy.concatenate(i1, axis=1)
    i2 = numpy.concatenate(i2, axis=1)

    reconstruction = numpy.concatenate((i1, i2), axis=0)
    
    model_dir = os.path.join(here, model_name)
    os.makedirs(model_dir, exist_ok=True)
    if latent_dims[-1] == 2:
        filename = os.path.join(model_dir, 'digits_over_latent.png')
        # display a 30x30 2D manifold of digits
        n = 30
        digit_size = x_test.shape[1]
        figure = numpy.zeros((digit_size * n, digit_size * n))
        # linearly spaced coordinates corresponding to the 2D plot
        # of digit classes in the latent space
        grid_x = numpy.linspace(-4, 4, n)
        grid_y = numpy.linspace(-4, 4, n)[::-1]

        for i, yi in enumerate(grid_y):
            for j, xi in enumerate(grid_x):
                z_sample = numpy.array([[xi, yi]])
                x_decoded = decoder.predict(z_sample)
                digit = x_decoded[0].reshape(x_train.shape[1:])
                if digit.ndim == 3:
                    digit = digit.mean(2)
                figure[i * digit_size: (i + 1) * digit_size,
                       j * digit_size: (j + 1) * digit_size] = digit

        plt.clf()
        plt.tight_layout()
        plt.figure(figsize=(8, 8))
        start_range = digit_size // 2
        end_range = n * digit_size + start_range + 1
        pixel_range = numpy.arange(start_range, end_range, digit_size)
        sample_range_x = numpy.round(grid_x, 1)
        sample_range_y = numpy.round(grid_y, 1)
        plt.xticks(pixel_range, sample_range_x)
        plt.yticks(pixel_range, sample_range_y)
        plt.xlabel('z[0]')
        plt.ylabel('z[1]')
        plt.imshow(figure, cmap='Greys_r')
        plt.savefig(filename)    
    else:
        filename = None
        
    latent_representation_train = encoder_mean.predict(x_train)
    labels = frozenset(y_train)
    means = []
    precisions = []
    dets = []
    for i in labels:
        data = latent_representation_train[y_train == i]
        mean = data.mean(0)
        means.append(mean)
        precision = inv(numpy.cov(data.T))
        precisions.append(precision)
        dets.append(det(precision))
    means = numpy.asarray(means)
    precisions = numpy.asarray(precisions)
    dets = numpy.asarray(dets)
    
    latent_representation_test = encoder_mean.predict(x_test)
    centered = latent_representation_test - means[:, numpy.newaxis]
    a = numpy.asarray([numpy.tensordot(centered0, precision0, axes=([1], [0])) for centered0, precision0 in zip(centered, precisions)])
    exponent = -(a ** 2).sum(2) / 2
    k_means_predictions = numpy.argmin((centered ** 2).sum(2), axis=0)
    gmm_predictions = numpy.argmax(exponent + 0.5 * dets[:, numpy.newaxis], axis=0)
    plot_model(encoder_mean, to_file=os.path.join(model_dir, 'encoder_{0}.png'.format(str(latent_dims[-1]))), show_shapes=True)
    return reconstruction, filename, (k_means_predictions == y_test).mean() * 100, (gmm_predictions == y_test).mean() * 100

imgs = []
models = ['variational_sample']
generative_files = []
k_means = []
gmm = []
for model in models:
    reconstruction, generative, k_means_correct, gmm_correct = run_tests(model)
    generative_files.append(generative)
    k_means.append(k_means_correct)
    gmm.append(gmm_correct)
    imgs.append(reconstruction)
    
imgs = numpy.concatenate(imgs, axis=0)

Train on 60000 samples, validate on 10000 samples
Epoch 1/60
Epoch 2/60

In [None]:
plt.clf()
plt.rcParams['figure.figsize'] = [8, 6]
plt.figure(dpi=800)
plt.title(', '.join(models))
plt.imshow(imgs, interpolation='none', cmap='gray')
plt.axis('off')
plt.savefig('reconstructed_{0}.png'.format(latent_dims[-1]))

In [5]:
k_means

[61.71, 55.269999999999996, 17.130000000000003, 49.78]

In [6]:
gmm

[60.099999999999994, 55.82, 16.03, 49.45]

In [3]:
gmm

[61.67, 54.669999999999995, 11.360000000000001, 9.8]

In [3]:
gmm

[39.47]

In [None]:
gmm