In [None]:
import keras
from keras import backend as K


In [None]:
# We have defined 4 modules:
# Encoder: [input_img] --> [mu, log_var]
# Sampling: [mu, log_var, eps] --> [z]
# Decoder: [z] --> [output_img]
# LossFuntionLayer: [input_img, output_img, mu, log_var] --> [output_img]

In [None]:
from keras.layers import Input,Dense,Conv2D,MaxPooling2D,UpSampling2D,Flatten
from keras import models
shape_z=2

In [None]:
# 1.The Encoder Network

In [None]:
# define convolutional layers
enc_conv1 = Conv2D(32, 3, padding='same',
                   activation='relu', name='enc_conv1')
enc_conv2 = Conv2D(64, 3, padding='same',
                   activation='relu', strides=(2,2), name='enc_conv2')
enc_conv3 = Conv2D(64, 3, padding='same',
                   activation='relu', name='enc_conv3')
enc_conv4 = Conv2D(64, 3, padding='same',
                   activation='relu', name='enc_conv4')

input_img=Input(shape=(28,28,1), name='input_img')
enc_conv_out1 = enc_conv1(input_img)
enc_conv_out2 = enc_conv2(enc_conv_out1)
enc_conv_out3 = enc_conv3(enc_conv_out2)
enc_conv_out4 = enc_conv4(enc_conv_out3)


In [None]:
# define flatten and dense layers
enc_flat = Flatten(name='enc_flat')
enc_dense = Dense(32, activation='relu',
                  name='enc_dense')
enc_mu = Dense(shape_z, name='mu')
enc_log_var = Dense(shape_z, name='enc_log_var')

enc_flat_out = enc_flat(enc_conv_out4)
enc_dense_out = enc_dense(enc_flat_out)
mu = enc_mu(enc_dense_out)
log_var = enc_log_var(enc_dense_out)

In [None]:
# model
encoder = models.Model(inputs=input_img,
                       outputs=[mu, log_var],
                       name='encoder')

In [None]:
encoder.summary()

Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_img (InputLayer)          [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
enc_conv1 (Conv2D)              (None, 28, 28, 32)   320         input_img[0][0]                  
__________________________________________________________________________________________________
enc_conv2 (Conv2D)              (None, 14, 14, 64)   18496       enc_conv1[0][0]                  
__________________________________________________________________________________________________
enc_conv3 (Conv2D)              (None, 14, 14, 64)   36928       enc_conv2[0][0]                  
____________________________________________________________________________________________

In [None]:
# 2. The Sampling Network

In [None]:
from keras.layers import Lambda, Multiply, Add
from keras import backend as K

In [None]:
# inputs
mu = Input(shape=(shape_z,), name='mu')
log_var = Input(shape=(shape_z,), name='log_var')
eps = Input(shape=(shape_z,), name='eps')

In [None]:
# layers
sigma = Lambda(lambda t: K.exp(.5*t), name='sigma')(log_var)
V = Multiply(name='v')([sigma, eps])
z = Add(name='z')([mu,V])

In [None]:
# model
sampling = models.Model(inputs=[mu, log_var, eps], outputs=z, name='sampling')

In [None]:
sampling.summary()

Model: "sampling"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
log_var (InputLayer)            [(None, 2)]          0                                            
__________________________________________________________________________________________________
sigma (Lambda)                  (None, 2)            0           log_var[0][0]                    
__________________________________________________________________________________________________
eps (InputLayer)                [(None, 2)]          0                                            
__________________________________________________________________________________________________
mu (InputLayer)                 [(None, 2)]          0                                            
___________________________________________________________________________________________

In [None]:
# 3. The Docoder Network
import numpy as np

shape_before_flattening = K.int_shape(enc_conv_out4)[1:]
shape_after_flattening = np.prod(shape_before_flattening)


In [None]:
from keras.layers import Dense, Reshape, Conv2D, Conv2DTranspose
dec_dense1 = Dense(32, activation='relu', name='dec_dense1')
dec_dense2 = Dense(shape_after_flattening,
                   activation='relu',name='dec_dense2')
dec_reshape = Reshape(shape_before_flattening)


z = Input(shape=(shape_z,), name='z')
dec_dense_out1 = dec_dense1(z)
dec_dense_out2 = dec_dense2(dec_dense_out1)
dec_reshape_out = dec_reshape(dec_dense_out2)

In [None]:
dec_conv1 = Conv2DTranspose(32, 3, padding='same',
                            activation='relu',
                            strides=(2,2),
                            name='dec_conv1')
dec_conv2 = Conv2D(1, 3, padding='same',
                   activation='relu',
                   name='gen_img')

dec_conv_out1 = dec_conv1(dec_reshape_out)
gen_img = dec_conv2(dec_conv_out1)

In [None]:
decoder = models.Model(inputs=z,
                       outputs=gen_img,
                       name='decoder')

In [None]:
decoder.summary()

Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
z (InputLayer)               [(None, 2)]               0         
_________________________________________________________________
dec_dense1 (Dense)           (None, 32)                96        
_________________________________________________________________
dec_dense2 (Dense)           (None, 12544)             413952    
_________________________________________________________________
reshape (Reshape)            (None, 14, 14, 64)        0         
_________________________________________________________________
dec_conv1 (Conv2DTranspose)  (None, 28, 28, 32)        18464     
_________________________________________________________________
gen_img (Conv2D)             (None, 28, 28, 1)         289       
Total params: 432,801
Trainable params: 432,801
Non-trainable params: 0
_____________________________________________________

In [None]:
class LossFunctionLayer(keras.layers.Layer):
  param = 1E-3
  def kl_loss(self, mu, log_var):
    l = -0.5*K.mean(1+log_var- K.square(mu)- K.exp(log_var),axis=-1)
    return self.param * K.mean(l)
  def gen_loss(self, input_img, output_img):
    l = keras.metrics.binary_crossentropy(input_img, output_img)
    return K.mean(l)
  def call(self, inputs): 
    input_img, output_img, mu, log_var = inputs
    loss1 = self.gen_loss(input_img, output_img)
    loss2 = self.kl_loss(mu, log_var)
    self.add_loss(loss1+loss2) 
    return output_img

In [None]:
input_img =Input(shape=(28,28,1), name='input_img')
n = K.shape(input_img)[0]
esp = Input(tensor=K.random_normal(shape=(n,shape_z)), name='eps')

mu, log_var = encoder(input_img)
z = sampling([mu, log_var, eps])
output_img = decoder(z)
output_img = LossFunctionLayer(name='loss')([input_img, output_img, mu, log_var])
model = models.Model(inputs=[input_img, eps], outputs=output_img)

In [None]:
model.summary()

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_img (InputLayer)          [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
encoder (Functional)            [(None, 2), (None, 2 494244      input_img[0][0]                  
__________________________________________________________________________________________________
eps (InputLayer)                [(None, 2)]          0                                            
__________________________________________________________________________________________________
sampling (Functional)           (None, 2)            0           encoder[0][0]                    
                                                                 encoder[0][1]         

In [None]:
from  keras.datasets import mnist


In [None]:
(x_train,y_train), (x_test,y_test)=mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [None]:
model.compile(optimizer='RMSProp')

In [None]:
history = model.fit(
    x_train,
    None,
    shuffle=True,
    epochs=50,
    batch_size=128,
    # validation_data=(x_test, None)
)

In [None]:
z_sample = np.array([0.1, 0.2]).reshape((1,2))
x_decoded = decoder.predict(z_sample)

fig = plt.figure(figsize=(6,6))
plt.imshow(x_decoded, cmap='gray')
plt.axis('off')
plt.show()