In [1]:

import time
import numpy as np
import scipy as sp

from typing import Any, List, Tuple, Dict


import tensorflow as tf
from tensorflow import keras
import tensorflow_probability as tfp
from tensorflow.keras import layers, Sequential


In [2]:

class Sampling(layers.Layer):
    """Implementation of the reparameterization trick.
    
    Sampling...
    
    
    Attributes:
        inputs: Tensor.
        
        
    Returns:
        A tensor of samples obtained from the latent space. 
    """
    

# ==============================================================================
# ==============================================================================

    def call(self, inputs: Tuple) -> tf.Tensor:
        z_mean, z_log_var = inputs
        batch   = tf.shape(z_mean)[0]
        dim     = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon
    

In [3]:

class Encoder(layers.Layer):
    """Encoder Model.
    
    The encoder contains the model that projects the input into the laten space,
    as well as regression model.
    """

# ==============================================================================
# ==============================================================================

    def __init__(self, original_dim, latent_dim=2, name="encoder", **kwargs):
        super(Encoder, self).__init__(name=name, **kwargs)
        
        self.original_dim  = original_dim
        self.latent_dim    = latent_dim
        self.sampling      = Sampling()
        self.dense_mean    = layers.Dense(latent_dim, activation='relu')
        self.dense_log_var = layers.Dense(latent_dim, activation='relu')
        self.reg_mean      = layers.Dense(1)
        self.reg_log_var   = layers.Dense(1)
        self.projection    = tf.keras.Sequential(
            [
                tf.keras.layers.Dense(512, activation='relu', input_shape=(self.original_dim,)),
                tf.keras.layers.Dense(256, activation='relu' ),
            ]
        )
        

# ==============================================================================
# ==============================================================================

    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        x         = self.projection(inputs)
        z_mean    = self.dense_mean(x)
        z_log_var = self.dense_log_var(x)
        z         = self.sampling((z_mean, z_log_var))

        return z_mean, z_log_var, z
    

In [4]:

class Decoder(layers.Layer):
    """Decoder Model.
    
    Converts z, the encoded FP vector, back into FP vector.
    
    """

# ==============================================================================
# ==============================================================================

    def __init__(self, original_dim, latent_dim=2, name="decoder", **kwargs):
        super(Decoder, self).__init__(name=name, **kwargs)
        
        self.original_dim = original_dim
        self.latent_dim   = latent_dim
        self.amplify = tf.keras.Sequential(
            [
                tf.keras.layers.Dense(256,  activation='relu' ),
                tf.keras.layers.Dense(512, activation='relu' ),
            ]
        )
        self.dense_output = layers.Dense(self.original_dim, activation="sigmoid")

        
# ==============================================================================
# ==============================================================================

    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        x = self.amplify(inputs)
        return self.dense_output(x)
    

In [6]:
class MLP_VAE(keras.Model):
    """Combines the encoder and decoder into an end-to-end model for training."""

# ==============================================================================
# ==============================================================================
    
    def __init__(
        self,
        original_dim,
        latent_dim=2,
        name="vae",
        **kwargs
        ):
        super(MLP_VAE, self).__init__(name=name, **kwargs)
        self.original_dim = original_dim
        self.encoding = Encoder(original_dim, latent_dim=latent_dim)
        self.decoding = Decoder(original_dim)
        
        # Loss trackers
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
        self.regression_loss_tracker = keras.metrics.MeanSquaredError(name="regression_loss")
    
    
# ==============================================================================
# ==============================================================================

    @property
    def metrics(self):
        return [self.total_loss_tracker, 
                self.reconstruction_loss_tracker, 
                self.kl_loss_tracker,
                self.regression_loss_tracker,
                ]
        
        
# ==============================================================================
# ==============================================================================
  
    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        z_mean, z_log_var, z = self.encoding(inputs)
        reconstructed = self.decoding(z)        
        return reconstructed
        
    
# ==============================================================================
# ==============================================================================

    def regression(self, inputs: tf.Tensor) -> tf.Tensor:
        x           = self.encoding.projection(inputs)
        reg_mean    = self.encoding.reg_mean(x)
        reg_log_var = self.encoding.reg_log_var(x)
        reg         = self.encoding.sampling((reg_mean, reg_log_var)) 
         
        return reg
    

# ==============================================================================
# ==============================================================================

    @tf.function
    def train_step(self, data) -> Dict[str, float]:
        """Custome train step function.
        
        This function contains the LOSS functions and their updates needed to 
        train the model.
        
        Returns:
            A dictionary containing the updated values of all LOSS functions.
            Example:
                {loss: 10, reconstruction_loss: 1, 
                kl_loss: 5, regression_loss: 4}
        
        """
        
        
        x_train, y_train = data
        
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoding(x_train)
            y_pred = self.regression(x_train)
            
            
            # Reconstruction Loss 
            reconstructed       = self.decoding(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_mean(
                    keras.losses.binary_crossentropy(x_train, reconstructed, from_logits=True)
                            )
                        )
            
            
            # KL Loss
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            
            
            # Regression Loss
            regression_loss = keras.losses.MSE(y_pred, y_train)
            regression_loss = tf.reduce_mean(regression_loss)
            
            
            # Total Loss
            total_loss = reconstruction_loss + kl_loss + regression_loss
            
            
        grads = tape.gradient(total_loss, self.trainable_weights)
        
        # Update the Loss Trackers
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        self.regression_loss_tracker.update_state(y_pred, y_train)
        
        
        return {
            "VAE Loss":            self.total_loss_tracker.result(),
            "Reconstruction Loss": self.reconstruction_loss_tracker.result(),
            "KL Loss":             self.kl_loss_tracker.result(),
            "Regression Loss":     self.regression_loss_tracker.result(),
            }
    