### Importing Libraries

In [1]:
import tensorflow as tf 
import keras 
import numpy as np
from keras.models import Model
from keras.layers import Layer
from keras.layers import Reshape, Conv2DTranspose, Add, Conv2D, MaxPool2D,Flatten, InputLayer, BatchNormalization, Input, Dense 
from keras.optimizers import Adam 

### Data Preparation

In [2]:
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis = 0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

In [3]:
dataset = tf.data.Dataset.from_tensor_slices(mnist_digits)

In [4]:
len(dataset)

70000

In [5]:
batch_size = 128
latent_dim = 2

In [6]:
train_dataset = (
    dataset.shuffle(buffer_size=1024, reshuffle_each_iteration=True)
    .batch(batch_size).prefetch(tf.data.AUTOTUNE)
)

In [7]:
train_dataset = train_dataset.batch(batch_size, drop_remainder=True)


### Modeling



#### Sampling

In [8]:
class Sampling(tf.keras.layers.Layer):
    def call(self, inputs):
        mean, log_var = inputs
        # Dynamically determine batch size based on inputs
        batch_size = tf.shape(mean)[0]
        latent_dim = tf.shape(mean)[1]
        epsilon = tf.random.normal(shape=(batch_size, latent_dim))
        return mean + tf.exp(0.5 * log_var) * epsilon


#### Encoder

In [9]:
encoder_inputs = Input(shape=(28,28,1))

x = Conv2D(32, 3, activation='relu', strides=2, padding='same')(encoder_inputs)
x = Conv2D(64, 3, activation='relu', strides=2, padding='same')(x)

x = Flatten()(x)
x = Dense(16, activation='relu')(x)

mean = Dense(latent_dim)(x)
log_var = Dense(latent_dim)(x)

z = Sampling()([mean, log_var])

encoder_models = Model(encoder_inputs, [z, mean, log_var], name = 'encoder')
encoder_models.summary()




#### Decoder

In [10]:
latent_inputs = Input(shape=(latent_dim,))

x = Dense(7*7*64, activation='relu')(latent_inputs)
x = Reshape((7,7,64))(x)

x = Conv2DTranspose(64, 3, activation='relu', strides=2, padding='same')(x)
x = Conv2DTranspose(32, 3, activation='relu', strides=2, padding='same')(x)

decoder_output = Conv2DTranspose(1, 3, activation='sigmoid', padding='same')(x)
decoder_model = Model(latent_inputs, decoder_output, name = 'decoder')
decoder_model.summary()

#### Overall VAE Model

In [11]:
vae_input = Input(shape=(28,28,1), name='vae_input')
z,_,_ = encoder_models(vae_input)
output = decoder_model(z)
vae = Model(vae_input, output, name = 'vae')
vae.summary()

### Training

In [12]:
optimizer = Adam(learning_rate=0.001)
epochs = 30

In [13]:
def custom_loss(y_true, y_pred, mean, log_var):
    loss_recons = tf.reduce_mean(tf.reduce_sum(keras.losses.binary_crossentropy(y_true, y_pred), axis = (1,2)))
    loss_reg = tf.reduce_mean(tf.reduce_sum(-0.5*(log_var + 1- tf.math.square(mean)- tf.math.exp(log_var)), axis = 1))

    return loss_recons + loss_reg

In [17]:
def training_block(x_batch):
    with tf.GradientTape() as recorder:
        z, mean, log_var = encoder_models(x_batch)
        y_pred = decoder_model(z)
        y_true = x_batch
        loss = custom_loss(y_true, y_pred, mean, log_var)
    
    partial_derivatives = recorder.gradient(loss, vae.trainable_weights)
    optimizer.apply_gradients(zip(partial_derivatives, vae.trainable_weights))
    return loss

In [18]:
def neuralearn(epochs):
    for epoch in range (1, epochs+1):
        print('training starts for epoch number {}'.format(epoch))
        
        for step, x_batch in enumerate(train_dataset):
            loss = training_block(x_batch)
        print('Training Loss is: ', loss)
    print('Training Complete !!! ')