In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras import backend as K
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [None]:
from tensorflow.keras.datasets import fashion_mnist

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

In [None]:
class Sampling(tf.keras.models.Model):
  def __call__(self, inputs):
    z_mean, z_log_var = inputs
    batch = tf.shape(z_mean)[0]
    dim = tf.shape(z_mean)[1]
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [None]:
encoder_input = layers.Input(shape=(IMG_SIZE, IMG_SIZE, IMG_DIM), name="encoder_input")
x = layers.Conv2D(32, (3, 3), strides=2, activation="relu", padding="same")(encoder_input)
x = layers.Conv2D(64, (3, 3), strides=2, activation="relu", padding="same")(x)
x = layers.Conv2D(128, (3, 3), strides=2, activation="relu", padding="same")(x)
shape_before_flatten = K.int_shape(x)[1:]

x = layers.Flatten()(x)
z_mean = layers.Dense(EMBEDDING_DIM, name="z_mean")(x)
z_log_var = layers.Dense(EMBEDDING_DIM, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])

encoder = models.Model(encoder_input, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

In [None]:
decoder_input = layers.Input(shape=(EMBEDDING_DIM,), name="decoder_input")
x = layers.Dense(np.prod(shape_before_flatten))(decoder_input)
x = layers.Reshape(shape_before_flatten)(x)
x = layers.Conv2DTranspose(128, (3, 3), strides=2, activation="relu", padding="same")(x)
x = layers.Conv2DTranspose(64, (3, 3), strides=2, activation="relu", padding="same")(x)
x = layers.Conv2DTranspose(32, (3, 3), strides=2, activation="relu", padding="same")(x)
decoder_output = layers.Conv2D(IMG_DIM, (3, 3),strides=1,activation="sigmoid",padding="same",name="decoder_output",)(x)

decoder = models.Model(decoder_input, decoder_output)
decoder.summary()



$$KLloss = -\frac{1}{2} \sum(1 + \log{(\sigma^2)} - \mu^2 - \sigma^2)$$

* here we are just backpropgating kl_loss with reconstruction_loss

* now, total_loss = reconstruction_loss + kl_loss

* reconstruction_loss = calculated_by binary crossentropy(L())

$$ReconstructionLoss = mean( 500 * L(X, Decoder(X)) )$$


$$-y * \log(p) - (1-y)\log(1-(p))$$

In [None]:
class VAE(tf.keras.models.Model):
  def __init__(self, encoder, decoder):
    super(VAE, self).__init__()
    self.encoder = encoder
    self.decoder = decoder

    self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
    self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name="reconstruction_loss")
    self.kl_loss_tracker =  tf.keras.metrics.Mean(name="kl_loss")

  @property
  def metrics(self):
    return [
        self.total_loss_tracker,
        self.reconstruction_loss_tracker,
        self.kl_loss_tracker
    ]

  def __call__(self, inputs, training=None):

    z_mean, z_log_var, z = encoder(inputs, training=training)
    reconstruction = decoder(z, training=training)
    return z_mean, z_log_var, reconstruction

  def train_step(self, data):
    with tf.GradientTape() as tape:
      z_mean, z_log_var, reconstruction = self(data)
      reconstruction_loss = tf.reduce_mean(500 *
                                           tf.keras.losses.binary_crossentropy(data, reconstruction, axis=(1,2,3)))  #(1,2,3) => along height, width, channel of image
      kl_loss = tf.reduce_mean(
          tf.reduce_sum(
              -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)),
              axis = 1
          )
      )
      total_loss = reconstruction_loss + kl_loss

      grads = tape.gradient(total_loss, self.trainable_weights)
      self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
      self.total_loss_tracker.update_state(total_loss)
      self.reconstruction_loss_tracker.update_state(reconstruction_loss)
      self.kl_loss_tracker.update_state(kl_loss)
      return {m.name:m.result() for m in self.metrics}


In [None]:
vae = VAE(encoder, decoder)
vae.compile(optimizer='adam')

In [None]:
vae.fit(
    x_train,
    epochs = 20
)

In [None]:
x_sample = np.random.normal(size=(1, 2))

generate = decoder(x_sample)

plt.imshow(generate[0])
plt.show()