In [1]:
import tensorflow as tf

In [2]:
tf.__version__

'2.9.1'

In [3]:
## Load dataset
mnist = tf.keras.datasets.mnist

In [4]:
# split into train test
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255.0, x_test/255.0

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [6]:
# build a ML model
# build a tf.keras.Sequential model
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28,28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
])

In [8]:
# predictions
predictions = model(x_train[:1]).numpy()

In [9]:
predictions

array([[-0.4076549 , -0.12041321,  0.2624552 , -0.00846532,  0.15242735,
        -0.6345658 , -0.1381886 ,  0.17314243, -0.32316408,  0.47094408]],
      dtype=float32)

In [10]:
# The tf.nn.softmax function converts these logits to probabilities for each class:
tf.nn.softmax(predictions).numpy()

array([[0.06703293, 0.0893379 , 0.1310127 , 0.09992037, 0.11736237,
        0.0534247 , 0.08776391, 0.1198189 , 0.07294275, 0.16138342]],
      dtype=float32)

In [11]:
# Define a loss function for training using losses.SparseCategoricalCrossentropy:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [12]:
loss_fn(y_train[:1], predictions).numpy()

2.929482

In [13]:
model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

In [14]:
model.fit(x_train, y_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x2b202aa39d0>

In [15]:
model.evaluate(x_test,  y_test, verbose=2)

313/313 - 0s - loss: 0.0750 - accuracy: 0.9769 - 422ms/epoch - 1ms/step


[0.07498075813055038, 0.9768999814987183]

In [16]:
probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])

In [17]:
probability_model(x_test[:5])

<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[3.4584279e-08, 6.2244099e-10, 1.1848533e-07, 3.4135999e-05,
        8.0517043e-12, 2.4560333e-07, 1.7374273e-13, 9.9994481e-01,
        2.3953808e-07, 2.0533436e-05],
       [8.7207519e-07, 1.1384645e-03, 9.9840719e-01, 4.4779843e-04,
        3.3045192e-13, 5.4102265e-06, 2.2713960e-08, 3.7291750e-12,
        3.0937701e-07, 5.9113409e-12],
       [2.3121302e-06, 9.9884135e-01, 3.1950622e-05, 2.7227956e-05,
        1.4283159e-04, 3.0437730e-05, 1.7379264e-04, 2.5441538e-04,
        4.9242278e-04, 3.2126875e-06],
       [9.9997318e-01, 1.9217089e-12, 1.5084209e-05, 6.9395427e-09,
        5.1740312e-07, 2.1154353e-06, 7.2421024e-08, 2.5362551e-06,
        7.9591350e-10, 6.5011277e-06],
       [1.6150854e-07, 8.3546041e-09, 1.3719856e-05, 2.2858138e-08,
        9.9902320e-01, 2.5337116e-07, 1.6953477e-07, 3.2507617e-06,
        1.6269403e-06, 9.5750624e-04]], dtype=float32)>