In [19]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
from tensorflow.keras import layers, models

In [11]:
BATCH_SIZE= 16
IMAGE_SIZE= (28,28,1)
NUM_CLASSES= 10
LATENT_DIM= 2

In [15]:
def map_image(image, label):
    image= tf.cast(image, tf.float32)
    image= image/255.0
    image= tf.reshape(image, shape= IMAGE_SIZE)
    
    return image

def get_dataset(map_func):
    dataset= tfds.load('mnist', as_supervised= True, split= 'train')
    dataset= dataset.map(map_func)
    
    return dataset

In [16]:
train_dataset= get_dataset(map_image)

In [17]:
builder = tfds.builder('mnist')
info = builder.info
print(info)

tfds.core.DatasetInfo(
    name='mnist',
    full_name='mnist/3.0.1',
    description="""
    The MNIST database of handwritten digits.
    """,
    homepage='http://yann.lecun.com/exdb/mnist/',
    data_path='C:\\Users\\ASUS TUF GAMING\\tensorflow_datasets\\mnist\\3.0.1',
    file_format=tfrecord,
    download_size=11.06 MiB,
    dataset_size=21.00 MiB,
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=uint8),
        'label': ClassLabel(shape=(), dtype=int64, num_classes=10),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=10000, num_shards=1>,
        'train': <SplitInfo num_examples=60000, num_shards=1>,
    },
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
      volume={2},
      year={2010}
 

In [18]:
class Sampling(tf.keras.layers.Layer): 
    def call(self, inputs):
        mu, sigma= inputs
        batch = tf.shape(mu)[0]
        dim= tf.shape(mu)[1]
        epsilon= tf.keras.backend.random_normal(shape=(batch, dim))
        
        return mu + tf.exp(0.5 * sigma) * epsilon

In [20]:
def encoder(inputs, latent_dim):
    x= layers.Conv2D(32, 3, 2, 'same', activation= 'relu', name= 'encoder_conv1')(inputs)
    x= layers.BatchNormalization(name= 'encoder_bn1')(x)
    x= layers.Conv2D(64, 3, 2, 'same', activation= 'relu', name= 'encoder_conv2')(x)
    
    batch_2 = layers.BatchNormalization(name= 'encoder_bn2')(x)
    
    x= layers.Flatten(name= '')
    x= layers.Dense(20, activation= 'relu', name= 'encoder_dense1')(x)
    x= layers.BatchNormalization(name= 'encoder_bn3')(x)
    
    mu= layers.Dense(latent_dim, name= 'latent_mu')(x)
    sigma= layers.Dense(latent_dim, name= 'latent_sigma')(x)
    
    return mu, sigma, batch_2.shape