#### Import relevant libraries

In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

#### Loading the data from tensorflow dataset along the information such as version, features, #samples of datasets

In [2]:
mnist_data, mnist_info = tfds.load(name='mnist', as_supervised=True, with_info=True)
mnist_train, mnist_test = mnist_data['train'], mnist_data['test']

#### Defining a function to standardize the dataset and scaling the dataset through the map function

In [3]:
def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255.
    return image, label

scaled_train_validation_data = mnist_train.map(scale)
scaled_test_data = mnist_test.map(scale)

#### From the mnist_info variable, extract the num of samples and store it

In [4]:
num_validation_samples = int(0.1 * mnist_info.splits['train'].num_examples)
num_test_samples = int(mnist_info.splits['test'].num_examples)

#### shuffle the datasets with Buffer size and no of samples

In [5]:
BUFFER_SIZE = 10000

train_validation_data = scaled_train_validation_data.shuffle(BUFFER_SIZE)
validation_data = train_validation_data.take(num_validation_samples)
train_data = train_validation_data.skip(num_validation_samples)

#### Batch the dataset to enhance the computational power

In [6]:
BATCH_SIZE = 100

train_data = train_data.batch(BATCH_SIZE)
validation_data = validation_data.batch(num_validation_samples)
test_data = scaled_test_data.batch(num_test_samples)

validation_inputs, validation_targets = next(iter(validation_data))