In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
tf.compat.v1.disable_eager_execution()
# tf.config.experimental_run_functions_eagerly(True)

In [2]:
epoch = 500
lr = 0.0003
droprate = 0.2
batch_size = 128
kl_weight = 1

intr_dim = 128
latent_dim = 64

max_length = 50 #time steps
emb_dim = 300
vocab_size = 24888

In [3]:
####################################################################################

class Sampling(tf.keras.layers.Layer):
    def __init__(self):
        super(Sampling, self).__init__()
        self.supports_masking = True

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        epsilon = tf.random.normal([batch, latent_dim])
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

class custom_lstm(tf.keras.layers.Layer):
    def __init__(self, intr_dim, droprate,  **kwargs):
        self.bi_lstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(intr_dim, recurrent_dropout=droprate, 
                                                                          return_sequences=False), merge_mode='concat')
        self.drop_layer = tf.keras.layers.Dropout(droprate)
        super(custom_lstm, self).__init__(**kwargs)
    
    def call(self, inputs):
        h = self.bi_lstm(inputs)
        h = self.drop_layer(h)
        return h
    
    def compute_mask(self, inputs, mask=None):
        return mask
    
x = tf.keras.layers.Input(shape=(max_length,))
embed_layer = tf.keras.layers.Embedding(vocab_size, emb_dim, input_length=max_length, trainable=False, mask_zero=True)
encoder_layer = custom_lstm(intr_dim, droprate)

h = embed_layer(x)
h = encoder_layer(h)
z_mean = tf.keras.layers.Dense(latent_dim, name='z_mean')(h)
z_log_var = tf.keras.layers.Dense(latent_dim, name='z_log_var')(h)
z = Sampling()([z_mean, z_log_var])

####################################################################################

class custom_decoder(tf.keras.layers.Layer):
    def __init__(self, vocab_size, intr_dim, max_length, droprate, **kwargs):
        self.rpv = tf.keras.layers.RepeatVector(max_length)
        self.lstm_layer_1 = tf.keras.layers.LSTM(intr_dim, return_sequences=True, recurrent_dropout=droprate)
        self.droplayer_2 = tf.keras.layers.Dropout(droprate)
        self.lstm_layer_2 = tf.keras.layers.LSTM(intr_dim*2, return_sequences=True, recurrent_dropout=droprate)
        self.droplayer_3 = tf.keras.layers.Dropout(droprate)
        self.decoded_logits = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(vocab_size, activation='linear'))
        super(custom_decoder, self).__init__(**kwargs)
    
    def call(self, inputs):
        h = self.rpv(inputs)
        h = self.lstm_layer_1(h)
        h = self.droplayer_2(h)
        h = self.lstm_layer_2(h)
        h = self.droplayer_3(h)
        decoded = self.decoded_logits(h)
        return decoded
    
    def compute_mask(self, inputs, mask=None):
        return mask
    
decoder_layer = custom_decoder(vocab_size, intr_dim, max_length, droprate)
decoded_logits = decoder_layer(z)

####################################################################################

class ELBO_Layer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(ELBO_Layer, self).__init__(**kwargs)
    
    def call(self, inputs, mask=None):
        broadcast_float_mask = tf.cast(mask, "float32")
        labels = tf.cast(x, tf.int32)
        reconstruction_loss = tf.reduce_sum(tfa.seq2seq.sequence_loss(inputs, labels, 
                                                                      weights=broadcast_float_mask,
                                                                      average_across_timesteps=False,
                                                                      average_across_batch=False), axis=1)
        
        kl_loss = - 0.5 * tf.reduce_sum(1 + z_log_var - tf.math.square(z_mean) - tf.math.exp(z_log_var), axis=1)
        total_loss = tf.reduce_mean(reconstruction_loss + kl_weight * kl_loss)
        self.add_loss(total_loss, inputs=[x, inputs])
        return tf.ones_like(x)
    
    def compute_mask(self, inputs, mask=None):
        return mask
        
elbo_layer = ELBO_Layer()
fake_decoded_prob = elbo_layer(decoded_logits)

####################################################################################

