In [1]:
from tensorflow import keras
from tensorflow.keras import layers 
from tensorflow.keras.datasets import mnist

import numpy as np
import matplotlib.pyplot as plt

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 60,000 samples, each a 28x28 matrix. Each one correponds to a label 0-9

# randomize order of data, to obtain better distribution 
random_indices = np.random.permutation(len(train_images))
train_images = train_images[random_indices]
train_labels = train_labels[random_indices]

# pre-process data into 2D arrays, each value is from 0-1 instead of 0-255. 
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype("float32") / 255

# Allocate 6,000 samples for validation
val_images = train_images[:6000]
val_labels = train_labels[:6000]

# Resize the training set
train_images = train_images[6000:]
train_labels = train_labels[6000:]

# model
def build_model():
  model = keras.Sequential([
      layers.Dense(512, activation="relu"),
      layers.Dense(10, activation="softmax")
  ])

  model.compile(optimizer="rmsprop",
                loss="sparse_categorical_crossentropy",
                metrics=["accuracy"])
  return model

model = build_model()
 
history = model.fit(train_images,
                      train_labels,
                      epochs=5,
                      batch_size = 128,
                      validation_data=(val_images, val_labels))
  
    
accuracy_history = history.history["val_accuracy"]
loss_of_model, accuracy_of_model = model.evaluate(test_images, test_labels)
print(f"accuracy_of_model: {accuracy_of_model}")

test_digits = test_images[0:10]
predictions = model.predict(test_digits)
predictions[0]

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
accuracy_of_model: 0.9790999889373779


array([9.1130964e-10, 1.6728835e-10, 2.1876302e-07, 6.3326333e-06,
       2.6465668e-12, 9.9856354e-09, 7.5618934e-15, 9.9999321e-01,
       1.0974314e-08, 2.0927955e-07], dtype=float32)