In [66]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist

In [67]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()

(X_train.shape, y_train.shape), (X_test.shape, y_test.shape)

(((60000, 28, 28), (60000,)), ((10000, 28, 28), (10000,)))

In [68]:
example_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))

#example_dataset

In [69]:
for (img, label) in example_dataset:
    print(img.numpy().shape, label.numpy())
    break
    

(28, 28) 5


In [70]:
def normalize_img(image, label):
    return (tf.cast(image, tf.float32)/244.0, label)

example_dataset = example_dataset.map(normalize_img,num_parallel_calls=tf.data.AUTOTUNE)

example_dataset

<_ParallelMapDataset element_spec=(TensorSpec(shape=(28, 28), dtype=tf.float32, name=None), TensorSpec(shape=(), dtype=tf.uint8, name=None))>

In [71]:
example_dataset = example_dataset.cache()

In [72]:
example_dataset = example_dataset.shuffle(len(example_dataset))


In [73]:
example_dataset = example_dataset.batch(64)
for (img, label) in example_dataset:
    print(img.numpy().shape, label.numpy())
    break

(64, 28, 28) [0 7 0 4 2 4 5 6 1 9 0 1 8 4 7 4 6 4 1 7 9 4 8 7 8 0 6 0 1 4 2 1 9 3 0 8 0
 8 9 6 1 0 9 0 6 3 2 0 7 4 9 4 7 7 0 7 1 6 6 7 8 7 4 3]


In [74]:
example_dataset = example_dataset.prefetch(tf.data.AUTOTUNE)

example_dataset

<_PrefetchDataset element_spec=(TensorSpec(shape=(None, 28, 28), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.uint8, name=None))>

In [75]:
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))

train_dataset = train_dataset.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.cache()

train_dataset = train_dataset.shuffle(len(train_dataset))
train_dataset = train_dataset.batch(64)

train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)

train_dataset

<_PrefetchDataset element_spec=(TensorSpec(shape=(None, 28, 28), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.uint8, name=None))>

In [76]:
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test))

test_dataset = test_dataset.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(64)
test_dataset = test_dataset.cache()

test_dataset = test_dataset.prefetch(tf.data.AUTOTUNE)

test_dataset

<_PrefetchDataset element_spec=(TensorSpec(shape=(None, 28, 28), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.uint8, name=None))>

In [77]:
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers

model = Sequential()
model.add(layers.InputLayer(input_shape=(28, 28)))
model.add(layers.Reshape((28,28, 1))) #Convolutional NN requires 3D ?
#Build convolutional layers
model.add(layers.Conv2D(filters=8, kernel_size=(3,3), strides=(1,1), activation='relu'))
model.add(layers.Conv2D(filters=8, kernel_size=(3,3), strides=(1,1), activation='relu'))
model.add(layers.Conv2D(filters=8, kernel_size=(3,3), strides=(1,1), activation='relu'))
model.add(layers.Conv2D(filters=8, kernel_size=(3,3), strides=(1,1), activation='relu'))
model.add(layers.Conv2D(filters=8, kernel_size=(3,3), strides=(1,1), activation='relu'))
model.add(layers.GlobalAveragePooling2D())
#Build NN
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(10, activation='softmax')) #softmax to predict probablities

model.summary()





In [78]:
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy

model.compile(loss=SparseCategoricalCrossentropy(), 
              optimizer=Adam(learning_rate=0.001), metrics=['accuracy'])


In [79]:
from tensorflow.keras.callbacks import EarlyStopping

es = EarlyStopping(patience=5)

In [80]:
model.fit(
    train_dataset,
    epochs=100,
    validation_data = test_dataset,
    callbacks=[es]
)

Epoch 1/100
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 7ms/step - accuracy: 0.3106 - loss: 1.8350 - val_accuracy: 0.6413 - val_loss: 1.0371
Epoch 2/100
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 8ms/step - accuracy: 0.7069 - loss: 0.8849 - val_accuracy: 0.8138 - val_loss: 0.6098
Epoch 3/100
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 12ms/step - accuracy: 0.8050 - loss: 0.6041 - val_accuracy: 0.8356 - val_loss: 0.5012
Epoch 4/100
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 11ms/step - accuracy: 0.8465 - loss: 0.4856 - val_accuracy: 0.8617 - val_loss: 0.4404
Epoch 5/100
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 11ms/step - accuracy: 0.8772 - loss: 0.3960 - val_accuracy: 0.8831 - val_loss: 0.3824
Epoch 6/100
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 11ms/step - accuracy: 0.8941 - loss: 0.3471 - val_accuracy: 0.9133 - val_loss: 0.2846
Epoch 7/100


<keras.src.callbacks.history.History at 0x1f3574c61b0>