In [None]:

import tensorflow as tf
import keras
from keras import layers
from keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt

In [None]:

#define the model
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(10),
    keras.layers.Softmax()
])


In [None]:
# compile the keras model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])



In [None]:
#Extract the MNIST images and labels
(x_train, y_train), (x_test, y_test) = mnist.load_data()

img_rows, img_cols, channels = 28, 28, 1
num_classes = 10

x_train = x_train / 255
x_test = x_test / 255

print("Data shapes", x_test.shape, y_test.shape, x_train.shape, y_train.shape)


In [None]:
# fit the keras model on the dataset
model.fit(x_train, y_train,
          batch_size=32,
          epochs=5,
          validation_data=(x_test, y_test))

# evaluate the keras model
train_loss, train_acc = model.evaluate(x_train,  y_train, verbose=2)
# evaluate the keras model
test_loss, test_acc = model.evaluate(x_test,  y_test, verbose=2)

print('\nTrain accuracy:', train_acc)
print('\nTest accuracy:', test_acc)

In [None]:

test_predictions = model.predict(x_test)
train_predictions = model.predict(x_train)

In [None]:
class_names = ['Zero','One','Two','Three','Four','Five','Six','Seven','Eight','Nine']


def display_images(images, predicted_labels, true_labels):
  '''
  Display the images in an orderly way that associates it with its predicted and true labels.
  :images: the input images
  :predicted_labels: the predicted labels from the model
  :true_labels: the correct labels
  '''
  n = 10  # How many digits we will display
  plt.figure(figsize=(20, 4))
  for i in range(n):
      # Display original
      ax = plt.subplot(2, n, i + 1)
      plt.imshow(images[i].reshape(28, 28))
      plt.gray()
      ax.get_yaxis().set_visible(False)

      predicted_label = np.argmax(predicted_labels[i])

      if predicted_label == true_labels[i]:
        color = 'blue'
      else:
        color = 'red'

      plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],
                                    100*np.max(predicted_labels[i]),
                                    class_names[true_labels[i]]),
                                    color=color)


  plt.show()

display_images(x_test, test_predictions, y_test)
display_images(x_train, train_predictions, y_train)