In [5]:
# Set up TensorFlow
import tensorflow as tf
print("TensorFlow version:", tf.__version__)


TensorFlow version: 2.9.0


In [6]:
# Load a dataset
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [7]:
# Build a machine learning model
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])


In [8]:
predictions = model(x_train[:1]).numpy()
predictions

array([[ 0.5297747 , -0.1898254 , -0.25597614, -0.33622277, -0.59584725,
        -0.8587326 , -0.03447479, -0.14085019, -0.20321476, -0.10999278]],
      dtype=float32)

In [9]:
tf.nn.softmax(predictions).numpy()


array([[0.19899249, 0.09689879, 0.09069627, 0.08370256, 0.06456323,
        0.04963815, 0.11318431, 0.10176255, 0.09561002, 0.10495164]],
      dtype=float32)

In [10]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)


In [11]:
loss_fn(y_train[:1], predictions).numpy()


3.0029955

In [12]:
model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])


In [13]:
# Train and evaluate your model
model.fit(x_train, y_train, epochs=5)


Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x1aadfbb4a60>

In [14]:
model.evaluate(x_test,  y_test, verbose=2)


313/313 - 0s - loss: 0.0752 - accuracy: 0.9767 - 338ms/epoch - 1ms/step


[0.0751894935965538, 0.9767000079154968]

In [15]:
probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])


In [16]:
probability_model(x_test[:5])


<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[6.72846880e-08, 7.95644262e-09, 7.36777247e-06, 2.79866043e-04,
        5.66056126e-11, 1.66004753e-07, 9.08023548e-13, 9.99706805e-01,
        8.64239610e-07, 4.96764869e-06],
       [1.54728639e-07, 2.98459818e-05, 9.99965549e-01, 3.04247874e-06,
        2.23931896e-14, 3.24682858e-08, 1.78613746e-07, 1.05788132e-12,
        1.31122465e-06, 1.31067532e-14],
       [1.53982214e-07, 9.98914599e-01, 6.60991282e-05, 5.61030029e-05,
        1.78481332e-05, 1.23643258e-05, 1.25791985e-05, 4.37386101e-04,
        4.81683237e-04, 1.21217840e-06],
       [9.99892592e-01, 3.38999229e-09, 4.30321006e-06, 2.97530311e-09,
        5.07726497e-07, 9.11879567e-08, 9.72121707e-05, 2.58021396e-06,
        1.36998963e-08, 2.76767719e-06],
       [5.67164534e-07, 1.12040177e-09, 8.30363490e-07, 1.28423885e-08,
        9.99687791e-01, 1.01426359e-07, 9.32312287e-06, 6.29652059e-05,
        8.65457082e-07, 2.37615968e-04]], dtype=float32)>