In [None]:
def build_vae_encoder(input_dim, output_dim):

    global K
    K.clear_session()

    encoder_input = Input(shape = input_dim, name = 'encoder_input')
    x = encoder_input

    conv1 = Conv2D( 8 , kernel_size=( 5 , 5 ) , strides=1 )( x )
    conv1 = LeakyReLU()( conv1 )
    conv1 = Conv2D( 16 , kernel_size=( 3 , 3 ) , strides=1)( conv1 )
    conv1 = LeakyReLU()( conv1 )
    conv1 = Conv2D( 16 , kernel_size=( 3 , 3 ) , strides=1)( conv1 )
    conv1 = LeakyReLU()( conv1 )

    conv2 = Conv2D( 16 , kernel_size=( 5 , 5 ) , strides=1)( conv1 )
    conv2 = LeakyReLU()( conv2 )
    conv2 = Conv2D( 32 , kernel_size=( 3 , 3 ) , strides=1 )( conv2 )
    conv2 = LeakyReLU()( conv2 )
    conv2 = Conv2D( 32 , kernel_size=( 3 , 3 ) , strides=1 )( conv2 )
    conv2 = LeakyReLU()( conv2 )

    conv3 = Conv2D( 32 , kernel_size=( 5 , 5 ) , strides=1 )( conv2 )
    conv3 = LeakyReLU()( conv3 )
    conv3 = Conv2D( 64 , kernel_size=( 3 , 3 ) , strides=1 )( conv3 )
    conv3 = LeakyReLU()( conv3 )
    conv3 = Conv2D( 64 , kernel_size=( 3 , 3 ) , strides=1 )( conv3 )
    conv3 = LeakyReLU()( conv3 )

 
    shape_before_flattening = K.int_shape(conv3)[1:]

    x = Flatten()(conv3)

    mean_mu = Dense(64, name = 'mu')(x)
    log_var = Dense(64, name = 'log_var')(x)

  
    z = Sampling()([mean_mu, log_var])
    cshape1 = tf.shape(conv1)[2]
    cshape2 = tf.shape(conv2)[2]
    cshape3 = tf.shape(conv3)[2]

    return encoder_input, [z, conv3, conv2, conv1], [cshape1, cshape2, cshape3], mean_mu, log_var , shape_before_flattening, Model(inputs = encoder_input, outputs = [mean_mu, log_var, z, conv3, conv2, conv1])

In [None]:
vae_encoder_input, vae_encoder_output, cshape, mean_mu, log_var, vae_shape_before_flattening, encoder  = build_vae_encoder(input_dim = INPUT_DIM,
                                    output_dim = Z_DIM)

encoder.summary()

In [None]:
def build_decoder(input_dim, shape_before_flattening):

    # Define model input
    decoder_input = Input(shape = (input_dim,) , name = 'decoder_input')
    conv3 = Input(shape = (200,200,64,) , name = 'conv3')
    conv2 = Input(shape = (208,208,32,) , name = 'conv2')
    conv1 = Input(shape = (216,216,16,) , name = 'conv1')

    x = Dense(np.prod(shape_before_flattening))(decoder_input)
    x = Reshape(shape_before_flattening)(x)


    concat_1 = Concatenate()( [ x , conv3 ] )
    conv_up_3 = Conv2DTranspose( 64 , kernel_size=( 3 , 3 ) , strides=1 )( concat_1 )
    conv_up_3 = LeakyReLU()( conv_up_3 )
    conv_up_3 = Conv2DTranspose( 64 , kernel_size=( 3 , 3 ) , strides=1 )( conv_up_3 )
    conv_up_3 = LeakyReLU()( conv_up_3 )
    conv_up_3 = Conv2DTranspose( 32 , kernel_size=( 5 , 5 ) , strides=1 )( conv_up_3 )
    conv_up_3 = LeakyReLU()( conv_up_3 )

    concat_2 = Concatenate()( [ conv_up_3 , conv2 ] )
    conv_up_2 = Conv2DTranspose( 32 , kernel_size=( 3 , 3 ) , strides=1 )( concat_2 )
    conv_up_2 = LeakyReLU()( conv_up_2 )
    conv_up_2 = Conv2DTranspose( 16 , kernel_size=( 3 , 3 ) , strides=1 )( conv_up_2 )
    conv_up_2 = LeakyReLU()( conv_up_2 )
    conv_up_2 = Conv2DTranspose( 16 , kernel_size=( 5 , 5 ) , strides=1 )( conv_up_2 )
    conv_up_2 = LeakyReLU()( conv_up_2 )

    concat_3 = Concatenate()( [ conv_up_2 , conv1 ] )
    conv_up_1 = Conv2DTranspose( 16 , kernel_size=( 3 , 3 ) , strides=1 )( concat_3 )
    conv_up_1 = LeakyReLU()( conv_up_1 )
    conv_up_1 = Conv2DTranspose( 8 , kernel_size=( 3 , 3 ) , strides=1 )( conv_up_1 )
    conv_up_1 = LeakyReLU()( conv_up_1 )
    x = Conv2DTranspose( 2 , kernel_size=( 5 , 5 ) , strides=1 , activation='relu')( conv_up_1 )


    decoder_output = x

    return [decoder_input, conv3, conv2, conv1], decoder_output, Model(inputs = [decoder_input, conv3, conv2, conv1], outputs = decoder_output)

In [None]:
vae_decoder_input, vae_decoder_output, decoder = build_decoder(input_dim = Z_DIM,
                                        shape_before_flattening = vae_shape_before_flattening
                                        )
decoder.summary()

In [None]:
lr = 0.00005
def r_accuracy(img_original, img_reconstructed):
    mse = tf.reduce_mean((img_original - img_reconstructed) ** 2)
    pixel_max = 1.0
    psnr = 20 * tf.math.log(pixel_max / tf.math.sqrt(mse))/tf.math.log(10.0)
    return psnr

In [None]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
        self.r_accuracy_tracker = keras.metrics.Mean(name="r_accuracy")
        self.r_accuracy = r_accuracy


    def call(self,x):
        z_mean, z_log_var, z, conv3, conv2, conv1 = self.encoder(x)
        reconstruction = self.decoder([z, conv3, conv2, conv1])
        return reconstruction


    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
            self.r_accuracy_tracker,
        ]

    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z, conv3, conv2, conv1 = self.encoder(x)
            reconstruction = self.decoder([z, conv3, conv2, conv1])
            reconstruction_loss = tf.reduce_mean(tf.math.square(y - reconstruction), axis=[1, 2, 3])
            kl_loss = -0.5 * tf.reduce_sum(1 + z_log_var - tf.math.square(z_mean) - tf.math.exp(z_log_var), axis = 1)
            kl_loss = tf.reduce_mean(kl_loss)
            #coorelation_loss = corr_loss(z)

            total_loss = 10000*reconstruction_loss + kl_loss
            r_accuracy = self.r_accuracy(y, reconstruction)

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        self.r_accuracy_tracker.update_state(r_accuracy)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
            "r_accuracy": self.r_accuracy_tracker.result(),
        }