In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
from keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [3]:
X_train = X_train.astype('float32')/255.
X_test = X_test.astype('float32')/255.

X_train = X_train.reshape(60000, 28*28)
X_test = X_test.reshape(10000, 28*28)

X_train.shape, X_test.shape

((60000, 784), (10000, 784))

In [4]:
from keras.layers import Input, Dense

input_img = Input(shape=(784,))

encoded = Dense(256, activation='elu')(input_img)
encoded = Dense(128, activation='elu')(encoded)

mean = Dense(2, name='mean')(encoded)
log_var = Dense(2, name='var')(encoded)

In [5]:
from keras import backend as K
from keras.layers import Lambda

def sampling(args):
  mean, log_var = args
  epsilon = K.random_normal(shape=(100,2), mean=0., stddev=1.0)
  return mean + K.exp(log_var)*epsilon

z = Lambda(sampling, output_shape=(2,))([mean, log_var])

In [6]:
from keras.models import Model

encoder = Model(input_img, mean)
encoder.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 784)]             0         
_________________________________________________________________
dense (Dense)                (None, 256)               200960    
_________________________________________________________________
dense_1 (Dense)              (None, 128)               32896     
_________________________________________________________________
mean (Dense)                 (None, 2)                 258       
Total params: 234,114
Trainable params: 234,114
Non-trainable params: 0
_________________________________________________________________


In [8]:
decoder_1 = Dense(128, activation='elu')
decoder_2 = Dense(256, activation='elu')
decoder_3 = Dense(784, activation='sigmoid')

z_sample = decoder_1(z)
z_sample = decoder_2(z_sample)
z_sample = decoder_3(z_sample)
z_sample.shape

TensorShape([100, 784])

In [11]:
z_sampe = Dense(128,activation='elu')(z)
z_sampe = Dense(256, activation='elu')(z_sampe)
z_sampe = Dense(784, activation='sigmoid')(z_sampe)
z_sampe.shape

TensorShape([100, 784])

In [13]:
decoder_input = Input(shape = (2,))

y_gen = decoder_1(decoder_input)
y_gen = decoder_2(y_gen)
y_gen = decoder_3(y_gen)

generator = Model(decoder_input, y_gen)
generator.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 2)]               0         
_________________________________________________________________
dense_5 (Dense)              multiple                  384       
_________________________________________________________________
dense_6 (Dense)              multiple                  33024     
_________________________________________________________________
dense_7 (Dense)              multiple                  201488    
Total params: 234,896
Trainable params: 234,896
Non-trainable params: 0
_________________________________________________________________


In [14]:
vae = Model(input_img, z_sample)
vae.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 784)]        0                                            
__________________________________________________________________________________________________
dense (Dense)                   (None, 256)          200960      input_1[0][0]                    
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 128)          32896       dense[0][0]                      
__________________________________________________________________________________________________
mean (Dense)                    (None, 2)            258         dense_1[0][0]                    
____________________________________________________________________________________________

In [15]:
from keras import objectives

reconstruction_loss = objectives.binary_crossentropy(input_img, z_sample)
kl_loss = 0.0005 * K.mean(K.square(mean) + K.exp(log_var) - log_var - 1, axis = -1)
vae_loss = reconstruction_loss + kl_loss

In [16]:
vae.add_loss(vae_loss)

In [17]:
vae.compile(optimizer = 'adam')

In [None]:
%%time
vae.fit(X_train, shuffle = True,  epochs = 300, batch_size = 100, validation_data = (X_test, None))

Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300
Epoch 50/300
Epoch 51/300
Epoch 52/300
Epoch 53/300
Epoch 54/300
Epoch 55/300
Epoch 56/300
Epoch 57/300
Epoch 58/300
Epoch 59/300
Epoch 60/300
Epoch 61/300
Epoch 62/300
Epoch 63/300
Epoch 64/300
Epoch 65/300
Epoch 66/300
Epoch 67/300
Epoch 68/300
Epoch 69/300
Epoch 70/300
Epoch 71/300
Epoch 72/300
Epoch 73/300
Epoch 74/300
Epoch 75/300
Epoch 76/300
Epoch 77/300
Epoch 78

In [None]:
X_test_latent= encoder.predict(X_test, batch_size=100)

plt.figure(figsize=(12,10))
plt.scatter(X_test_latent[:,0], X_test_latent[:,1], c=y_test)
plr.colorbar()
plt.show()

In [None]:
from scipy.stats import norm

n=20
digit_size=28
figure = np.zeros((digit_size*n, digit_size*n))

grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

for i, yi in enumerate(grid_x):
  for j, xi in enumerate(grid_y):
    z_sample = np.array([[xi, yi]])
    x_decoded = generator.predict(z_sample)
    digit = x_decoded[0].reshape(digit_size, digit_size)
    figure[i * digit_size: (i + 1) * digit_size, j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize = (10, 10))
plt.imshow(figure, cmap = 'Greys_r')
plt.show()