In [7]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import pandas as pd
from sklearn import preprocessing

# Define sampling layer

In [8]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z"""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

# Define encoder and decoder

In [9]:
latent_dim = 3
inputlen = 9024
units = 1180

# Encoder
encoder_inputs = keras.Input(shape=(inputlen,))
x = layers.Dense(units=units, activation='relu')(encoder_inputs)
x = layers.Dense(units=units//2, activation='relu')(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

# Decoder
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(units=units//2, activation='relu')(latent_inputs)
x = layers.Dense(units=units, activation='relu')(x)
decoder_outputs = layers.Dense(inputlen, activation='linear')(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 9024)]       0           []                               
                                                                                                  
 dense_5 (Dense)                (None, 1180)         10649500    ['input_3[0][0]']                
                                                                                                  
 dense_6 (Dense)                (None, 590)          696790      ['dense_5[0][0]']                
                                                                                                  
 z_mean (Dense)                 (None, 3)            1773        ['dense_6[0][0]']                
                                                                                            

# Define VAE

In [10]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            #print("reconstruction = " + str(reconstruction))
            #print("z_mean = " + str(z_mean)+"\n")
            #print("z_log_var = " + str(z_log_var)+"\n")
            #print("z = " + str(z)+"\n")
            
            reconstruction_loss = keras.losses.mean_squared_error(data, reconstruction)
            
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            #print("First kl_loss = "+str(kl_loss)+"\n")
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss,axis=1)) 
            #mprint("Second kl_loss = "+str(kl_loss)+"\n")
            total_loss = reconstruction_loss + kl_loss
            
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

# Load dataset

In [11]:
df = pd.read_csv("/home/vmh/vmhdocs/Research/Inria/Anl/MetaGenAutoencoder/Data/KO_metaG.norm.txt",sep="\t")
df = df.iloc[:,2-len(df.columns):]



# Load VAE

In [12]:
vae = VAE(encoder, decoder)
vae.load_weights('./vae_checkpoint/my_checkpoint')

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7f82dfe47970>

# Generate new data

In [13]:
n = 5
scale = 1.0
grid_x = np.linspace(-scale, scale, n)
grid_y = np.linspace(-scale, scale, n)
grid_z = np.linspace(-scale, scale, n)

for i, xi in enumerate(grid_x):
    for j, yi in enumerate(grid_y):
        for k, zi in enumerate(grid_z): 
            z_sample = np.array([[xi, yi, zi]])
            decoded = vae.decoder.predict(z_sample)
            #print(z_sample)
            print(decoded)
            print(decoded.shape)
            #print("")

[[ 0.00670468  0.74539596  0.00616014 ... -0.00155882  0.00343195
   0.00311395]]
(1, 9024)
[[6.5008122e-03 7.4407846e-01 5.3775464e-03 ... 6.7161629e-04
  3.1527621e-03 2.7447776e-03]]
(1, 9024)
[[7.0824204e-03 7.4509716e-01 6.3952338e-03 ... 6.7483494e-04
  2.2794446e-03 3.5178484e-03]]
(1, 9024)
[[ 7.1559548e-03  7.3966283e-01  7.4866391e-03 ... -2.3768225e-04
   2.1672761e-03  3.4320527e-03]]
(1, 9024)
[[7.1419328e-03 7.3231757e-01 7.2567277e-03 ... 3.6346901e-04
  2.3179883e-03 4.1658129e-03]]
(1, 9024)
[[ 0.00595354  0.7462716   0.00691528 ... -0.00081549  0.00269981
   0.00406226]]
(1, 9024)
[[0.00603282 0.7464688  0.00564288 ... 0.00088725 0.00215176 0.00364145]]
(1, 9024)
[[7.2771776e-03 7.4557346e-01 6.5052640e-03 ... 6.2601466e-04
  2.3852088e-03 4.2062690e-03]]
(1, 9024)
[[ 8.0481302e-03  7.4238944e-01  6.9480841e-03 ... -3.0865258e-04
   2.1969900e-03  4.0055779e-03]]
(1, 9024)
[[ 7.5908126e-03  7.3600543e-01  7.2135762e-03 ... -9.1124384e-05
   3.3288798e-03  3.5485337e-0