### import relevant packages

In [7]:
import numpy as np
import tensorflow as tf

import tensorflow_datasets as tfds

### Data

In [8]:
mnist_dataset, mnist_info = tfds.load(name='mnist', with_info=True, as_supervised=True)

mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']

num_validation_samples = 0.1*mnist_info.splits['train'].num_examples
num_validation_samples = tf.cast(num_validation_samples, tf.int64)

num_test_samples = mnist_info.splits['test'].num_examples


def scale(image,label):
    image = tf.cast(image, tf.float32)
    image /= 255.
    return image, label

scaled_train_and_validation_data = mnist_train.map(scale)
scaled_test_data = mnist_test.map(scale)

BUFFER_SIZE = 10000
shuffled_train_and_validation_data = scaled_train_and_validation_data.shuffle(BUFFER_SIZE)

validation_data = shuffled_train_and_validation_data.take(num_validation_samples)
train_data = shuffled_train_and_validation_data.skip(num_validation_samples)

BATCH_SIZE = 100

train_data = train_data.batch(BATCH_SIZE)
validation_data = validation_data.batch(num_validation_samples)
test_data = scaled_test_data.batch(num_test_samples)

validation_inputs, validation_targets = next(iter(validation_data))

2022-03-14 11:58:40.001119: W tensorflow/core/kernels/data/cache_dataset_ops.cc:820] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


## Model

### Outline the model

In [9]:
input_size = 784
output_size = 10
hidden_layer_size = 100

model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28,28,1)),
    tf.keras.layers.Dense(hidden_layer_size, activation = 'relu'),
    tf.keras.layers.Dense(hidden_layer_size, activation = 'relu'),
    tf.keras.layers.Dense(output_size, activation = 'softmax')
])

### choose the optimizer and loss function

In [13]:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

### training

In [14]:
NUM_EPOCHS = 5

model.fit(train_data, epochs = NUM_EPOCHS, validation_data=(validation_inputs, validation_targets), validation_steps = 1, verbose=2)

Epoch 1/5


2022-03-14 12:04:00.776430: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


540/540 - 24s - loss: 0.0548 - accuracy: 0.9834 - val_loss: 0.0000e+00 - val_accuracy: 0.0000e+00
Epoch 2/5
540/540 - 22s - loss: 0.0424 - accuracy: 0.9872 - val_loss: 0.0522 - val_accuracy: 0.9835


2022-03-14 12:04:23.524981: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 3/5
540/540 - 22s - loss: 0.0375 - accuracy: 0.9875 - val_loss: 0.0464 - val_accuracy: 0.9860


2022-03-14 12:04:45.672681: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 4/5
540/540 - 24s - loss: 0.0301 - accuracy: 0.9905 - val_loss: 0.0439 - val_accuracy: 0.9880


2022-03-14 12:05:10.086227: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 5/5
540/540 - 22s - loss: 0.0275 - accuracy: 0.9915 - val_loss: 0.0388 - val_accuracy: 0.9880


2022-03-14 12:05:31.707901: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


<tensorflow.python.keras.callbacks.History at 0x15f463fd0>

### test the model

In [15]:
test_loss, test_accuracy = model.evaluate(test_data)



2022-03-14 12:05:36.183308: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


In [16]:
print('Test loss: (0:.2f). Test accuracy: (1:.2f)%'.format(test_loss, test_accuracy*100.))


Test loss: (0:.2f). Test accuracy: (1:.2f)%
