In [None]:
import tensorflow as tf
import tensorflow_datasets as tf_ds
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer
from tensorflow.keras.layers import (Reshape, Conv2DTranspose, Add, Conv2D, MaxPool2D, Dense,
                                     Flatten, InputLayer, BatchNormalization, Input, )
from tensorflow.keras.optimizers import Adam

In [None]:
data=tf_ds.load(name="fashion_mnist",as_supervised=True)

In [None]:
train_data=data["train"]
test_data=data["test"]
def rescaling(img,lab):
  return tf.cast(img,dtype=tf.float32)/255.
train_data=train_data.map(rescaling).shuffle(1024).batch(64).prefetch(tf.data.AUTOTUNE)
test_data=test_data.map(rescaling).shuffle(1024).batch(64).prefetch(tf.data.AUTOTUNE)

In [None]:
class Sampling(Layer):
  def call(self, inputs):
    mean, log_var = inputs
    return mean + tf.math.exp(0.5*log_var)*tf.random.normal(shape = (tf.shape(mean)[0], tf.shape(mean)[1]))

In [None]:
LATENT_DIM=2
encoder_inputs = Input(shape=(28,28,1))

x = Conv2D(32, 3, activation='relu', strides=2, padding='same')(encoder_inputs)
x = Conv2D(64, 3, activation='relu', strides=2, padding='same')(x)

x = Flatten()(x)
x = Dense(16, activation='relu')(x)

mean = Dense(LATENT_DIM,)(x)
log_var = Dense(LATENT_DIM,)(x)

z = Sampling()([mean,log_var])

encoder_model = Model(encoder_inputs,[z,mean,log_var], name='encoder')
encoder_model.summary()

Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 28, 28, 1)]  0           []                               
                                                                                                  
 conv2d_2 (Conv2D)              (None, 14, 14, 32)   320         ['input_3[0][0]']                
                                                                                                  
 conv2d_3 (Conv2D)              (None, 7, 7, 64)     18496       ['conv2d_2[0][0]']               
                                                                                                  
 flatten_1 (Flatten)            (None, 3136)         0           ['conv2d_3[0][0]']               
                                                                                            

In [None]:
latent_inputs = Input(shape=(LATENT_DIM,))


x = Dense(7*7*64, activation='relu')(latent_inputs)
x = Reshape((7,7,64))(x)

x = Conv2DTranspose(64, 3, activation='relu', strides=2, padding='same')(x)
x = Conv2DTranspose(32, 3, activation='relu', strides=2, padding='same')(x)

decoder_output = Conv2DTranspose(1, 3, activation='sigmoid', padding='same')(x)
decoder_model = Model(latent_inputs,decoder_output,name='decoder')
decoder_model.summary()

Model: "decoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_4 (InputLayer)        [(None, 2)]               0         
                                                                 
 dense_7 (Dense)             (None, 3136)              9408      
                                                                 
 reshape_1 (Reshape)         (None, 7, 7, 64)          0         
                                                                 
 conv2d_transpose_3 (Conv2DT  (None, 14, 14, 64)       36928     
 ranspose)                                                       
                                                                 
 conv2d_transpose_4 (Conv2DT  (None, 28, 28, 32)       18464     
 ranspose)                                                       
                                                                 
 conv2d_transpose_5 (Conv2DT  (None, 28, 28, 1)        289 

In [None]:
vae_input = Input(shape=(28,28,1), name="vae_input")
z,_,_ = encoder_model(vae_input)
output = decoder_model(z)
vae = Model(vae_input, output, name="vae")
vae.summary()

Model: "vae"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 vae_input (InputLayer)      [(None, 28, 28, 1)]       0         
                                                                 
 encoder (Functional)        [(None, 2),               69076     
                              (None, 2),                         
                              (None, 2)]                         
                                                                 
 decoder (Functional)        (None, 28, 28, 1)         65089     
                                                                 
Total params: 134,165
Trainable params: 134,165
Non-trainable params: 0
_________________________________________________________________


In [None]:
OPTIMIZER = tf.keras.optimizers.SGD(learning_rate = 0.01)
EPOCHS = 20

In [None]:
def custom_loss(y_true,y_pred,mean,log_var):

  loss_rec = tf.reduce_mean(tf.reduce_sum(tf.keras.losses.binary_crossentropy(y_true,y_pred), axis = (1,2)))

  loss_reg = -0.5 * (1 + log_var - tf.square(mean) - tf.exp(log_var))

  return loss_rec+tf.reduce_mean(tf.reduce_sum(loss_reg, axis=1))

In [None]:
@tf.function
def training_block(x_batch):
  with tf.GradientTape() as recorder:
    z,mean,log_var = encoder_model(x_batch)
    y_pred = decoder_model(z)
    y_true = x_batch
    loss = custom_loss(y_true,y_pred, mean, log_var)

  partial_derivatives = recorder.gradient(loss,vae.trainable_weights)
  OPTIMIZER.apply_gradients(zip(partial_derivatives, vae.trainable_weights))
  return loss

In [None]:
import os
def custom_model_ckpt(epoch,ckpt_path,model=vae,save_weights_only=True):
  os.makedirs(ckpt_path,exist_ok=True)
  fp=os.path.join(f'{ckpt_path}',f'epoch_{epoch}.h5')
  model.save_weights(fp)

In [None]:
def neuralearn(epochs):
  total_loss={}
  for epoch in range(1,epochs+1):
    print('Training starts for epoch number {}'.format(epoch))

    for step, x_batch in enumerate(train_data):
      loss = training_block(x_batch)
    custom_model_ckpt(epoch,"ckpt")
    total_loss.update({epoch:loss})
    print('Training Loss is: ', loss)
  print('Training Complete!!!')
  return total_loss

In [None]:
total_loss=neuralearn(200)

Training starts for epoch number 1
Training Loss is:  tf.Tensor(469.8823, shape=(), dtype=float32)
Training starts for epoch number 2
Training Loss is:  tf.Tensor(493.89154, shape=(), dtype=float32)
Training starts for epoch number 3
Training Loss is:  tf.Tensor(473.3352, shape=(), dtype=float32)
Training starts for epoch number 4
Training Loss is:  tf.Tensor(465.63242, shape=(), dtype=float32)
Training starts for epoch number 5
Training Loss is:  tf.Tensor(451.50946, shape=(), dtype=float32)
Training starts for epoch number 6
Training Loss is:  tf.Tensor(477.50662, shape=(), dtype=float32)
Training starts for epoch number 7
Training Loss is:  tf.Tensor(473.8729, shape=(), dtype=float32)
Training starts for epoch number 8
Training Loss is:  tf.Tensor(496.61157, shape=(), dtype=float32)
Training starts for epoch number 9
Training Loss is:  tf.Tensor(478.51462, shape=(), dtype=float32)
Training starts for epoch number 10
Training Loss is:  tf.Tensor(474.89404, shape=(), dtype=float32)
Tr

In [None]:
plt.imshow(vae.layers[-1].predict([[1,1]])[0])

In [None]:
y_out=vae.layers[1].predict(test_data)



In [None]:
all_img=vae.layers[-1].predict(y_out[0])



In [None]:
plt.imshow(all_img[10])