# This short introduction uses Keras to:
1. Build a neural network that classifies images.
2. Train this neural network.
3. And, finally, evaluate the accuracy of the model.

In [1]:
import tensorflow as tf

In [2]:
# load and prepare the MNIST dataset. COnvert the samples from integers to floating-point numbers

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [5]:
# Build the tf.keras.Sequential model by stacking layers.

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
])

In [8]:
# The losses.SparseCategoricalCrossentropy loss takes a vector of logits and a True index and returns a scalar loss for each example

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer="adam", loss=loss_fn, metrics=['accuracy'])

In [9]:
# The model.fit method adjusts the model parameters to minimize the loss
model.fit(x_train, y_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x25fcaae3670>

In [11]:
# The Model.evalutate method checks the models performance, usually on a "Validation-set" or "Test-set"
model.evaluate(x_test, y_test, verbose = 2)

313/313 - 0s - loss: 0.0737 - accuracy: 0.9768


[0.07366747409105301, 0.9768000245094299]

In [13]:
# If you want your model to return a probability, you can wrap the trained model, and attach the softmax to it
probability_model = tf.keras.Sequential([
    model,
    tf.keras.layers.Softmax()
])

In [14]:
probability_model(x_test[:5])

<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[6.6812447e-07, 3.5303760e-08, 6.0976226e-06, 7.0643798e-04,
        5.3280325e-11, 5.3329740e-07, 2.4906601e-13, 9.9927992e-01,
        6.3111054e-07, 5.7080310e-06],
       [6.5376482e-09, 1.7262628e-06, 9.9994993e-01, 3.9675921e-05,
        1.0849511e-15, 8.2270799e-07, 3.1612327e-08, 1.2087458e-13,
        7.7005116e-06, 4.6209070e-12],
       [6.8249918e-08, 9.9933845e-01, 1.2885789e-04, 7.2261805e-06,
        2.5609483e-05, 7.9640762e-07, 6.2639742e-06, 1.9269402e-04,
        2.9888813e-04, 8.2650052e-07],
       [9.9983668e-01, 6.3110332e-07, 4.9696424e-05, 1.9897004e-06,
        2.0389670e-07, 7.1348968e-06, 5.3113730e-05, 3.8762344e-05,
        1.3751875e-07, 1.1548567e-05],
       [1.1479000e-05, 2.3331545e-08, 4.1098565e-06, 2.6841170e-07,
        9.9548233e-01, 7.7048717e-06, 7.4981506e-07, 2.3682113e-04,
        1.0260678e-06, 4.2554042e-03]], dtype=float32)>