def zero_loss(y_true, y_pred):
    return tf.zeros_like(y_pred)

def kl_loss(x, fake_decoded_prob):
    kl_loss = - 0.5 * tf.reduce_sum(1 + z_log_var - tf.math.square(z_mean) - tf.math.exp(z_log_var), axis=1)
    kl_loss = kl_weight * kl_loss
    return tf.reduce_mean(kl_loss)

vae = tf.keras.models.Model(x, fake_decoded_prob, name='VAE')
opt = tf.keras.optimizers.Adam(lr=lr)
vae.compile(optimizer=opt, loss=[zero_loss], metrics=[kl_loss])
vae.summary()

for i, l in enumerate(vae.layers):
    print(f'layer {i}: {l}', flush=True)
    print(f'has input mask: {l.input_mask}', flush=True)
    print(f'has output mask: {l.output_mask}', flush=True)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Model: "VAE"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 50)]         0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, 50, 300)      7466400     input_1[0][0]                    
__________________________________________________________________________________________________
custom_lstm (custom_lstm)       (None, 256)          439296      embedding[0][0]                  
__________________________________________________________________________________________________
z_mean (Dense)                  (None, 64)           16448       custom_lstm[0][0]                
_______________

In [4]:
# logits = tf.random.normal([batch_size, max_length, vocab_size], dtype=tf.float32)
# targets = tf.random.uniform([batch_size, max_length], minval=0, maxval=vocab_size, dtype=tf.int32)
# proj_w = tf.random.normal([vocab_size, vocab_size], dtype=tf.float32)
# proj_b = tf.zeros(vocab_size, dtype=tf.float32)

# print(logits.shape, targets.shape, proj_w.shape, proj_b.shape)

In [5]:
# build a model to project sentences on the latent space
encoder = tf.keras.models.Model(x, z)

for i, l in enumerate(encoder.layers):
    print(f'layer {i}: {l}', flush=True)
    print(f'has input mask: {l.input_mask}', flush=True)
    print(f'has output mask: {l.output_mask}', flush=True)

layer 0: <tensorflow.python.keras.engine.input_layer.InputLayer object at 0x2b57d45d2a10>
has input mask: None
has output mask: None
layer 1: <tensorflow.python.keras.layers.embeddings.Embedding object at 0x2b57d4649d90>
has input mask: None
has output mask: Tensor("embedding/NotEqual:0", shape=(None, 50), dtype=bool)
layer 2: <__main__.custom_lstm object at 0x2b57d4649f50>
has input mask: Tensor("embedding/NotEqual:0", shape=(None, 50), dtype=bool)
has output mask: Tensor("embedding/NotEqual:0", shape=(None, 50), dtype=bool)
layer 3: <tensorflow.python.keras.layers.core.Dense object at 0x2b57d472f390>
has input mask: Tensor("embedding/NotEqual:0", shape=(None, 50), dtype=bool)
has output mask: Tensor("embedding/NotEqual:0", shape=(None, 50), dtype=bool)
layer 4: <tensorflow.python.keras.layers.core.Dense object at 0x2b57d46e5f50>
has input mask: Tensor("embedding/NotEqual:0", shape=(None, 50), dtype=bool)
has output mask: Tensor("embedding/NotEqual:0", shape=(None, 50), dtype=bool)
la

In [6]:
# build a generator that can sample sentences from the learned distribution
ins = tf.keras.layers.Input(shape=(latent_dim,))
x_logits = decoder_layer(ins)
generator = tf.keras.models.Model(ins, x_logits)

for i, l in enumerate(generator.layers):
    print(f'layer {i}: {l}', flush=True)
    print(f'has input mask: {l.input_mask}', flush=True)
    print(f'has output mask: {l.output_mask}', flush=True)

layer 0: <tensorflow.python.keras.engine.input_layer.InputLayer object at 0x2b5805f57b90>
has input mask: None
has output mask: None
layer 1: <__main__.custom_decoder object at 0x2b58057e46d0>
has input mask: Tensor("embedding/NotEqual:0", shape=(None, 50), dtype=bool)
has output mask: Tensor("embedding/NotEqual:0", shape=(None, 50), dtype=bool)
