In [193]:
import numpy as np
from datetime import datetime
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Flatten, Reshape, Dense, Layer
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras import Input
import tensorflow.keras as keras
from matplotlib import pyplot as plt

In [194]:
(xtrain, ytrain), (xtest, ytest) = mnist.load_data()

### Encoder

In [195]:
def sampling(args):
    print("here")
    z_mean, z_logsig = args
    epsilon = K.random_normal(shape=z_mean.shape[1:])
    return z_mean + K.exp(z_logsig) * epsilon

In [196]:
input_shape = xtrain[0].shape
hidden_dim = 10

image = Input(shape=input_shape + (1,), name="image")
en = Conv2D(filters=16, kernel_size=5, padding="valid", activation="relu", name="en_conv1")(image)
en = Conv2D(filters=16, kernel_size=3, padding="valid", activation="relu", name="en_conv2")(en)
shape_before_flattening = en.shape[1:]
en = Flatten(name="en_flatten")(en)
en = Dense(units=32, activation="sigmoid", name="en_dense")(en)
z_mean = Dense(units=hidden_dim, activation="linear", name="z_mean")(en)
z_logsig = Dense(units=hidden_dim, activation="linear", name="z_logsig")(en)
encoded_image = sampling([z_mean, z_logsig])
encoder = Model(image, [encoded_image, z_mean, z_logsig], name="encoder")
encoder.summary()

here
Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
image (InputLayer)              [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
en_conv1 (Conv2D)               (None, 24, 24, 16)   416         image[0][0]                      
__________________________________________________________________________________________________
en_conv2 (Conv2D)               (None, 22, 22, 16)   2320        en_conv1[0][0]                   
__________________________________________________________________________________________________
en_flatten (Flatten)            (None, 7744)         0           en_conv2[0][0]                   
_______________________________________________________________________________________

In [197]:
image_code = Input(encoded_image.shape[1:])
de = Dense(units=np.prod(shape_before_flattening), activation="sigmoid", name="de_dense")(image_code)
de = Reshape(target_shape=shape_before_flattening, name="de_flat")(de)
de = Conv2DTranspose(filters=16, kernel_size=3, padding="valid", activation="relu", name="de_conv1")(de)
decoded_image = Conv2DTranspose(filters=1, kernel_size=5, padding="valid", activation="sigmoid", name="de_conv2")(de)
decoder = Model(image_code, decoded_image, name="decoder")
decoder.summary()

Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_43 (InputLayer)        [(None, 10)]              0         
_________________________________________________________________
de_dense (Dense)             (None, 7744)              85184     
_________________________________________________________________
de_flat (Reshape)            (None, 22, 22, 16)        0         
_________________________________________________________________
de_conv1 (Conv2DTranspose)   (None, 24, 24, 16)        2320      
_________________________________________________________________
de_conv2 (Conv2DTranspose)   (None, 28, 28, 1)         401       
Total params: 87,905
Trainable params: 87,905
Non-trainable params: 0
_________________________________________________________________


In [198]:
class CustomVariationalLayer(Layer):
    def vae_loss(self, image, decoded_image, code_moments):
        code_log_var, code_mean = code_moments
        image = K.flatten(image)
        decoded_image = K.flatten(decoded_image)
        xent_loss = keras.metrics.binary_crossentropy(image, decoded_image)
        kl_loss = -5e-4 * K.mean(
            1 + code_log_var - K.square(code_mean) - K.exp(code_log_var)
        , axis=-1)
        return K.mean(xent_loss + kl_loss)
    
    def call(self, inputs):
        image = inputs[0]
        code_moments = inputs[1][1:]
        decoded_image = inputs[2]
        loss = self.vae_loss(image, decoded_image, code_moments)
        self.add_loss(loss, inputs=inputs)
        return image

In [199]:
loss = CustomVariationalLayer(name="loss")([image, encoder(image), decoder(encoder(image))])
vae = Model(image, loss, name="vae")
vae.compile(optimizer="rmsprop", loss=None)
vae.summary()

W1027 19:38:54.966152 4657624512 training_utils.py:1348] Output loss missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to loss.


Model: "vae"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
image (InputLayer)              [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
encoder (Model)                 [(None, 10), (None,  251236      image[0][0]                      
                                                                 image[0][0]                      
__________________________________________________________________________________________________
decoder (Model)                 (None, 28, 28, 1)    87905       encoder[2][0]                    
                                                                 encoder[2][1]                    
                                                                 encoder[2][2]                  

In [200]:
xtrain = xtrain.astype('float32') / 255.
xtrain = xtrain.reshape(xtrain.shape + (1,)) 
xtest = xtest.astype('float32') / 255.
xtest = xtest.reshape(xtest.shape + (1,))

In [201]:
logdir = "tflogs/" + datetime.now().strftime("%Y%m%d-%H%M%S")
callbacks = [
 keras.callbacks.TensorBoard(log_dir=logdir, histogram_freq=1, embeddings_freq=1, update_freq='batch')
]

In [202]:
vae.fit(x=xtrain, y=None, shuffle=True, epochs=10, batch_size=32, validation_data=(xtest, None), callbacks=callbacks)

Train on 60000 samples, validate on 10000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10


KeyboardInterrupt: 

### Sample from latent space

In [None]:
def sample_from_decoder(num_samples):
    sample_codes = np.array([nu * np.random.normal(0, 1, size=hidden_dim) for nu in np.linspace(-10,10,num_samples)])
    return decoder(sample_codes)

num_rows = 10
f, axarr = plt.subplots(num_rows, num_samples, figsize=(25,25))
for row in range(num_rows):
    for col, dec in enumerate(sample_from_decoder(num_samples)):
        axarr[row][col].imshow(dec.numpy()[:,:,0])

In [None]:
plt.imshow(xtrain[0][:,:,0])

In [None]:
code_5 = encoder(np.reshape(xtrain[0], (1, 28, 28, 1)))
decoded_5 = decoder(code_5)
plt.imshow(decoded_5[0].numpy()[:,:,0])

In [None]:
f, axarr = plt.subplots(1,10, figsize=(20,10))
code_5 = encoder(np.reshape(xtrain[0], (1, 28, 28, 1)))
code_5 = code_5[0].numpy()[0,:]
for i in range(10):
    perturbed_code_5 = code_5 + np.random.normal(scale=1.5, size=hidden_dim)
    axarr[i].imshow(decoder(np.reshape(perturbed_code_5, (1,) + perturbed_code_5.shape)).numpy()[0,:,:,0])