In [1]:
import tensorflow as tf

Load and prepare the MNIST dataset. Convert the samples from integers to floating-point numbers:

In [2]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [4]:
x_train

array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0.

Build the tf.keras.Sequential model by stacking layers. Choose an optimizer and loss function for training:

In [3]:
model = tf.keras.models.Sequential([
                                    tf.keras.layers.Flatten(input_shape=(28, 28)),
                                    tf.keras.layers.Dense(128, activation='relu'),
                                    tf.keras.layers.Dropout(0.2),
                                    tf.keras.layers.Dense(10)
])

For each example the model returns a vector of "logits" or "log-odds" scores, one for each class.

In [6]:
predictions = model(x_train[:1]).numpy()
predictions

array([[ 0.93133265, -0.04678848,  0.04231543, -0.25063103, -0.57805395,
        -0.53832513, -0.2790822 ,  0.17230403, -0.1525774 , -0.11047502]],
      dtype=float32)

The tf.nn.softmax function converts these logits to "probabilities" for each class:

In [8]:
tf.nn.softmax(predictions).numpy()

array([[0.24986987, 0.09395529, 0.10271139, 0.07662908, 0.05523262,
        0.05747112, 0.07447961, 0.11696932, 0.08452356, 0.08815818]],
      dtype=float32)

The losses.SparseCategoricalCrossentropy loss takes a vector of logits and a True index and returns a scalar loss for each example.

In [9]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

This loss is equal to the negative log probability of the true class: It is zero if the model is sure of the correct class.

This untrained model gives probabilities close to random (1/10 for each class), so the initial loss should be close to -tf.math.log(1/10) ~= 2.3.

In [11]:
loss_fn(y_train[:1], predictions).numpy()

2.856473

In [12]:
model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

The Model.fit method adjusts the model parameters to minimize the loss:

In [13]:
model.fit(x_train, y_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7fb63a953410>

The Model.evaluate method checks the models performance, usually on a "Validation-set" or "Test-set".

In [14]:
model.evaluate(x_test, y_test, verbose=2)

313/313 - 0s - loss: 0.0798 - accuracy: 0.9740


[0.07982804626226425, 0.9739999771118164]

The image classifier is now trained to ~97% accuracy on this dataset. To learn more, read the TensorFlow tutorials.

If you want your model to return a probability, you can wrap the trained model, and attach the softmax to it:

In [15]:
probability_model = tf.keras.Sequential([
                                         model,
                                         tf.keras.layers.Softmax()
])

In [17]:
probability_model(x_test[:5]).numpy()

array([[2.96996827e-09, 1.48973385e-07, 3.23550830e-06, 1.67640173e-04,
        2.37378541e-12, 1.89445757e-08, 1.24220495e-15, 9.99827862e-01,
        8.01792410e-08, 1.00510977e-06],
       [2.45825005e-09, 2.28013250e-05, 9.99975324e-01, 1.40465170e-06,
        6.71113945e-15, 8.63548077e-09, 4.51670745e-09, 3.68924565e-15,
        4.41780429e-07, 2.08129243e-14],
       [2.23282541e-07, 9.98872459e-01, 1.02483071e-04, 6.31672265e-06,
        7.22444747e-05, 3.29978729e-06, 3.16129663e-05, 8.37985834e-04,
        7.12481851e-05, 1.92638322e-06],
       [9.99971747e-01, 3.83746313e-10, 2.43978593e-05, 2.48656362e-09,
        1.52952833e-08, 2.46121800e-07, 2.74772651e-06, 1.72722903e-08,
        3.25523697e-09, 8.16092324e-07],
       [2.64381515e-06, 1.04723235e-08, 3.92004813e-06, 5.67005642e-09,
        9.96870100e-01, 1.10116851e-07, 1.45465810e-06, 4.79112896e-05,
        5.89502690e-07, 3.07336333e-03]], dtype=float32)