In [1]:
import tensorflow as tf
import numpy as np
from IPython.display import Image

In [2]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [3]:
print(x_train.shape)
print(x_test.shape)

(60000, 28, 28)
(10000, 28, 28)


In [4]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

In [5]:
predictions = model(x_train[:1]).numpy()
predictions

array([[-0.22688316, -0.6046939 , -0.07600784, -0.43407953, -0.26864272,
         0.5557027 ,  0.58586365, -1.5440036 ,  0.04276763,  0.07108124]],
      dtype=float32)

In [6]:
#tf.nn.softmax function converts these logits to probabilities
tf.nn.softmax(predictions).numpy()

array([[0.08343129, 0.05718049, 0.09701821, 0.06781796, 0.08001898,
        0.18247429, 0.18806173, 0.02235171, 0.10925387, 0.11239145]],
      dtype=float32)

In [7]:
#The losses.SparseCategoricalCrossentropy loss takes a vector of logits and a True index and returns a 
#scalar loss for each example.
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [8]:
loss_fn(y_train[:1], predictions).numpy()

1.7011459

In [9]:
model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

In [10]:
model.fit(x_train, y_train, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f4ea86bfe10>

In [11]:
model.evaluate(x_test,  y_test, verbose=2)

313/313 - 0s - loss: 0.0693 - accuracy: 0.9818


[0.06928431987762451, 0.9818000197410583]

In [12]:
probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])

In [13]:
probability_model(x_test[:5])

<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[2.8706013e-09, 1.6977364e-11, 3.7204948e-08, 1.7462480e-05,
        1.0249460e-15, 2.9425737e-09, 4.4779435e-18, 9.9997783e-01,
        3.0435778e-08, 4.7933054e-06],
       [9.2895559e-13, 3.5554217e-06, 9.9999642e-01, 1.3577404e-09,
        3.2161559e-22, 2.9811682e-09, 3.9467738e-11, 1.4847796e-18,
        4.8411401e-12, 4.9376391e-20],
       [2.7063558e-09, 9.9991810e-01, 9.2088439e-06, 8.5328111e-08,
        1.8727447e-06, 4.5256417e-07, 8.8064053e-07, 6.3379310e-05,
        6.0745770e-06, 5.5032945e-08],
       [9.9999440e-01, 2.1400104e-12, 3.7544389e-06, 1.0358105e-09,
        8.0599264e-08, 4.9496339e-08, 1.1130312e-06, 3.2164129e-07,
        2.9789355e-08, 1.6666871e-07],
       [4.6204963e-08, 5.5001373e-15, 3.5274461e-08, 1.7044995e-12,
        9.9859339e-01, 3.6310882e-10, 4.2404201e-07, 2.3409990e-05,
        1.8585782e-08, 1.3826523e-03]], dtype=float32)>