In [1]:
import sys
sys.path.append("..")
from VAE.vae import *

In [2]:
config_space = ConfigurationSpace(
                {'input_dropout': 0.1, 'intermediate_activation': "relu", 'intermediate_dimension': 10,
                'intermediate_layers': 2, 'latent_dimension': 1, 'learning_rate': 0.001,
                'original_dim': 200, 'solver': 'nadam'}
            )
config = config_space.get_default_configuration()

In [16]:
@keras.saving.register_keras_serializable(package="FIA_VAE")
class Sampling(layers.Layer):
        """
        Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.
        """
        def call(self, inputs):
            z_mean, z_log_var = inputs
            z_mean_shape = ops.shape(z_mean)
            batch   = z_mean_shape[0]
            dim     = z_mean_shape[1]
            epsilon = keras.random.normal(shape=(batch,dim))
            return ops.multiply(ops.add(z_mean, ops.exp(0.5 * z_log_var)), epsilon)


@keras.saving.register_keras_serializable(package="FIA_VAE")
def kl_reconstruction_loss(y_true, y_pred, sigma, mu):
    """
    Loss function for Kullback-Leibler + Reconstruction loss

    Args:
        true: True values
        pred: Predicted values
    Returns:
        Loss = Kullback-Leibler + Reconstruction loss
    """
    reconstruction_loss = losses.mean_absolute_error(y_true, y_pred)
    kl_loss = -0.5 * ops.sum( 1.0 + sigma - ops.square(mu) - ops.exp(sigma) )
    loss = reconstruction_loss + kl_loss
    
    return {"reconstruction_loss": reconstruction_loss, "kl_loss": kl_loss, "loss": loss}

@keras.saving.register_keras_serializable(package="FIA_VAE")
class FIA_VAE(Model):
    """
    A variational autoencoder for flow injection analysis
    """
    def __init__(self, config:Union[Configuration, dict]):
        super().__init__()
        self.config             = config
        intermediate_dims       = [i for i in range(config["intermediate_layers"]) 
                                    if config["intermediate_dimension"] // 2**i > config["latent_dimension"]]
        activation_function     = get_activation_function( config["intermediate_activation"] )

        # Encoder (with sucessive halfing of intermediate dimension)
        self.dropout            = Dropout( config["input_dropout"] , name="dropout")        
        self.intermediate_enc   = Sequential ( [ Input(shape=(config["original_dim"],), name='encoder_input') ] +
                                               [ Dense( config["intermediate_dimension"] // 2**i,
                                                        activation=activation_function ) 
                                                for i in intermediate_dims] , name="encoder_intermediate")

        self.mu_encoder         = Dense( config["latent_dimension"], name='latent_mu' )
        self.sigma_encoder      = Dense( config["latent_dimension"], name='latent_sigma' )
        self.z_encoder          = Sampling(name="latent_reparametrization") 

        # Decoder
        self.decoder            = Sequential( [ Input(shape=(config["latent_dimension"], ), name='decoder_input') ] +
                                              [ Dense( config["intermediate_dimension"] // 2**i,
                                                       activation=activation_function )
                                               for i in reversed(intermediate_dims) ] +
                                              [ Dense(config["original_dim"], activation="relu") ] , name="Decoder")

        # Loss trackers
        self.reconstruction_loss    = metrics.Mean(name="reconstruction_loss")
        self.kl_loss                = metrics.Mean(name="kl_loss")
        self.loss_tracker           = metrics.Mean(name="loss")

        # Define optimizer
        self.optimizer = get_solver( config["solver"] )( config["learning_rate"] )

        # Compile VAE
        self.compile(optimizer=self.optimizer, loss=kl_reconstruction_loss)

    @property
    def metrics(self):
        return [self.loss_tracker, self.reconstruction_loss, self.kl_loss]
    
    def get_config(self):
        return {"config": dict(self.config)}

    def call(self, data, training=False):
        x = self.dropout(data, training=training)
        return self.decode(self.encode(x))

    def encode(self, data):
        x = self.intermediate_enc(data)
        self.mu = self.mu_encoder(x)
        self.sigma = self.sigma_encoder(x)
        self.z = self.z_encoder( [self.mu, self.sigma] )
        return self.z
    
    def encode_mu(self, data):
        x = self.intermediate_enc(data)
        return self.mu(x)
    
    def decode(self, x):
        return self.decoder(x)
    
    def train_step(self, data):
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute our own loss
            loss = kl_reconstruction_loss(y, y_pred, self.sigma, self.mu)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        self.reconstruction_loss.update_state( loss["reconstruction_loss"] )
        self.kl_loss.update_state( loss["kl_loss"] )
        self.loss_tracker.update_state( loss["loss"] )
        return loss
    
    def test_step(self, data):
        x, y = data
        y_pred = self(x, training=False)
        loss = kl_reconstruction_loss(y, y_pred, self.sigma, self.mu)
        self.reconstruction_loss.update_state( loss["reconstruction_loss"] )
        self.kl_loss.update_state( loss["kl_loss"] )
        self.loss_tracker.update_state( loss["loss"] )
        return loss


In [17]:
model = FIA_VAE(config)

In [18]:
model.summary()

## Saving

In [19]:
keras.saving.save_model(model, "../../runs/VAE/training/test.keras")

  saving_lib.save_model(model, filepath)


In [20]:
model.save_weights("../../runs/VAE/training/test.weights.h5")

## Loading

In [21]:
model2 = keras.saving.load_model("../../runs/VAE/training/test.keras")

In [22]:
model2.load_weights("../../runs/VAE/training/test.weights.h5")

## Testing

In [23]:
X = np.random.normal(0.0, 1.0, size=(5, 200))

In [24]:
model2.fit(X, X, epochs=10)

Epoch 1/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - kl_loss: 8.5074 - loss: 9.3195 - reconstruction_loss: 0.8122
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - kl_loss: 7.5106 - loss: 8.3222 - reconstruction_loss: 0.8116
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - kl_loss: 9.0302 - loss: 9.8400 - reconstruction_loss: 0.8098
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - kl_loss: 5.4824 - loss: 6.2958 - reconstruction_loss: 0.8134
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - kl_loss: 5.0127 - loss: 5.8234 - reconstruction_loss: 0.8108
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - kl_loss: 4.4637 - loss: 5.2724 - reconstruction_loss: 0.8087
Epoch 7/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - kl_loss: 3.3990 - loss: 4.2082 - reconstr

<keras.src.callbacks.history.History at 0x7f724040d750>

In [26]:
loss, kl_loss, recon_loss = model2.evaluate(X, X)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - kl_loss: 1.8405 - loss: 2.6484 - reconstruction_loss: 0.8078


In [14]:
model2.predict(X)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step


array([[0.00000000e+00, 3.55275522e-04, 0.00000000e+00, 0.00000000e+00,
        2.10498832e-03, 0.00000000e+00, 2.08790787e-03, 0.00000000e+00,
        1.23358611e-03, 0.00000000e+00, 1.41797820e-03, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 2.00977339e-03, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 6.97784999e-04,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 7.31553533e-04,
        2.06541247e-03, 0.00000000e+00, 1.94932602e-03, 0.00000000e+00,
        1.30252074e-03, 1.66016701e-03, 0.00000000e+00, 0.00000000e+00,
        1.36951555e-03, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        1.93478819e-03, 0.00000000e+00, 2.02537701e-03, 2.06639245e-03,
        5.01542818e-04, 1.57206971e-03, 0.00000000e+00, 1.85412681e-03,
        2.47976277e-06, 0.00000000e+00, 1.42496079e-03, 2.088987