<a href="https://colab.research.google.com/github/PeterPirog/tf-autoencoders/blob/main/01_mnist_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#VAE CODER
import tensorflow.keras as keras

from keras import layers
from keras import backend as K 
from keras.models import Model
import numpy as np

img_shape=(28,28,1)
batch_size=16
latent_dim =2

input_img=keras.Input(shape=img_shape)

x=layers.Conv2D(32,3,padding='same',activation='relu')(input_img)
x=layers.Conv2D(64,3,padding='same',activation='relu',strides=(2,2))(x)
x=layers.Conv2D(64,3,padding='same',activation='relu')(x)
x=layers.Conv2D(64,3,padding='same',activation='relu')(x)

shape_before_flattening=K.int_shape(x)

x=layers.Flatten()(x)
x=layers.Dense(32,activation='relu')(x)

z_mean=layers.Dense(latent_dim)(x)
z_log_var=layers.Dense(latent_dim)(x)



In [2]:

# Sampling layer
def sampling(args):
  z_mean, z_log_var=args
  epsilon=K.random_normal(shape=(K.shape(z_mean)[0],latent_dim),mean=0,stddev=1.)
  return z_mean + K.exp(z_log_var)*epsilon

z=layers.Lambda(sampling)([z_mean, z_log_var])

In [3]:
# Layer to map latent vector to image
decoder_input =layers.Input(K.int_shape(z)[1:])

x=layers.Dense(np.prod(shape_before_flattening[1:]),activation='relu')(decoder_input)
x=layers.Reshape(shape_before_flattening[1:])(x)
x=layers.Conv2DTranspose(32,3,padding='same',activation='relu',strides=(2,2))(x)
x=layers.Conv2D(1,3,padding='same',activation='sigmoid')(x)

decoder = Model(decoder_input,x)
z_decoded=decoder(z)
decoder.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 2)]               0         
                                                                 
 dense_3 (Dense)             (None, 12544)             37632     
                                                                 
 reshape (Reshape)           (None, 14, 14, 64)        0         
                                                                 
 conv2d_transpose (Conv2DTra  (None, 28, 28, 32)       18464     
 nspose)                                                         
                                                                 
 conv2d_4 (Conv2D)           (None, 28, 28, 1)         289       
                                                                 
Total params: 56,385
Trainable params: 56,385
Non-trainable params: 0
_________________________________________________________

In [5]:
class CustomVariationalLayer(keras.layers.Layer):
  def vae_loss(self,x,z_decoded):
    x=K.flatten(x)
    z_decoded=K.flatten(z_decoded)
    xent_loss=keras.metrics.binary_crossentropy(x,z_decoded)
    kl_loss=-5e-4*K.mean(1+z_log_var-K.square(z_mean)-K.exp(z_log_var),axis=-1)
    return K.mean(xent_loss+kl_loss)
  
  def call(self, inputs):
    x=inputs[0]
    z_decoded=inputs[1]
    loss=self.vae_loss(x,z_decoded)
    self.add_loss(loss, inputs=inputs)

y=CustomVariationalLayer()([input_img,z_decoded])

In [7]:
from keras.datasets import mnist

vae=Model(input_img,y)
vae.compile(optimizer='rmsprop',loss=None)
vae.summary()

Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 28, 28, 1)]  0           []                               
                                                                                                  
 conv2d (Conv2D)                (None, 28, 28, 32)   320         ['input_1[0][0]']                
                                                                                                  
 conv2d_1 (Conv2D)              (None, 14, 14, 64)   18496       ['conv2d[0][0]']                 
                                                                                                  
 conv2d_2 (Conv2D)              (None, 14, 14, 64)   36928       ['conv2d_1[0][0]']               
                                                                                            

In [11]:
(x_train,_),(x_test,_)=mnist.load_data()
x_train=x_train.astype('float32')/255
x_train=x_train.reshape(x_train.shape+(1,))

x_test=x_test.astype('float32')/255
x_test=x_test.reshape(x_test.shape+(1,))

In [13]:
vae.fit(x=x_train,y=None,
        shuffle=True,
        epochs=10,
        batch_size=batch_size,
        validation_data=(x_test, None))

Epoch 1/10


TypeError: ignored