Basis: https://colab.research.google.com/github/jrzaurin/infinitoml/blob/master/_notebooks/2020-05-15-mult-vae.ipynb

In [None]:
import tensorflow as tf
import tensorflow.keras.layers as kl
from functools import partial

In [None]:
class MultVAE(tf.keras.Model):

    '''
    Simple MultVAE
    '''

    def __init__(self, latent_size, input_shape, encoder=None, decoder=None):
        super().__init__()
        self.latent_size = latent_size
        self.output_size = input_shape[-1]
        
        self.encoder = encoder if encoder is not None else MultVAE.get_default_encoder()
        self.decoder = decoder if decoder is not None else MultVAE.get_default_decoder(self.output_size)

        self.mu_layer = kl.Dense(self.latent_size, name='mu_layer')
        self.lv_layer = kl.Dense(self.latent_size, name='lv_layer')


    def call(self, x):
        ## Encode x into a learned mean/logvar analogue 
        ex = self.encoder(x)
        mu = self.mu_layer(ex)
        lv = self.lv_layer(ex)

        ## Reparameterize step to get latent z variate
        eps = tf.random.normal(tf.shape(mu))
        zu = mu + lv * tf.exp(0.5 * eps)

        ## Return decoding of generated variate in shape of input
        dz = tf.reshape(self.decoder(zu), x.shape)
        return dz, mu, lv


    @staticmethod
    def get_default_encoder():
        '''
        NOTE: Authors implement dropout (i.e. 50%, training-only) 
        and col-wise l2 normalization (preprocessing) at beginning of encoder.
        Consider adding regularization (author uses 0.01 per-layer L2)
        '''
        return tf.keras.Sequential(
            [kl.Dense(2**n, activation='tanh') for n in (8, 7, 5)], 
            name='encoder')


    @staticmethod
    def get_default_decoder(output_size):
        return tf.keras.Sequential(
            [kl.Dense(2**n, activation='tanh') for n in (5, 7, 8)] +
            [kl.Dense(output_size)], 
            name='decoder')
        
################################################################################

class MultELBOLoss(tf.keras.losses.Loss):

    def __init__(self, beta=0.5):
        super().__init__()
        self.beta = beta

    def call(self, y_true, y_pred, mu, lv):
        '''
        y_true  : groud truth
        y_pred  : model prediction
        mu      : distribution mean
        lv      : distribution log-variance
        beta    : annealing parameter for KL-RC ratio
        '''
        pe_rc = tf.nn.log_softmax(y_pred) * y_true   ## Per-entry Multinomial LL
        pe_kl = 1 + lv - tf.square(mu) - tf.exp(lv)  ## Per-entry KL-Divergence
        rc = -tf.reduce_mean(tf.reduce_sum(pe_rc, axis=-1))
        kl = -tf.reduce_mean(tf.reduce_sum(pe_kl, axis=-1)) / 2
        return rc + self.beta * kl

################################################################################
################################################################################

input_shape = (20000, 2000)
vae_model = MultVAE(latent_size=20, input_shape=input_shape)
vae_model.compile(
    optimizer   = tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss        = MultELBOLoss(beta=0.5),
    metrics     = [tf.keras.metrics.MSE]
)
vae_model.build(input_shape)
vae_model.summary()

Model: "mult_vae_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder (Sequential)        (20000, 32)               549280    
                                                                 
 decoder (Sequential)        (20000, 2000)             551920    
                                                                 
 mu_layer (Dense)            multiple                  660       
                                                                 
 lv_layer (Dense)            multiple                  660       
                                                                 
Total params: 1,102,520
Trainable params: 1,102,520
Non-trainable params: 0
_________________________________________________________________


In [None]:
vae_model.encoder.layers[0]
vae_model.encoder.layers[1]

<keras.layers.core.dense.Dense at 0x7f194b33aa50>

In [None]:
input_shape = (20000, 2000)
vae_model = MultVAE(latent_size=20, input_shape=input_shape)
vae_model.compile(
    optimizer   = tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss        = MultELBOLoss(beta=0.5),
    metrics     = [tf.keras.metrics.MSE]
)
vae_model.build(input_shape)
vae_model.summary()

In [None]:
inps = kl.Input((20000, 2000))

encoder_slice = tf.keras.Model(
    inputs  = [inps],
    outputs = vae_model.encoder.layers[1](inps),
    name    = 'encoder_slice'
)
encoder_slice.summary()

Model: "encoder_slice"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_8 (InputLayer)        [(None, 20000, 2000)]     0         
                                                                 
 dense_15 (Dense)            multiple                  32896     
                                                                 
Total params: 32,896
Trainable params: 32,896
Non-trainable params: 0
_________________________________________________________________


In [None]:
inps = kl.Input((20000, 2000))

output_model = tf.keras.Sequential(
    vae_model.encoder.layers[:-1] + 
    vae_model.decoder.layers[2:]
)

encoder_slice = tf.keras.Model(
    inputs  = [inps],
    outputs = output_model(inps),
    name    = 'encoder_slice'
)
encoder_slice.compile(
    optimizer   = tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss        = MultELBOLoss(beta=0.5),
    metrics     = [tf.keras.metrics.MSE]
)
encoder_slice.summary()

Model: "encoder_slice"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_12 (InputLayer)       [(None, 20000, 2000)]     0         
                                                                 
 sequential_2 (Sequential)   (None, 20000, 2000)       1092176   
                                                                 
Total params: 1,092,176
Trainable params: 1,092,176
Non-trainable params: 0
_________________________________________________________________


In [None]:
output_model.summary()

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_14 (Dense)            multiple                  512256    
                                                                 
 dense_15 (Dense)            multiple                  32896     
                                                                 
 dense_19 (Dense)            multiple                  33024     
                                                                 
 dense_20 (Dense)            multiple                  514000    
                                                                 
Total params: 1,092,176
Trainable params: 1,092,176
Non-trainable params: 0
_________________________________________________________________


In [None]:
encoder = tf.keras.Sequential([
    kl.Dense(units= 8, activation='leaky_relu'),
    kl.Dense(units= 4, activation='leaky_relu'),       
    kl.Dense(units= 2, activation='leaky_relu'),           
], name='model_encoder')

decoder = tf.keras.Sequential([
    kl.Dense(units= 4, activation='leaky_relu'),           
    kl.Dense(units= 8, activation='leaky_relu'),           
    kl.Dense(units=20, activation='sigmoid'),           
], name='model_decoder')

## Recall the syntax for model chaining
inps = kl.Input((10, 20))
autoencoder = tf.keras.Model(
    inputs  = [inps],
    outputs = decoder(encoder(inps)),
    name    = 'autoencoder'
)

autoencoder.summary()

Model: "autoencoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_9 (InputLayer)        [(None, 10, 20)]          0         
                                                                 
 model_encoder (Sequential)  (None, 10, 2)             214       
                                                                 
 model_decoder (Sequential)  (None, 10, 20)            232       
                                                                 
Total params: 446
Trainable params: 446
Non-trainable params: 0
_________________________________________________________________
