In [None]:
import keras
import numpy as np
from keras import backend as k
from keras import optimizers
from keras.layers import *
from keras.models import Model
from keras.callbacks import TensorBoard, ModelCheckpoint
from keras.models import load_model
from keras.engine.topology import Layer
from keras.datasets import mnist
from keras.losses import binary_crossentropy

In [None]:
batchsize = 32
n_z = 20

In [None]:
def sample_z(args):
    mu, log_sigma = args
    batch = K.shape(mu)[0]
    dim = K.int_shape(mu)[1]
    eps = K.random_normal(shape=(batch, dim), mean=0., stddev=1.)
    return mu + K.exp(log_sigma / 2) * eps

In [None]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
batches = int(len(x_train)/batchsize)

In [None]:
image_size = x_train.shape[1]

In [None]:
def batch_generator():
    
    while True:

        for batch in range(batches):
            x = []
            
            data = x_train[batch * batchsize: (batch + 1) * batchsize]
            for i in data:
                x.append(np.reshape(i, (28, 28, 1)))
            
            x_train1 = np.asarray(x)
            yield (x_train1, x_train1) # the data is also the target (since we're re-generating images)

In [None]:
inputs1 = Input(shape = (28, 28, 1))

encode = Conv2D(32, (2, 2), strides=2, padding = 'same')(inputs1)

encode = Activation('relu')(encode)

encode = Conv2D(32, (2, 2), strides=2, padding = 'same')(encode)

encode = Activation('relu')(encode)

out_shape = list(k.int_shape(encode))

encode = Flatten()(encode)

encode = Dense(420, activation = 'relu')(encode)

encode = Dropout(0.5)(encode)

mu = Dense(n_z, activation = 'linear')(encode)

log_sigma = Dense(n_z, activation = 'linear')(encode)

latent_vector = Lambda(sample_z, output_shape=(n_z,))([mu, log_sigma])

model_encoder = Model(inputs = inputs1, outputs = [mu, log_sigma, latent_vector])

In [None]:
model_encoder.summary()

In [None]:
inputs2 = Input(shape = (n_z,))

decode = Dense(out_shape[1] * out_shape[2] * out_shape[3], activation = 'relu')(inputs2)

decode = Reshape((out_shape[1], out_shape[2], out_shape[3]))(decode)

decode = Conv2DTranspose(32, (2, 2), strides=2, padding = 'same')(decode)

decode = Activation('relu')(decode)

decode = Conv2DTranspose(32, (2, 2), strides=2, padding = 'same')(decode)

decode = Activation('relu')(decode)

out = Conv2DTranspose(1, (2, 2), activation = 'sigmoid', padding = 'same')(decode)

model_decoder = Model(inputs = inputs2, outputs = out)

In [None]:
model_decoder.summary()

In [None]:
outputs = model_decoder(model_encoder(inputs1)[2])
model = Model(inputs1, outputs)

In [None]:
def _loss(y_true, y_pred):
    reconstruction_loss = image_size * image_size * binary_crossentropy(K.flatten(y_true), K.flatten(y_pred))
    kl_loss = -0.5 * K.sum(1 + log_sigma - K.square(mu) - K.exp(log_sigma), axis=-1)
    vae_loss = K.mean(reconstruction_loss + kl_loss)
    
    return vae_loss

In [None]:
adam_ = optimizers.Adam(lr = 1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-8)

In [None]:
model.compile(optimizer = adam_, loss = _loss)

In [None]:
model.fit_generator(batch_generator(), steps_per_epoch = batches, epochs = 10, shuffle = True)