In [2]:
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input, Activation, Add, Conv2D, Dense, Layer, Dropout, Conv2DTranspose, LeakyReLU, Reshape, Flatten, GlobalMaxPool2D
from tensorflow.keras.models import Model

# VAE

In [2]:
l2_reg = 0.001
droprate = 0.1

In [3]:
class Sampling(Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.random.normal([batch, dim])
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon
    
def conv_block(x, channels, kernel_size = 3, padding = 'same'):
    x = Conv2D(channels, kernel_size, padding=padding, 
               kernel_regularizer=tf.keras.regularizers.l2(l2_reg), bias_regularizer=tf.keras.regularizers.l2(l2_reg))(x)
    x = LeakyReLU()(x)
    return x

def res_block(x, channels, kernel_size = 3):
    input_x = x
    x = Conv2D(channels, kernel_size, padding='same', 
               kernel_regularizer=tf.keras.regularizers.l2(l2_reg), bias_regularizer=tf.keras.regularizers.l2(l2_reg))(x)
    x = LeakyReLU()(x)
    x = Add()([input_x, x])
    return x

def downsampling_conv_block(x, channels, kernel_size = 4):
    x = Conv2D(channels, kernel_size, strides=(2, 2), padding="same",
               kernel_regularizer=tf.keras.regularizers.l2(l2_reg), bias_regularizer=tf.keras.regularizers.l2(l2_reg))(x)
    x = LeakyReLU()(x)
    return x

def upsampling_conv_block(x, channels, kernel_size = 3):
    x = Conv2DTranspose(channels, kernel_size, strides=2, padding="same", 
                        kernel_regularizer=tf.keras.regularizers.l2(l2_reg), bias_regularizer=tf.keras.regularizers.l2(l2_reg))(x)
    x = LeakyReLU()(x)
    return x

def create_encoder(latent_dim, num_layer, channel_multiplier):
    encoder_iput = Input(shape=image_shape, name='image')
    channels = channel_multiplier
    x = conv_block(encoder_iput, channels, kernel_size = 4)
    x = res_block(x, channels)
    x = Dropout(rate=droprate)(x)
    
    print(K.int_shape(x))
    for i in range(num_layer):
        channels *= 2
        x = downsampling_conv_block(x, channels)
        print(K.int_shape(x))
        x = res_block(x, channels)
        x = Dropout(rate=droprate)(x)
        
    last_conv_shape = K.int_shape(x)
    x = GlobalMaxPool2D()(x)

    z_mean = Dense(latent_dim, name='z_mean', 
               kernel_regularizer=tf.keras.regularizers.l2(l2_reg), bias_regularizer=tf.keras.regularizers.l2(l2_reg))(x)
    z_log_var = Dense(latent_dim, name='z_log_var', 
               kernel_regularizer=tf.keras.regularizers.l2(l2_reg), bias_regularizer=tf.keras.regularizers.l2(l2_reg))(x)
    z = Sampling()([z_mean, z_log_var])

    model = Model(encoder_iput, [z_mean, z_log_var, z], name='encoder')
    model.summary()
    return model, last_conv_shape

def create_decoder(latent_dim, first_conv_shape, num_layer):
    decoder_input = Input(shape=(latent_dim,), name='latent_z')
    x = Dense(first_conv_shape[1] * first_conv_shape[2] * first_conv_shape[3], 
               kernel_regularizer=tf.keras.regularizers.l2(l2_reg), bias_regularizer=tf.keras.regularizers.l2(l2_reg))(decoder_input)
    x = Reshape((first_conv_shape[1], first_conv_shape[2], first_conv_shape[3]))(x)
    
    print(K.int_shape(x))
    channels = first_conv_shape[3]
    
    for i in range(num_layer):
        x = res_block(x, channels)
        channels //= 2
        x = upsampling_conv_block(x, channels)
        print(K.int_shape(x))
        x = Dropout(rate=droprate)(x)
        
    x = res_block(x, channels)
    x = Conv2D(3, 3, padding='same',
               kernel_regularizer=tf.keras.regularizers.l2(l2_reg), bias_regularizer=tf.keras.regularizers.l2(l2_reg))(x)
    x = Activation('sigmoid', name='rec_image')(x)
    print(K.int_shape(x))
    model = Model(decoder_input, x, name='decoder')
    model.summary()
    return model

In [4]:
latent_dim = 128
num_layer = 5
image_shape = (128, 128, 3)
channel_multiplier = 8

model_encoder, last_conv_shape = create_encoder(latent_dim, num_layer, channel_multiplier)
model_decoder = create_decoder(latent_dim, last_conv_shape, num_layer)

(None, 128, 128, 8)
(None, 64, 64, 16)
(None, 32, 32, 32)
(None, 16, 16, 64)
(None, 8, 8, 128)
(None, 4, 4, 256)
Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
image (InputLayer)              [(None, 128, 128, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 128, 128, 8)  392         image[0][0]                      
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 128, 128, 8)  0           conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 128, 128, 8)  584         leaky_re_lu[0