In [None]:
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Flatten, Dense, Input
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy #, MeanSquaredError

In [None]:
import matplotlib.pyplot as plt

In [None]:
# If you want to compare the results of runs with
# different settings, make sure to set the random seed.
import tensorflow as tf
tf.random.set_seed(42)

In [None]:
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_labels, test_labels = to_categorical(train_labels), to_categorical(test_labels)
train_images, test_images = train_images/255.0, test_images/255.0

In [None]:
model = Sequential()
model.add(Input(shape=(28,28)))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(10, activation='softmax'))

In [None]:
model.compile(
    optimizer = Adam(learning_rate=0.001),
    loss = CategoricalCrossentropy(),
    metrics = ['accuracy']
)

In [None]:
BATCH_SIZE = 64
EPOCHS = 20

In [None]:
hist = model.fit(
    train_images,
    train_labels,
    batch_size = BATCH_SIZE,
    validation_split = 0.2,
    epochs = EPOCHS
)

In [None]:
x_range = range(1, EPOCHS+1)

fig, ax = plt.subplots(2,1)
ax[0].plot(x_range, hist.history['loss'],     color='r', label="train loss")
ax[0].plot(x_range, hist.history['val_loss'], color='b', label="validation loss")
ax[0].legend(loc='best', shadow=True)
ax[0].set_xlim([1, EPOCHS])
ax[0].set_xticks(range(1, EPOCHS+1))
ax[0].set_ylim([0,0.4])
ax[0].set_ylabel("loss")

ax[1].plot(x_range, hist.history['accuracy'], color='r',label="train acc")
ax[1].plot(x_range, hist.history['val_accuracy'],  color='b',label="validation acc")
ax[1].legend(loc='best', shadow=True)
ax[1].set_xlim([1,EPOCHS])
ax[1].set_xticks(x_range)
ax[1].set_ylim([0.8,1])
ax[1].set_ylabel("accuracy")
plt.show()

In [None]:
results = model.evaluate(test_images, test_labels, verbose=2)
results