In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer
from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Reshape, Conv2DTranspose
from tensorflow.keras.optimizers import Adam

In [None]:
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

In [None]:
dataset = tf.data.Dataset.from_tensor_slices(mnist_digits)

In [None]:
len(dataset)

In [None]:
BATCH_SIZE = 128
LATENT_DIM = 2

In [None]:
train_dataset = (
                  dataset
                 .shuffle(buffer_size=1024, reshuffle_each_iteration=True)
                 .batch(BATCH_SIZE)
                 .prefetch(tf.data.AUTOTUNE)
                 )

In [None]:
train_dataset

In [None]:
class Sampling(Layer):
  def call(self, inputs):
    mean, log_var = inputs
    return mean + tf.math.exp(0.5*log_var)*tf.random.normal(shape = (tf.shape(mean)[0], tf.shape(mean)[1]))

In [None]:
encoder_inputs = Input(shape=(28,28,1))

x = Conv2D(32, 3, activation='relu', strides=2, padding='same')(encoder_inputs)
x = Conv2D(64, 3, activation='relu', strides=2, padding='same')(x)

x = Flatten()(x)
x = Dense(16, activation='relu')(x)

mean = Dense(LATENT_DIM,)(x)
log_var = Dense(LATENT_DIM,)(x)

z = Sampling()([mean,log_var])

encoder_model = Model(encoder_inputs,[z,mean,log_var], name='encoder')
encoder_model.summary()

In [None]:
latent_inputs = Input(shape=(LATENT_DIM,))


x = Dense(7*7*64, activation='relu')(latent_inputs)
x = Reshape((7,7,64))(x)

x = Conv2DTranspose(64, 3, activation='relu', strides=2, padding='same')(x)
x = Conv2DTranspose(32, 3, activation='relu', strides=2, padding='same')(x)

decoder_output = Conv2DTranspose(1, 3, activation='sigmoid', padding='same')(x)
decoder_model = Model(latent_inputs,decoder_output,name='decoder')
decoder_model.summary()

In [None]:
vae_input = Input(shape=(28,28,1), name="vae_input")
z,_,_ = encoder_model(vae_input)
output = decoder_model(z)
vae = Model(vae_input, output, name="vae")
vae.summary()

In [None]:
OPTIMIZER = Adam(learning_rate=1e-3)
EPOCH = 30

In [None]:
def custom_loss(y_true,y_pred,mean,log_var):

  loss_rec = tf.reduce_mean(tf.reduce_sum(tf.keras.losses.binary_crossentropy(y_true,y_pred), axis = (1,2)))

  loss_reg = -0.5 * (1 + log_var - tf.square(mean) - tf.exp(log_var))

  return loss_rec+tf.reduce_mean(tf.reduce_sum(loss_reg, axis=1))

In [None]:
@tf.function
def training_block(x_batch):
  with tf.GradientTape() as recorder:
    z,mean,log_var = encoder_model(x_batch)
    y_pred = decoder_model(z)
    y_true = x_batch
    loss = custom_loss(y_true,y_pred, mean, log_var)

  partial_derivatives = recorder.gradient(loss,vae.trainable_weights)
  OPTIMIZER.apply_gradients(zip(partial_derivatives, vae.trainable_weights))
  return loss

In [None]:
def model_learn(epochs):
  for epoch in range(1,epochs+1):
    print('Training starts for epoch number {}'.format(epoch))

    for step, x_batch in enumerate(train_dataset):
      loss = training_block(x_batch)
    print('Training Loss is: ', loss)
  print('Training Complete!!!')

In [None]:
model_learn(EPOCH)

In [None]:
class VAE(tf.keras.Model):
  def __init__(self, encoder_model, decoder_model):
    super(VAE, self).__init__()
    self.encoder = encoder_model
    self.decoder = decoder_model
    self.loss_tracker = tf.keras.metrics.Mean(name="loss")

  @property
  def metrics(self):
    return [self.loss_tracker]

  def train_step(self,x_batch): # creating fit method for custom models
    with tf.GradientTape() as recorder:
      z,mean,log_var = encoder_model(x_batch)
      y_pred = decoder_model(z)
      y_true = x_batch
      loss = custom_loss(y_true,y_pred, mean, log_var)

    partial_derivatives = recorder.gradient(loss,self.trainable_weights)
    OPTIMIZER.apply_gradients(zip(partial_derivatives, self.trainable_weights))

    self.loss_tracker.update_state(loss)

    return {"loss":self.loss_tracker.result()}

In [None]:
model = VAE(encoder_model, decoder_model)
model.compile(optimizer=OPTIMIZER)
model.fit(train_dataset, epochs=20, batch_size=128)

In [None]:
scale = 1
n = 16

In [None]:
grid_x = np.linspace(-scale,scale,n)
grid_y = np.linspace(-scale,scale,n)

In [None]:
grid_x, grid_y

In [None]:
plt.figure(figsize=(5,5))
k = 0

for i in grid_x:
  for j in grid_y:
    ax = plt.subplot(n,n, k+1)

    input = tf.constant([[i,j]])
    out = model.decoder.predict(input)[0][...,0]
    # out = vae.layers[2].predict(input)[0][...,0]
    plt.imshow(out, cmap='Greys_r')
    plt.axis('off')
    k+=1
plt.show()

In [None]:
(x_train, y_train), _  = tf.keras.datasets.mnist.load_data()
mnist_digits = np.expand_dims(x_train, -1).astype("float32") / 255

In [None]:
z,_,_ = vae.layers[1].predict(x_train)

plt.figure(figsize=(12,12))
plt.scatter(z[:,0],z[:,1], c=y_train)
plt.colorbar()
plt.show()