In [1]:
import tensorflow as tf
from tensorflow import keras
print("TensorFlow version:", tf.__version__)

TensorFlow version: 2.5.0


In [2]:
if tf.test.gpu_device_name() != '/device:GPU:0':
  print('WARNING: GPU device not found.')
else:
  print('SUCCESS: Found GPU: {}'.format(tf.test.gpu_device_name()))

SUCCESS: Found GPU: /device:GPU:0


In [3]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

In [4]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

In [5]:
predictions = model(x_train[:1]).numpy()
predictions

array([[ 1.4407283 ,  0.28669316, -0.60342443, -0.5758071 ,  0.1326196 ,
         0.43802   ,  0.14441395, -0.15539102, -0.91915077, -0.5191233 ]],
      dtype=float32)

In [6]:
tf.nn.softmax(predictions).numpy()

array([[0.3416802 , 0.10775284, 0.04424412, 0.04548306, 0.0923667 ,
        0.12535715, 0.09346256, 0.06925227, 0.03226542, 0.04813567]],
      dtype=float32)

In [7]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [8]:
loss_fn(y_train[:1], predictions).numpy()

2.0765884

In [9]:
model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

In [11]:
model.fit(x_train, y_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x23de7174d30>

In [12]:
model.evaluate(x_test,  y_test, verbose=2)

313/313 - 1s - loss: 0.0693 - accuracy: 0.9804


[0.06927861273288727, 0.980400025844574]

In [13]:
probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])

In [14]:
probability_model(x_test[:5])

<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[5.55125368e-08, 1.79840978e-08, 1.28440362e-07, 5.68630567e-05,
        1.08736406e-14, 7.53111564e-08, 2.50061643e-14, 9.99939442e-01,
        8.37094660e-09, 3.30467651e-06],
       [1.43656602e-14, 2.93477065e-08, 1.00000000e+00, 8.94813446e-10,
        1.97261327e-23, 3.38960943e-11, 8.44339224e-16, 1.15972852e-22,
        3.53017338e-11, 7.08013499e-18],
       [6.72353995e-09, 9.99861479e-01, 4.56884954e-05, 3.34809698e-07,
        2.98556688e-06, 1.13763015e-07, 1.85441266e-07, 7.77290479e-05,
        1.14669601e-05, 3.19253317e-08],
       [9.99879122e-01, 2.39044687e-11, 3.50926075e-06, 2.93216007e-09,
        1.09447615e-06, 5.60057316e-08, 7.12479903e-07, 7.28689656e-07,
        2.91612490e-09, 1.14759314e-04],
       [2.34891884e-10, 1.39387048e-12, 2.80130781e-08, 3.49286780e-11,
        9.98399198e-01, 1.11193825e-10, 2.49217841e-10, 4.32246225e-06,
        1.70629587e-11, 1.59649050e-03]], dtype=float32)>