In [None]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
        self.r_accuracy_tracker = keras.metrics.Mean(name="r_accuracy")
        self.r_accuracy = r_accuracy


    def call(self,x):
        z_mean, z_log_var, z, conv3, conv2, conv1 = self.encoder(x)
        reconstruction = self.decoder([z, conv3, conv2, conv1])
        return reconstruction


    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
            self.r_accuracy_tracker,
        ]

    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z, conv3, conv2, conv1 = self.encoder(x)
            reconstruction = self.decoder([z, conv3, conv2, conv1])
            reconstruction_loss = tf.reduce_mean(tf.math.square(y - reconstruction), axis=[1, 2, 3])
            kl_loss = -0.5 * tf.reduce_sum(1 + z_log_var - tf.math.square(z_mean) - tf.math.exp(z_log_var), axis = 1)
            kl_loss = tf.reduce_mean(kl_loss)
            #coorelation_loss = corr_loss(z)

            total_loss = 10000*reconstruction_loss + kl_loss
            r_accuracy = self.r_accuracy(y, reconstruction)

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        self.r_accuracy_tracker.update_state(r_accuracy)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
            "r_accuracy": self.r_accuracy_tracker.result(),
        }
    
    

    
model1 = VAE(encoder,decoder)
model1.compile(optimizer=keras.optimizers.Adam(learning_rate = lr))



history = model1.fit(train_x, train_y, epochs=600, batch_size=16, verbose=1)